from __future__ import annotations import os import time import json import numpy as np import torch import torch .nn as nn import torch .nn .functional as F from torch .distributions import Categorical ,Normal from typing import Optional ,Tuple ,List from server .rl .environment import ( DISCRETE_NVEC , N_CONTINUOUS , TOTAL_OBS_DIM , AlphaBypassEnv , ) from server .rl .reward import reward_to_label class PolicyNetwork (nn .Module ): def __init__ ( self , obs_dim :int =TOTAL_OBS_DIM , hidden :int =512 , discrete_nvec :List [int ]=DISCRETE_NVEC , n_continuous :int =N_CONTINUOUS , ): super ().__init__ () self .discrete_nvec =discrete_nvec self .n_continuous =n_continuous self .trunk =nn .Sequential ( nn .Linear (obs_dim ,hidden ), nn .LayerNorm (hidden ), nn .ReLU (), nn .Linear (hidden ,hidden ), nn .LayerNorm (hidden ), nn .ReLU (), nn .Linear (hidden ,hidden ), nn .LayerNorm (hidden ), nn .ReLU (), ) self .discrete_heads =nn .ModuleList ([ nn .Linear (hidden ,n )for n in discrete_nvec ]) self .cont_mu =nn .Linear (hidden ,n_continuous ) self .cont_log_std =nn .Parameter (torch .zeros (n_continuous )) self .value_head =nn .Sequential ( nn .Linear (hidden ,256 ), nn .ReLU (), nn .Linear (256 ,1 ), ) def forward (self ,obs :torch .Tensor ): h =self .trunk (obs ) logits =[head (h )for head in self .discrete_heads ] mu =torch .sigmoid (self .cont_mu (h )) log_std =self .cont_log_std .clamp (-4 ,0 ) value =self .value_head (h ).squeeze (-1 ) return logits ,mu ,log_std ,value def get_action_and_log_prob ( self , obs :torch .Tensor , action_masks :Optional [List [Optional [torch .Tensor ]]]=None , )->Tuple [np .ndarray ,np .ndarray ,torch .Tensor ,torch .Tensor ]: logits ,mu ,log_std ,value =self .forward (obs ) discrete_actions =[] log_probs_discrete =[] for i ,(lg ,n )in enumerate (zip (logits ,self .discrete_nvec )): if action_masks and action_masks [i ]is not None : mask =action_masks [i ].to (lg .device ) lg =lg .masked_fill (~mask ,float ("-inf")) dist =Categorical (logits =lg ) a =dist .sample () discrete_actions .append (a .item ()) log_probs_discrete .append (dist .log_prob (a )) log_prob_discrete =torch .stack (log_probs_discrete ).sum () std =log_std .exp () dist_cont =Normal (mu ,std ) cont_sample =dist_cont .sample () cont_action =cont_sample .clamp (0.0 ,1.0 ) log_prob_cont =dist_cont .log_prob (cont_sample ).sum () total_log_prob =log_prob_discrete +log_prob_cont return ( np .array (discrete_actions ,dtype =np .int32 ), cont_action .detach ().cpu ().numpy (), total_log_prob , value , ) def evaluate_actions ( self , obs :torch .Tensor , discrete_actions :torch .Tensor , cont_actions :torch .Tensor , )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: logits ,mu ,log_std ,value =self .forward (obs ) log_prob_d =torch .zeros (obs .shape [0 ],device =obs .device ) entropy_d =torch .zeros (obs .shape [0 ],device =obs .device ) for i ,lg in enumerate (logits ): dist =Categorical (logits =lg ) log_prob_d +=dist .log_prob (discrete_actions [:,i ]) entropy_d +=dist .entropy () std =log_std .exp () dist_c =Normal (mu ,std ) log_prob_c =dist_c .log_prob (cont_actions ).sum (-1 ) entropy_c =dist_c .entropy ().sum (-1 ) return log_prob_d +log_prob_c ,(entropy_d +entropy_c )/2 ,value class RolloutBuffer : def __init__ (self ): self .clear () def clear (self ): self .obs :List [np .ndarray ]=[] self .discrete_actions :List [np .ndarray ]=[] self .cont_actions :List [np .ndarray ]=[] self .rewards :List [float ]=[] self .log_probs :List [torch .Tensor ]=[] self .values :List [torch .Tensor ]=[] self .dones :List [bool ]=[] def add (self ,obs ,d_action ,c_action ,reward ,log_prob ,value ,done ): self .obs .append (obs ) self .discrete_actions .append (d_action ) self .cont_actions .append (c_action ) self .rewards .append (reward ) self .log_probs .append (log_prob ) self .values .append (value ) self .dones .append (done ) def compute_returns (self ,last_value :float ,gamma :float =0.99 ,gae_lambda :float =0.95 ): advantages =[] gae =0.0 values =[v .item ()for v in self .values ]+[last_value ] for t in reversed (range (len (self .rewards ))): delta =self .rewards [t ]+gamma *values [t +1 ]*(1 -self .dones [t ])-values [t ] gae =delta +gamma *gae_lambda *(1 -self .dones [t ])*gae advantages .insert (0 ,gae ) returns =[a +v .item ()for a ,v in zip (advantages ,self .values )] return advantages ,returns def to_tensors (self ,device :torch .device ): obs =torch .FloatTensor (np .stack (self .obs )).to (device ) d_act =torch .LongTensor (np .stack (self .discrete_actions )).to (device ) c_act =torch .FloatTensor (np .stack (self .cont_actions )).to (device ) return obs ,d_act ,c_act class PPOTrainer : def __init__ ( self , env :AlphaBypassEnv , device_str :str ="cuda", lr :float =3e-4 , gamma :float =0.99 , gae_lambda :float =0.95 , clip_eps :float =0.2 , entropy_coef :float =0.01 , vf_coef :float =0.5 , max_grad_norm :float =0.5 , update_epochs :int =4 , steps_per_update :int =8 , checkpoint_dir :str ="checkpoints", checkpoint_every :int =100 , ): self .env =env self .device =torch .device (device_str if torch .cuda .is_available ()else "cpu") print (f"[PPO] device: {self .device }") self .policy =PolicyNetwork ().to (self .device ) self .optimizer =torch .optim .Adam (self .policy .parameters (),lr =lr ) self .scheduler =torch .optim .lr_scheduler .ExponentialLR (self .optimizer ,gamma =0.999 ) self .gamma =gamma self .gae_lambda =gae_lambda self .clip_eps =clip_eps self .entropy_coef =entropy_coef self .vf_coef =vf_coef self .max_grad_norm =max_grad_norm self .update_epochs =update_epochs self .steps_per_update =steps_per_update self .checkpoint_dir =checkpoint_dir self .checkpoint_every =checkpoint_every os .makedirs (checkpoint_dir ,exist_ok =True ) self .total_episodes =0 self .best_reward =-float ("inf") self .reward_history :List [float ]=[] def _build_action_masks (self ,obs_tensor :torch .Tensor )->List [Optional [torch .Tensor ]]: return [None ]*len (DISCRETE_NVEC ) def collect_rollout (self )->RolloutBuffer : buffer =RolloutBuffer () obs =self .env ._build_obs () for _ in range (self .steps_per_update ): obs_t =torch .FloatTensor (obs ).unsqueeze (0 ).to (self .device ) with torch .no_grad (): masks =self ._build_action_masks (obs_t ) d_action ,c_action ,log_prob ,value =self .policy .get_action_and_log_prob ( obs_t .squeeze (0 ),masks ) next_obs ,reward ,done ,info =self .env .step (d_action ,c_action ) self .total_episodes +=1 self .reward_history .append (reward ) print ( f"[Ep {self .total_episodes :04d}] " f"reward={reward :+.4f} {reward_to_label (reward )} | " f"transport={info ['transport']:5s} dest={info ['dest']:30s} | " f"stable={info ['stability']:.2f} " f"speed={info ['throughput_mbps']:.2f}Mbps" ) buffer .add (obs ,d_action ,c_action ,reward ,log_prob ,value ,done ) obs =next_obs if done : obs =self .env .reset () if self .total_episodes %self .checkpoint_every ==0 : self .save_checkpoint () return buffer def update (self ,buffer :RolloutBuffer ): print (f"\n[PPO] ── Update #{self .total_episodes //self .steps_per_update } ──────────────────────────") print (f"[PPO] Buffer: {len (buffer .rewards )} episodes | " f"mean_reward={sum (buffer .rewards )/len (buffer .rewards ):+.4f} | " f"positive={sum (1 for r in buffer .rewards if r >0 )}/{len (buffer .rewards )}") obs_t ,d_act_t ,c_act_t =buffer .to_tensors (self .device ) with torch .no_grad (): last_obs =torch .FloatTensor (self .env ._build_obs ()).to (self .device ) _ ,_ ,_ ,last_val =self .policy .forward (last_obs .unsqueeze (0 )) last_value =last_val .item () advantages ,returns =buffer .compute_returns (last_value ,self .gamma ,self .gae_lambda ) adv_t =torch .FloatTensor (advantages ).to (self .device ) ret_t =torch .FloatTensor (returns ).to (self .device ) adv_t =(adv_t -adv_t .mean ())/(adv_t .std ()+1e-8 ) old_log_probs =torch .stack (buffer .log_probs ).to (self .device ).detach () for _ in range (self .update_epochs ): log_probs ,entropy ,values =self .policy .evaluate_actions (obs_t ,d_act_t ,c_act_t ) ratio =(log_probs -old_log_probs ).exp () surr1 =ratio *adv_t surr2 =ratio .clamp (1 -self .clip_eps ,1 +self .clip_eps )*adv_t policy_loss =-torch .min (surr1 ,surr2 ).mean () value_loss =F .mse_loss (values ,ret_t ) entropy_loss =-entropy .mean () loss =policy_loss +self .vf_coef *value_loss +self .entropy_coef *entropy_loss self .optimizer .zero_grad () loss .backward () nn .utils .clip_grad_norm_ (self .policy .parameters (),self .max_grad_norm ) self .optimizer .step () self .scheduler .step () def train (self ,total_episodes :int =10000 ): print (f"\n{'='*60 }") print (f" AlphaBypass — PPO Training") print (f" Target: {total_episodes } episodes") print (f" Device: {self .device }") print (f"{'='*60 }\n") obs =self .env .reset () while self .total_episodes =20 : recent =self .reward_history [-20 :] print ( f"\n[Stats] last 20 episodes: " f"mean={np .mean (recent ):+.4f} " f"max={np .max (recent ):+.4f} " f"min={np .min (recent ):+.4f}\n" ) def save_checkpoint (self ,tag :str =""): path =os .path .join ( self .checkpoint_dir , f"checkpoint_ep{self .total_episodes :05d}{tag }.pt" ) torch .save ({ "episode":self .total_episodes , "policy_state":self .policy .state_dict (), "optimizer_state":self .optimizer .state_dict (), "reward_history":self .reward_history , "best_reward":self .best_reward , },path ) print (f"[Checkpoint] saved → {path }") r =np .mean (self .reward_history [-10 :])if len (self .reward_history )>=10 else -999 if r >self .best_reward : self .best_reward =r best_path =os .path .join (self .checkpoint_dir ,"best.pt") torch .save (torch .load (path ),best_path ) print (f"[Checkpoint] 🏆 new best ({r :+.4f}) → {best_path }") def load_checkpoint (self ,path :str ): ck =torch .load (path ,map_location =self .device ) self .policy .load_state_dict (ck ["policy_state"]) self .optimizer .load_state_dict (ck ["optimizer_state"]) self .total_episodes =ck ["episode"] self .reward_history =ck .get ("reward_history",[]) self .best_reward =ck .get ("best_reward",-float ("inf")) print (f"[Checkpoint] loaded from ep {self .total_episodes }")