File size: 18,826 Bytes
20bc5e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
"""
Inference Script — PLL Cyberattack Detection OpenEnv
=====================================================
MANDATORY environment variables:
  API_BASE_URL   The API endpoint for the LLM
  MODEL_NAME     The model identifier to use
  HF_TOKEN       Your Hugging Face / API key

Uses a HYBRID approach:
  - A fast rule-based heuristic agent runs by default (no LLM needed)
  - The heuristic analyzes vq/omega_deviation windows to detect attacks
  - Set USE_LLM=1 env var to use the LLM instead (slower, may fail)

Must be named inference.py and placed at the project root.
Uses OpenAI client for LLM calls when enabled.
"""

import os
import json
from typing import List, Optional
import time
import math
import requests
from openai import OpenAI

API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
HF_TOKEN = os.getenv("HF_TOKEN")
ENV_URL = os.getenv("ENV_URL", "https://krishuggingface-cyberattack-pll.hf.space")
USE_LLM = os.environ.get("USE_LLM", "0") == "1"

client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)

SYSTEM_PROMPT = """You are an AI agent monitoring a power grid inverter's Phase-Locked Loop (PLL).
You receive time-windowed sensor readings each step and must detect cyberattacks.

vq_window: q-axis voltage error (should be ~0 when healthy)
vd_window: d-axis voltage
omega_window: estimated frequency (normalized, nominal=0)
omega_deviation_window: frequency deviation from nominal in rad/s (useful for detecting slow phase drift)
raw_voltages: [va, vb, vc] at current step
task_id: 0=detect only, 1=classify type, 2=detect stealthy attack

For task_id=0: Focus on detecting any attack (attack_detected=True/False).
For task_id=1: Also classify the attack type (1=sinusoidal, 2=ramp, 3=pulse).
For task_id=2: Detect very subtle attacks before the PLL loses lock. Look for slow drifts in omega_deviation and vq.

Analysis tips:
- In healthy state, vq values should be near 0 and stable.
- Sinusoidal attacks cause oscillating patterns in vq.
- Ramp attacks cause steadily increasing vq magnitude.
- Pulse attacks cause sudden step changes in vq.
- Stealthy attacks cause very slow, gradual drift in omega_deviation_window.
- Look at trends across the full window, not just the latest value.

Respond ONLY with valid JSON, no explanation:
{
  "attack_detected": <bool>,
  "attack_type": <int 0-4>,
  "confidence": <float 0.0-1.0>,
  "protective_action": <int 0-3>
}"""

TASK_NAMES = {
    0: "Sinusoidal FDI Detection (Easy)",
    1: "Multi-Attack Classification (Medium)",
    2: "Stealthy Attack Detection (Hard)",
}

DEFAULT_ACTION = {
    "attack_detected": False,
    "attack_type": 0,
    "confidence": 0.5,
    "protective_action": 0,
}


# =====================================================================
# Logging Helpers (OpenEnv compliance)
# =====================================================================

def log_start(task: str, env: str, model: str) -> None:
    print(f"[START] task={task} env={env} model={model}", flush=True)


def log_step(step: int, action: dict, reward: float, done: bool, error) -> None:
    action_str = json.dumps(action, separators=(',', ':'))
    error_val = error if error else "null"
    print(f"[STEP] step={step} action={action_str} reward={reward:.2f} done={str(done).lower()} error={error_val}", flush=True)


def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
    rewards_str = ",".join(f"{r:.2f}" for r in rewards)
    print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)


# =====================================================================
# Detector-Based Agent
# =====================================================================

def detector_agent(prev_info: dict) -> Optional[dict]:
    """Reads the environment's adaptive detector output from the previous step."""
    det = prev_info.get("detector", {})
    if not det or "attack_detected" not in det:
        return None
        
    return {
        "attack_detected": det.get("attack_detected", False),
        "attack_type": det.get("attack_type", 0),
        "confidence": det.get("confidence", 0.5),
        "protective_action": det.get("protective_action", 0),
    }


# =====================================================================
# Rule-Based Heuristic Agent
# =====================================================================

