|
|
import os, io, tempfile, asyncio |
|
|
from fastapi import FastAPI |
|
|
from pydantic import BaseModel |
|
|
import torch, torch.nn as nn, torch.nn.functional as F |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
import zstandard as zstd |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
OBS_DIM, ACT_DIM = 4, 2 |
|
|
LR, CLIP_EPS, VF_COEFF, ENT_COEFF = 3e-4, 0.2, 0.5, 0.01 |
|
|
CKPT_PATH = "ckpt.zst" |
|
|
CKPT_FREQ = 100 |
|
|
LOG_DIR = "logs" |
|
|
|
|
|
|
|
|
try: |
|
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
|
with open(os.path.join(LOG_DIR, "test_write"), "w") as f: |
|
|
f.write("ok") |
|
|
os.remove(os.path.join(LOG_DIR, "test_write")) |
|
|
except Exception: |
|
|
LOG_DIR = "/tmp/logs" |
|
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
|
|
|
|
writer = SummaryWriter(LOG_DIR) |
|
|
DEVICE = "cpu" |
|
|
MANAGER_HOST = "0.0.0.0" |
|
|
MANAGER_PORT = 7860 |
|
|
|
|
|
|
|
|
class PolicyNet(nn.Module): |
|
|
def __init__(self): super().__init__(); self.net = nn.Sequential(nn.Linear(OBS_DIM,256), nn.ReLU(), nn.Linear(256,256), nn.ReLU()); self.logits=nn.Linear(256,ACT_DIM) |
|
|
def forward(self,x): return self.logits(self.net(x)) |
|
|
class ValueNet(nn.Module): |
|
|
def __init__(self): super().__init__(); self.net = nn.Sequential(nn.Linear(OBS_DIM,256), nn.ReLU(), nn.Linear(256,256), nn.ReLU()); self.val=nn.Linear(256,1) |
|
|
def forward(self,x): return self.val(self.net(x)).squeeze(-1) |
|
|
|
|
|
|
|
|
class PPO: |
|
|
def __init__(self,policy,val,lr): |
|
|
self.policy, self.value = policy, val |
|
|
self.opt = torch.optim.Adam(list(policy.parameters())+list(val.parameters()), lr=lr) |
|
|
self.clip, self.vf_c, self.ent_c = CLIP_EPS, VF_COEFF, ENT_COEFF |
|
|
def update(self,obs,act,old_lp,ret,adv): |
|
|
logits = self.policy(obs) |
|
|
dist = torch.distributions.Categorical(logits=logits) |
|
|
lp = dist.log_prob(act) |
|
|
r = torch.exp(lp-old_lp) |
|
|
pg = -torch.min(r*adv, torch.clamp(r,1-self.clip,1+self.clip)*adv).mean() |
|
|
v = self.value(obs) |
|
|
vloss = F.mse_loss(v,ret) |
|
|
ent = -dist.entropy().mean() |
|
|
loss = pg + self.vf_c*vloss + self.ent_c*ent |
|
|
self.opt.zero_grad(); loss.backward(); self.opt.step() |
|
|
return {"pl": pg.item(), "vl": vloss.item(), "en": dist.entropy().mean().item()} |
|
|
|
|
|
|
|
|
def save_ckpt(step): |
|
|
data={"step":step,"ps":policy.state_dict(),"vs":value.state_dict(),"opt":ppo.opt.state_dict()} |
|
|
with tempfile.NamedTemporaryFile(delete=False) as t: |
|
|
torch.save(data,t.name) |
|
|
c = zstd.ZstdCompressor().compress(open(t.name,"rb").read()) |
|
|
open(CKPT_PATH,"wb").write(c) |
|
|
def load_ckpt(): |
|
|
d=zstd.ZstdDecompressor().decompress(open(CKPT_PATH,"rb").read()) |
|
|
data=torch.load(io.BytesIO(d)) |
|
|
policy.load_state_dict(data["ps"]); value.load_state_dict(data["vs"]); ppo.opt.load_state_dict(data["opt"]) |
|
|
return data["step"] |
|
|
|
|
|
|
|
|
app = FastAPI(); policy=PolicyNet().to(DEVICE); value=ValueNet().to(DEVICE); ppo=PPO(policy,value,LR) |
|
|
writer=SummaryWriter(LOG_DIR); step=0 |
|
|
|
|
|
class Grad(BaseModel): |
|
|
obs:list; actions:list; old_logps:list; returns:list; advs:list |
|
|
|
|
|
@app.post("/push_grad") |
|
|
async def push(g: Grad): |
|
|
global step |
|
|
obs = torch.tensor(g.obs, device=DEVICE); a=torch.tensor(g.actions,device=DEVICE) |
|
|
olp=torch.tensor(g.old_logps,device=DEVICE); ret=torch.tensor(g.returns,device=DEVICE); adv=torch.tensor(g.advs,device=DEVICE) |
|
|
m = ppo.update(obs,a,olp,ret,adv) |
|
|
step+=1 |
|
|
for k,v in m.items(): writer.add_scalar(f"train/{k}", v, step) |
|
|
if step%CKPT_FREQ==0: save_ckpt(step) |
|
|
|
|
|
return {"weights": {k:v.cpu().numpy() for k,v in policy.state_dict().items()}} |
|
|
|
|
|
@app.post("/restore") |
|
|
async def restore(): |
|
|
s=load_ckpt(); return {"restored_step": s} |
|
|
|
|
|
if __name__=="__main__": |
|
|
uvicorn.run(app, host=MANAGER_HOST, port=MANAGER_PORT) |
|
|
|