NickupAI commited on
Commit
c7c0a38
·
verified ·
1 Parent(s): 96d3759

Upload 6 files

Browse files
Files changed (6) hide show
  1. agent.py +368 -0
  2. config_generator.py +272 -0
  3. crypto.py +105 -0
  4. environment.py +214 -0
  5. reward.py +54 -0
  6. schemas.py +182 -0
agent.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ import os
5
+ import time
6
+ import json
7
+ import numpy as np
8
+ import torch
9
+ import torch .nn as nn
10
+ import torch .nn .functional as F
11
+ from torch .distributions import Categorical ,Normal
12
+ from typing import Optional ,Tuple ,List
13
+
14
+ from server .rl .environment import (
15
+ DISCRETE_NVEC ,
16
+ N_CONTINUOUS ,
17
+ TOTAL_OBS_DIM ,
18
+ AlphaBypassEnv ,
19
+ )
20
+ from server .rl .reward import reward_to_label
21
+
22
+
23
+
24
+
25
+
26
+
27
+ class PolicyNetwork (nn .Module ):
28
+
29
+ def __init__ (
30
+ self ,
31
+ obs_dim :int =TOTAL_OBS_DIM ,
32
+ hidden :int =512 ,
33
+ discrete_nvec :List [int ]=DISCRETE_NVEC ,
34
+ n_continuous :int =N_CONTINUOUS ,
35
+ ):
36
+ super ().__init__ ()
37
+ self .discrete_nvec =discrete_nvec
38
+ self .n_continuous =n_continuous
39
+
40
+
41
+ self .trunk =nn .Sequential (
42
+ nn .Linear (obs_dim ,hidden ),
43
+ nn .LayerNorm (hidden ),
44
+ nn .ReLU (),
45
+ nn .Linear (hidden ,hidden ),
46
+ nn .LayerNorm (hidden ),
47
+ nn .ReLU (),
48
+ nn .Linear (hidden ,hidden ),
49
+ nn .LayerNorm (hidden ),
50
+ nn .ReLU (),
51
+ )
52
+
53
+
54
+ self .discrete_heads =nn .ModuleList ([
55
+ nn .Linear (hidden ,n )for n in discrete_nvec
56
+ ])
57
+
58
+
59
+ self .cont_mu =nn .Linear (hidden ,n_continuous )
60
+ self .cont_log_std =nn .Parameter (torch .zeros (n_continuous ))
61
+
62
+
63
+ self .value_head =nn .Sequential (
64
+ nn .Linear (hidden ,256 ),
65
+ nn .ReLU (),
66
+ nn .Linear (256 ,1 ),
67
+ )
68
+
69
+ def forward (self ,obs :torch .Tensor ):
70
+ h =self .trunk (obs )
71
+ logits =[head (h )for head in self .discrete_heads ]
72
+ mu =torch .sigmoid (self .cont_mu (h ))
73
+ log_std =self .cont_log_std .clamp (-4 ,0 )
74
+ value =self .value_head (h ).squeeze (-1 )
75
+ return logits ,mu ,log_std ,value
76
+
77
+ def get_action_and_log_prob (
78
+ self ,
79
+ obs :torch .Tensor ,
80
+ action_masks :Optional [List [Optional [torch .Tensor ]]]=None ,
81
+ )->Tuple [np .ndarray ,np .ndarray ,torch .Tensor ,torch .Tensor ]:
82
+
83
+ logits ,mu ,log_std ,value =self .forward (obs )
84
+
85
+ discrete_actions =[]
86
+ log_probs_discrete =[]
87
+
88
+ for i ,(lg ,n )in enumerate (zip (logits ,self .discrete_nvec )):
89
+ if action_masks and action_masks [i ]is not None :
90
+
91
+ mask =action_masks [i ].to (lg .device )
92
+ lg =lg .masked_fill (~mask ,float ("-inf"))
93
+ dist =Categorical (logits =lg )
94
+ a =dist .sample ()
95
+ discrete_actions .append (a .item ())
96
+ log_probs_discrete .append (dist .log_prob (a ))
97
+
98
+ log_prob_discrete =torch .stack (log_probs_discrete ).sum ()
99
+
100
+
101
+ std =log_std .exp ()
102
+ dist_cont =Normal (mu ,std )
103
+ cont_sample =dist_cont .sample ()
104
+ cont_action =cont_sample .clamp (0.0 ,1.0 )
105
+ log_prob_cont =dist_cont .log_prob (cont_sample ).sum ()
106
+
107
+ total_log_prob =log_prob_discrete +log_prob_cont
108
+
109
+ return (
110
+ np .array (discrete_actions ,dtype =np .int32 ),
111
+ cont_action .detach ().cpu ().numpy (),
112
+ total_log_prob ,
113
+ value ,
114
+ )
115
+
116
+ def evaluate_actions (
117
+ self ,
118
+ obs :torch .Tensor ,
119
+ discrete_actions :torch .Tensor ,
120
+ cont_actions :torch .Tensor ,
121
+ )->Tuple [torch .Tensor ,torch .Tensor ,torch .Tensor ]:
122
+ logits ,mu ,log_std ,value =self .forward (obs )
123
+
124
+ log_prob_d =torch .zeros (obs .shape [0 ],device =obs .device )
125
+ entropy_d =torch .zeros (obs .shape [0 ],device =obs .device )
126
+ for i ,lg in enumerate (logits ):
127
+ dist =Categorical (logits =lg )
128
+ log_prob_d +=dist .log_prob (discrete_actions [:,i ])
129
+ entropy_d +=dist .entropy ()
130
+
131
+ std =log_std .exp ()
132
+ dist_c =Normal (mu ,std )
133
+ log_prob_c =dist_c .log_prob (cont_actions ).sum (-1 )
134
+ entropy_c =dist_c .entropy ().sum (-1 )
135
+
136
+ return log_prob_d +log_prob_c ,(entropy_d +entropy_c )/2 ,value
137
+
138
+
139
+
140
+
141
+
142
+
143
+ class RolloutBuffer :
144
+ def __init__ (self ):
145
+ self .clear ()
146
+
147
+ def clear (self ):
148
+ self .obs :List [np .ndarray ]=[]
149
+ self .discrete_actions :List [np .ndarray ]=[]
150
+ self .cont_actions :List [np .ndarray ]=[]
151
+ self .rewards :List [float ]=[]
152
+ self .log_probs :List [torch .Tensor ]=[]
153
+ self .values :List [torch .Tensor ]=[]
154
+ self .dones :List [bool ]=[]
155
+
156
+ def add (self ,obs ,d_action ,c_action ,reward ,log_prob ,value ,done ):
157
+ self .obs .append (obs )
158
+ self .discrete_actions .append (d_action )
159
+ self .cont_actions .append (c_action )
160
+ self .rewards .append (reward )
161
+ self .log_probs .append (log_prob )
162
+ self .values .append (value )
163
+ self .dones .append (done )
164
+
165
+ def compute_returns (self ,last_value :float ,gamma :float =0.99 ,gae_lambda :float =0.95 ):
166
+ advantages =[]
167
+ gae =0.0
168
+ values =[v .item ()for v in self .values ]+[last_value ]
169
+
170
+ for t in reversed (range (len (self .rewards ))):
171
+ delta =self .rewards [t ]+gamma *values [t +1 ]*(1 -self .dones [t ])-values [t ]
172
+ gae =delta +gamma *gae_lambda *(1 -self .dones [t ])*gae
173
+ advantages .insert (0 ,gae )
174
+
175
+ returns =[a +v .item ()for a ,v in zip (advantages ,self .values )]
176
+ return advantages ,returns
177
+
178
+ def to_tensors (self ,device :torch .device ):
179
+ obs =torch .FloatTensor (np .stack (self .obs )).to (device )
180
+ d_act =torch .LongTensor (np .stack (self .discrete_actions )).to (device )
181
+ c_act =torch .FloatTensor (np .stack (self .cont_actions )).to (device )
182
+ return obs ,d_act ,c_act
183
+
184
+
185
+
186
+
187
+
188
+
189
+ class PPOTrainer :
190
+ def __init__ (
191
+ self ,
192
+ env :AlphaBypassEnv ,
193
+ device_str :str ="cuda",
194
+ lr :float =3e-4 ,
195
+ gamma :float =0.99 ,
196
+ gae_lambda :float =0.95 ,
197
+ clip_eps :float =0.2 ,
198
+ entropy_coef :float =0.01 ,
199
+ vf_coef :float =0.5 ,
200
+ max_grad_norm :float =0.5 ,
201
+ update_epochs :int =4 ,
202
+ steps_per_update :int =8 ,
203
+ checkpoint_dir :str ="checkpoints",
204
+ checkpoint_every :int =100 ,
205
+ ):
206
+ self .env =env
207
+ self .device =torch .device (device_str if torch .cuda .is_available ()else "cpu")
208
+ print (f"[PPO] device: {self .device }")
209
+
210
+ self .policy =PolicyNetwork ().to (self .device )
211
+ self .optimizer =torch .optim .Adam (self .policy .parameters (),lr =lr )
212
+ self .scheduler =torch .optim .lr_scheduler .ExponentialLR (self .optimizer ,gamma =0.999 )
213
+
214
+ self .gamma =gamma
215
+ self .gae_lambda =gae_lambda
216
+ self .clip_eps =clip_eps
217
+ self .entropy_coef =entropy_coef
218
+ self .vf_coef =vf_coef
219
+ self .max_grad_norm =max_grad_norm
220
+ self .update_epochs =update_epochs
221
+ self .steps_per_update =steps_per_update
222
+ self .checkpoint_dir =checkpoint_dir
223
+ self .checkpoint_every =checkpoint_every
224
+
225
+ os .makedirs (checkpoint_dir ,exist_ok =True )
226
+
227
+ self .total_episodes =0
228
+ self .best_reward =-float ("inf")
229
+ self .reward_history :List [float ]=[]
230
+
231
+ def _build_action_masks (self ,obs_tensor :torch .Tensor )->List [Optional [torch .Tensor ]]:
232
+
233
+ return [None ]*len (DISCRETE_NVEC )
234
+
235
+ def collect_rollout (self )->RolloutBuffer :
236
+ buffer =RolloutBuffer ()
237
+ obs =self .env ._build_obs ()
238
+
239
+ for _ in range (self .steps_per_update ):
240
+ obs_t =torch .FloatTensor (obs ).unsqueeze (0 ).to (self .device )
241
+
242
+ with torch .no_grad ():
243
+ masks =self ._build_action_masks (obs_t )
244
+ d_action ,c_action ,log_prob ,value =self .policy .get_action_and_log_prob (
245
+ obs_t .squeeze (0 ),masks
246
+ )
247
+
248
+ next_obs ,reward ,done ,info =self .env .step (d_action ,c_action )
249
+
250
+ self .total_episodes +=1
251
+ self .reward_history .append (reward )
252
+
253
+ print (
254
+ f"[Ep {self .total_episodes :04d}] "
255
+ f"reward={reward :+.4f} {reward_to_label (reward )} | "
256
+ f"transport={info ['transport']:5s} dest={info ['dest']:30s} | "
257
+ f"stable={info ['stability']:.2f} "
258
+ f"speed={info ['throughput_mbps']:.2f}Mbps"
259
+ )
260
+
261
+ buffer .add (obs ,d_action ,c_action ,reward ,log_prob ,value ,done )
262
+ obs =next_obs
263
+
264
+ if done :
265
+ obs =self .env .reset ()
266
+
267
+
268
+ if self .total_episodes %self .checkpoint_every ==0 :
269
+ self .save_checkpoint ()
270
+
271
+ return buffer
272
+
273
+ def update (self ,buffer :RolloutBuffer ):
274
+ print (f"\n[PPO] ── Update #{self .total_episodes //self .steps_per_update } ──────────────────────────")
275
+ print (f"[PPO] Buffer: {len (buffer .rewards )} episodes | "
276
+ f"mean_reward={sum (buffer .rewards )/len (buffer .rewards ):+.4f} | "
277
+ f"positive={sum (1 for r in buffer .rewards if r >0 )}/{len (buffer .rewards )}")
278
+ obs_t ,d_act_t ,c_act_t =buffer .to_tensors (self .device )
279
+
280
+
281
+ with torch .no_grad ():
282
+ last_obs =torch .FloatTensor (self .env ._build_obs ()).to (self .device )
283
+ _ ,_ ,_ ,last_val =self .policy .forward (last_obs .unsqueeze (0 ))
284
+ last_value =last_val .item ()
285
+
286
+ advantages ,returns =buffer .compute_returns (last_value ,self .gamma ,self .gae_lambda )
287
+ adv_t =torch .FloatTensor (advantages ).to (self .device )
288
+ ret_t =torch .FloatTensor (returns ).to (self .device )
289
+
290
+
291
+ adv_t =(adv_t -adv_t .mean ())/(adv_t .std ()+1e-8 )
292
+
293
+ old_log_probs =torch .stack (buffer .log_probs ).to (self .device ).detach ()
294
+
295
+ for _ in range (self .update_epochs ):
296
+ log_probs ,entropy ,values =self .policy .evaluate_actions (obs_t ,d_act_t ,c_act_t )
297
+
298
+ ratio =(log_probs -old_log_probs ).exp ()
299
+ surr1 =ratio *adv_t
300
+ surr2 =ratio .clamp (1 -self .clip_eps ,1 +self .clip_eps )*adv_t
301
+
302
+ policy_loss =-torch .min (surr1 ,surr2 ).mean ()
303
+ value_loss =F .mse_loss (values ,ret_t )
304
+ entropy_loss =-entropy .mean ()
305
+
306
+ loss =policy_loss +self .vf_coef *value_loss +self .entropy_coef *entropy_loss
307
+
308
+ self .optimizer .zero_grad ()
309
+ loss .backward ()
310
+ nn .utils .clip_grad_norm_ (self .policy .parameters (),self .max_grad_norm )
311
+ self .optimizer .step ()
312
+
313
+ self .scheduler .step ()
314
+
315
+ def train (self ,total_episodes :int =10000 ):
316
+
317
+ print (f"\n{'='*60 }")
318
+ print (f" AlphaBypass — PPO Training")
319
+ print (f" Target: {total_episodes } episodes")
320
+ print (f" Device: {self .device }")
321
+ print (f"{'='*60 }\n")
322
+
323
+ obs =self .env .reset ()
324
+
325
+ while self .total_episodes <total_episodes :
326
+ buffer =self .collect_rollout ()
327
+ self .update (buffer )
328
+
329
+
330
+ if len (self .reward_history )>=20 :
331
+ recent =self .reward_history [-20 :]
332
+ print (
333
+ f"\n[Stats] last 20 episodes: "
334
+ f"mean={np .mean (recent ):+.4f} "
335
+ f"max={np .max (recent ):+.4f} "
336
+ f"min={np .min (recent ):+.4f}\n"
337
+ )
338
+
339
+ def save_checkpoint (self ,tag :str =""):
340
+ path =os .path .join (
341
+ self .checkpoint_dir ,
342
+ f"checkpoint_ep{self .total_episodes :05d}{tag }.pt"
343
+ )
344
+ torch .save ({
345
+ "episode":self .total_episodes ,
346
+ "policy_state":self .policy .state_dict (),
347
+ "optimizer_state":self .optimizer .state_dict (),
348
+ "reward_history":self .reward_history ,
349
+ "best_reward":self .best_reward ,
350
+ },path )
351
+ print (f"[Checkpoint] saved → {path }")
352
+
353
+
354
+ r =np .mean (self .reward_history [-10 :])if len (self .reward_history )>=10 else -999
355
+ if r >self .best_reward :
356
+ self .best_reward =r
357
+ best_path =os .path .join (self .checkpoint_dir ,"best.pt")
358
+ torch .save (torch .load (path ),best_path )
359
+ print (f"[Checkpoint] 🏆 new best ({r :+.4f}) → {best_path }")
360
+
361
+ def load_checkpoint (self ,path :str ):
362
+ ck =torch .load (path ,map_location =self .device )
363
+ self .policy .load_state_dict (ck ["policy_state"])
364
+ self .optimizer .load_state_dict (ck ["optimizer_state"])
365
+ self .total_episodes =ck ["episode"]
366
+ self .reward_history =ck .get ("reward_history",[])
367
+ self .best_reward =ck .get ("best_reward",-float ("inf"))
368
+ print (f"[Checkpoint] loaded from ep {self .total_episodes }")
config_generator.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import json
3
+ import uuid
4
+ import secrets
5
+ import string
6
+ from typing import Optional
7
+ from shared .schemas import VlessConfig
8
+
9
+
10
+
11
+
12
+
13
+
14
+ def generate_uuid ()->str :
15
+ return str (uuid .uuid4 ())
16
+
17
+
18
+ def generate_short_id (length :int =8 )->str :
19
+
20
+ return secrets .token_hex (length //2 )
21
+
22
+
23
+ def random_service_name (seed :int )->str :
24
+
25
+ rng_chars =string .ascii_lowercase +string .digits
26
+ r =__import__ ("random").Random (seed )
27
+ length =r .randint (6 ,20 )
28
+ return "".join (r .choice (rng_chars )for _ in range (length ))
29
+
30
+
31
+
32
+
33
+
34
+
35
+ def build_server_config (
36
+ cfg :VlessConfig ,
37
+ vless_uuid :str ,
38
+ private_key :str ,
39
+ public_key :str ,
40
+ listen_ip :str ="0.0.0.0",
41
+ )->dict :
42
+
43
+
44
+ stream =_build_stream_settings_server (cfg )
45
+
46
+
47
+ inbound ={
48
+ "tag":"vless-in",
49
+ "listen":listen_ip ,
50
+ "port":cfg .proxy_port ,
51
+ "protocol":"vless",
52
+ "settings":{
53
+ "clients":[
54
+ {
55
+ "id":vless_uuid ,
56
+ "flow":_pick_flow (cfg ),
57
+ }
58
+ ],
59
+ "decryption":"none",
60
+ },
61
+ "streamSettings":stream ,
62
+ "sniffing":{
63
+ "enabled":True ,
64
+ "destOverride":["http","tls","quic"],
65
+ },
66
+ }
67
+
68
+
69
+ inbound ["streamSettings"]["realitySettings"]={
70
+ "show":False ,
71
+ "dest":f"{cfg .dest_domain }:443",
72
+ "xver":0 ,
73
+ "serverNames":[cfg .dest_domain ],
74
+ "privateKey":private_key ,
75
+ "shortIds":[cfg .short_id ],
76
+ "spiderX":cfg .spider_x ,
77
+ }
78
+
79
+ config ={
80
+ "log":{"loglevel":"warning"},
81
+ "inbounds":[inbound ],
82
+ "outbounds":[
83
+ {"tag":"direct","protocol":"freedom"},
84
+ {"tag":"block","protocol":"blackhole"},
85
+ ],
86
+ "routing":{
87
+ "rules":[
88
+ {"type":"field","ip":["geoip:private"],"outboundTag":"block"},
89
+ ]
90
+ },
91
+ }
92
+
93
+ return config
94
+
95
+
96
+ def _pick_flow (cfg :VlessConfig )->str :
97
+ if cfg .transport_type =="tcp":
98
+ return "xtls-rprx-vision"
99
+ return ""
100
+
101
+
102
+ def _build_stream_settings_server (cfg :VlessConfig )->dict :
103
+ base ={
104
+ "network":cfg .transport_type ,
105
+ "security":"reality",
106
+ }
107
+
108
+ if cfg .transport_type =="grpc":
109
+ base ["grpcSettings"]={
110
+ "serviceName":cfg .grpc_service_name ,
111
+ "multiMode":False ,
112
+ }
113
+ elif cfg .transport_type =="xhttp":
114
+ base ["xhttpSettings"]={
115
+ "mode":cfg .xhttp_mode ,
116
+ "path":cfg .spider_x or "/",
117
+ "host":cfg .dest_domain ,
118
+ }
119
+
120
+ return base
121
+
122
+
123
+
124
+
125
+
126
+
127
+ def build_client_config (
128
+ cfg :VlessConfig ,
129
+ vless_uuid :str ,
130
+ server_ip :str ,
131
+ public_key :str ,
132
+ socks_port :int =10808 ,
133
+ http_port :int =10809 ,
134
+ )->dict :
135
+
136
+ stream =_build_stream_settings_client (cfg ,public_key )
137
+
138
+ outbound ={
139
+ "tag":"proxy",
140
+ "protocol":"vless",
141
+ "settings":{
142
+ "vnext":[
143
+ {
144
+ "address":server_ip ,
145
+ "port":cfg .proxy_port ,
146
+ "users":[
147
+ {
148
+ "id":vless_uuid ,
149
+ "encryption":"none",
150
+ "flow":_pick_flow (cfg ),
151
+ }
152
+ ],
153
+ }
154
+ ]
155
+ },
156
+ "streamSettings":stream ,
157
+ }
158
+
159
+
160
+ if cfg .mux_concurrency >0 :
161
+ outbound ["mux"]={
162
+ "enabled":True ,
163
+ "concurrency":cfg .mux_concurrency ,
164
+ "xudpConcurrency":cfg .mux_concurrency ,
165
+ }
166
+
167
+
168
+ if cfg .transport_type =="tcp"and cfg .fragment_strategy !="none":
169
+ outbound ["streamSettings"]["sockopt"]={
170
+ "dialerProxy":"fragment",
171
+ }
172
+
173
+ config ={
174
+ "log":{"loglevel":"warning"},
175
+ "inbounds":[
176
+ {
177
+ "tag":"socks",
178
+ "listen":"127.0.0.1",
179
+ "port":socks_port ,
180
+ "protocol":"socks",
181
+ "settings":{"auth":"noauth","udp":True },
182
+ },
183
+ {
184
+ "tag":"http",
185
+ "listen":"127.0.0.1",
186
+ "port":http_port ,
187
+ "protocol":"http",
188
+ },
189
+ ],
190
+ "outbounds":[outbound ,{"tag":"direct","protocol":"freedom"}],
191
+ }
192
+
193
+
194
+ if cfg .transport_type =="tcp"and cfg .fragment_strategy !="none":
195
+ config ["outbounds"].append (_build_fragment_outbound (cfg ))
196
+
197
+ return config
198
+
199
+
200
+ def _build_stream_settings_client (cfg :VlessConfig ,public_key :str )->dict :
201
+ base :dict ={
202
+ "network":cfg .transport_type ,
203
+ "security":"reality",
204
+ "realitySettings":{
205
+ "fingerprint":cfg .fingerprint ,
206
+ "serverName":cfg .dest_domain ,
207
+ "publicKey":public_key ,
208
+ "shortId":cfg .short_id ,
209
+ "spiderX":cfg .spider_x ,
210
+ },
211
+ }
212
+
213
+ if cfg .alpn :
214
+ base ["realitySettings"]["alpn"]=cfg .alpn
215
+
216
+ if cfg .transport_type =="grpc":
217
+ base ["grpcSettings"]={
218
+ "serviceName":cfg .grpc_service_name ,
219
+ }
220
+ elif cfg .transport_type =="xhttp":
221
+ headers ={"Host":cfg .dest_domain }
222
+ headers .update (cfg .extra_headers )
223
+ base ["xhttpSettings"]={
224
+ "mode":cfg .xhttp_mode ,
225
+ "path":cfg .spider_x or "/",
226
+ "headers":headers ,
227
+ }
228
+
229
+ return base
230
+
231
+
232
+ def _build_fragment_outbound (cfg :VlessConfig )->dict :
233
+ return {
234
+ "tag":"fragment",
235
+ "protocol":"freedom",
236
+ "settings":{
237
+ "domainStrategy":"AsIs",
238
+ },
239
+ "streamSettings":{
240
+ "sockopt":{
241
+ "dialerProxy":"",
242
+ "tcpKeepAliveInterval":0 ,
243
+ }
244
+ },
245
+ "fragment":{
246
+ "packets":cfg .fragment_strategy ,
247
+ "length":f"{cfg .fragment_length_min }-{cfg .fragment_length_max }",
248
+ "interval":f"{cfg .fragment_interval_min }-{cfg .fragment_interval_max }",
249
+ },
250
+ }
251
+
252
+
253
+
254
+
255
+
256
+
257
+ def generate_reality_keys ()->tuple [str ,str ]:
258
+
259
+ import subprocess
260
+ result =subprocess .run (
261
+ ["xray","x25519"],
262
+ capture_output =True ,text =True ,timeout =10
263
+ )
264
+ lines =result .stdout .strip ().splitlines ()
265
+ priv =lines [0 ].split (": ")[1 ].strip ()
266
+ pub =lines [1 ].split (": ")[1 ].strip ()
267
+ return priv ,pub
268
+
269
+
270
+ def save_config (config :dict ,path :str ):
271
+ with open (path ,"w")as f :
272
+ json .dump (config ,f ,indent =2 )
crypto.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import hmac
3
+ import hashlib
4
+ import time
5
+ import json
6
+ import os
7
+ import base64
8
+ from typing import Tuple
9
+
10
+
11
+
12
+
13
+
14
+ def sign_payload (payload :dict ,secret :str )->Tuple [str ,str ]:
15
+
16
+ payload ["_ts"]=int (time .time ())
17
+ body =json .dumps (payload ,separators =(",",":"),sort_keys =True )
18
+ sig =hmac .new (
19
+ secret .encode (),
20
+ body .encode (),
21
+ hashlib .sha256
22
+ ).hexdigest ()
23
+ return body ,sig
24
+
25
+
26
+ def verify_payload (body :str ,sig :str ,secret :str ,max_age_seconds :int =300 )->dict :
27
+
28
+ expected =hmac .new (
29
+ secret .encode (),
30
+ body .encode (),
31
+ hashlib .sha256
32
+ ).hexdigest ()
33
+
34
+ if not hmac .compare_digest (expected ,sig ):
35
+ raise ValueError ("Invalid HMAC signature")
36
+
37
+ data =json .loads (body )
38
+ ts =data .get ("_ts",0 )
39
+ if abs (time .time ()-ts )>max_age_seconds :
40
+ raise ValueError (f"Stale request: {abs (time .time ()-ts ):.0f}s old")
41
+
42
+ return data
43
+
44
+
45
+
46
+
47
+
48
+
49
+ def generate_self_signed_cert (cert_path :str ,key_path :str ,cn :str ="localhost"):
50
+
51
+ from cryptography import x509
52
+ from cryptography .x509 .oid import NameOID
53
+ from cryptography .hazmat .primitives import hashes ,serialization
54
+ from cryptography .hazmat .primitives .asymmetric import rsa
55
+ from cryptography .hazmat .backends import default_backend
56
+ import datetime
57
+
58
+ key =rsa .generate_private_key (
59
+ public_exponent =65537 ,
60
+ key_size =2048 ,
61
+ backend =default_backend ()
62
+ )
63
+
64
+ subject =issuer =x509 .Name ([
65
+ x509 .NameAttribute (NameOID .COMMON_NAME ,cn ),
66
+ ])
67
+
68
+ cert =(
69
+ x509 .CertificateBuilder ()
70
+ .subject_name (subject )
71
+ .issuer_name (issuer )
72
+ .public_key (key .public_key ())
73
+ .serial_number (x509 .random_serial_number ())
74
+ .not_valid_before (datetime .datetime .utcnow ())
75
+ .not_valid_after (datetime .datetime .utcnow ()+datetime .timedelta (days =3650 ))
76
+ .add_extension (
77
+ x509 .SubjectAlternativeName ([x509 .DNSName (cn )]),
78
+ critical =False ,
79
+ )
80
+ .sign (key ,hashes .SHA256 (),default_backend ())
81
+ )
82
+
83
+ with open (cert_path ,"wb")as f :
84
+ f .write (cert .public_bytes (serialization .Encoding .PEM ))
85
+
86
+ with open (key_path ,"wb")as f :
87
+ f .write (key .private_bytes (
88
+ serialization .Encoding .PEM ,
89
+ serialization .PrivateFormat .TraditionalOpenSSL ,
90
+ serialization .NoEncryption ()
91
+ ))
92
+
93
+
94
+
95
+
96
+
97
+
98
+ def load_or_create_secret (path :str ="shared_secret.key")->str :
99
+ if os .path .exists (path ):
100
+ with open (path ,"r")as f :
101
+ return f .read ().strip ()
102
+ secret =base64 .urlsafe_b64encode (os .urandom (32 )).decode ()
103
+ with open (path ,"w")as f :
104
+ f .write (secret )
105
+ return secret
environment.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+
4
+ import math
5
+ import time
6
+ import uuid
7
+ import numpy as np
8
+ from typing import Optional ,Tuple
9
+
10
+ from shared .schemas import (
11
+ VlessConfig ,
12
+ EpisodeMetrics ,
13
+ TRANSPORT_TYPES ,
14
+ CANDIDATE_PORTS ,
15
+ SNI_DOMAINS ,
16
+ FINGERPRINTS ,
17
+ ALPN_OPTIONS ,
18
+ FRAGMENT_STRATEGIES ,
19
+ MUX_CONCURRENCY_VALUES ,
20
+ SHORT_ID_LENGTHS ,
21
+ XHTTP_MODES ,
22
+ )
23
+ from server .rl .reward import compute_reward
24
+
25
+ HISTORY_LEN =10
26
+ OBS_PER_EPISODE =7
27
+
28
+ TOTAL_OBS_DIM =HISTORY_LEN *OBS_PER_EPISODE +5
29
+
30
+
31
+
32
+
33
+
34
+ DISCRETE_NVEC =[
35
+ len (TRANSPORT_TYPES ),
36
+ len (CANDIDATE_PORTS ),
37
+ len (SNI_DOMAINS ),
38
+ len (FINGERPRINTS ),
39
+ len (ALPN_OPTIONS ),
40
+ len (FRAGMENT_STRATEGIES ),
41
+ len (MUX_CONCURRENCY_VALUES ),
42
+ len (SHORT_ID_LENGTHS ),
43
+ len (XHTTP_MODES ),
44
+ 100 ,
45
+ ]
46
+
47
+ N_CONTINUOUS =5
48
+
49
+
50
+ def decode_action (discrete :np .ndarray ,continuous :np .ndarray )->VlessConfig :
51
+ import secrets as _sec
52
+
53
+ transport =TRANSPORT_TYPES [int (discrete [0 ])]
54
+ port =CANDIDATE_PORTS [int (discrete [1 ])]
55
+ dest =SNI_DOMAINS [int (discrete [2 ])]
56
+ fingerprint =FINGERPRINTS [int (discrete [3 ])]
57
+ alpn =ALPN_OPTIONS [int (discrete [4 ])]
58
+ frag_strat =FRAGMENT_STRATEGIES [int (discrete [5 ])]
59
+ mux_conc =MUX_CONCURRENCY_VALUES [int (discrete [6 ])]
60
+ sid_len =SHORT_ID_LENGTHS [int (discrete [7 ])]
61
+ xhttp_mode =XHTTP_MODES [int (discrete [8 ])]
62
+ grpc_seed =int (discrete [9 ])
63
+
64
+ c =continuous
65
+ frag_len_min =int (10 +c [0 ]*190 )
66
+ frag_len_max =frag_len_min +int (c [1 ]*100 )
67
+ frag_interval_min =int (c [2 ]*50 )
68
+ frag_interval_max =frag_interval_min +5
69
+ padding_min =int (c [3 ]*500 )
70
+ padding_max =padding_min +int (c [4 ]*500 )
71
+
72
+ from server .config_generator import random_service_name ,generate_short_id
73
+ grpc_name =random_service_name (grpc_seed )
74
+ short_id =generate_short_id (sid_len )
75
+
76
+ if transport !="tcp":
77
+ frag_strat ="none"
78
+
79
+ padding_enabled =(transport =="tcp")and (padding_min >0 )
80
+
81
+ return VlessConfig (
82
+ transport_type =transport ,
83
+ proxy_port =port ,
84
+ dest_domain =dest ,
85
+ short_id =short_id ,
86
+ spider_x ="/",
87
+ fingerprint =fingerprint ,
88
+ alpn =alpn ,
89
+ grpc_service_name =grpc_name ,
90
+ xhttp_mode =xhttp_mode ,
91
+ fragment_strategy =frag_strat ,
92
+ fragment_length_min =frag_len_min ,
93
+ fragment_length_max =frag_len_max ,
94
+ fragment_interval_min =frag_interval_min ,
95
+ fragment_interval_max =frag_interval_max ,
96
+ padding_enabled =padding_enabled ,
97
+ padding_min =padding_min ,
98
+ padding_max =padding_max ,
99
+ mux_concurrency =mux_conc ,
100
+ )
101
+
102
+
103
+ def metrics_to_obs_vector (m :EpisodeMetrics )->np .ndarray :
104
+ return np .array ([
105
+ float (m .connected ),
106
+ min (m .stability_ratio ,1.0 ),
107
+ min (m .throughput_ratio ,1.0 ),
108
+ max (0.0 ,1.0 -m .avg_ping_ms /1000.0 ),
109
+ 1.0 -min (m .packet_loss_ratio ,1.0 ),
110
+ max (0.0 ,1.0 -m .connect_time_ms /5000.0 ),
111
+ max (0.0 ,1.0 -m .reconnect_count /5.0 ),
112
+ ],dtype =np .float32 )
113
+
114
+
115
+ class AlphaBypassEnv :
116
+ def __init__ (
117
+ self ,
118
+ bridge ,
119
+ episode_duration :int =90 ,
120
+ baseline_mbps :float =1.0 ,
121
+ max_steps :int =0 ,
122
+ fail_streak_warn :int =10 ,
123
+ ):
124
+ self .bridge =bridge
125
+ self .episode_duration =episode_duration
126
+ self .baseline_mbps =baseline_mbps
127
+ self .max_steps =max_steps
128
+ self .fail_streak_warn =fail_streak_warn
129
+
130
+ self .history :list [EpisodeMetrics ]=[]
131
+ self .step_count :int =0
132
+ self ._fail_streak :int =0
133
+
134
+ self .obs_dim =TOTAL_OBS_DIM
135
+ self .discrete_nvec =DISCRETE_NVEC
136
+ self .n_continuous =N_CONTINUOUS
137
+
138
+ def _build_obs (self )->np .ndarray :
139
+ obs =np .zeros (self .obs_dim ,dtype =np .float32 )
140
+
141
+ relevant =self .history [-HISTORY_LEN :]
142
+ for i ,m in enumerate (reversed (relevant )):
143
+ start =i *OBS_PER_EPISODE
144
+ obs [start :start +OBS_PER_EPISODE ]=metrics_to_obs_vector (m )
145
+
146
+ base =HISTORY_LEN *OBS_PER_EPISODE
147
+
148
+
149
+ t =time .localtime ()
150
+ hour =t .tm_hour +t .tm_min /60.0
151
+ obs [base ]=min (self .step_count /1000.0 ,1.0 )
152
+ obs [base +1 ]=math .sin (2 *math .pi *hour /24 )
153
+ obs [base +2 ]=math .cos (2 *math .pi *hour /24 )
154
+
155
+
156
+ dow =t .tm_wday
157
+ obs [base +3 ]=math .sin (2 *math .pi *dow /7 )
158
+ obs [base +4 ]=math .cos (2 *math .pi *dow /7 )
159
+
160
+ return obs
161
+
162
+ def reset (self )->np .ndarray :
163
+ self .history =[]
164
+ self .step_count =0
165
+ self ._fail_streak =0
166
+ return self ._build_obs ()
167
+
168
+ def step (
169
+ self ,
170
+ discrete_action :np .ndarray ,
171
+ continuous_action :np .ndarray ,
172
+ )->Tuple [np .ndarray ,float ,bool ,dict ]:
173
+ cfg =decode_action (discrete_action ,continuous_action )
174
+ episode_id =str (uuid .uuid4 ())[:8 ]
175
+
176
+ metrics =self .bridge .run_episode (
177
+ cfg =cfg ,
178
+ episode_id =episode_id ,
179
+ duration =self .episode_duration ,
180
+ )
181
+
182
+ reward =compute_reward (metrics ,self .baseline_mbps )
183
+
184
+
185
+ self .bridge .report_reward (episode_id ,reward )
186
+ self .history .append (metrics )
187
+ self .step_count +=1
188
+
189
+
190
+ if not metrics .connected :
191
+ self ._fail_streak +=1
192
+ if self ._fail_streak ==self .fail_streak_warn :
193
+ print (
194
+ f"\n⚠️ [Degradation] {self ._fail_streak } FAIL подряд! "
195
+ f"Возможно РКН изменил политику или проблема с сетью."
196
+ )
197
+ else :
198
+ self ._fail_streak =0
199
+
200
+ done =(self .max_steps >0 and self .step_count >=self .max_steps )
201
+ obs =self ._build_obs ()
202
+
203
+ info ={
204
+ "episode_id":episode_id ,
205
+ "reward":reward ,
206
+ "connected":metrics .connected ,
207
+ "stability":metrics .stability_ratio ,
208
+ "throughput_mbps":metrics .throughput_mbps ,
209
+ "transport":cfg .transport_type ,
210
+ "dest":cfg .dest_domain ,
211
+ "fail_streak":self ._fail_streak ,
212
+ }
213
+
214
+ return obs ,reward ,done ,info
reward.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from shared .schemas import EpisodeMetrics
3
+
4
+
5
+ def compute_reward (metrics :EpisodeMetrics ,baseline_mbps :float =1.0 )->float :
6
+
7
+
8
+ if not metrics .connected :
9
+ return -1.0
10
+
11
+ r =0.0
12
+
13
+
14
+
15
+ connect_score =max (0.0 ,1.0 -metrics .connect_time_ms /5000.0 )
16
+
17
+ loss_score =1.0 -metrics .packet_loss_ratio
18
+
19
+ ping_score =max (0.0 ,1.0 -metrics .avg_ping_ms /1000.0 )
20
+
21
+ connection_component =(connect_score *0.3 +loss_score *0.4 +ping_score *0.3 )
22
+ r +=0.50 *connection_component
23
+
24
+
25
+
26
+ stability_score =metrics .stability_ratio
27
+
28
+ reconnect_penalty =max (0.0 ,1.0 -metrics .reconnect_count /5.0 )
29
+ stability_component =stability_score *0.7 +reconnect_penalty *0.3
30
+ r +=0.35 *stability_component
31
+
32
+
33
+ safe_baseline =max (baseline_mbps ,0.1 )
34
+ speed_ratio =min (metrics .throughput_mbps /safe_baseline ,1.0 )
35
+
36
+ import math
37
+ speed_score =math .log1p (speed_ratio *9 )/math .log1p (9 )
38
+ r +=0.15 *speed_score
39
+
40
+ return round (r ,4 )
41
+
42
+
43
+ def reward_to_label (r :float )->str :
44
+ if r <-0.5 :
45
+ return "💀 FAIL"
46
+ if r <0.0 :
47
+ return "❌ bad"
48
+ if r <0.3 :
49
+ return "⚠️ weak"
50
+ if r <0.6 :
51
+ return "🟡 ok"
52
+ if r <0.8 :
53
+ return "🟢 good"
54
+ return "🏆 great"
schemas.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from __future__ import annotations
3
+ from dataclasses import dataclass ,field ,asdict
4
+ from typing import Optional ,List ,Dict ,Any
5
+ import json
6
+ import time
7
+
8
+
9
+
10
+
11
+
12
+
13
+ TRANSPORT_TYPES =["tcp","grpc"]
14
+
15
+
16
+ CANDIDATE_PORTS =[443 ,80 ,8443 ,2053 ,2083 ,2087 ,9443 ]
17
+
18
+
19
+ SNI_DOMAINS =[
20
+
21
+ "download.nvidia.com",
22
+ "swscan.apple.com",
23
+ "updates.cdn-apple.com",
24
+ "steamcdn-a.akamaihd.net",
25
+ "dl.delivery.mp.microsoft.com",
26
+ "download.windowsupdate.com",
27
+ "cdn.cloudflare.steamstatic.com",
28
+ "origin-a.akamaihd.net",
29
+ "pkg-containers.githubusercontent.com",
30
+ "download.jetbrains.com",
31
+ "packages.ubuntu.com",
32
+
33
+ "ajax.aspnetcdn.com",
34
+ "github-releases.githubusercontent.com",
35
+ "objects.githubusercontent.com",
36
+ "software.download.prss.microsoft.com",
37
+
38
+
39
+
40
+ ]
41
+
42
+ FINGERPRINTS =["chrome","firefox","edge","safari","ios","random","randomized"]
43
+
44
+
45
+ ALPN_OPTIONS =[
46
+ ["h2","http/1.1"],
47
+ ["h2"],
48
+ ["http/1.1"],
49
+ ]
50
+
51
+ FRAGMENT_STRATEGIES =["none","tlshello","all"]
52
+
53
+ MUX_CONCURRENCY_VALUES =[0 ,1 ,2 ,4 ,8 ,16 ,32 ]
54
+
55
+ SHORT_ID_LENGTHS =[4 ,8 ,16 ]
56
+
57
+ XHTTP_MODES =["packet-up","streaming"]
58
+
59
+
60
+
61
+
62
+
63
+
64
+ @dataclass
65
+ class VlessConfig :
66
+
67
+ transport_type :str ="tcp"
68
+ proxy_port :int =443
69
+
70
+
71
+ dest_domain :str ="download.nvidia.com"
72
+ short_id :str ="abcdef01"
73
+ spider_x :str ="/"
74
+
75
+
76
+ fingerprint :str ="chrome"
77
+ alpn :List [str ]=field (default_factory =lambda :["h2","http/1.1"])
78
+
79
+
80
+ grpc_service_name :str ="grpc"
81
+ xhttp_mode :str ="packet-up"
82
+
83
+
84
+ fragment_strategy :str ="none"
85
+ fragment_length_min :int =50
86
+ fragment_length_max :int =100
87
+ fragment_interval_min :int =1
88
+ fragment_interval_max :int =5
89
+
90
+
91
+ padding_enabled :bool =False
92
+ padding_min :int =0
93
+ padding_max :int =0
94
+
95
+
96
+ mux_concurrency :int =0
97
+
98
+
99
+ extra_headers :Dict [str ,str ]=field (default_factory =dict )
100
+
101
+ def to_dict (self )->dict :
102
+ return asdict (self )
103
+
104
+ @classmethod
105
+ def from_dict (cls ,d :dict )->"VlessConfig":
106
+ return cls (**d )
107
+
108
+ def to_json (self )->str :
109
+ return json .dumps (self .to_dict ())
110
+
111
+ @classmethod
112
+ def from_json (cls ,s :str )->"VlessConfig":
113
+ return cls .from_dict (json .loads (s ))
114
+
115
+
116
+
117
+
118
+
119
+
120
+ @dataclass
121
+ class EpisodeMetrics :
122
+ episode_id :str =""
123
+ timestamp :float =field (default_factory =time .time )
124
+
125
+
126
+ connected :bool =False
127
+ connect_time_ms :float =0.0
128
+
129
+
130
+ stability_ratio :float =0.0
131
+ reconnect_count :int =0
132
+ drop_count :int =0
133
+
134
+
135
+ throughput_mbps :float =0.0
136
+ throughput_ratio :float =0.0
137
+
138
+
139
+ avg_ping_ms :float =0.0
140
+ max_ping_ms :float =0.0
141
+ packet_loss_ratio :float =0.0
142
+
143
+
144
+ error_message :Optional [str ]=None
145
+ samples :int =0
146
+
147
+ def to_dict (self )->dict :
148
+ return asdict (self )
149
+
150
+ @classmethod
151
+ def from_dict (cls ,d :dict )->"EpisodeMetrics":
152
+ return cls (**d )
153
+
154
+
155
+
156
+
157
+
158
+
159
+ @dataclass
160
+ class EpisodeCommand :
161
+ episode_id :str =""
162
+ config :Optional [dict ]=None
163
+ duration_seconds :int =90
164
+ server_ip :str =""
165
+ server_port :int =443
166
+ uuid :str =""
167
+
168
+ def to_dict (self )->dict :
169
+ return asdict (self )
170
+
171
+ @classmethod
172
+ def from_dict (cls ,d :dict )->"EpisodeCommand":
173
+ obj =cls (**{k :v for k ,v in d .items ()if k !="config"})
174
+ obj .config =d .get ("config")
175
+ return obj
176
+
177
+
178
+ @dataclass
179
+ class ClientStatus :
180
+ episode_id :str =""
181
+ partial_metrics :Optional [dict ]=None
182
+ phase :str ="idle"