krishuggingface commited on
Commit
be25efb
·
1 Parent(s): e897efa

fix(inference): align OpenEnv strict proxy compliance overrides and remove hardcoded API validations

Browse files
Files changed (1) hide show
  1. inference.py +258 -432
inference.py CHANGED
@@ -1,54 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import json
3
  import time
4
- import logging
5
- import traceback
6
- import threading
7
- from typing import Optional, Dict, Any
8
-
9
  import requests
 
10
  from openai import OpenAI
11
 
12
- # ---------------------------------------------------------------------
13
- # 1. SETUP LOGGING
14
- # ---------------------------------------------------------------------
15
- # Ensure logs look like: [TIMESTAMP] [STAGE] message
16
- class StageFormatter(logging.Formatter):
17
- def format(self, record):
18
- # We manually use the prefix if provided in extra
19
- stage = getattr(record, 'stage', 'SYSTEM')
20
- self._style._fmt = f"[%(asctime)s] [{stage}] %(message)s"
21
- # Ensure fast formatting matching standard requirements
22
- return super().format(record)
23
-
24
- logger = logging.getLogger("inference")
25
- logger.setLevel(logging.DEBUG)
26
- handler = logging.StreamHandler()
27
- handler.setFormatter(StageFormatter(datefmt="%Y-%m-%d %H:%M:%S"))
28
- logger.addHandler(handler)
29
-
30
- logger.info("Initializing Agent Scripts", extra={"stage": "APP STARTUP"})
31
-
32
  API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
33
- MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
34
- API_KEY = os.environ.get("API_KEY")
35
- ENV_URL = os.getenv("ENV_URL", "http://127.0.0.1:7860")
36
- USE_LLM = os.environ.get("USE_LLM", "0") == "1"
37
 
38
- logger.info("Environment variables loaded.", extra={"stage": "APP STARTUP"})
 
39
 
40
- client: Optional[OpenAI] = None
41
- if API_BASE_URL and API_KEY:
42
- try:
43
- logger.info("Initializing OpenAI Client", extra={"stage": "MODEL LOADING"})
44
- _start_time = time.time()
45
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
46
- _end_time = time.time()
47
- logger.info(f"Client Initialized. Duration: {_end_time - _start_time:.4f}s", extra={"stage": "MODEL LOADING"})
48
- except Exception as e:
49
- logger.error(f"Failed to initialize OpenAI client: {e}\n{traceback.format_exc()}", extra={"stage": "APP STARTUP"})
50
- client = None
 
 
 
 
51
 
 
52
  SYSTEM_PROMPT = """You are an AI agent monitoring a power grid inverter's Phase-Locked Loop (PLL).
53
  You receive time-windowed sensor readings each step and must detect cyberattacks.
54
 
@@ -63,6 +62,14 @@ For task_id=0: Focus on detecting any attack (attack_detected=True/False).
63
  For task_id=1: Also classify the attack type (1=sinusoidal, 2=ramp, 3=pulse).
64
  For task_id=2: Detect very subtle attacks before the PLL loses lock. Look for slow drifts in omega_deviation and vq.
65
 
 
 
 
 
 
 
 
 
66
  Respond ONLY with valid JSON, no explanation:
67
  {
68
  "attack_detected": <bool>,
@@ -71,322 +78,190 @@ Respond ONLY with valid JSON, no explanation:
71
  "protective_action": <int 0-3>
72
  }"""
73
 
74
- TASK_NAMES = {
75
- 0: "Sinusoidal FDI Detection (Easy)",
76
- 1: "Multi-Attack Classification (Medium)",
77
- 2: "Stealthy Attack Detection (Hard)",
78
- }
79
-
80
- DEFAULT_ACTION = {
81
- "attack_detected": False,
82
- "attack_type": 0,
83
- "confidence": 0.5,
84
- "protective_action": 0,
85
- }
86
-
87
 
88
  def log_start(task: str, env: str, model: str) -> None:
89
- logger.info(f"task={task} env={env} model={model}", extra={"stage": "EPISODE START"})
90
 
91
 
92
  def log_step(step: int, action: dict, reward: float, done: bool, error) -> None:
93
- action_str = json.dumps(action, separators=(",", ":"))
94
- error_val = error if error else "null"
95
- logger.debug(
96
- f"step={step} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_val}",
97
- extra={"stage": "EPISODE STEP"}
 
98
  )
99
 
100
 
101
- def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
102
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
103
- logger.info(
104
- f"success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
105
- extra={"stage": "EPISODE END"}
 
106
  )
107
 
