alphabypass3 / agent.py
NickupAI's picture
Upload 6 files
c7c0a38 verified
from __future__ import annotations
import os
import time
import json
import numpy as np
import torch
import torch .nn as nn
import torch .nn .functional as F
from torch .distributions import Categorical ,Normal
from typing import Optional ,Tuple ,List
from server .rl .environment import (
DISCRETE_NVEC ,
N_CONTINUOUS ,
TOTAL_OBS_DIM ,
AlphaBypassEnv ,
)
from server .rl .reward import reward_to_label
class PolicyNetwork (nn .Module ):
def __init__ (
self ,
obs_dim :int =TOTAL_OBS_DIM ,
hidden :int =512 ,
discrete_nvec :List [int ]=DISCRETE_NVEC ,
n_continuous :int =N_CONTINUOUS ,
):
super ().__init__ ()
self .discrete_nvec =discrete_nvec
self .n_continuous =n_continuous
self .trunk =nn .Sequential (
nn .Linear (obs_dim ,hidden ),
nn .LayerNorm (hidden ),
nn .ReLU (),
nn .Linear (hidden ,hidden ),
nn .LayerNorm (hidden ),
nn .ReLU (),
nn .Linear (hidden ,hidden ),
nn .LayerNorm (hidden ),
nn .ReLU (),
)
self .discrete_heads =nn .ModuleList ([
nn .Linear (hidden ,n )for n in discrete_nvec
])
self .cont_mu =nn .Linear (hidden ,n_continuous )
self .cont_log_std =nn .Parameter (torch .zeros (n_continuous ))
self .value_head =nn .Sequential (
nn .Linear (hidden ,256 ),
nn .ReLU (),
nn .Linear (256 ,1 ),
)
def forward (self ,obs :torch .Tensor ):
h =self .trunk (obs )
logits =[head (h )for head in self .discrete_heads ]
mu =torch .sigmoid (self .cont_mu (h ))
log_std =self .cont_log_std .clamp (-4 ,0 )
value =self .value_head (h ).squeeze (-1 )
return logits ,mu ,log_std ,value
def get_action_and_log_prob (
self ,
obs :torch .Tensor ,
action_masks :Optional [List [Optional [torch .Tensor ]]]=None ,
)->Tuple [np .ndarray ,np .ndarray ,torch .Tensor ,torch .Tensor ]:
logits ,mu ,log_std ,value =self .forward (obs )
discrete_actions =[]
log_probs_discrete =[]
for i ,(lg ,n )in enumerate (zip (logits ,self .discrete_nvec )):
if action_masks and action_masks [i ]is not None :
mask =action_masks [i ].to (lg .device )
lg =lg .masked_fill (~mask ,float ("-inf"))
dist =Categorical (logits =lg )
a =dist .sample ()
discrete_actions .append (a .item ())
log_probs_discrete .append (dist .log_prob (a ))
log_prob_discrete =torch .stack (log_probs_discrete ).sum ()
std =log_std .exp ()
dist_cont =Normal (mu ,std )
cont_sample =dist_cont .sample ()
cont_action =cont_sample .clamp (0.0 ,1.0 )
log_prob_cont =dist_cont .log_prob (cont_sample ).sum ()
total_log_prob =log_prob_discrete +log_prob_cont
return (
np .array (discrete_actions ,dtype =np .int32 ),
cont_action .detach ().cpu ().numpy (),
total_log_prob ,
value ,
)
def evaluate_actions (
self ,
obs :torch .Tensor ,
discrete_actions :torch .Tensor ,
cont_actions :torch .Tensor ,
)->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
logits ,mu ,log_std ,value =self .forward (obs )
log_prob_d =torch .zeros (obs .shape [0 ],device =obs .device )
entropy_d =torch .zeros (obs .shape [0 ],device =obs .device )
for i ,lg in enumerate (logits ):
dist =Categorical (logits =lg )
log_prob_d +=dist .log_prob (discrete_actions [:,i ])
entropy_d +=dist .entropy ()
std =log_std .exp ()
dist_c =Normal (mu ,std )
log_prob_c =dist_c .log_prob (cont_actions ).sum (-1 )
entropy_c =dist_c .entropy ().sum (-1 )
return log_prob_d +log_prob_c ,(entropy_d +entropy_c )/2 ,value
class RolloutBuffer :
def __init__ (self ):
self .clear ()
def clear (self ):
self .obs :List [np .ndarray ]=[]
self .discrete_actions :List [np .ndarray ]=[]
self .cont_actions :List [np .ndarray ]=[]
self .rewards :List [float ]=[]
self .log_probs :List [torch .Tensor ]=[]
self .values :List [torch .Tensor ]=[]
self .dones :List [bool ]=[]
def add (self ,obs ,d_action ,c_action ,reward ,log_prob ,value ,done ):
self .obs .append (obs )
self .discrete_actions .append (d_action )
self .cont_actions .append (c_action )
self .rewards .append (reward )
self .log_probs .append (log_prob )
self .values .append (value )
self .dones .append (done )
def compute_returns (self ,last_value :float ,gamma :float =0.99 ,gae_lambda :float =0.95 ):
advantages =[]
gae =0.0
values =[v .item ()for v in self .values ]+[last_value ]
for t in reversed (range (len (self .rewards ))):
delta =self .rewards [t ]+gamma *values [t +1 ]*(1 -self .dones [t ])-values [t ]
gae =delta +gamma *gae_lambda *(1 -self .dones [t ])*gae
advantages .insert (0 ,gae )
returns =[a +v .item ()for a ,v in zip (advantages ,self .values )]
return advantages ,returns
def to_tensors (self ,device :torch .device ):
obs =torch .FloatTensor (np .stack (self .obs )).to (device )
d_act =torch .LongTensor (np .stack (self .discrete_actions )).to (device )
c_act =torch .FloatTensor (np .stack (self .cont_actions )).to (device )
return obs ,d_act ,c_act
class PPOTrainer :
def __init__ (
self ,
env :AlphaBypassEnv ,
device_str :str ="cuda",
lr :float =3e-4 ,
gamma :float =0.99 ,
gae_lambda :float =0.95 ,
clip_eps :float =0.2 ,
entropy_coef :float =0.01 ,
vf_coef :float =0.5 ,
max_grad_norm :float =0.5 ,
update_epochs :int =4 ,
steps_per_update :int =8 ,
checkpoint_dir :str ="checkpoints",
checkpoint_every :int =100 ,
):
self .env =env
self .device =torch .device (device_str if torch .cuda .is_available ()else "cpu")
print (f"[PPO] device: {self .device }")
self .policy =PolicyNetwork ().to (self .device )
self .optimizer =torch .optim .Adam (self .policy .parameters (),lr =lr )
self .scheduler =torch .optim .lr_scheduler .ExponentialLR (self .optimizer ,gamma =0.999 )
self .gamma =gamma
self .gae_lambda =gae_lambda
self .clip_eps =clip_eps
self .entropy_coef =entropy_coef
self .vf_coef =vf_coef
self .max_grad_norm =max_grad_norm
self .update_epochs =update_epochs
self .steps_per_update =steps_per_update
self .checkpoint_dir =checkpoint_dir
self .checkpoint_every =checkpoint_every
os .makedirs (checkpoint_dir ,exist_ok =True )
self .total_episodes =0
self .best_reward =-float ("inf")
self .reward_history :List [float ]=[]
def _build_action_masks (self ,obs_tensor :torch .Tensor )->List [Optional [torch .Tensor ]]:
return [None ]*len (DISCRETE_NVEC )
def collect_rollout (self )->RolloutBuffer :
buffer =RolloutBuffer ()
obs =self .env ._build_obs ()
for _ in range (self .steps_per_update ):
obs_t =torch .FloatTensor (obs ).unsqueeze (0 ).to (self .device )
with torch .no_grad ():
masks =self ._build_action_masks (obs_t )
d_action ,c_action ,log_prob ,value =self .policy .get_action_and_log_prob (
obs_t .squeeze (0 ),masks
)
next_obs ,reward ,done ,info =self .env .step (d_action ,c_action )
self .total_episodes +=1
self .reward_history .append (reward )
print (
f"[Ep {self .total_episodes :04d}] "
f"reward={reward :+.4f} {reward_to_label (reward )} | "
f"transport={info ['transport']:5s} dest={info ['dest']:30s} | "
f"stable={info ['stability']:.2f} "
f"speed={info ['throughput_mbps']:.2f}Mbps"
)
buffer .add (obs ,d_action ,c_action ,reward ,log_prob ,value ,done )
obs =next_obs
if done :
obs =self .env .reset ()
if self .total_episodes %self .checkpoint_every ==0 :
self .save_checkpoint ()
return buffer
def update (self ,buffer :RolloutBuffer ):
print (f"\n[PPO] ── Update #{self .total_episodes //self .steps_per_update } ──────────────────────────")
print (f"[PPO] Buffer: {len (buffer .rewards )} episodes | "
f"mean_reward={sum (buffer .rewards )/len (buffer .rewards ):+.4f} | "
f"positive={sum (1 for r in buffer .rewards if r >0 )}/{len (buffer .rewards )}")
obs_t ,d_act_t ,c_act_t =buffer .to_tensors (self .device )
with torch .no_grad ():
last_obs =torch .FloatTensor (self .env ._build_obs ()).to (self .device )
_ ,_ ,_ ,last_val =self .policy .forward (last_obs .unsqueeze (0 ))
last_value =last_val .item ()
advantages ,returns =buffer .compute_returns (last_value ,self .gamma ,self .gae_lambda )
adv_t =torch .FloatTensor (advantages ).to (self .device )
ret_t =torch .FloatTensor (returns ).to (self .device )
adv_t =(adv_t -adv_t .mean ())/(adv_t .std ()+1e-8 )
old_log_probs =torch .stack (buffer .log_probs ).to (self .device ).detach ()
for _ in range (self .update_epochs ):
log_probs ,entropy ,values =self .policy .evaluate_actions (obs_t ,d_act_t ,c_act_t )
ratio =(log_probs -old_log_probs ).exp ()
surr1 =ratio *adv_t
surr2 =ratio .clamp (1 -self .clip_eps ,1 +self .clip_eps )*adv_t
policy_loss =-torch .min (surr1 ,surr2 ).mean ()
value_loss =F .mse_loss (values ,ret_t )
entropy_loss =-entropy .mean ()
loss =policy_loss +self .vf_coef *value_loss +self .entropy_coef *entropy_loss
self .optimizer .zero_grad ()
loss .backward ()
nn .utils .clip_grad_norm_ (self .policy .parameters (),self .max_grad_norm )
self .optimizer .step ()
self .scheduler .step ()
def train (self ,total_episodes :int =10000 ):
print (f"\n{'='*60 }")
print (f" AlphaBypass β€” PPO Training")
print (f" Target: {total_episodes } episodes")
print (f" Device: {self .device }")
print (f"{'='*60 }\n")
obs =self .env .reset ()
while self .total_episodes <total_episodes :
buffer =self .collect_rollout ()
self .update (buffer )
if len (self .reward_history )>=20 :
recent =self .reward_history [-20 :]
print (
f"\n[Stats] last 20 episodes: "
f"mean={np .mean (recent ):+.4f} "
f"max={np .max (recent ):+.4f} "
f"min={np .min (recent ):+.4f}\n"
)
def save_checkpoint (self ,tag :str =""):
path =os .path .join (
self .checkpoint_dir ,
f"checkpoint_ep{self .total_episodes :05d}{tag }.pt"
)
torch .save ({
"episode":self .total_episodes ,
"policy_state":self .policy .state_dict (),
"optimizer_state":self .optimizer .state_dict (),
"reward_history":self .reward_history ,
"best_reward":self .best_reward ,
},path )
print (f"[Checkpoint] saved β†’ {path }")
r =np .mean (self .reward_history [-10 :])if len (self .reward_history )>=10 else -999
if r >self .best_reward :
self .best_reward =r
best_path =os .path .join (self .checkpoint_dir ,"best.pt")
torch .save (torch .load (path ),best_path )
print (f"[Checkpoint] πŸ† new best ({r :+.4f}) β†’ {best_path }")
def load_checkpoint (self ,path :str ):
ck =torch .load (path ,map_location =self .device )
self .policy .load_state_dict (ck ["policy_state"])
self .optimizer .load_state_dict (ck ["optimizer_state"])
self .total_episodes =ck ["episode"]
self .reward_history =ck .get ("reward_history",[])
self .best_reward =ck .get ("best_reward",-float ("inf"))
print (f"[Checkpoint] loaded from ep {self .total_episodes }")