| from __future__ import annotations | |
| import math | |
| import time | |
| import uuid | |
| import numpy as np | |
| from typing import Optional ,Tuple | |
| from shared .schemas import ( | |
| VlessConfig , | |
| EpisodeMetrics , | |
| TRANSPORT_TYPES , | |
| CANDIDATE_PORTS , | |
| SNI_DOMAINS , | |
| FINGERPRINTS , | |
| ALPN_OPTIONS , | |
| FRAGMENT_STRATEGIES , | |
| MUX_CONCURRENCY_VALUES , | |
| SHORT_ID_LENGTHS , | |
| XHTTP_MODES , | |
| ) | |
| from server .rl .reward import compute_reward | |
| HISTORY_LEN =10 | |
| OBS_PER_EPISODE =7 | |
| TOTAL_OBS_DIM =HISTORY_LEN *OBS_PER_EPISODE +5 | |
| DISCRETE_NVEC =[ | |
| len (TRANSPORT_TYPES ), | |
| len (CANDIDATE_PORTS ), | |
| len (SNI_DOMAINS ), | |
| len (FINGERPRINTS ), | |
| len (ALPN_OPTIONS ), | |
| len (FRAGMENT_STRATEGIES ), | |
| len (MUX_CONCURRENCY_VALUES ), | |
| len (SHORT_ID_LENGTHS ), | |
| len (XHTTP_MODES ), | |
| 100 , | |
| ] | |
| N_CONTINUOUS =5 | |
| def decode_action (discrete :np .ndarray ,continuous :np .ndarray )->VlessConfig : | |
| import secrets as _sec | |
| transport =TRANSPORT_TYPES [int (discrete [0 ])] | |
| port =CANDIDATE_PORTS [int (discrete [1 ])] | |
| dest =SNI_DOMAINS [int (discrete [2 ])] | |
| fingerprint =FINGERPRINTS [int (discrete [3 ])] | |
| alpn =ALPN_OPTIONS [int (discrete [4 ])] | |
| frag_strat =FRAGMENT_STRATEGIES [int (discrete [5 ])] | |
| mux_conc =MUX_CONCURRENCY_VALUES [int (discrete [6 ])] | |
| sid_len =SHORT_ID_LENGTHS [int (discrete [7 ])] | |
| xhttp_mode =XHTTP_MODES [int (discrete [8 ])] | |
| grpc_seed =int (discrete [9 ]) | |
| c =continuous | |
| frag_len_min =int (10 +c [0 ]*190 ) | |
| frag_len_max =frag_len_min +int (c [1 ]*100 ) | |
| frag_interval_min =int (c [2 ]*50 ) | |
| frag_interval_max =frag_interval_min +5 | |
| padding_min =int (c [3 ]*500 ) | |
| padding_max =padding_min +int (c [4 ]*500 ) | |
| from server .config_generator import random_service_name ,generate_short_id | |
| grpc_name =random_service_name (grpc_seed ) | |
| short_id =generate_short_id (sid_len ) | |
| if transport !="tcp": | |
| frag_strat ="none" | |
| padding_enabled =(transport =="tcp")and (padding_min >0 ) | |
| return VlessConfig ( | |
| transport_type =transport , | |
| proxy_port =port , | |
| dest_domain =dest , | |
| short_id =short_id , | |
| spider_x ="/", | |
| fingerprint =fingerprint , | |
| alpn =alpn , | |
| grpc_service_name =grpc_name , | |
| xhttp_mode =xhttp_mode , | |
| fragment_strategy =frag_strat , | |
| fragment_length_min =frag_len_min , | |
| fragment_length_max =frag_len_max , | |
| fragment_interval_min =frag_interval_min , | |
| fragment_interval_max =frag_interval_max , | |
| padding_enabled =padding_enabled , | |
| padding_min =padding_min , | |
| padding_max =padding_max , | |
| mux_concurrency =mux_conc , | |
| ) | |
| def metrics_to_obs_vector (m :EpisodeMetrics )->np .ndarray : | |
| return np .array ([ | |
| float (m .connected ), | |
| min (m .stability_ratio ,1.0 ), | |
| min (m .throughput_ratio ,1.0 ), | |
| max (0.0 ,1.0 -m .avg_ping_ms /1000.0 ), | |
| 1.0 -min (m .packet_loss_ratio ,1.0 ), | |
| max (0.0 ,1.0 -m .connect_time_ms /5000.0 ), | |
| max (0.0 ,1.0 -m .reconnect_count /5.0 ), | |
| ],dtype =np .float32 ) | |
| class AlphaBypassEnv : | |
| def __init__ ( | |
| self , | |
| bridge , | |
| episode_duration :int =90 , | |
| baseline_mbps :float =1.0 , | |
| max_steps :int =0 , | |
| fail_streak_warn :int =10 , | |
| ): | |
| self .bridge =bridge | |
| self .episode_duration =episode_duration | |
| self .baseline_mbps =baseline_mbps | |
| self .max_steps =max_steps | |
| self .fail_streak_warn =fail_streak_warn | |
| self .history :list [EpisodeMetrics ]=[] | |
| self .step_count :int =0 | |
| self ._fail_streak :int =0 | |
| self .obs_dim =TOTAL_OBS_DIM | |
| self .discrete_nvec =DISCRETE_NVEC | |
| self .n_continuous =N_CONTINUOUS | |
| def _build_obs (self )->np .ndarray : | |
| obs =np .zeros (self .obs_dim ,dtype =np .float32 ) | |
| relevant =self .history [-HISTORY_LEN :] | |
| for i ,m in enumerate (reversed (relevant )): | |
| start =i *OBS_PER_EPISODE | |
| obs [start :start +OBS_PER_EPISODE ]=metrics_to_obs_vector (m ) | |
| base =HISTORY_LEN *OBS_PER_EPISODE | |
| t =time .localtime () | |
| hour =t .tm_hour +t .tm_min /60.0 | |
| obs [base ]=min (self .step_count /1000.0 ,1.0 ) | |
| obs [base +1 ]=math .sin (2 *math .pi *hour /24 ) | |
| obs [base +2 ]=math .cos (2 *math .pi *hour /24 ) | |
| dow =t .tm_wday | |
| obs [base +3 ]=math .sin (2 *math .pi *dow /7 ) | |
| obs [base +4 ]=math .cos (2 *math .pi *dow /7 ) | |
| return obs | |
| def reset (self )->np .ndarray : | |
| self .history =[] | |
| self .step_count =0 | |
| self ._fail_streak =0 | |
| return self ._build_obs () | |
| def step ( | |
| self , | |
| discrete_action :np .ndarray , | |
| continuous_action :np .ndarray , | |
| )->Tuple [np .ndarray ,float ,bool ,dict ]: | |
| cfg =decode_action (discrete_action ,continuous_action ) | |
| episode_id =str (uuid .uuid4 ())[:8 ] | |
| metrics =self .bridge .run_episode ( | |
| cfg =cfg , | |
| episode_id =episode_id , | |
| duration =self .episode_duration , | |
| ) | |
| reward =compute_reward (metrics ,self .baseline_mbps ) | |
| self .bridge .report_reward (episode_id ,reward ) | |
| self .history .append (metrics ) | |
| self .step_count +=1 | |
| if not metrics .connected : | |
| self ._fail_streak +=1 | |
| if self ._fail_streak ==self .fail_streak_warn : | |
| print ( | |
| f"\n⚠️ [Degradation] {self ._fail_streak } FAIL подряд! " | |
| f"Возможно РКН изменил политику или проблема с сетью." | |
| ) | |
| else : | |
| self ._fail_streak =0 | |
| done =(self .max_steps >0 and self .step_count >=self .max_steps ) | |
| obs =self ._build_obs () | |
| info ={ | |
| "episode_id":episode_id , | |
| "reward":reward , | |
| "connected":metrics .connected , | |
| "stability":metrics .stability_ratio , | |
| "throughput_mbps":metrics .throughput_mbps , | |
| "transport":cfg .transport_type , | |
| "dest":cfg .dest_domain , | |
| "fail_streak":self ._fail_streak , | |
| } | |
| return obs ,reward ,done ,info | |