108
-
109
- def safe_action(action: Dict[str, Any]) -> Dict[str, Any]:
110
- try:
111
- return {
112
- "attack_detected": bool(action.get("attack_detected", False)),
113
- "attack_type": max(0, min(4, int(action.get("attack_type", 0)))),
114
- "confidence": max(0.0, min(1.0, float(action.get("confidence", 0.5)))),
115
- "protective_action": max(0, min(3, int(action.get("protective_action", 0)))),
116
- }
117
- except Exception as e:
118
- logger.error(f"Action constraint failed: {e}\n{traceback.format_exc()}", extra={"stage": "POSTPROCESSING"})
119
- return DEFAULT_ACTION.copy()
120
-
121
-
122
- def safe_post_json(
123
- url: str,
124
- payload: Dict[str, Any],
125
- timeout: int = 10,
126
- retries: int = 2,
127
- ) -> Optional[Dict[str, Any]]:
128
- last_error = None
129
- logger.debug(f"Calling endpoint {url}", extra={"stage": "API CALL (REQ)"})
130
- _start_t = time.time()
131
-
132
- for attempt in range(retries + 1):
133
- try:
134
- response = requests.post(url, json=payload, timeout=timeout)
135
- response.raise_for_status()
136
- logger.debug(f"Response ok from {url} in {time.time()-_start_t:.4f}s", extra={"stage": "API CALL (RES)"})
137
- return response.json()
138
- except Exception as e:
139
- last_error = e
140
- logger.warning(
141
- f"HTTP error calling {url} (attempt {attempt + 1}/{retries + 1}): {e}",
142
- extra={"stage": "API CALL (ERR)"}
143
- )
144
- time.sleep(0.5)
145
-
146
- logger.error(f"Giving up on {url}: {last_error}\n{traceback.format_exc()}", extra={"stage": "API CALL (ERR)"})
147
- return None
148
-
149
-
150
- def _warmup_worker() -> None:
151
- """Non-blocking LLM warmup executed inside a thread."""
152
- if client is None:
153
- logger.info("LLM proxy warmup skipped (client unavailable).", extra={"stage": "MODEL LOADING"})
154
- return
155
-
156
- logger.info("Initializing LLM Proxy Warmup Thread...", extra={"stage": "MODEL LOADING"})
157
- _req_t = time.time()
158
- try:
159
- _ = client.chat.completions.create(
160
- model=MODEL_NAME,
161
- messages=[{"role": "user", "content": "ping"}],
162
- max_tokens=1,
163
- temperature=0,
164
- )
165
- logger.info(f"LLM proxy warmup successful in {time.time() - _req_t:.4f}s.", extra={"stage": "MODEL LOADING"})
166
- except Exception as e:
167
- logger.error(f"LLM proxy warmup failed: {e}\n{traceback.format_exc()}", extra={"stage": "MODEL LOADING (ERR)"})
168
-
169
- def warmup_proxy() -> None:
170
- """Make one tiny proxy call gracefully via threading to avoid app blocking"""
171
- t = threading.Thread(target=_warmup_worker, daemon=True)
172
- t.start()
173
-
174
-
175
- # ---------------------------------------------------------------------
176
- # ZERO-DEPENDENCY HEALTHCHECK SERVER
177
- # ---------------------------------------------------------------------
178
- from http.server import BaseHTTPRequestHandler, HTTPServer
179
-
180
- class FastHealthcheck(BaseHTTPRequestHandler):
181
- def do_GET(self):
182
- logger.info(f"Healthcheck triggered at {self.path}", extra={"stage": "HEALTHCHECK"})
183
- self.send_response(200)
184
- self.send_header("Content-type", "application/json")
185
- self.end_headers()
186
- self.wfile.write(b'{"status":"ok"}')
187
- logger.info("Healthcheck returned 200 OK immediately", extra={"stage": "HEALTHCHECK"})
188
-
189
- def log_message(self, format, *args):
190
- pass # disable default stdout spam from simple server
191
-
192
- def _run_healthcheck() -> None:
193
- try:
194
- # Binding to 7860 as Spaces default checks it
195
- server = HTTPServer(('0.0.0.0', 7860), FastHealthcheck)
196
- logger.info("Background Healthcheck server bound to 0.0.0.0:7860", extra={"stage": "APP STARTUP"})
197
- server.serve_forever()
198
- except Exception as e:
199
- logger.error(f"Healthcheck server crash: {e}\n{traceback.format_exc()}", extra={"stage": "APP STARTUP (ERR)"})
200
-
201
- # Start Healthcheck Thread instantly
202
- t_health = threading.Thread(target=_run_healthcheck, daemon=True)
203
- t_health.start()
204
-
205
-
206
- def detector_agent(prev_info: dict) -> Optional[dict]:
207
- det = (prev_info or {}).get("detector", {})
208
- if not isinstance(det, dict) or "attack_detected" not in det:
209
- return None
210
- return {
211
- "attack_detected": det.get("attack_detected", False),
212
- "attack_type": det.get("attack_type", 0),
213
- "confidence": det.get("confidence", 0.5),
214
- "protective_action": det.get("protective_action", 0),
215
- }
216
-
217
 
218
  class HeuristicState:
 
 
219
  def __init__(self):
220
  self.reset()
221
 
222
  def reset(self):
223
- self.vq_history = []
224
  self.omega_dev_history = []
