antony647 commited on
Commit
0ddfcb6
·
verified ·
1 Parent(s): 2256614

Upload app.py

Browse files
Files changed (1) hide show
  1. server/app.py +340 -0
server/app.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from typing import Optional
4
+ import uvicorn
5
+ import random
6
+ import math
7
+
8
+ from models import Observation, Action, StepResponse
9
+
10
+ app = FastAPI(
11
+ title="CLAIRS Autonomous Defense Environment",
12
+ description="OpenEnv-compliant RL environment for IoT DDoS mitigation",
13
+ version="2.0.0",
14
+ )
15
+
16
+
17
+ class ResetRequest(BaseModel):
18
+ task_id: str = "task_1_easy"
19
+
20
+
21
+ class ActionPayload(BaseModel):
22
+ decision: str = "monitor"
23
+
24
+
25
+ ATTACK_PROFILES = {
26
+ "task_1_easy": {
27
+ "name": "Normal Traffic Monitoring",
28
+ "phases": [
29
+ {
30
+ "start": 0,
31
+ "end": 10,
32
+ "type": "normal",
33
+ "base_pps": 120,
34
+ "base_cpu": 10.0,
35
+ },
36
+ ],
37
+ },
38
+ "task_2_medium": {
39
+ "name": "Volumetric DDoS Flood",
40
+ "phases": [
41
+ {
42
+ "start": 0,
43
+ "end": 2,
44
+ "type": "normal",
45
+ "base_pps": 200,
46
+ "base_cpu": 15.0,
47
+ },
48
+ {
49
+ "start": 2,
50
+ "end": 10,
51
+ "type": "attack_ramp",
52
+ "pps_start": 5000,
53
+ "pps_end": 50000,
54
+ "cpu_start": 55.0,
55
+ "cpu_end": 99.0,
56
+ },
57
+ ],
58
+ },
59
+ "task_3_hard": {
60
+ "name": "Stealth Low-and-Slow DDoS",
61
+ "phases": [
62
+ {
63
+ "start": 0,
64
+ "end": 2,
65
+ "type": "normal",
66
+ "base_pps": 150,
67
+ "base_cpu": 12.0,
68
+ },
69
+ {
70
+ "start": 2,
71
+ "end": 10,
72
+ "type": "attack_ramp",
73
+ "pps_start": 2000,
74
+ "pps_end": 25000,
75
+ "cpu_start": 30.0,
76
+ "cpu_end": 75.0,
77
+ },
78
+ ],
79
+ },
80
+ "task_4_expert": {
81
+ "name": "Multi-Wave APT Campaign",
82
+ "phases": [
83
+ {
84
+ "start": 0,
85
+ "end": 2,
86
+ "type": "normal",
87
+ "base_pps": 130,
88
+ "base_cpu": 11.0,
89
+ },
90
+ {
91
+ "start": 2,
92
+ "end": 5,
93
+ "type": "attack_ramp",
94
+ "pps_start": 4000,
95
+ "pps_end": 12000,
96
+ "cpu_start": 40.0,
97
+ "cpu_end": 60.0,
98
+ },
99
+ {
100
+ "start": 5,
101
+ "end": 7,
102
+ "type": "normal",
103
+ "base_pps": 180,
104
+ "base_cpu": 13.0,
105
+ },
106
+ {
107
+ "start": 7,
108
+ "end": 10,
109
+ "type": "attack_ramp",
110
+ "pps_start": 15000,
111
+ "pps_end": 45000,
112
+ "cpu_start": 70.0,
113
+ "cpu_end": 99.0,
114
+ },
115
+ ],
116
+ },
117
+ }
118
+
119
+
120
+ class NetworkSimulator:
121
+
122
+ def __init__(self):
123
+ self.task_id = "task_1_easy"
124
+ self.step_count = 0
125
+ self.max_steps = 10
126
+ self.system_health = 100.0
127
+ self.current_pps = 100.0
128
+ self.current_cpu = 10.0
129
+ self.current_connections = 10
130
+ self.current_bandwidth = 1.0
131
+ self.current_memory = 30.0
132
+ self.false_positives = 0
133
+ self.attack_detected_step = None
134
+ self.cumulative_damage = 0.0
135
+
136
+ def reset(self, task_id: str) -> Observation:
137
+ self.task_id = task_id
138
+ self.step_count = 0
139
+ self.system_health = 100.0
140
+ self.false_positives = 0
141
+ self.attack_detected_step = None
142
+ self.cumulative_damage = 0.0
143
+
144
+ first_phase = ATTACK_PROFILES[task_id]["phases"][0]
145
+
146
+ noise = random.uniform(0.88, 1.12)
147
+ self.current_pps = first_phase["base_pps"] * noise
148
+ self.current_cpu = min(100.0, first_phase["base_cpu"] * random.uniform(0.9, 1.1))
149
+ self.current_connections = max(1, int(self.current_pps / 8 + random.randint(-5, 5)))
150
+ self.current_bandwidth = max(0.1, self.current_pps * 0.001 * random.uniform(0.8, 1.2))
151
+ self.current_memory = 25.0 + random.uniform(-3, 8)
152
+
153
+ return self._observation()
154
+
155
+ def step(self, action: str):
156
+ action = action.lower().strip()
157
+ if action not in ("monitor", "rate_limit", "block"):
158
+ action = "monitor"
159
+
160
+ reward = self._compute_reward(action)
161
+ self._advance_traffic(action)
162
+
163
+ self.step_count += 1
164
+ done = self.step_count >= self.max_steps
165
+
166
+ info = {
167
+ "mitigation_applied": action,
168
+ "is_attack_phase": self._is_attack(),
169
+ "attack_severity": round(self._severity(), 2),
170
+ "system_health": round(self.system_health, 1),
171
+ "false_positives": self.false_positives,
172
+ "cumulative_damage": round(self.cumulative_damage, 1),
173
+ }
174
+
175
+ return self._observation(), reward, done, info
176
+
177
+ def get_state(self) -> Observation:
178
+ return self._observation()
179
+
180
+ def _current_phase(self) -> dict:
181
+ for phase in ATTACK_PROFILES[self.task_id]["phases"]:
182
+ if phase["start"] <= self.step_count < phase["end"]:
183
+ return phase
184
+ return ATTACK_PROFILES[self.task_id]["phases"][-1]
185
+
186
+ def _is_attack(self) -> bool:
187
+ return self._current_phase()["type"] == "attack_ramp"
188
+
189
+ def _severity(self) -> float:
190
+ phase = self._current_phase()
191
+ if phase["type"] != "attack_ramp":
192
+ return 0.0
193
+ span = max(1, phase["end"] - phase["start"] - 1)
194
+ return min(1.0, (self.step_count - phase["start"]) / span)
195
+
196
+ def _advance_traffic(self, action: str):
197
+ phase = self._current_phase()
198
+ noise = random.uniform(0.88, 1.12)
199
+
200
+ mitigation = 1.0
201
+ if action == "block":
202
+ mitigation = 0.05 + random.uniform(0, 0.03)
203
+ elif action == "rate_limit":
204
+ mitigation = 0.35 + random.uniform(0, 0.08)
205
+
206
+ if phase["type"] == "normal":
207
+ target_pps = phase["base_pps"] * noise
208
+ target_cpu = phase["base_cpu"] * random.uniform(0.9, 1.1)
209
+ else:
210
+ span = max(1, phase["end"] - phase["start"] - 1)
211
+ progress = (self.step_count - phase["start"]) / span
212
+ ramp = min(1.0, progress ** 1.3)
213
+
214
+ raw_pps = phase["pps_start"] + (phase["pps_end"] - phase["pps_start"]) * ramp
215
+ raw_cpu = phase["cpu_start"] + (phase["cpu_end"] - phase["cpu_start"]) * ramp
216
+
217
+ target_pps = raw_pps * noise * mitigation
218
+ target_cpu = min(100.0, raw_cpu * noise * (0.3 + 0.7 * mitigation))
219
+
220
+ alpha = 0.7
221
+ self.current_pps = (1 - alpha) * self.current_pps + alpha * target_pps
222
+ self.current_cpu = (1 - alpha) * self.current_cpu + alpha * target_cpu
223
+ self.current_connections = max(1, int(self.current_pps / 8 + random.randint(-3, 3)))
224
+ self.current_bandwidth = max(0.1, self.current_pps * 0.001 * random.uniform(0.85, 1.15))
225
+
226
+ mem_delta = random.uniform(-2, 3)
227
+ if self._is_attack() and action == "monitor":
228
+ mem_delta += self._severity() * 4
229
+ self.current_memory = max(20.0, min(95.0, self.current_memory + mem_delta))
230
+
231
+ if self._is_attack() and action == "monitor":
232
+ dmg = self._severity() * random.uniform(3.0, 7.0)
233
+ self.system_health = max(0.0, self.system_health - dmg)
234
+ self.cumulative_damage += dmg
235
+ elif self._is_attack() and action == "rate_limit":
236
+ dmg = self._severity() * random.uniform(0.5, 2.0)
237
+ self.system_health = max(0.0, self.system_health - dmg)
238
+ self.cumulative_damage += dmg
239
+ else:
240
+ self.system_health = min(100.0, self.system_health + random.uniform(0.3, 1.0))
241
+
242
+ def _compute_reward(self, action: str) -> float:
243
+ is_attack = self._is_attack()
244
+ severity = self._severity()
245
+ reward = 0.50
246
+
247
+ if not is_attack:
248
+ if action == "monitor":
249
+ reward = 0.90 + random.uniform(0, 0.08)
250
+ elif action == "rate_limit":
251
+ reward = 0.25 + random.uniform(0, 0.08)
252
+ self.false_positives += 1
253
+ elif action == "block":
254
+ reward = 0.08 + random.uniform(0, 0.06)
255
+ self.false_positives += 1
256
+ else:
257
+ if severity > 0.6:
258
+ if action == "block":
259
+ reward = 0.88 + random.uniform(0, 0.09)
260
+ elif action == "rate_limit":
261
+ reward = 0.48 + random.uniform(0, 0.10)
262
+ else:
263
+ reward = 0.03 + random.uniform(0, 0.05)
264
+ elif severity > 0.2:
265
+ if action == "rate_limit":
266
+ reward = 0.85 + random.uniform(0, 0.09)
267
+ elif action == "block":
268
+ reward = 0.58 + random.uniform(0, 0.10)
269
+ else:
270
+ reward = 0.05 + random.uniform(0, 0.07)
271
+ else:
272
+ if action in ("rate_limit", "block"):
273
+ reward = 0.78 + random.uniform(0, 0.10)
274
+ else:
275
+ reward = 0.10 + random.uniform(0, 0.08)
276
+
277
+ if self.attack_detected_step is None and action in ("rate_limit", "block"):
278
+ self.attack_detected_step = self.step_count
279
+ if self.step_count <= 3:
280
+ reward = min(0.99, reward + 0.04)
281
+
282
+ if self.task_id == "task_3_hard" and is_attack:
283
+ if action == "rate_limit":
284
+ reward = min(0.99, reward + 0.04)
285
+ elif action == "block" and severity < 0.5:
286
+ reward = max(0.01, reward - 0.08)
287
+
288
+ if self.system_health > 70:
289
+ reward = min(0.99, reward + 0.02)
290
+
291
+ return round(max(0.01, min(0.99, reward)), 4)
292
+
293
+ def _observation(self) -> Observation:
294
+ return Observation(
295
+ cpu_usage_percent=round(self.current_cpu, 2),
296
+ packet_rate_pps=round(self.current_pps, 2),
297
+ active_connections=max(0, self.current_connections),
298
+ bandwidth_mbps=round(self.current_bandwidth, 2),
299
+ memory_usage_percent=round(self.current_memory, 2),
300
+ system_health=round(self.system_health, 2),
301
+ )
302
+
303
+
304
+ simulator = NetworkSimulator()
305
+
306
+
307
+ @app.post("/reset")
308
+ def reset(req: Optional[ResetRequest] = None):
309
+ task_id = req.task_id if req else "task_1_easy"
310
+ if task_id not in ATTACK_PROFILES:
311
+ task_id = "task_1_easy"
312
+
313
+ obs = simulator.reset(task_id)
314
+ return obs.model_dump()
315
+
316
+
317
+ @app.post("/step", response_model=StepResponse)
318
+ def step(payload: Optional[ActionPayload] = None):
319
+ action = payload.decision.lower() if payload else "monitor"
320
+ obs, reward, done, info = simulator.step(action)
321
+
322
+ return StepResponse(observation=obs, reward=reward, done=done, info=info)
323
+
324
+
325
+ @app.get("/state", response_model=Observation)
326
+ def state():
327
+ return simulator.get_state()
328
+
329
+
330
+ @app.get("/health")
331
+ def health():
332
+ return {"status": "ok"}
333
+
334
+
335
+ def main():
336
+ uvicorn.run("server.app:app", host="0.0.0.0", port=7860)
337
+
338
+
339
+ if __name__ == "__main__":
340
+ main()