class HeuristicState:
    """Tracks running state for the heuristic agent across steps."""
    def __init__(self):
        self.reset()

    def reset(self):
        self.vq_history = []           # all vq_mean(abs) values
        self.omega_dev_history = []    # all omega_dev_mean(abs) values
        self.attack_detected = False   # latched detection flag
        self.predicted_type = 0        # latched classification
        self.settled_baseline = None   # omega_dev baseline when PLL settles
        self.peak_vq = 0.0            # highest vq_mean seen


_hstate = HeuristicState()


def heuristic_agent(obs: dict) -> dict:
    """
    Rule-based attack detector using cumulative state tracking.
    No LLM needed — runs instantly.

    The key insight is that the PLL's closed-loop response transforms
    attack signals, so we track statistics over time rather than
    trying to classify from a single 20-step vq window shape.
    """
    global _hstate
    vq = obs["vq_window"]
    omega_dev = obs["omega_deviation_window"]
    task_id = obs["task_id"]
    step = obs["step"]

    if step == 0:
        _hstate.reset()

    # --- Compute per-step features ---
    vq_abs = [abs(v) for v in vq]
    vq_mean = sum(vq_abs) / len(vq_abs)
    vq_max = max(vq_abs)
    vq_latest = abs(vq[-1])

    omega_dev_abs = [abs(v) for v in omega_dev]
    omega_dev_mean = sum(omega_dev_abs) / len(omega_dev_abs)

    # Track history
    _hstate.vq_history.append(vq_mean)
    _hstate.omega_dev_history.append(omega_dev_mean)
    _hstate.peak_vq = max(_hstate.peak_vq, vq_mean)

    # Record baseline around step 45-50 (PLL settled)
    if step == 50:
        _hstate.settled_baseline = omega_dev_mean

    # -----------------------------------------------------------------
    # Detection: is vq significantly elevated?
    # After PLL warm-start settles (~step 20-30), healthy vq < 0.005
    # -----------------------------------------------------------------
    if step < 25:
        # PLL still settling, don't detect
        detected = False
    else:
        detected = vq_mean > 0.01 or vq_max > 0.025

    # Latch detection on
    if detected:
        _hstate.attack_detected = True

    # -----------------------------------------------------------------
    # Task 0: Binary detection only
    # -----------------------------------------------------------------
    if task_id == 0:
        return {
            "attack_detected": _hstate.attack_detected,
            "attack_type": 1 if _hstate.attack_detected else 0,
            "confidence": min(1.0, vq_mean * 50) if _hstate.attack_detected else 0.8,
            "protective_action": 1 if _hstate.attack_detected else 0,
        }

    # -----------------------------------------------------------------
    # Task 1: Classification using cumulative patterns
    # -----------------------------------------------------------------
    if task_id == 1:
        if not _hstate.attack_detected:
            return {
                "attack_detected": False,
                "attack_type": 0,
                "confidence": 0.7,
                "protective_action": 0,
            }

        # Classify using cumulative vq_history
        # Only classify after enough attack data (10+ steps of elevated vq)
        n_elevated = sum(1 for v in _hstate.vq_history if v > 0.01)

        if n_elevated < 5:
            # Not enough data yet, use simple guess
            attack_type = 1
        else:
            # Get recent vq trend (last 10 elevated values)
            elevated = [v for v in _hstate.vq_history if v > 0.005]
            recent = elevated[-min(20, len(elevated)):]

            # Feature 1: Is vq currently high or has it decayed?
            current_vs_peak = vq_mean / _hstate.peak_vq if _hstate.peak_vq > 0 else 0

            # Feature 2: How many zero crossings in current window
            zero_crossings = sum(1 for i in range(1, len(vq)) if vq[i] * vq[i-1] < 0)

            # Feature 3: Is vq growing or shrinking over recent history
            if len(recent) >= 6:
                first_third = sum(recent[:len(recent)//3]) / (len(recent)//3)
                last_third = sum(recent[-len(recent)//3:]) / (len(recent)//3)
                growth = last_third / first_third if first_third > 0.001 else 1.0
            else:
                growth = 1.0

            # Classification logic:
            # Sinusoidal: persistent oscillation, zero crossings, stable amplitude
            # Ramp: growing vq over time (growth > 1)
            # Pulse: high initial vq that decays to near zero (current_vs_peak < 0.3)

            if current_vs_peak < 0.15 and _hstate.peak_vq > 0.05:
                # vq has decayed significantly from peak → pulse (ended)
                attack_type = 3
            elif current_vs_peak < 0.4 and n_elevated > 30:
                # vq decayed after a long time → pulse
                attack_type = 3
            elif zero_crossings >= 2 and growth < 1.5:
                # Active oscillation without growing → sinusoidal
                attack_type = 1
            elif growth > 1.3:
                # Growing signal → ramp
                attack_type = 2
            elif zero_crossings >= 1:
                # Some oscillation → sinusoidal
                attack_type = 1
            else:
                # Default: if mono-decrease, pulse; else sinusoidal
                vq_diffs = [vq[i] - vq[i-1] for i in range(1, len(vq))]
                neg = sum(1 for d in vq_diffs if d < 0)
                if neg > 14:  # 14/19 = 73% decreasing
                    attack_type = 3
                else:
                    attack_type = 1

            _hstate.predicted_type = attack_type

        return {
            "attack_detected": True,
            "attack_type": _hstate.predicted_type,
            "confidence": 0.8,
            "protective_action": 1,
        }

    # -----------------------------------------------------------------
    # Task 2: Stealthy attack — detect omega_dev rising above baseline
    # -----------------------------------------------------------------
    if task_id == 2:
        drift_detected = False
        confidence = 0.3

        if step > 50 and _hstate.settled_baseline is not None:
            baseline = _hstate.settled_baseline

            # Compare current to baseline
            ratio = omega_dev_mean / baseline if baseline > 0.01 else omega_dev_mean * 100

            # Check if omega_dev is rising relative to recent history
            if len(_hstate.omega_dev_history) > 10:
                recent_10 = _hstate.omega_dev_history[-10:]
                old_10 = _hstate.omega_dev_history[-20:-10] if len(_hstate.omega_dev_history) > 20 else _hstate.omega_dev_history[:10]
                recent_avg = sum(recent_10) / len(recent_10)
                old_avg = sum(old_10) / len(old_10)
                rising = recent_avg > old_avg * 1.1
            else:
                rising = False

            if ratio > 2.0:
                drift_detected = True
                confidence = 0.9
            elif ratio > 1.3 and rising:
                drift_detected = True
                confidence = 0.8
            elif rising and vq_mean > 0.1:
                drift_detected = True
                confidence = 0.6
            elif vq_mean > 0.2:
                drift_detected = True
                confidence = 0.5

        if drift_detected:
            _hstate.attack_detected = True

        return {
            "attack_detected": drift_detected,
            "attack_type": 4 if drift_detected else 0,
            "confidence": confidence,
            "protective_action": 2 if drift_detected else 0,
        }

    return DEFAULT_ACTION.copy()


# =====================================================================
# LLM Agent (optional, set USE_LLM=1)
# =====================================================================

def parse_llm_response(response_text: str) -> dict:
    """Parse LLM response JSON, returning default action on failure."""
    try:
        text = response_text.strip()
        if text.startswith("```"):
            lines = text.split("\n")
            json_lines = []
            in_block = False
            for line in lines:
                if line.strip().startswith("```") and not in_block:
                    in_block = True
                    continue
                elif line.strip().startswith("```") and in_block:
                    break
                elif in_block:
                    json_lines.append(line)
            text = "\n".join(json_lines)

        parsed = json.loads(text)
        action = {
            "attack_detected": bool(parsed.get("attack_detected", False)),
            "attack_type": max(0, min(4, int(parsed.get("attack_type", 0)))),
            "confidence": max(0.0, min(1.0, float(parsed.get("confidence", 0.5)))),
            "protective_action": max(0, min(3, int(parsed.get("protective_action", 0)))),
        }
        return action
    except (json.JSONDecodeError, KeyError, TypeError, ValueError):
        return DEFAULT_ACTION.copy()


def format_observation(obs: dict) -> str:
    """Format observation dict into a concise string for the LLM."""
    parts = [
        f"Step: {obs['step']}",
        f"Task: {obs['task_id']}",
        f"vq_window (last 20): {[round(v, 6) for v in obs['vq_window']]}",
        f"vd_window (last 20): {[round(v, 6) for v in obs['vd_window']]}",
        f"omega_window (last 20): {[round(v, 6) for v in obs['omega_window']]}",
        f"omega_deviation_window (last 20): {[round(v, 6) for v in obs['omega_deviation_window']]}",
        f"raw_voltages: {[round(v, 6) for v in obs['raw_voltages']]}",
    ]
    return "\n".join(parts)


def llm_agent(obs: dict) -> dict:
    """Call the LLM to decide an action. Falls back to heuristic on error."""
    try:
        obs_text = format_observation(obs)
        completion = client.chat.completions.create(
            model=MODEL_NAME,
            messages=[
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": obs_text},
            ],
            temperature=0.1,
            max_tokens=200,
        )
        llm_response = completion.choices[0].message.content
        return parse_llm_response(llm_response)
    except Exception as e:
        print(f"    LLM error ({type(e).__name__}: {e}), falling back to heuristic")
        return heuristic_agent(obs)


# =====================================================================
# Episode Runner
# =====================================================================

def run_episode(task_id: int) -> float:
    log_start(task=TASK_NAMES[task_id], env="pll-cyberattack-detection", model=MODEL_NAME if USE_LLM else "rule-based-heuristic")

    print(f"\n{'='*60}")
    print(f"Task {task_id}: {TASK_NAMES[task_id]}")
    print(f"Agent: {'LLM (' + MODEL_NAME + ')' if USE_LLM else 'Rule-Based Heuristic'}")
    print(f"{'='*60}")

    step_count = 0
    grader_score = 0.0
    rewards = []

    try:
        # Reset environment
        reset_response = requests.post(
            f"{ENV_URL}/reset",
            json={"task_id": task_id},
            timeout=30,
        )
        reset_response.raise_for_status()
        obs = reset_response.json()

        done = False
        total_reward = 0.0
        prev_info = {}

        while not done:
            # Choose agent
            if USE_LLM:
                action = llm_agent(obs)
            else:
                if step_count == 0:
                    action = DEFAULT_ACTION.copy()
                else:
                    det_action = detector_agent(prev_info) if "detector" in prev_info else None
                    heur_action = heuristic_agent(obs)
                    
                    if not det_action:
                        action = heur_action
                    elif det_action["confidence"] < 0.5:
                        action = heur_action
                    else:
                        action = heur_action

            # Step environment
            step_response = requests.post(
                f"{ENV_URL}/step",
                json=action,
                timeout=30,
            )
            step_response.raise_for_status()
            result = step_response.json()

            obs = result["observation"]
            reward = result["reward"]
            done = result["done"]
            info = result["info"]
            total_reward += reward["total"]
            rewards.append(reward["total"])
            log_step(step=step_count, action=action, reward=reward["total"], done=done, error=None)

            prev_info = info
            step_count += 1

            # Print progress every 50 steps
            if step_count % 50 == 0:
                print(f"  Step {step_count:3d} | Reward: {reward['total']:+.4f} | "
                      f"Cumulative: {total_reward:+.4f} | "
                      f"Detected: {action['attack_detected']} | "
                      f"Type: {action['attack_type']}")

        # Extract grader score
        grader_score = info.get("grader_score", 0.0)
        print(f"\n  Episode complete: {step_count} steps")
        print(f"  Total reward: {total_reward:+.4f}")
        print(f"  Grader score: {grader_score:.4f}")
    finally:
        log_end(success=grader_score > 0.0, steps=step_count, score=grader_score, rewards=rewards)

    return grader_score


if __name__ == "__main__":
    agent_name = f"LLM ({MODEL_NAME})" if USE_LLM else "Rule-Based Heuristic"
    print("PLL Cyberattack Detection — Agentic Inference")
    print(f"Agent: {agent_name}")
    print(f"Environment: {ENV_URL}")
    if not USE_LLM:
        print("(Set USE_LLM=1 to use LLM agent instead of heuristic)")

    start_time = time.time()
    scores = []

    for task_id in range(3):
        score = run_episode(task_id)
        print(f"Task {task_id} score: {score:.4f}")
        scores.append(score)

    elapsed = time.time() - start_time

    print(f"\n{'='*60}")
    print("FINAL RESULTS")
    print(f"{'='*60}")
    for i, score in enumerate(scores):
        print(f"  Task {i} ({TASK_NAMES[i]}): {score:.4f}")
    print(f"\n  Average score: {sum(scores)/len(scores):.4f}")
    print(f"  Total time: {elapsed:.1f}s ({elapsed/60:.1f} min)")
    print(f"{'='*60}")