File size: 3,829 Bytes
b883a05
 
 
 
 
 
 
 
 
 
 
 
 
 
d8a701a
 
 
 
 
 
 
 
 
 
 
 
b883a05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os, io, tempfile, asyncio
from fastapi import FastAPI
from pydantic import BaseModel
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import zstandard as zstd
import uvicorn

# β€”β€”β€” Hyperparams & Paths β€”β€”β€”
OBS_DIM, ACT_DIM = 4, 2
LR, CLIP_EPS, VF_COEFF, ENT_COEFF = 3e-4, 0.2, 0.5, 0.01
CKPT_PATH = "ckpt.zst"
CKPT_FREQ = 100
LOG_DIR = "logs"

# Try main logs dir
try:
    os.makedirs(LOG_DIR, exist_ok=True)
    with open(os.path.join(LOG_DIR, "test_write"), "w") as f:
        f.write("ok")
    os.remove(os.path.join(LOG_DIR, "test_write"))
except Exception:
    LOG_DIR = "/tmp/logs"
    os.makedirs(LOG_DIR, exist_ok=True)

writer = SummaryWriter(LOG_DIR)
DEVICE = "cpu"
MANAGER_HOST = "0.0.0.0"
MANAGER_PORT = 7860

# β€”β€”β€” Networks β€”β€”β€”
class PolicyNet(nn.Module):
    def __init__(self): super().__init__(); self.net = nn.Sequential(nn.Linear(OBS_DIM,256), nn.ReLU(), nn.Linear(256,256), nn.ReLU()); self.logits=nn.Linear(256,ACT_DIM)
    def forward(self,x): return self.logits(self.net(x))
class ValueNet(nn.Module):
    def __init__(self): super().__init__(); self.net = nn.Sequential(nn.Linear(OBS_DIM,256), nn.ReLU(), nn.Linear(256,256), nn.ReLU()); self.val=nn.Linear(256,1)
    def forward(self,x): return self.val(self.net(x)).squeeze(-1)

# β€”β€”β€” PPO Implementation β€”β€”β€”
class PPO:
    def __init__(self,policy,val,lr): 
        self.policy, self.value = policy, val
        self.opt = torch.optim.Adam(list(policy.parameters())+list(val.parameters()), lr=lr)
        self.clip, self.vf_c, self.ent_c = CLIP_EPS, VF_COEFF, ENT_COEFF
    def update(self,obs,act,old_lp,ret,adv):
        logits = self.policy(obs)
        dist = torch.distributions.Categorical(logits=logits)
        lp = dist.log_prob(act)
        r = torch.exp(lp-old_lp)
        pg = -torch.min(r*adv, torch.clamp(r,1-self.clip,1+self.clip)*adv).mean()
        v = self.value(obs)
        vloss = F.mse_loss(v,ret)
        ent = -dist.entropy().mean()
        loss = pg + self.vf_c*vloss + self.ent_c*ent
        self.opt.zero_grad(); loss.backward(); self.opt.step()
        return {"pl": pg.item(), "vl": vloss.item(), "en": dist.entropy().mean().item()}

# β€”β€”β€” Utils β€”β€”β€”
def save_ckpt(step):
    data={"step":step,"ps":policy.state_dict(),"vs":value.state_dict(),"opt":ppo.opt.state_dict()}
    with tempfile.NamedTemporaryFile(delete=False) as t:
        torch.save(data,t.name)
        c = zstd.ZstdCompressor().compress(open(t.name,"rb").read())
        open(CKPT_PATH,"wb").write(c)
def load_ckpt():
    d=zstd.ZstdDecompressor().decompress(open(CKPT_PATH,"rb").read())
    data=torch.load(io.BytesIO(d))
    policy.load_state_dict(data["ps"]); value.load_state_dict(data["vs"]); ppo.opt.load_state_dict(data["opt"])
    return data["step"]

# β€”β€”β€” FastAPI Server β€”β€”β€”
app = FastAPI(); policy=PolicyNet().to(DEVICE); value=ValueNet().to(DEVICE); ppo=PPO(policy,value,LR)
writer=SummaryWriter(LOG_DIR); step=0

class Grad(BaseModel):
    obs:list; actions:list; old_logps:list; returns:list; advs:list

@app.post("/push_grad")
async def push(g: Grad):
    global step
    obs = torch.tensor(g.obs, device=DEVICE); a=torch.tensor(g.actions,device=DEVICE)
    olp=torch.tensor(g.old_logps,device=DEVICE); ret=torch.tensor(g.returns,device=DEVICE); adv=torch.tensor(g.advs,device=DEVICE)
    m = ppo.update(obs,a,olp,ret,adv)
    step+=1
    for k,v in m.items(): writer.add_scalar(f"train/{k}", v, step)
    if step%CKPT_FREQ==0: save_ckpt(step)
    # return weights
    return {"weights": {k:v.cpu().numpy() for k,v in policy.state_dict().items()}}

@app.post("/restore")
async def restore():
    s=load_ckpt(); return {"restored_step": s}

if __name__=="__main__":
    uvicorn.run(app, host=MANAGER_HOST, port=MANAGER_PORT)