from __future__ import annotations import math import time import uuid import numpy as np from typing import Optional ,Tuple from shared .schemas import ( VlessConfig , EpisodeMetrics , TRANSPORT_TYPES , CANDIDATE_PORTS , SNI_DOMAINS , FINGERPRINTS , ALPN_OPTIONS , FRAGMENT_STRATEGIES , MUX_CONCURRENCY_VALUES , SHORT_ID_LENGTHS , XHTTP_MODES , ) from server .rl .reward import compute_reward HISTORY_LEN =10 OBS_PER_EPISODE =7 TOTAL_OBS_DIM =HISTORY_LEN *OBS_PER_EPISODE +5 DISCRETE_NVEC =[ len (TRANSPORT_TYPES ), len (CANDIDATE_PORTS ), len (SNI_DOMAINS ), len (FINGERPRINTS ), len (ALPN_OPTIONS ), len (FRAGMENT_STRATEGIES ), len (MUX_CONCURRENCY_VALUES ), len (SHORT_ID_LENGTHS ), len (XHTTP_MODES ), 100 , ] N_CONTINUOUS =5 def decode_action (discrete :np .ndarray ,continuous :np .ndarray )->VlessConfig : import secrets as _sec transport =TRANSPORT_TYPES [int (discrete [0 ])] port =CANDIDATE_PORTS [int (discrete [1 ])] dest =SNI_DOMAINS [int (discrete [2 ])] fingerprint =FINGERPRINTS [int (discrete [3 ])] alpn =ALPN_OPTIONS [int (discrete [4 ])] frag_strat =FRAGMENT_STRATEGIES [int (discrete [5 ])] mux_conc =MUX_CONCURRENCY_VALUES [int (discrete [6 ])] sid_len =SHORT_ID_LENGTHS [int (discrete [7 ])] xhttp_mode =XHTTP_MODES [int (discrete [8 ])] grpc_seed =int (discrete [9 ]) c =continuous frag_len_min =int (10 +c [0 ]*190 ) frag_len_max =frag_len_min +int (c [1 ]*100 ) frag_interval_min =int (c [2 ]*50 ) frag_interval_max =frag_interval_min +5 padding_min =int (c [3 ]*500 ) padding_max =padding_min +int (c [4 ]*500 ) from server .config_generator import random_service_name ,generate_short_id grpc_name =random_service_name (grpc_seed ) short_id =generate_short_id (sid_len ) if transport !="tcp": frag_strat ="none" padding_enabled =(transport =="tcp")and (padding_min >0 ) return VlessConfig ( transport_type =transport , proxy_port =port , dest_domain =dest , short_id =short_id , spider_x ="/", fingerprint =fingerprint , alpn =alpn , grpc_service_name =grpc_name , xhttp_mode =xhttp_mode , fragment_strategy =frag_strat , fragment_length_min =frag_len_min , fragment_length_max =frag_len_max , fragment_interval_min =frag_interval_min , fragment_interval_max =frag_interval_max , padding_enabled =padding_enabled , padding_min =padding_min , padding_max =padding_max , mux_concurrency =mux_conc , ) def metrics_to_obs_vector (m :EpisodeMetrics )->np .ndarray : return np .array ([ float (m .connected ), min (m .stability_ratio ,1.0 ), min (m .throughput_ratio ,1.0 ), max (0.0 ,1.0 -m .avg_ping_ms /1000.0 ), 1.0 -min (m .packet_loss_ratio ,1.0 ), max (0.0 ,1.0 -m .connect_time_ms /5000.0 ), max (0.0 ,1.0 -m .reconnect_count /5.0 ), ],dtype =np .float32 ) class AlphaBypassEnv : def __init__ ( self , bridge , episode_duration :int =90 , baseline_mbps :float =1.0 , max_steps :int =0 , fail_streak_warn :int =10 , ): self .bridge =bridge self .episode_duration =episode_duration self .baseline_mbps =baseline_mbps self .max_steps =max_steps self .fail_streak_warn =fail_streak_warn self .history :list [EpisodeMetrics ]=[] self .step_count :int =0 self ._fail_streak :int =0 self .obs_dim =TOTAL_OBS_DIM self .discrete_nvec =DISCRETE_NVEC self .n_continuous =N_CONTINUOUS def _build_obs (self )->np .ndarray : obs =np .zeros (self .obs_dim ,dtype =np .float32 ) relevant =self .history [-HISTORY_LEN :] for i ,m in enumerate (reversed (relevant )): start =i *OBS_PER_EPISODE obs [start :start +OBS_PER_EPISODE ]=metrics_to_obs_vector (m ) base =HISTORY_LEN *OBS_PER_EPISODE t =time .localtime () hour =t .tm_hour +t .tm_min /60.0 obs [base ]=min (self .step_count /1000.0 ,1.0 ) obs [base +1 ]=math .sin (2 *math .pi *hour /24 ) obs [base +2 ]=math .cos (2 *math .pi *hour /24 ) dow =t .tm_wday obs [base +3 ]=math .sin (2 *math .pi *dow /7 ) obs [base +4 ]=math .cos (2 *math .pi *dow /7 ) return obs def reset (self )->np .ndarray : self .history =[] self .step_count =0 self ._fail_streak =0 return self ._build_obs () def step ( self , discrete_action :np .ndarray , continuous_action :np .ndarray , )->Tuple [np .ndarray ,float ,bool ,dict ]: cfg =decode_action (discrete_action ,continuous_action ) episode_id =str (uuid .uuid4 ())[:8 ] metrics =self .bridge .run_episode ( cfg =cfg , episode_id =episode_id , duration =self .episode_duration , ) reward =compute_reward (metrics ,self .baseline_mbps ) self .bridge .report_reward (episode_id ,reward ) self .history .append (metrics ) self .step_count +=1 if not metrics .connected : self ._fail_streak +=1 if self ._fail_streak ==self .fail_streak_warn : print ( f"\n⚠️ [Degradation] {self ._fail_streak } FAIL подряд! " f"Возможно РКН изменил политику или проблема с сетью." ) else : self ._fail_streak =0 done =(self .max_steps >0 and self .step_count >=self .max_steps ) obs =self ._build_obs () info ={ "episode_id":episode_id , "reward":reward , "connected":metrics .connected , "stability":metrics .stability_ratio , "throughput_mbps":metrics .throughput_mbps , "transport":cfg .transport_type , "dest":cfg .dest_domain , "fail_streak":self ._fail_streak , } return obs ,reward ,done ,info