| from __future__ import annotations | |
| import os | |
| import time | |
| import json | |
| import numpy as np | |
| import torch | |
| import torch .nn as nn | |
| import torch .nn .functional as F | |
| from torch .distributions import Categorical ,Normal | |
| from typing import Optional ,Tuple ,List | |
| from server .rl .environment import ( | |
| DISCRETE_NVEC , | |
| N_CONTINUOUS , | |
| TOTAL_OBS_DIM , | |
| AlphaBypassEnv , | |
| ) | |
| from server .rl .reward import reward_to_label | |
| class PolicyNetwork (nn .Module ): | |
| def __init__ ( | |
| self , | |
| obs_dim :int =TOTAL_OBS_DIM , | |
| hidden :int =512 , | |
| discrete_nvec :List [int ]=DISCRETE_NVEC , | |
| n_continuous :int =N_CONTINUOUS , | |
| ): | |
| super ().__init__ () | |
| self .discrete_nvec =discrete_nvec | |
| self .n_continuous =n_continuous | |
| self .trunk =nn .Sequential ( | |
| nn .Linear (obs_dim ,hidden ), | |
| nn .LayerNorm (hidden ), | |
| nn .ReLU (), | |
| nn .Linear (hidden ,hidden ), | |
| nn .LayerNorm (hidden ), | |
| nn .ReLU (), | |
| nn .Linear (hidden ,hidden ), | |
| nn .LayerNorm (hidden ), | |
| nn .ReLU (), | |
| ) | |
| self .discrete_heads =nn .ModuleList ([ | |
| nn .Linear (hidden ,n )for n in discrete_nvec | |
| ]) | |
| self .cont_mu =nn .Linear (hidden ,n_continuous ) | |
| self .cont_log_std =nn .Parameter (torch .zeros (n_continuous )) | |
| self .value_head =nn .Sequential ( | |
| nn .Linear (hidden ,256 ), | |
| nn .ReLU (), | |
| nn .Linear (256 ,1 ), | |
| ) | |
| def forward (self ,obs :torch .Tensor ): | |
| h =self .trunk (obs ) | |
| logits =[head (h )for head in self .discrete_heads ] | |
| mu =torch .sigmoid (self .cont_mu (h )) | |
| log_std =self .cont_log_std .clamp (-4 ,0 ) | |
| value =self .value_head (h ).squeeze (-1 ) | |
| return logits ,mu ,log_std ,value | |
| def get_action_and_log_prob ( | |
| self , | |
| obs :torch .Tensor , | |
| action_masks :Optional [List [Optional [torch .Tensor ]]]=None , | |
| )->Tuple [np .ndarray ,np .ndarray ,torch .Tensor ,torch .Tensor ]: | |
| logits ,mu ,log_std ,value =self .forward (obs ) | |
| discrete_actions =[] | |
| log_probs_discrete =[] | |
| for i ,(lg ,n )in enumerate (zip (logits ,self .discrete_nvec )): | |
| if action_masks and action_masks [i ]is not None : | |
| mask =action_masks [i ].to (lg .device ) | |
| lg =lg .masked_fill (~mask ,float ("-inf")) | |
| dist =Categorical (logits =lg ) | |
| a =dist .sample () | |
| discrete_actions .append (a .item ()) | |
| log_probs_discrete .append (dist .log_prob (a )) | |
| log_prob_discrete =torch .stack (log_probs_discrete ).sum () | |
| std =log_std .exp () | |
| dist_cont =Normal (mu ,std ) | |
| cont_sample =dist_cont .sample () | |
| cont_action =cont_sample .clamp (0.0 ,1.0 ) | |
| log_prob_cont =dist_cont .log_prob (cont_sample ).sum () | |
| total_log_prob =log_prob_discrete +log_prob_cont | |
| return ( | |
| np .array (discrete_actions ,dtype =np .int32 ), | |
| cont_action .detach ().cpu ().numpy (), | |
| total_log_prob , | |
| value , | |
| ) | |
| def evaluate_actions ( | |
| self , | |
| obs :torch .Tensor , | |
| discrete_actions :torch .Tensor , | |
| cont_actions :torch .Tensor , | |
| )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]: | |
| logits ,mu ,log_std ,value =self .forward (obs ) | |
| log_prob_d =torch .zeros (obs .shape [0 ],device =obs .device ) | |
| entropy_d =torch .zeros (obs .shape [0 ],device =obs .device ) | |
| for i ,lg in enumerate (logits ): | |
| dist =Categorical (logits =lg ) | |
| log_prob_d +=dist .log_prob (discrete_actions [:,i ]) | |
| entropy_d +=dist .entropy () | |
| std =log_std .exp () | |
| dist_c =Normal (mu ,std ) | |
| log_prob_c =dist_c .log_prob (cont_actions ).sum (-1 ) | |
| entropy_c =dist_c .entropy ().sum (-1 ) | |
| return log_prob_d +log_prob_c ,(entropy_d +entropy_c )/2 ,value | |
| class RolloutBuffer : | |
| def __init__ (self ): | |
| self .clear () | |
| def clear (self ): | |
| self .obs :List [np .ndarray ]=[] | |
| self .discrete_actions :List [np .ndarray ]=[] | |
| self .cont_actions :List [np .ndarray ]=[] | |
| self .rewards :List [float ]=[] | |
| self .log_probs :List [torch .Tensor ]=[] | |
| self .values :List [torch .Tensor ]=[] | |
| self .dones :List [bool ]=[] | |
| def add (self ,obs ,d_action ,c_action ,reward ,log_prob ,value ,done ): | |
| self .obs .append (obs ) | |
| self .discrete_actions .append (d_action ) | |
| self .cont_actions .append (c_action ) | |
| self .rewards .append (reward ) | |
| self .log_probs .append (log_prob ) | |
| self .values .append (value ) | |
| self .dones .append (done ) | |
| def compute_returns (self ,last_value :float ,gamma :float =0.99 ,gae_lambda :float =0.95 ): | |
| advantages =[] | |
| gae =0.0 | |
| values =[v .item ()for v in self .values ]+[last_value ] | |
| for t in reversed (range (len (self .rewards ))): | |
| delta =self .rewards [t ]+gamma *values [t +1 ]*(1 -self .dones [t ])-values [t ] | |
| gae =delta +gamma *gae_lambda *(1 -self .dones [t ])*gae | |
| advantages .insert (0 ,gae ) | |
| returns =[a +v .item ()for a ,v in zip (advantages ,self .values )] | |
| return advantages ,returns | |
| def to_tensors (self ,device :torch .device ): | |
| obs =torch .FloatTensor (np .stack (self .obs )).to (device ) | |
| d_act =torch .LongTensor (np .stack (self .discrete_actions )).to (device ) | |
| c_act =torch .FloatTensor (np .stack (self .cont_actions )).to (device ) | |
| return obs ,d_act ,c_act | |
| class PPOTrainer : | |
| def __init__ ( | |
| self , | |
| env :AlphaBypassEnv , | |
| device_str :str ="cuda", | |
| lr :float =3e-4 , | |
| gamma :float =0.99 , | |
| gae_lambda :float =0.95 , | |
| clip_eps :float =0.2 , | |
| entropy_coef :float =0.01 , | |
| vf_coef :float =0.5 , | |
| max_grad_norm :float =0.5 , | |
| update_epochs :int =4 , | |
| steps_per_update :int =8 , | |
| checkpoint_dir :str ="checkpoints", | |
| checkpoint_every :int =100 , | |
| ): | |
| self .env =env | |
| self .device =torch .device (device_str if torch .cuda .is_available ()else "cpu") | |
| print (f"[PPO] device: {self .device }") | |
| self .policy =PolicyNetwork ().to (self .device ) | |
| self .optimizer =torch .optim .Adam (self .policy .parameters (),lr =lr ) | |
| self .scheduler =torch .optim .lr_scheduler .ExponentialLR (self .optimizer ,gamma =0.999 ) | |
| self .gamma =gamma | |
| self .gae_lambda =gae_lambda | |
| self .clip_eps =clip_eps | |
| self .entropy_coef =entropy_coef | |
| self .vf_coef =vf_coef | |
| self .max_grad_norm =max_grad_norm | |
| self .update_epochs =update_epochs | |
| self .steps_per_update =steps_per_update | |
| self .checkpoint_dir =checkpoint_dir | |
| self .checkpoint_every =checkpoint_every | |
| os .makedirs (checkpoint_dir ,exist_ok =True ) | |
| self .total_episodes =0 | |
| self .best_reward =-float ("inf") | |
| self .reward_history :List [float ]=[] | |
| def _build_action_masks (self ,obs_tensor :torch .Tensor )->List [Optional [torch .Tensor ]]: | |
| return [None ]*len (DISCRETE_NVEC ) | |
| def collect_rollout (self )->RolloutBuffer : | |
| buffer =RolloutBuffer () | |
| obs =self .env ._build_obs () | |
| for _ in range (self .steps_per_update ): | |
| obs_t =torch .FloatTensor (obs ).unsqueeze (0 ).to (self .device ) | |
| with torch .no_grad (): | |
| masks =self ._build_action_masks (obs_t ) | |
| d_action ,c_action ,log_prob ,value =self .policy .get_action_and_log_prob ( | |
| obs_t .squeeze (0 ),masks | |
| ) | |
| next_obs ,reward ,done ,info =self .env .step (d_action ,c_action ) | |
| self .total_episodes +=1 | |
| self .reward_history .append (reward ) | |
| print ( | |
| f"[Ep {self .total_episodes :04d}] " | |
| f"reward={reward :+.4f} {reward_to_label (reward )} | " | |
| f"transport={info ['transport']:5s} dest={info ['dest']:30s} | " | |
| f"stable={info ['stability']:.2f} " | |
| f"speed={info ['throughput_mbps']:.2f}Mbps" | |
| ) | |
| buffer .add (obs ,d_action ,c_action ,reward ,log_prob ,value ,done ) | |
| obs =next_obs | |
| if done : | |
| obs =self .env .reset () | |
| if self .total_episodes %self .checkpoint_every ==0 : | |
| self .save_checkpoint () | |
| return buffer | |
| def update (self ,buffer :RolloutBuffer ): | |
| print (f"\n[PPO] ββ Update #{self .total_episodes //self .steps_per_update } ββββββββββββββββββββββββββ") | |
| print (f"[PPO] Buffer: {len (buffer .rewards )} episodes | " | |
| f"mean_reward={sum (buffer .rewards )/len (buffer .rewards ):+.4f} | " | |
| f"positive={sum (1 for r in buffer .rewards if r >0 )}/{len (buffer .rewards )}") | |
| obs_t ,d_act_t ,c_act_t =buffer .to_tensors (self .device ) | |
| with torch .no_grad (): | |
| last_obs =torch .FloatTensor (self .env ._build_obs ()).to (self .device ) | |
| _ ,_ ,_ ,last_val =self .policy .forward (last_obs .unsqueeze (0 )) | |
| last_value =last_val .item () | |
| advantages ,returns =buffer .compute_returns (last_value ,self .gamma ,self .gae_lambda ) | |
| adv_t =torch .FloatTensor (advantages ).to (self .device ) | |
| ret_t =torch .FloatTensor (returns ).to (self .device ) | |
| adv_t =(adv_t -adv_t .mean ())/(adv_t .std ()+1e-8 ) | |
| old_log_probs =torch .stack (buffer .log_probs ).to (self .device ).detach () | |
| for _ in range (self .update_epochs ): | |
| log_probs ,entropy ,values =self .policy .evaluate_actions (obs_t ,d_act_t ,c_act_t ) | |
| ratio =(log_probs -old_log_probs ).exp () | |
| surr1 =ratio *adv_t | |
| surr2 =ratio .clamp (1 -self .clip_eps ,1 +self .clip_eps )*adv_t | |
| policy_loss =-torch .min (surr1 ,surr2 ).mean () | |
| value_loss =F .mse_loss (values ,ret_t ) | |
| entropy_loss =-entropy .mean () | |
| loss =policy_loss +self .vf_coef *value_loss +self .entropy_coef *entropy_loss | |
| self .optimizer .zero_grad () | |
| loss .backward () | |
| nn .utils .clip_grad_norm_ (self .policy .parameters (),self .max_grad_norm ) | |
| self .optimizer .step () | |
| self .scheduler .step () | |
| def train (self ,total_episodes :int =10000 ): | |
| print (f"\n{'='*60 }") | |
| print (f" AlphaBypass β PPO Training") | |
| print (f" Target: {total_episodes } episodes") | |
| print (f" Device: {self .device }") | |
| print (f"{'='*60 }\n") | |
| obs =self .env .reset () | |
| while self .total_episodes <total_episodes : | |
| buffer =self .collect_rollout () | |
| self .update (buffer ) | |
| if len (self .reward_history )>=20 : | |
| recent =self .reward_history [-20 :] | |
| print ( | |
| f"\n[Stats] last 20 episodes: " | |
| f"mean={np .mean (recent ):+.4f} " | |
| f"max={np .max (recent ):+.4f} " | |
| f"min={np .min (recent ):+.4f}\n" | |
| ) | |
| def save_checkpoint (self ,tag :str =""): | |
| path =os .path .join ( | |
| self .checkpoint_dir , | |
| f"checkpoint_ep{self .total_episodes :05d}{tag }.pt" | |
| ) | |
| torch .save ({ | |
| "episode":self .total_episodes , | |
| "policy_state":self .policy .state_dict (), | |
| "optimizer_state":self .optimizer .state_dict (), | |
| "reward_history":self .reward_history , | |
| "best_reward":self .best_reward , | |
| },path ) | |
| print (f"[Checkpoint] saved β {path }") | |
| r =np .mean (self .reward_history [-10 :])if len (self .reward_history )>=10 else -999 | |
| if r >self .best_reward : | |
| self .best_reward =r | |
| best_path =os .path .join (self .checkpoint_dir ,"best.pt") | |
| torch .save (torch .load (path ),best_path ) | |
| print (f"[Checkpoint] π new best ({r :+.4f}) β {best_path }") | |
| def load_checkpoint (self ,path :str ): | |
| ck =torch .load (path ,map_location =self .device ) | |
| self .policy .load_state_dict (ck ["policy_state"]) | |
| self .optimizer .load_state_dict (ck ["optimizer_state"]) | |
| self .total_episodes =ck ["episode"] | |
| self .reward_history =ck .get ("reward_history",[]) | |
| self .best_reward =ck .get ("best_reward",-float ("inf")) | |
| print (f"[Checkpoint] loaded from ep {self .total_episodes }") | |