225
- self.attack_detected = False
226
- self.predicted_type = 0
227
- self.settled_baseline = None
228
- self.peak_vq = 0.0
229
 
230
 
231
  _hstate = HeuristicState()
232
 
233
 
234
  def heuristic_agent(obs: dict) -> dict:
 
235
  global _hstate
236
 
237
- try:
238
- vq = obs["vq_window"]
239
- omega_dev = obs["omega_deviation_window"]
240
- task_id = int(obs["task_id"])
241
- step = int(obs["step"])
242
- except Exception:
243
- return DEFAULT_ACTION.copy()
244
 
245
  if step == 0:
246
  _hstate.reset()
247
 
248
- try:
249
- vq_abs = [abs(v) for v in vq]
250
- vq_mean = sum(vq_abs) / len(vq_abs)
251
- vq_max = max(vq_abs)
252
- vq_latest = abs(vq[-1]) if vq else 0.0
253
-
254
- omega_dev_abs = [abs(v) for v in omega_dev]
255
- omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs) if omega_dev_abs else 0.0
256
 
257
- _hstate.vq_history.append(vq_mean)
258
- _hstate.omega_dev_history.append(omega_dev_mean)
259
- _hstate.peak_vq = max(_hstate.peak_vq, vq_mean)
260
 
261
- if step == 50:
262
- _hstate.settled_baseline = omega_dev_mean
263
 
264
- if step < 25:
265
- detected = False
266
- else:
267
- detected = vq_mean > 0.01 or vq_max > 0.025
268
 
269
- if detected:
270
- _hstate.attack_detected = True
 
 
 
 
 
 
271
 
272
- if task_id == 0:
 
 
273
  return {
274
- "attack_detected": _hstate.attack_detected,
275
- "attack_type": 1 if _hstate.attack_detected else 0,
276
- "confidence": min(1.0, vq_mean * 50) if _hstate.attack_detected else 0.8,
277
- "protective_action": 1 if _hstate.attack_detected else 0,
278
  }
279
 
280
- if task_id == 1:
281
- if not _hstate.attack_detected:
282
- return {
283
- "attack_detected": False,
284
- "attack_type": 0,
285
- "confidence": 0.7,
286
- "protective_action": 0,
287
- }
288
 
289
- n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- if n_elevated < 5:
 
 
 
 
 
 
 
 
292
  attack_type = 1
293
  else:
