manager / manager.py
Nodiw52992's picture
Update manager.py
d8a701a verified
import os, io, tempfile, asyncio
from fastapi import FastAPI
from pydantic import BaseModel
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import zstandard as zstd
import uvicorn
# β€”β€”β€” Hyperparams & Paths β€”β€”β€”
OBS_DIM, ACT_DIM = 4, 2
LR, CLIP_EPS, VF_COEFF, ENT_COEFF = 3e-4, 0.2, 0.5, 0.01
CKPT_PATH = "ckpt.zst"
CKPT_FREQ = 100
LOG_DIR = "logs"
# Try main logs dir
try:
os.makedirs(LOG_DIR, exist_ok=True)
with open(os.path.join(LOG_DIR, "test_write"), "w") as f:
f.write("ok")
os.remove(os.path.join(LOG_DIR, "test_write"))
except Exception:
LOG_DIR = "/tmp/logs"
os.makedirs(LOG_DIR, exist_ok=True)
writer = SummaryWriter(LOG_DIR)
DEVICE = "cpu"
MANAGER_HOST = "0.0.0.0"
MANAGER_PORT = 7860
# β€”β€”β€” Networks β€”β€”β€”
class PolicyNet(nn.Module):
def __init__(self): super().__init__(); self.net = nn.Sequential(nn.Linear(OBS_DIM,256), nn.ReLU(), nn.Linear(256,256), nn.ReLU()); self.logits=nn.Linear(256,ACT_DIM)
def forward(self,x): return self.logits(self.net(x))
class ValueNet(nn.Module):
def __init__(self): super().__init__(); self.net = nn.Sequential(nn.Linear(OBS_DIM,256), nn.ReLU(), nn.Linear(256,256), nn.ReLU()); self.val=nn.Linear(256,1)
def forward(self,x): return self.val(self.net(x)).squeeze(-1)
# β€”β€”β€” PPO Implementation β€”β€”β€”
class PPO:
def __init__(self,policy,val,lr):
self.policy, self.value = policy, val
self.opt = torch.optim.Adam(list(policy.parameters())+list(val.parameters()), lr=lr)
self.clip, self.vf_c, self.ent_c = CLIP_EPS, VF_COEFF, ENT_COEFF
def update(self,obs,act,old_lp,ret,adv):
logits = self.policy(obs)
dist = torch.distributions.Categorical(logits=logits)
lp = dist.log_prob(act)
r = torch.exp(lp-old_lp)
pg = -torch.min(r*adv, torch.clamp(r,1-self.clip,1+self.clip)*adv).mean()
v = self.value(obs)
vloss = F.mse_loss(v,ret)
ent = -dist.entropy().mean()
loss = pg + self.vf_c*vloss + self.ent_c*ent
self.opt.zero_grad(); loss.backward(); self.opt.step()
return {"pl": pg.item(), "vl": vloss.item(), "en": dist.entropy().mean().item()}
# β€”β€”β€” Utils β€”β€”β€”
def save_ckpt(step):
data={"step":step,"ps":policy.state_dict(),"vs":value.state_dict(),"opt":ppo.opt.state_dict()}
with tempfile.NamedTemporaryFile(delete=False) as t:
torch.save(data,t.name)
c = zstd.ZstdCompressor().compress(open(t.name,"rb").read())
open(CKPT_PATH,"wb").write(c)
def load_ckpt():
d=zstd.ZstdDecompressor().decompress(open(CKPT_PATH,"rb").read())
data=torch.load(io.BytesIO(d))
policy.load_state_dict(data["ps"]); value.load_state_dict(data["vs"]); ppo.opt.load_state_dict(data["opt"])
return data["step"]
# β€”β€”β€” FastAPI Server β€”β€”β€”
app = FastAPI(); policy=PolicyNet().to(DEVICE); value=ValueNet().to(DEVICE); ppo=PPO(policy,value,LR)
writer=SummaryWriter(LOG_DIR); step=0
class Grad(BaseModel):
obs:list; actions:list; old_logps:list; returns:list; advs:list
@app.post("/push_grad")
async def push(g: Grad):
global step
obs = torch.tensor(g.obs, device=DEVICE); a=torch.tensor(g.actions,device=DEVICE)
olp=torch.tensor(g.old_logps,device=DEVICE); ret=torch.tensor(g.returns,device=DEVICE); adv=torch.tensor(g.advs,device=DEVICE)
m = ppo.update(obs,a,olp,ret,adv)
step+=1
for k,v in m.items(): writer.add_scalar(f"train/{k}", v, step)
if step%CKPT_FREQ==0: save_ckpt(step)
# return weights
return {"weights": {k:v.cpu().numpy() for k,v in policy.state_dict().items()}}
@app.post("/restore")
async def restore():
s=load_ckpt(); return {"restored_step": s}
if __name__=="__main__":
uvicorn.run(app, host=MANAGER_HOST, port=MANAGER_PORT)