alphabypass3 / environment.py
NickupAI's picture
Upload 6 files
c7c0a38 verified
from __future__ import annotations
import math
import time
import uuid
import numpy as np
from typing import Optional ,Tuple
from shared .schemas import (
VlessConfig ,
EpisodeMetrics ,
TRANSPORT_TYPES ,
CANDIDATE_PORTS ,
SNI_DOMAINS ,
FINGERPRINTS ,
ALPN_OPTIONS ,
FRAGMENT_STRATEGIES ,
MUX_CONCURRENCY_VALUES ,
SHORT_ID_LENGTHS ,
XHTTP_MODES ,
)
from server .rl .reward import compute_reward
HISTORY_LEN =10
OBS_PER_EPISODE =7
TOTAL_OBS_DIM =HISTORY_LEN *OBS_PER_EPISODE +5
DISCRETE_NVEC =[
len (TRANSPORT_TYPES ),
len (CANDIDATE_PORTS ),
len (SNI_DOMAINS ),
len (FINGERPRINTS ),
len (ALPN_OPTIONS ),
len (FRAGMENT_STRATEGIES ),
len (MUX_CONCURRENCY_VALUES ),
len (SHORT_ID_LENGTHS ),
len (XHTTP_MODES ),
100 ,
]
N_CONTINUOUS =5
def decode_action (discrete :np .ndarray ,continuous :np .ndarray )->VlessConfig :
import secrets as _sec
transport =TRANSPORT_TYPES [int (discrete [0 ])]
port =CANDIDATE_PORTS [int (discrete [1 ])]
dest =SNI_DOMAINS [int (discrete [2 ])]
fingerprint =FINGERPRINTS [int (discrete [3 ])]
alpn =ALPN_OPTIONS [int (discrete [4 ])]
frag_strat =FRAGMENT_STRATEGIES [int (discrete [5 ])]
mux_conc =MUX_CONCURRENCY_VALUES [int (discrete [6 ])]
sid_len =SHORT_ID_LENGTHS [int (discrete [7 ])]
xhttp_mode =XHTTP_MODES [int (discrete [8 ])]
grpc_seed =int (discrete [9 ])
c =continuous
frag_len_min =int (10 +c [0 ]*190 )
frag_len_max =frag_len_min +int (c [1 ]*100 )
frag_interval_min =int (c [2 ]*50 )
frag_interval_max =frag_interval_min +5
padding_min =int (c [3 ]*500 )
padding_max =padding_min +int (c [4 ]*500 )
from server .config_generator import random_service_name ,generate_short_id
grpc_name =random_service_name (grpc_seed )
short_id =generate_short_id (sid_len )
if transport !="tcp":
frag_strat ="none"
padding_enabled =(transport =="tcp")and (padding_min >0 )
return VlessConfig (
transport_type =transport ,
proxy_port =port ,
dest_domain =dest ,
short_id =short_id ,
spider_x ="/",
fingerprint =fingerprint ,
alpn =alpn ,
grpc_service_name =grpc_name ,
xhttp_mode =xhttp_mode ,
fragment_strategy =frag_strat ,
fragment_length_min =frag_len_min ,
fragment_length_max =frag_len_max ,
fragment_interval_min =frag_interval_min ,
fragment_interval_max =frag_interval_max ,
padding_enabled =padding_enabled ,
padding_min =padding_min ,
padding_max =padding_max ,
mux_concurrency =mux_conc ,
)
def metrics_to_obs_vector (m :EpisodeMetrics )->np .ndarray :
return np .array ([
float (m .connected ),
min (m .stability_ratio ,1.0 ),
min (m .throughput_ratio ,1.0 ),
max (0.0 ,1.0 -m .avg_ping_ms /1000.0 ),
1.0 -min (m .packet_loss_ratio ,1.0 ),
max (0.0 ,1.0 -m .connect_time_ms /5000.0 ),
max (0.0 ,1.0 -m .reconnect_count /5.0 ),
],dtype =np .float32 )
class AlphaBypassEnv :
def __init__ (
self ,
bridge ,
episode_duration :int =90 ,
baseline_mbps :float =1.0 ,
max_steps :int =0 ,
fail_streak_warn :int =10 ,
):
self .bridge =bridge
self .episode_duration =episode_duration
self .baseline_mbps =baseline_mbps
self .max_steps =max_steps
self .fail_streak_warn =fail_streak_warn
self .history :list [EpisodeMetrics ]=[]
self .step_count :int =0
self ._fail_streak :int =0
self .obs_dim =TOTAL_OBS_DIM
self .discrete_nvec =DISCRETE_NVEC
self .n_continuous =N_CONTINUOUS
def _build_obs (self )->np .ndarray :
obs =np .zeros (self .obs_dim ,dtype =np .float32 )
relevant =self .history [-HISTORY_LEN :]
for i ,m in enumerate (reversed (relevant )):
start =i *OBS_PER_EPISODE
obs [start :start +OBS_PER_EPISODE ]=metrics_to_obs_vector (m )
base =HISTORY_LEN *OBS_PER_EPISODE
t =time .localtime ()
hour =t .tm_hour +t .tm_min /60.0
obs [base ]=min (self .step_count /1000.0 ,1.0 )
obs [base +1 ]=math .sin (2 *math .pi *hour /24 )
obs [base +2 ]=math .cos (2 *math .pi *hour /24 )
dow =t .tm_wday
obs [base +3 ]=math .sin (2 *math .pi *dow /7 )
obs [base +4 ]=math .cos (2 *math .pi *dow /7 )
return obs
def reset (self )->np .ndarray :
self .history =[]
self .step_count =0
self ._fail_streak =0
return self ._build_obs ()
def step (
self ,
discrete_action :np .ndarray ,
continuous_action :np .ndarray ,
)->Tuple [np .ndarray ,float ,bool ,dict ]:
cfg =decode_action (discrete_action ,continuous_action )
episode_id =str (uuid .uuid4 ())[:8 ]
metrics =self .bridge .run_episode (
cfg =cfg ,
episode_id =episode_id ,
duration =self .episode_duration ,
)
reward =compute_reward (metrics ,self .baseline_mbps )
self .bridge .report_reward (episode_id ,reward )
self .history .append (metrics )
self .step_count +=1
if not metrics .connected :
self ._fail_streak +=1
if self ._fail_streak ==self .fail_streak_warn :
print (
f"\n⚠️ [Degradation] {self ._fail_streak } FAIL подряд! "
f"Возможно РКН изменил политику или проблема с сетью."
)
else :
self ._fail_streak =0
done =(self .max_steps >0 and self .step_count >=self .max_steps )
obs =self ._build_obs ()
info ={
"episode_id":episode_id ,
"reward":reward ,
"connected":metrics .connected ,
"stability":metrics .stability_ratio ,
"throughput_mbps":metrics .throughput_mbps ,
"transport":cfg .transport_type ,
"dest":cfg .dest_domain ,
"fail_streak":self ._fail_streak ,
}
return obs ,reward ,done ,info