294
- elevated = [v for v in _hstate.vq_history if v > 0.005]
295
- recent = elevated[-min(20, len(elevated)):]
296
-
297
- current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0.0
298
- zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i - 1] < 0)
299
-
300
- if len(recent) >= 6:
301
- third = max(1, len(recent) // 3)
302
- first_third = sum(recent[:third]) / third
303
- last_third = sum(recent[-third:]) / third
304
- growth = last_third / first_third if first_third > 0.001 else 1.0
305
- else:
306
- growth = 1.0
307
-
308
- if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05:
309
- attack_type = 3
310
- elif current_vs_peak < 0.4 and n_elevated > 30:
311
- attack_type = 3
312
- elif zero_crossings >= 2 and growth < 1.5:
313
- attack_type = 1
314
- elif growth > 1.3:
315
- attack_type = 2
316
- elif zero_crossings >= 1:
317
- attack_type = 1
318
- else:
319
- vq_diffs = [vq[i] - vq[i - 1] for i in range(1, len(vq))]
320
- neg = sum(1 for d in vq_diffs if d < 0)
321
- attack_type = 3 if neg > 14 else 1
322
-
323
- _hstate.predicted_type = attack_type
324
 
325
- return {
326
- "attack_detected": True,
327
- "attack_type": _hstate.predicted_type,
328
- "confidence": 0.8,
329
- "protective_action": 1,
330
- }
331
 
332
- if task_id == 2:
333
- drift_detected = False
334
- confidence = 0.3
335
-
336
- if step > 50 and _hstate.settled_baseline is not None:
337
- baseline = _hstate.settled_baseline
338
- ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100.0
339
-
340
- if len(_hstate.omega_dev_history) > 10:
341
- recent_10 = _hstate.omega_dev_history[-10:]
342
- old_10 = (
343
- _hstate.omega_dev_history[-20:-10]
344
- if len(_hstate.omega_dev_history) > 20
345
- else _hstate.omega_dev_history[:10]
346
- )
347
- recent_avg = sum(recent_10) / len(recent_10)
348
- old_avg = sum(old_10) / len(old_10)
349
- rising = recent_avg > old_avg * 1.1
350
- else:
351
- rising = False
352
-
353
- if ratio > 2.0:
354
- drift_detected = True
355
- confidence = 0.9
356
- elif ratio > 1.3 and rising:
357
- drift_detected = True
358
- confidence = 0.8
359
- elif rising and vq_mean > 0.1:
360
- drift_detected = True
361
- confidence = 0.6
362
- elif vq_mean > 0.2:
363
- drift_detected = True
364
- confidence = 0.5
365
-
366
- if drift_detected:
367
- _hstate.attack_detected = True
368
 
369
- return {
370
- "attack_detected": drift_detected,
371
- "attack_type": 4 if drift_detected else 0,
372
- "confidence": confidence,
373
- "protective_action": 2 if drift_detected else 0,
374
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
- return DEFAULT_ACTION.copy()
 
 
 
 
 
377
 
378
- except Exception as e:
379
- logger.warning(f"heuristic_agent failed: {e}\n{traceback.format_exc()}", extra={"stage": "HEURISTIC AGENT (ERR)"})
380
- return DEFAULT_ACTION.copy()
381
 
 
382
 
383
  def parse_llm_response(response_text: str) -> dict:
384
  try:
385
- text = (response_text or "").strip()
386
  if text.startswith("```"):
387
- lines = text.split("\n")
388
- json_lines = []
389
- in_block = False
390
  for line in lines:
391
  if line.strip().startswith("```") and not in_block:
392
  in_block = True
@@ -398,190 +273,141 @@ def parse_llm_response(response_text: str) -> dict:
398
  text = "\n".join(json_lines)
399
 
400
  parsed = json.loads(text)
401
- return safe_action(
402
- {
403
- "attack_detected": parsed.get("attack_detected", False),
404
- "attack_type": parsed.get("attack_type", 0),
405
- "confidence": parsed.get("confidence", 0.5),
406
- "protective_action": parsed.get("protective_action", 0),
407
- }
408
- )
409
- except Exception:
410
  return DEFAULT_ACTION.copy()
411
 
412
 
413
  def format_observation(obs: dict) -> str:
414
- try:
415
- parts = [
416
- f"Step: {obs['step']}",
417
- f"Task: {obs['task_id']}",
418
- f"vq_window (last 20): {[round(v, 6) for v in obs['vq_window']]}",
419
- f"vd_window (last 20): {[round(v, 6) for v in obs['vd_window']]}",
420
- f"omega_window (last 20): {[round(v, 6) for v in obs['omega_window']]}",
421
- f"omega_deviation_window (last 20): {[round(v, 6) for v in obs['omega_deviation_window']]}",
422
- f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}",
423
- ]
424
- return "\n".join(parts)
425
- except Exception:
426
- return ""
427
 
428
 
429
  def llm_agent(obs: dict) -> dict:
430
- if client is None:
431
- return heuristic_agent(obs)
432
-
433
  try:
434
- obs_text = format_observation(obs)
435
  completion = client.chat.completions.create(
436
  model=MODEL_NAME,
437
  messages=[
438
  {"role": "system", "content": SYSTEM_PROMPT},
439
- {"role": "user", "content": obs_text},
440
  ],
441
  temperature=0.1,
442
  max_tokens=200,
443
  )
444
- llm_response = completion.choices[0].message.content if completion and completion.choices else ""
445
- return parse_llm_response(llm_response)
446
  except Exception as e:
447
- logger.warning(f"LLM error ({type(e).__name__}: {e})\n{traceback.format_exc()}", extra={"stage": "LLM AGENT (ERR)"})
448
  return heuristic_agent(obs)
449
 
450
-
451
- def choose_action(obs: dict, prev_info: dict) -> dict:
452
- # Preserve the baseline heuristic behavior by default.
453
- try:
454
- if USE_LLM and client is not None:
455
- return safe_action(llm_agent(obs))
456
- except Exception:
457
- pass
458
- return safe_action(heuristic_agent(obs))
459
-
460
 
461
  def run_episode(task_id: int) -> float:
462
- log_start(
463
- task=TASK_NAMES[task_id],
464
- env="pll-cyberattack-detection",
465
- model=MODEL_NAME if USE_LLM else "rule-based-heuristic",
466
- )
467
 
468
- print(f"\n{'=' * 60}")
469
- print(f"Task {task_id}: {TASK_NAMES[task_id]}")
470
- print(f"Agent: {'LLM (' + MODEL_NAME + ')' if USE_LLM else 'Rule-Based Heuristic'}")
471
- print(f"{'=' * 60}")
472
 
473
- step_count = 0
 
 
 
 
474
  grader_score = 0.0
475
- rewards = []
476
- info: Dict[str, Any] = {}
477
- prev_info: Dict[str, Any] = {}
478
 
479
  try:
480
- reset_result = safe_post_json(
481
  f"{ENV_URL}/reset",
482
- {"task_id": task_id},
483
- timeout=10,
484
- retries=2,
485
  )
486
- if not isinstance(reset_result, dict):
487
- logger.error("Reset failed; skipping episode.", extra={"stage": "ENV RESET"})
488
- return 0.0
489
 
490
- obs = reset_result
491
- done = False
492
  total_reward = 0.0
 
493
 
494
  while not done:
495
- try:
496
- action = choose_action(obs, prev_info)
497
- except Exception as e:
498
- logger.warning(f"Action selection failed: {e}\n{traceback.format_exc()}", extra={"stage": "ACTION SELECTION"})
499
- action = DEFAULT_ACTION.copy()
500
 
501
- result = safe_post_json(
502
  f"{ENV_URL}/step",
503
- action,
504
- timeout=10,
505
- retries=2,
506
  )
507
- if not isinstance(result, dict):
508
- logger.error("Step failed; ending episode early.", extra={"stage": "ENV STEP"})
509
- break
510
-
511
- obs = result.get("observation", obs)
512
- reward = result.get("reward", {})
513
- done = bool(result.get("done", False))
514
- info = result.get("info", {})
515
-
516
- step_reward = 0.0
517
- if isinstance(reward, dict):
518
- try:
519
- step_reward = float(reward.get("total", 0.0))
520
- except Exception:
521
- step_reward = 0.0
522
 
 
523
  total_reward += step_reward
524
  rewards.append(step_reward)
525
- log_step(step=step_count, action=action, reward=step_reward, done=done, error=None)
526
 
527
- prev_info = info if isinstance(info, dict) else {}
528
  step_count += 1
 
529
 
530
  if step_count % 50 == 0:
531
  print(
532
- f" Step {step_count:3d} | Reward: {step_reward:+.4f} | "
533
- f"Cumulative: {total_reward:+.4f} | "
534
- f"Detected: {action.get('attack_detected', False)} | "
535
- f"Type: {action.get('attack_type', 0)}",
536
  flush=True,
537
  )
538
 
539
- if isinstance(info, dict):
540
- try:
541
- grader_score = float(info.get("grader_score", 0.0))
542
- except Exception:
543
- grader_score = 0.0
544
-
545
- print(f"\n Episode complete: {step_count} steps")
546
- print(f" Total reward: {total_reward:+.4f}")
547
- print(f" Grader score: {grader_score:.4f}")
548
 
549
- return grader_score
550
-
551
- except Exception as e:
552
- logger.error(f"Episode crashed safely: {e}\n{traceback.format_exc()}", extra={"stage": "EPISODE SEVERE ERR"})
553
- return 0.0
554
 
555
  finally:
556
- log_end(success=grader_score > 0.0, steps=step_count, score=grader_score, rewards=rewards)
557
 
 
558
 
559
- if __name__ == "__main__":
560
- agent_name = f"LLM ({MODEL_NAME})" if USE_LLM else "Rule-Based Heuristic"
561
- logger.info("PLL Cyberattack Detection — Agentic Inference", extra={"stage": "APP STARTUP"})
562
- logger.info(f"Agent: {agent_name}", extra={"stage": "APP STARTUP"})
563
- logger.info(f"Environment: {ENV_URL}", extra={"stage": "APP STARTUP"})
564
- if not USE_LLM:
565
- logger.info("(Set USE_LLM=1 to use LLM agent instead of heuristic)", extra={"stage": "APP STARTUP"})
566
 
567
- warmup_proxy()
 
568
 
569
  start_time = time.time()
570
- scores = []
571
 
572
  for task_id in range(3):
573
- score = run_episode(task_id)
574
- print(f"Task {task_id} score: {score:.4f}")
 
 
 
575
  scores.append(score)
 
576
 
577
  elapsed = time.time() - start_time
 
 
 
578
 
579
- print(f"\n{'=' * 60}")
580
- print("FINAL RESULTS")
581
- print(f"{'=' * 60}")
582
- for i, score in enumerate(scores):
583
- print(f" Task {i} ({TASK_NAMES[i]}): {score:.4f}")
584
- if scores:
585
- print(f"\n Average score: {sum(scores) / len(scores):.4f}")
586
- print(f" Total time: {elapsed:.1f}s ({elapsed / 60:.1f} min)")
587
- print(f"{'=' * 60}")
 
1
+ """
2
+ Inference Script — PLL Cyberattack Detection OpenEnv
3
+ =====================================================
4
+ Environment variables (injected by the judging sandbox):
5
+ API_BASE_URL LiteLLM proxy endpoint (MUST be used for all LLM calls)
6
+ API_KEY LiteLLM proxy key (MUST be used — do not hardcode keys)
7
+ MODEL_NAME Model identifier
8
+ ENV_URL Environment server URL (default: http://localhost:7860)
9
+
10
+ STDOUT FORMAT (OpenEnv compliance):
11
+ [START] task=<task_name> env=<benchmark> model=<model_name>
12
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
13
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
14
+ """
15
+
16
  import os
17
  import json
18
  import time
 
 
 
 
 
19
  import requests
20
+ from typing import List, Optional
21
  from openai import OpenAI
22
 
23
+ # ── Config — always read from environment, never hardcode ─────────────────────
24
+ # The judging sandbox injects API_BASE_URL and API_KEY via their LiteLLM proxy.
25
+ # All LLM calls MUST go through these values or the submission will be rejected.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
27
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
28
+ API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN", "dummy")
29
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860")
 
30
 
31
+ # OpenAI client pointed at the proxy — never bypass this
32
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
33
 
34
+ # ── Task metadata ─────────────────────────────────────────────────────────────
35
+ TASK_NAMES = {
36
+ 0: "Sinusoidal FDI Detection (Easy)",
37
+ 1: "Multi-Attack Classification (Medium)",
38
+ 2: "Stealthy Attack Detection (Hard)",
39
+ }
40
+
41
+ BENCHMARK = "pll-cyberattack-detection"
42
+
43
+ DEFAULT_ACTION = {
44
+ "attack_detected": False,
45
+ "attack_type": 0,
46
+ "confidence": 0.5,
47
+ "protective_action": 0,
48
+ }
49
 
50
+ # ── System prompt ─────────────────────────────────────────────────────────────
51
  SYSTEM_PROMPT = """You are an AI agent monitoring a power grid inverter's Phase-Locked Loop (PLL).
52
  You receive time-windowed sensor readings each step and must detect cyberattacks.
53
 
 
62
  For task_id=1: Also classify the attack type (1=sinusoidal, 2=ramp, 3=pulse).
63
  For task_id=2: Detect very subtle attacks before the PLL loses lock. Look for slow drifts in omega_deviation and vq.
64
 
65
+ Analysis tips:
66
+ - In healthy state, vq values should be near 0 and stable.
67
+ - Sinusoidal attacks cause oscillating patterns in vq.
68
+ - Ramp attacks cause steadily increasing vq magnitude.
69
+ - Pulse attacks cause sudden step changes in vq.
70
+ - Stealthy attacks cause very slow, gradual drift in omega_deviation_window.
71
+ - Look at trends across the full window, not just the latest value.
72
+
73
  Respond ONLY with valid JSON, no explanation:
74
  {
75
  "attack_detected": <bool>,
 
78
  "protective_action": <int 0-3>
79
  }"""
80
 
81
+ # ── Logging helpers ───────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  def log_start(task: str, env: str, model: str) -> None:
84
+ print(f"[START] task={task} env={env} model={model}", flush=True)
85
 
86
 
87
  def log_step(step: int, action: dict, reward: float, done: bool, error) -> None:
88
+ action_str = json.dumps(action, separators=(',', ':'))
89
+ error_val = error if error else "null"
90
+ print(
91
+ f"[STEP] step={step} action={action_str} "
92
+ f"reward={reward:.2f} done={str(done).lower()} error={error_val}",
93
+ flush=True,
94
  )
95
 
96
 
97
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
98
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
99
+ print(
100
+ f"[END] success={str(success).lower()} steps={steps} "
101
+ f"score={score:.3f} rewards={rewards_str}",
102
+ flush=True,
103
  )
104
 
105
+ # ── Heuristic agent (FALLBACK ONLY — used when LLM call fails) ────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
  class HeuristicState:
108
+ """Tracks running state for the heuristic agent across steps."""
109
+
110
  def __init__(self):
111
  self.reset()
112
 
113
  def reset(self):
114
+ self.vq_history = []
115
  self.omega_dev_history = []
116
+ self.attack_detected = False
117
+ self.predicted_type = 0
118
+ self.settled_baseline = None
119
+ self.peak_vq = 0.0
120
 
121
 
122
  _hstate = HeuristicState()
123
 
124
 
125
  def heuristic_agent(obs: dict) -> dict:
126
+ """Rule-based fallback — only called when the LLM request fails."""
127
  global _hstate
128
 
129
+ vq = obs["vq_window"]
130
+ omega_dev = obs["omega_deviation_window"]
131
+ task_id = obs["task_id"]
132
+ step = obs["step"]
 
 
 
133
 
134
  if step == 0:
135
  _hstate.reset()
136
 
137
+ vq_abs = [abs(v) for v in vq]
138
+ vq_mean = sum(vq_abs) / len(vq_abs)
139
+ vq_max = max(vq_abs)
140
+ omega_dev_abs = [abs(v) for v in omega_dev]
141
+ omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs)
 
 
 
