import os, io, tempfile, asyncio from fastapi import FastAPI from pydantic import BaseModel import torch, torch.nn as nn, torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter import zstandard as zstd import uvicorn # ——— Hyperparams & Paths ——— OBS_DIM, ACT_DIM = 4, 2 LR, CLIP_EPS, VF_COEFF, ENT_COEFF = 3e-4, 0.2, 0.5, 0.01 CKPT_PATH = "ckpt.zst" CKPT_FREQ = 100 LOG_DIR = "logs" # Try main logs dir try: os.makedirs(LOG_DIR, exist_ok=True) with open(os.path.join(LOG_DIR, "test_write"), "w") as f: f.write("ok") os.remove(os.path.join(LOG_DIR, "test_write")) except Exception: LOG_DIR = "/tmp/logs" os.makedirs(LOG_DIR, exist_ok=True) writer = SummaryWriter(LOG_DIR) DEVICE = "cpu" MANAGER_HOST = "0.0.0.0" MANAGER_PORT = 7860 # ——— Networks ——— class PolicyNet(nn.Module): def __init__(self): super().__init__(); self.net = nn.Sequential(nn.Linear(OBS_DIM,256), nn.ReLU(), nn.Linear(256,256), nn.ReLU()); self.logits=nn.Linear(256,ACT_DIM) def forward(self,x): return self.logits(self.net(x)) class ValueNet(nn.Module): def __init__(self): super().__init__(); self.net = nn.Sequential(nn.Linear(OBS_DIM,256), nn.ReLU(), nn.Linear(256,256), nn.ReLU()); self.val=nn.Linear(256,1) def forward(self,x): return self.val(self.net(x)).squeeze(-1) # ——— PPO Implementation ——— class PPO: def __init__(self,policy,val,lr): self.policy, self.value = policy, val self.opt = torch.optim.Adam(list(policy.parameters())+list(val.parameters()), lr=lr) self.clip, self.vf_c, self.ent_c = CLIP_EPS, VF_COEFF, ENT_COEFF def update(self,obs,act,old_lp,ret,adv): logits = self.policy(obs) dist = torch.distributions.Categorical(logits=logits) lp = dist.log_prob(act) r = torch.exp(lp-old_lp) pg = -torch.min(r*adv, torch.clamp(r,1-self.clip,1+self.clip)*adv).mean() v = self.value(obs) vloss = F.mse_loss(v,ret) ent = -dist.entropy().mean() loss = pg + self.vf_c*vloss + self.ent_c*ent self.opt.zero_grad(); loss.backward(); self.opt.step() return {"pl": pg.item(), "vl": vloss.item(), "en": dist.entropy().mean().item()} # ——— Utils ——— def save_ckpt(step): data={"step":step,"ps":policy.state_dict(),"vs":value.state_dict(),"opt":ppo.opt.state_dict()} with tempfile.NamedTemporaryFile(delete=False) as t: torch.save(data,t.name) c = zstd.ZstdCompressor().compress(open(t.name,"rb").read()) open(CKPT_PATH,"wb").write(c) def load_ckpt(): d=zstd.ZstdDecompressor().decompress(open(CKPT_PATH,"rb").read()) data=torch.load(io.BytesIO(d)) policy.load_state_dict(data["ps"]); value.load_state_dict(data["vs"]); ppo.opt.load_state_dict(data["opt"]) return data["step"] # ——— FastAPI Server ——— app = FastAPI(); policy=PolicyNet().to(DEVICE); value=ValueNet().to(DEVICE); ppo=PPO(policy,value,LR) writer=SummaryWriter(LOG_DIR); step=0 class Grad(BaseModel): obs:list; actions:list; old_logps:list; returns:list; advs:list @app.post("/push_grad") async def push(g: Grad): global step obs = torch.tensor(g.obs, device=DEVICE); a=torch.tensor(g.actions,device=DEVICE) olp=torch.tensor(g.old_logps,device=DEVICE); ret=torch.tensor(g.returns,device=DEVICE); adv=torch.tensor(g.advs,device=DEVICE) m = ppo.update(obs,a,olp,ret,adv) step+=1 for k,v in m.items(): writer.add_scalar(f"train/{k}", v, step) if step%CKPT_FREQ==0: save_ckpt(step) # return weights return {"weights": {k:v.cpu().numpy() for k,v in policy.state_dict().items()}} @app.post("/restore") async def restore(): s=load_ckpt(); return {"restored_step": s} if __name__=="__main__": uvicorn.run(app, host=MANAGER_HOST, port=MANAGER_PORT)