142
 
143
+ _hstate.vq_history.append(vq_mean)
144
+ _hstate.omega_dev_history.append(omega_dev_mean)
145
+ _hstate.peak_vq = max(_hstate.peak_vq, vq_mean)
146
 
147
+ if step == 50:
148
+ _hstate.settled_baseline = omega_dev_mean
149
 
150
+ detected = False if step < 25 else (vq_mean > 0.01 or vq_max > 0.025)
151
+ if detected:
152
+ _hstate.attack_detected = True
 
153
 
154
+ # ── Task 0: binary detection ──────────────────────────────────────────────
155
+ if task_id == 0:
156
+ return {
157
+ "attack_detected": _hstate.attack_detected,
158
+ "attack_type": 1 if _hstate.attack_detected else 0,
159
+ "confidence": min(1.0, vq_mean * 50) if _hstate.attack_detected else 0.8,
160
+ "protective_action": 1 if _hstate.attack_detected else 0,
161
+ }
162
 
163
+ # ── Task 1: classification ────────────────────────────────────────────────
164
+ if task_id == 1:
165
+ if not _hstate.attack_detected:
166
  return {
167
+ "attack_detected": False,
168
+ "attack_type": 0,
169
+ "confidence": 0.7,
170
+ "protective_action": 0,
171
  }
172
 
173
+ n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)
 
 
 
 
 
 
 
174
 
175
+ if n_elevated < 5:
176
+ attack_type = 1
177
+ else:
178
+ elevated = [v for v in _hstate.vq_history if v > 0.005]
179
+ recent = elevated[-min(20, len(elevated)):]
180
+
181
+ current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0
182
+ zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i - 1] < 0)
183
+
184
+ if len(recent) >= 6:
185
+ first_third = sum(recent[: len(recent) // 3]) / (len(recent) // 3)
186
+ last_third = sum(recent[-len(recent) // 3 :]) / (len(recent) // 3)
187
+ growth = last_third / first_third if first_third > 0.001 else 1.0
188
+ else:
189
+ growth = 1.0
190
 
191
+ if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05:
192
+ attack_type = 3
193
+ elif current_vs_peak < 0.4 and n_elevated > 30:
194
+ attack_type = 3
195
+ elif zero_crossings >= 2 and growth < 1.5:
196
+ attack_type = 1
197
+ elif growth > 1.3:
198
+ attack_type = 2
199
+ elif zero_crossings >= 1:
200
  attack_type = 1
201
  else:
202
+ vq_diffs = [vq[i] - vq[i - 1] for i in range(1, len(vq))]
203
+ neg = sum(1 for d in vq_diffs if d < 0)
204
+ attack_type = 3 if neg > 14 else 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
+ _hstate.predicted_type = attack_type
 
 
 
 
 
207
 
208
+ return {
209
+ "attack_detected": True,
210
+ "attack_type": _hstate.predicted_type,
211
+ "confidence": 0.8,
212
+ "protective_action": 1,
213
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
+ # ── Task 2: stealthy attack ───────────────────────────────────────────────
216
+ if task_id == 2:
217
+ drift_detected = False
218
+ confidence = 0.3
219
+
220
+ if step > 50 and _hstate.settled_baseline is not None:
221
+ baseline = _hstate.settled_baseline
222
+ ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100
223
+
224
+ if len(_hstate.omega_dev_history) > 10:
225
+ recent_10 = _hstate.omega_dev_history[-10:]
226
+ old_10 = (_hstate.omega_dev_history[-20:-10]
227
+ if len(_hstate.omega_dev_history) > 20
228
+ else _hstate.omega_dev_history[:10])
229
+ recent_avg = sum(recent_10) / len(recent_10)
230
+ old_avg = sum(old_10) / len(old_10)
231
+ rising = recent_avg > old_avg * 1.1
232
+ else:
233
+ rising = False
234
+
235
+ if ratio > 2.0:
236
+ drift_detected, confidence = True, 0.9
237
+ elif ratio > 1.3 and rising:
238
+ drift_detected, confidence = True, 0.8
239
+ elif rising and vq_mean > 0.1:
240
+ drift_detected, confidence = True, 0.6
241
+ elif vq_mean > 0.2:
242
+ drift_detected, confidence = True, 0.5
243
+
244
+ if drift_detected:
245
+ _hstate.attack_detected = True
246
 
247
+ return {
248
+ "attack_detected": drift_detected,
249
+ "attack_type": 4 if drift_detected else 0,
250
+ "confidence": confidence,
251
+ "protective_action": 2 if drift_detected else 0,
252
+ }
253
 
254
+ return DEFAULT_ACTION.copy()
 
 
255
 
256
+ # ── LLM agent (PRIMARY — always called first) ─────────────────────────────────
257
 
258
  def parse_llm_response(response_text: str) -> dict:
259
  try:
260
+ text = response_text.strip()
261
  if text.startswith("```"):
262
+ lines = text.split("\n")
263
+ in_block = False
264
+ json_lines: List[str] = []
265
  for line in lines:
266
  if line.strip().startswith("```") and not in_block:
267
  in_block = True
 
273
  text = "\n".join(json_lines)
274
 
275
  parsed = json.loads(text)
276
+ return {
277
+ "attack_detected": bool(parsed.get("attack_detected", False)),
278
+ "attack_type": max(0, min(4, int(parsed.get("attack_type", 0)))),
279
+ "confidence": max(0.0, min(1.0, float(parsed.get("confidence", 0.5)))),
280
+ "protective_action": max(0, min(3, int(parsed.get("protective_action", 0)))),
281
+ }
282
+ except (json.JSONDecodeError, KeyError, TypeError, ValueError):
 
 
283
  return DEFAULT_ACTION.copy()
284
 
285
 
286
  def format_observation(obs: dict) -> str:
287
+ return "\n".join([
288
+ f"Step: {obs['step']}",
289
+ f"Task: {obs['task_id']}",
290
+ f"vq_window (last 20): {[round(v, 6) for v in obs['vq_window']]}",
291
+ f"vd_window (last 20): {[round(v, 6) for v in obs['vd_window']]}",
292
+ f"omega_window (last 20): {[round(v, 6) for v in obs['omega_window']]}",
293
+ f"omega_deviation_window (last 20): {[round(v, 6) for v in obs['omega_deviation_window']]}",
294
+ f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}",
295
+ ])
 
 
 
 
296
 
297
 
298
  def llm_agent(obs: dict) -> dict:
299
+ """Primary agent calls the LLM through the injected proxy.
300
+ Falls back to heuristic only if the API call itself raises an exception.
301
+ """
302
  try:
 
303
  completion = client.chat.completions.create(
304
  model=MODEL_NAME,
305
  messages=[
306
  {"role": "system", "content": SYSTEM_PROMPT},
307
+ {"role": "user", "content": format_observation(obs)},
308
  ],
309
  temperature=0.1,
310
  max_tokens=200,
311
  )
312
+ return parse_llm_response(completion.choices[0].message.content or "")
 
313
  except Exception as e:
314
+ print(f"[DEBUG] LLM error ({type(e).__name__}: {e}), falling back to heuristic", flush=True)
315
  return heuristic_agent(obs)
316
 
317
+ # ── Episode runner ────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
318
 
319
  def run_episode(task_id: int) -> float:
320
+ task_name = TASK_NAMES[task_id]
 
 
 
 
321
 
322
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
 
 
 
323
 
324
+ # Reset heuristic state before every episode so stale data from a previous
325
+ # task never bleeds into the next one (also covers the LLM fallback path).
326
+ _hstate.reset()
327
+
328
+ step_count = 0
329
  grader_score = 0.0
330
+ rewards: List[float] = []
331
+ success = False
 
332
 
333
  try:
334
+ reset_resp = requests.post(
335
  f"{ENV_URL}/reset",
336
+ json={"task_id": task_id},
337
+ timeout=60,
 
338
  )
339
+ reset_resp.raise_for_status()
340
+ obs = reset_resp.json()
 
341
 
342
+ done = False
 
343
  total_reward = 0.0
344
+ info = {}
345
 
346
  while not done:
347
+ # LLM is always primary; heuristic is the silent fallback inside llm_agent()
348
+ action = llm_agent(obs)
 
 
 
349
 
350
+ step_resp = requests.post(
351
  f"{ENV_URL}/step",
352
+ json=action,
353
+ timeout=60,
 
354
  )
355
+ step_resp.raise_for_status()
356
+ result = step_resp.json()
357
+
358
+ obs = result["observation"]
359
+ reward = result["reward"]
360
+ done = result["done"]
361
+ info = result.get("info", {})
362
+ error = result.get("error", None)
 
 
 
 
 
 
 
363
 
364
+ step_reward = reward["total"] if isinstance(reward, dict) else float(reward)
365
  total_reward += step_reward
366
  rewards.append(step_reward)
 
367
 
 
368
  step_count += 1
369
+ log_step(step=step_count, action=action, reward=step_reward, done=done, error=error)
370
 
371
  if step_count % 50 == 0:
372
  print(
373
+ f"[DEBUG] step={step_count} cumulative_reward={total_reward:+.4f} "
374
+ f"detected={action['attack_detected']} type={action['attack_type']}",
 
 
375
  flush=True,
376
  )
377
 
378
+ grader_score = info.get("grader_score", 0.0)
379
+ success = grader_score > 0.0
 
 
 
 
 
 
 
380
 
381
+ except Exception as exc:
382
+ print(f"[DEBUG] Episode error: {type(exc).__name__}: {exc}", flush=True)
383
+ success = False
 
 
384
 
385
  finally:
386
+ log_end(success=success, steps=step_count, score=grader_score, rewards=rewards)
387
 
388
+ return grader_score
389
 
390
+ # ── Entry point ───────────────────────────────────────────────────────────────
 
 
 
 
 
 
391
 
392
+ def main() -> None:
393
+ print(f"[DEBUG] PLL Cyberattack Detection — model={MODEL_NAME} env={ENV_URL}", flush=True)
394
 
395
  start_time = time.time()
396
+ scores: List[float] = []
397
 
398
  for task_id in range(3):
399
+ try:
400
+ score = run_episode(task_id)
401
+ except Exception as exc:
402
+ print(f"[DEBUG] run_episode({task_id}) crashed: {exc}", flush=True)
403
+ score = 0.0
404
  scores.append(score)
405
+ print(f"[DEBUG] task={task_id} score={score:.4f}", flush=True)
406
 
407
  elapsed = time.time() - start_time
408
+ avg = sum(scores) / len(scores) if scores else 0.0
409
+ print(f"[DEBUG] avg_score={avg:.4f} elapsed={elapsed:.1f}s", flush=True)
410
+
411
 
412
+ if __name__ == "__main__":
413
+ main()