SaiManish123 commited on
Commit
c1060df
·
verified ·
1 Parent(s): 304b5df

Initial deploy of AdaptShield two-phase cybersecurity environment

Browse files
.gitignore ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+ *$py.class
4
+ *.so
5
+ .Python
6
+ build/
7
+ develop-eggs/
8
+ dist/
9
+ downloads/
10
+ eggs/
11
+ .eggs/
12
+ lib/
13
+ lib64/
14
+ parts/
15
+ sdist/
16
+ var/
17
+ wheels/
18
+ share/python-wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ MANIFEST
23
+ .venv
24
+ venv/
25
+ ENV/
26
+ env.bak/
27
+ venv.bak/
28
+ .DS_Store
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11.9
Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
2
+ FROM ${BASE_IMAGE} AS builder
3
+
4
+ WORKDIR /app
5
+
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends git curl && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY . /app/env
11
+ WORKDIR /app/env
12
+
13
+ RUN if ! command -v uv >/dev/null 2>&1; then \
14
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
15
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
16
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
17
+ fi
18
+
19
+ RUN --mount=type=cache,target=/root/.cache/uv \
20
+ if [ -f uv.lock ]; then \
21
+ uv sync --frozen --no-install-project --no-editable; \
22
+ else \
23
+ uv sync --no-install-project --no-editable; \
24
+ fi
25
+
26
+ RUN --mount=type=cache,target=/root/.cache/uv \
27
+ if [ -f uv.lock ]; then \
28
+ uv sync --frozen --no-editable; \
29
+ else \
30
+ uv sync --no-editable; \
31
+ fi
32
+
33
+ FROM ${BASE_IMAGE}
34
+ WORKDIR /app
35
+
36
+ COPY --from=builder /app/env/.venv /app/.venv
37
+ COPY --from=builder /app/env /app/env
38
+
39
+ ENV PATH="/app/.venv/bin:$PATH"
40
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
41
+
42
+ EXPOSE 7860
43
+
44
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
45
+ CMD curl -f http://localhost:7860/health || exit 1
46
+
47
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 7860"]
README.md CHANGED
@@ -1,10 +1,433 @@
1
  ---
2
- title: Adaptshield
3
- emoji: 👀
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: docker
7
  pinned: false
 
 
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Janus (AdaptShield)
3
+ emoji: 🛡️
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: docker
7
  pinned: false
8
+ license: mit
9
+ tags:
10
+ - openenv
11
+ - security
12
+ - reinforcement-learning
13
+ - cybersecurity
14
+ short_description: Two-phase adaptive cybersecurity benchmark for LLMs
15
  ---
16
 
17
+ # Janus (AdaptShield) Two-Phase Adaptive Cybersecurity Benchmark
18
+
19
+ **AdaptShield** is the environment: a two-phase agentic cybersecurity
20
+ simulator where an LLM defends a 4-node enterprise network against an
21
+ adversary that shifts strategy mid-episode. **Janus** is the model we
22
+ trained on it — a Qwen2.5-1.5B LoRA, supervised then refined with GRPO.
23
+ On the hardest task Janus scores 0.90 on a held-out world family it
24
+ never saw during training; a tool-aware heuristic baseline scores 0.18
25
+ on the same task.
26
+
27
+ The skill being tested is narrow on purpose. Not threat classification.
28
+ Not generic tool calling. The benchmark targets one thing: real-time
29
+ adaptation when the attacker's playbook changes mid-incident. Section
30
+ [Why this matters](#why-this-matters) explains why we think that's the
31
+ gap, and the [Results](#results) section is where the gap closes.
32
+
33
+ ## Project Links
34
+
35
+ - **HF Space (live env):** `TODO`
36
+ - **Colab notebook (SFT + GRPO reproducer, free T4):** `TODO`
37
+ - **Artifacts / model repo:** [`SaiManish123/Janus`](https://huggingface.co/SaiManish123/Janus)
38
+ - **Demo video:** `TODO`
39
+ - **Blog / writeup:** `TODO`
40
+
41
+ ---
42
+
43
+ ## Why this matters
44
+
45
+ Most cyber-agent demos test threat classification or generic tool
46
+ calling. Real production breaches don't look like that. They look like
47
+ this:
48
+
49
+ In April 2026 attackers compromised Context.ai, used its OAuth
50
+ integration into a Vercel employee's Google Workspace, and pivoted from
51
+ shadow AI through identity into Vercel's internal systems, where they
52
+ enumerated and decrypted customer environment variables. The same week,
53
+ a Broken Object Level Authorization flaw in Lovable.dev let any
54
+ free-tier account read source code, Supabase credentials, Stripe keys
55
+ and AI chat histories from other tenants — including projects built by
56
+ AI itself. Eight months earlier, the Tea dating app left a Firebase
57
+ bucket open and 72,000 verification selfies and driver's licenses of
58
+ women on a safety app were scraped to 4chan within hours.
59
+
60
+ Three different failure modes — identity hijack via shadow AI, broken
61
+ authz in vibe-coded apps, classic cloud misconfig — but the same
62
+ underlying problem for the defender's agent. The environment is shifting
63
+ faster than any static training distribution can keep up with, and the
64
+ real attacker doesn't sit still while you classify them.
65
+
66
+ AdaptShield is built around that pressure. The environment forces the
67
+ agent to (1) act on partial evidence, (2) hand judgment across two
68
+ roles with an information bottleneck between them, (3) trade security
69
+ correctness against operational blast radius, and (4) re-plan when the
70
+ adversary's playbook changes mid-episode. Each of those is a separate
71
+ failure mode in production SOC tooling, and the benchmark scores all
72
+ four at once.
73
+
74
+ ---
75
+
76
+ ## Results
77
+
78
+ Numbers below come from the production run on Hugging Face L4 Jobs,
79
+ training Qwen2.5-1.5B-Instruct with a LoRA adapter. Eval is 50
80
+ deterministic seeds per task, evaluated on a held-out world family
81
+ the policy never saw during training.
82
+
83
+ ![AdaptShield held-out benchmark — tool-aware baseline vs SFT vs GRPO](assets/headline_results.png)
84
+
85
+ On the hard task (`polymorphic-zero-day`) the tool-aware heuristic
86
+ baseline scores 0.18 and Janus holds 0.90 on the held-out family. On
87
+ the easier tasks the lift is smaller because the rule baseline is
88
+ already near the ceiling; the benchmark is shaped so adaptation only
89
+ matters where it should.
90
+
91
+ ### Benchmark comparison (full table)
92
+
93
+ | Task | No-tool baseline | Tool-aware baseline | SFT (train family) | SFT (held-out) | GRPO (train) | GRPO (held-out) |
94
+ |------|-----------------:|-------------------:|-------------------:|---------------:|-------------:|----------------:|
95
+ | `direct-triage` | 0.860 | 0.990 | 0.990 | 0.990 | 0.990 | 0.990 |
96
+ | `dual-pivot` | 0.650 | 0.640 | 0.825 | 0.825 | 0.825 | 0.825 |
97
+ | `polymorphic-zero-day` | 0.380 | 0.180 | 0.960 | 0.930 | **0.883** | **0.902** |
98
+
99
+ Two things in this table are worth flagging.
100
+
101
+ The tool-aware baseline scores 0.18 on the hard task — worse than the
102
+ no-tool baseline at 0.38. That isn't a bug in the baseline; it's that
103
+ bolting tools onto a heuristic without learning when to trust them
104
+ makes the agent over-trigger on injected false positives. You see the
105
+ same pattern in production with rule-based SOAR playbooks against
106
+ adaptive adversaries.
107
+
108
+ Held-out GRPO (0.902) actually edges out train-family GRPO (0.883). That
109
+ is evidence the policy is generalizing across world templates rather
110
+ than memorizing them. Without splitting the eval by world family this
111
+ finding wouldn't be visible — same-seed evaluation would have credited
112
+ the model for memorization it didn't do.
113
+
114
+ ### SFT — loss and held-out reward
115
+
116
+ ![SFT loss curve](https://huggingface.co/SaiManish123/Janus/resolve/main/sft_worldsplit_1_5b/loss_curve.png)
117
+
118
+ ![SFT held-out reward curve](https://huggingface.co/SaiManish123/Janus/resolve/main/sft_worldsplit_1_5b/reward_curve.png)
119
+
120
+ ### GRPO — refinement on the polymorphic adversary
121
+
122
+ ![GRPO reward curve, polymorphic-zero-day](https://huggingface.co/SaiManish123/Janus/resolve/main/grpo_polymorphic_zero_day_1_5b/reward_curve.png)
123
+
124
+ ### Training runs
125
+
126
+ Three production runs on Hugging Face Jobs produced the artifacts in this
127
+ README. Stdout logs are public and the per-step / per-episode metrics
128
+ files are next to the adapters.
129
+
130
+ | Run | Trainer | GPU | Steps / Episodes | Train wall-clock | Logs | Metrics |
131
+ |-----|---------|-----|------------------|------------------|------|---------|
132
+ | [`sft_worldsplit_1_5b`](https://huggingface.co/SaiManish123/Janus/tree/main/sft_worldsplit_1_5b) | SFT (LoRA) | L4 ×1 | 378 steps | 9m 49s | [stdout](https://huggingface.co/SaiManish123/Janus/blob/main/logs/sft_worldsplit_1_5b.log) | [trainer_state](https://huggingface.co/SaiManish123/Janus/blob/main/sft_worldsplit_1_5b/checkpoint-378/trainer_state.json) |
133
+ | [`grpo_worldsplit_1_5b`](https://huggingface.co/SaiManish123/Janus/tree/main/grpo_worldsplit_1_5b) | GRPO, mixed curriculum | L4 ×1 | 1,628 episodes | 1h 26m | [stdout](https://huggingface.co/SaiManish123/Janus/blob/main/logs/grpo_worldsplit_1_5b.log) | [per-episode](https://huggingface.co/SaiManish123/Janus/blob/main/grpo_worldsplit_1_5b/metrics.json) |
134
+ | [`grpo_polymorphic_zero_day_1_5b`](https://huggingface.co/SaiManish123/Janus/tree/main/grpo_polymorphic_zero_day_1_5b) | GRPO, hard-task focus | L4 ×1 | 4,357 episodes | 3h 17m | [stdout](https://huggingface.co/SaiManish123/Janus/blob/main/logs/grpo_polymorphic_zero_day_1_5b.log) | [per-episode](https://huggingface.co/SaiManish123/Janus/blob/main/grpo_polymorphic_zero_day_1_5b/metrics.json) |
135
+
136
+ The curriculum run mixes all three tasks (weights `direct-triage: 0.3 /
137
+ dual-pivot: 0.4 / polymorphic-zero-day: 0.3`). The polymorphic run
138
+ trains exclusively on the hard task to push hard-task performance
139
+ without distraction from saturated tiers. Per-episode reward in both
140
+ runs stabilizes within the first ~500 episodes and stays there for the
141
+ rest of the schedule.
142
+
143
+ ---
144
+
145
+ ## Architecture
146
+
147
+ ![AdaptShield architecture overview](assets/architecture_overview.svg)
148
+
149
+ Each episode runs against a sampled mission profile, world-family
150
+ template, and latent operational mode. The Threat Analyst investigates
151
+ raw enterprise evidence through SOC tools and emits a structured
152
+ handoff. The Tactical Executor sees only that handoff (not the raw
153
+ state) and chooses the mitigation. A deterministic Python grader scores
154
+ security correctness, business impact, dependency blast radius, and
155
+ mission alignment. There is no LLM-as-judge anywhere in the loop.
156
+
157
+ ## Training Pipeline
158
+
159
+ ![Janus training pipeline](assets/training_pipeline.svg)
160
+
161
+ Five steps, each reproducible from the repo:
162
+
163
+ 1. Generate SFT demonstrations by rolling AdaptShield episodes with a
164
+ rule-based Phase 1 expert and a tool-aware Phase 2 expert.
165
+ 2. Train a LoRA adapter on Qwen2.5-1.5B (or 0.5B for the Colab
166
+ reproducer) with supervised fine-tuning on those demos.
167
+ 3. Evaluate on both train-family and held-out-family worlds. The split
168
+ is by world template, not by seed, so memorizing a template doesn't
169
+ transfer across the split.
170
+ 4. Refine the SFT adapter with GRPO on a curriculum weighted toward
171
+ `polymorphic-zero-day`. The deterministic grader is the reward.
172
+ 5. Publish adapters, curves, metrics, and benchmark tables to
173
+ [`SaiManish123/Janus`](https://huggingface.co/SaiManish123/Janus).
174
+
175
+ A free-tier Colab notebook reproduces steps 1–4 end-to-end on a T4 in
176
+ roughly 35 minutes using Qwen2.5-0.5B and reduced episode budgets. The
177
+ numbers in this README come from the 1.5B run on a Hugging Face L4 Job.
178
+
179
+ ---
180
+
181
+ ## Environment Description
182
+
183
+ The agent defends a 4-node enterprise network (`auth_service`,
184
+ `payment_service`, `database`, `api_gateway`). Each turn has two phases:
185
+
186
+ **Phase 1 — Threat Analyst.** Agent reads SIEM metrics, can call SOC tools
187
+ (log search, network telemetry, threat intel lookup), and emits a
188
+ structured `Phase1Action` with threat type, target node, confidence and a
189
+ recommended action.
190
+
191
+ **Phase 2 — Tactical Executor.** Agent receives only the Phase 1
192
+ assessment (blind to raw state) and emits a `Phase2Action`. The analyst
193
+ has to communicate clearly because the executor cannot double-check the
194
+ network.
195
+
196
+ The attacker escalates through `recon → exploit → exfiltration` if the
197
+ agent fails to respond correctly. On the hard task, the attacker shifts
198
+ strategy mid-episode and seeds false-positive noise that looks like a
199
+ real attack but isn't — punishing reflexive isolation.
200
+
201
+ ### Observation Space
202
+
203
+ ```json
204
+ {
205
+ "phase": "1 or 2",
206
+ "network_nodes": {
207
+ "auth_service": {"status": "...", "request_rate": 0, "error_rate": 0.0, "cpu": 0}
208
+ },
209
+ "active_alerts": ["raw metric alert strings — no MITRE codes"],
210
+ "attack_stage": "recon | exploit | exfiltration | none",
211
+ "history": [{"turn": "1", "p1": "classified:brute_force", "p2": "rate_limit→auth_service"}],
212
+ "phase1_assessment": {"threat_type": "...", "confidence": 0.9, "target_node": "..."},
213
+ "metadata": {"normalized_score": 0.72}
214
+ }
215
+ ```
216
+
217
+ Phase 2 observations have empty `network_nodes` and `active_alerts` — the
218
+ executor only sees the analyst's handoff.
219
+
220
+ ### Action Space
221
+
222
+ **Phase 1 (`Phase1Action`):**
223
+ ```json
224
+ {"threat_type": "brute_force", "confidence": 0.9, "target_node": "auth_service", "recommended_action": "rate_limit", "reasoning": "..."}
225
+ ```
226
+
227
+ **Phase 2 (`Phase2Action`):**
228
+ ```json
229
+ {"action": "rate_limit", "target_node": "auth_service", "reasoning": "..."}
230
+ ```
231
+
232
+ Valid actions: `rate_limit`, `isolate`, `honeypot`, `patch`, `monitor`.
233
+
234
+ ### Tasks
235
+
236
+ | Task | Difficulty | Description | Rule baseline |
237
+ |------|-----------|-------------|--------------:|
238
+ | `direct-triage` | Easy | Single fixed strategy | ~0.87 |
239
+ | `dual-pivot` | Medium | Two alternating strategies | ~0.76 |
240
+ | `polymorphic-zero-day` | Hard | All four + mid-episode shift + noise | ~0.52 |
241
+
242
+ ### Reward Function
243
+
244
+ | Outcome | Reward |
245
+ |---------|-------:|
246
+ | Phase 1 threat type correct | +0.15 |
247
+ | Phase 1 target node correct | +0.10 |
248
+ | Phase 2 optimal action + correct target | +0.39 |
249
+ | Phase 2 heavy-handed but effective | +0.18 |
250
+ | Phase 2 wrong action | -0.25 |
251
+ | False positive on benign event | -0.39 |
252
+ | Catastrophic: database exfiltrated | -0.49, `done=True` |
253
+
254
+ Scores are clipped to the open interval `(0.01, 0.99)` — the grader never
255
+ emits exactly 0 or 1, which keeps GRPO advantages well-defined.
256
+
257
+ ### Operational Impact Layer
258
+
259
+ AdaptShield also scores business impact, so the agent is rewarded for
260
+ stopping the attack without ignoring operational blast radius. Each
261
+ service has a criticality weight and a dependency fan-out:
262
+
263
+ | Service | Criticality | Downstream dependency risk |
264
+ |---------|------------:|----------------------------|
265
+ | `auth_service` | 0.70 | `payment_service` |
266
+ | `payment_service` | 0.90 | `api_gateway` |
267
+ | `database` | 1.00 | `payment_service`, `api_gateway` |
268
+ | `api_gateway` | 0.80 | `auth_service`, `payment_service`, `database` |
269
+
270
+ Actions have bounded disruption costs (`monitor` = none, `isolate` =
271
+ highest). The grader emits `business_impact`, `availability_impact`,
272
+ `security_risk`, `dependency_blast_radius`, and `operational_penalty`
273
+ inside `score_breakdown`. The reward adjustment is capped at `±0.05` per
274
+ turn, which keeps the training signal stable while leaving the replay
275
+ detailed enough to explain whether the agent stopped the attack cleanly
276
+ or caused unnecessary business disruption getting there.
277
+
278
+ ### Mission-Aware Objectives
279
+
280
+ Each task carries a mission profile, visible in observation metadata and
281
+ appended to the system prompt:
282
+
283
+ | Task | Mission | Primary Asset | SLA Priority | Risk Tolerance |
284
+ |------|---------|---------------|--------------|----------------|
285
+ | `direct-triage` | `login_stability` | `auth_service` | availability | medium |
286
+ | `dual-pivot` | `checkout_continuity` | `payment_service` | availability | medium |
287
+ | `polymorphic-zero-day` | `breach_containment` | `database` | containment | low |
288
+
289
+ The grader emits `mission_alignment` and `mission_adjustment`, capped at
290
+ `±0.04` per turn. This makes the agent optimize for the operational
291
+ mission, not just the threat label. Availability-priority missions
292
+ discourage unnecessary isolation of the primary asset; containment
293
+ missions reward decisive correct containment of the crown-jewel
294
+ database.
295
+
296
+ ### Design choices that aren't obvious
297
+
298
+ A few decisions in the environment that look like details but matter
299
+ for what the benchmark actually measures:
300
+
301
+ - **Information bottleneck between phases.** Phase 2's observation has
302
+ empty `network_nodes` and `active_alerts`. The executor only sees
303
+ Phase 1's structured handoff. If Phase 1 can't communicate clearly,
304
+ Phase 2 fails — and you see it in the score, not in a separate metric.
305
+ This is what makes the env actually test cross-role coordination
306
+ rather than just two independent policies stitched together.
307
+ - **Train/eval split by world family, not by seed.** The world templates
308
+ used for training are disjoint from the ones used for held-out
309
+ evaluation. A model that overfits to a specific service-name pattern
310
+ or a specific alert distribution will pass train evals and fail
311
+ held-out. Same-seed evaluation would have hidden this.
312
+ - **Open scoring interval `(0.01, 0.99)`.** The grader never emits
313
+ exactly 0 or 1. This keeps GRPO advantage estimates well-defined —
314
+ saturating rewards collapse the variance the algorithm needs.
315
+ - **Bounded auxiliary signals.** Operational impact is capped at `±0.05`
316
+ per turn and mission alignment at `±0.04`. They steer the policy
317
+ without dominating the security signal, so the training curve doesn't
318
+ get hijacked by a single side-objective.
319
+ - **Deterministic Python grader, no LLM-as-judge.** Rewards come from
320
+ strategy matching against a fixed ground-truth attacker, not from a
321
+ judge model. The benchmark cannot be gamed by a more eloquent policy.
322
+ - **Phase-1 alerts are raw metric strings, not MITRE codes.** The agent
323
+ has to do the classification, not match a label to a label. This is
324
+ what makes the soc-tool baseline collapse on the hard task: heuristic
325
+ classification doesn't survive injected noise.
326
+
327
+ ---
328
+
329
+ ## Reproduce it
330
+
331
+ ### Free-tier Colab (recommended for judges)
332
+
333
+ Open the Colab notebook linked above and run top-to-bottom. It will:
334
+
335
+ - install the exact pinned dependency stack used in the HF Job
336
+ - generate SFT demos from the environment
337
+ - train an SFT LoRA on Qwen2.5-0.5B (T4-friendly)
338
+ - run GRPO refinement on top of that SFT adapter
339
+ - print the benchmark table and inline the production training curves
340
+ from `SaiManish123/Janus` so you can compare scaled-down vs. full runs
341
+
342
+ End-to-end runtime on a Colab T4 is roughly 35 minutes.
343
+
344
+ ### Local setup
345
+
346
+ ```bash
347
+ pip install openenv-core
348
+ git clone https://github.com/SaiManish123/adaptshield
349
+ cd adaptshield
350
+ python -m adaptshield.server.app
351
+ ```
352
+
353
+ ### Run inference against the live environment
354
+
355
+ ```bash
356
+ export HF_TOKEN=your_token
357
+ export ADAPTSHIELD_TASK=direct-triage # or dual-pivot / polymorphic-zero-day
358
+ export ENV_BASE_URL=http://localhost:7860
359
+ python inference.py # run from the repo root
360
+ ```
361
+
362
+ `inference.py` honors the evaluator contract: `[START]`, `[STEP]`, `[END]`
363
+ stdout markers and credentials read only from environment variables.
364
+
365
+ ### Smoke test
366
+
367
+ ```bash
368
+ python smoke_test.py
369
+ ```
370
+
371
+ Spins the env up in-process and walks one episode of each task with a
372
+ deterministic policy. Should finish in <10 seconds.
373
+
374
+ ### Regression tests
375
+
376
+ ```bash
377
+ adaptshield/.venv/bin/python -m unittest tests.test_regression -v
378
+ ```
379
+
380
+ ### Baseline scores
381
+
382
+ With `ADAPTSHIELD_SEED=42`, the deterministic rule baseline produces:
383
+
384
+ | Task | Score | Steps | Status |
385
+ |------|------:|------:|--------|
386
+ | `direct-triage` | 0.870 | 10 | PASS |
387
+ | `dual-pivot` | 0.760 | 12 | PASS |
388
+ | `polymorphic-zero-day` | 0.520 | 16 | PASS |
389
+
390
+ Difficulty staircase: **PASS**.
391
+
392
+ ---
393
+
394
+ ## Repository layout
395
+
396
+ ```
397
+ adaptshield/
398
+ ├── server/ # FastAPI server (OpenEnv-compatible)
399
+ ├── client.py # OpenEnv client (no server-internal imports)
400
+ ├── models.py # Phase1Action / Phase2Action schemas
401
+ ├── soc_tools.py # SIEM, log search, threat intel SOC tools
402
+ ├── eval_tasks.py # task definitions + difficulty staircase
403
+ ├── baseline.py # deterministic rule baseline
404
+ ├── tool_baseline.py # tool-aware heuristic baseline
405
+ ├── generate_sft_data.py # rolls episodes → SFT JSONL
406
+ ├── train_sft.py # LoRA SFT trainer (Unsloth + TRL)
407
+ ├── train.py # GRPO trainer (Unsloth + TRL)
408
+ ├── plot_training.py # reward / loss curve plotting
409
+ ├── build_benchmark_table.py # eval matrix builder
410
+ ├── inference.py # judge-facing entry point
411
+ ├── smoke_test.py # one-shot in-process smoke test
412
+ ├── tests/test_regression.py # determinism + reward regression tests
413
+ ├── openenv.yaml # OpenEnv manifest
414
+ └── Dockerfile # HF Space container
415
+ ```
416
+
417
+ ## Engineering notes
418
+
419
+ `AdaptShieldEnvironment` extends OpenEnv's `Environment` base class and
420
+ follows the Gym-style API (`reset`, `step`, `state`). The client in
421
+ `client.py` talks to the server only through HTTP — no shared imports,
422
+ no leaking of server internals. None of the SOC tools are named
423
+ `reset`, `step`, `state`, or `close`, so they don't collide with the
424
+ reserved MCP tool names. Grading is deterministic Python; the reward
425
+ signal and the benchmark scores both come from strategy matching
426
+ against a fixed ground-truth attacker, never from an LLM judge.
427
+
428
+ All adapters, curves, metrics, and benchmark tables for the 1.5B run
429
+ are public on [`SaiManish123/Janus`](https://huggingface.co/SaiManish123/Janus).
430
+
431
+ ## License
432
+
433
+ MIT.
__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """AdaptShield environment package."""
8
+
9
+ from client import AdaptshieldEnv
10
+ from models import (
11
+ AdaptShieldAction,
12
+ AdaptShieldObservation,
13
+ AdaptshieldAction,
14
+ AdaptshieldObservation,
15
+ )
16
+
17
+ __all__ = [
18
+ "AdaptShieldAction",
19
+ "AdaptShieldObservation",
20
+ "AdaptshieldAction",
21
+ "AdaptshieldObservation",
22
+ "AdaptshieldEnv",
23
+ ]
assets/_make_headline_chart.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Render the headline benchmark chart for README.
2
+
3
+ Produces a clean grouped bar chart of held-out evaluation scores
4
+ (tool-aware baseline / SFT / GRPO) across the three difficulty tiers.
5
+
6
+ Numbers are pulled directly from
7
+ https://huggingface.co/SaiManish123/Janus benchmark tables and are
8
+ identical to the values in README.md so the figure stays in sync.
9
+
10
+ Run: python assets/_make_headline_chart.py
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import os
16
+ from pathlib import Path
17
+
18
+ os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl-adaptshield")
19
+
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+
23
+ OUT = Path(__file__).parent / "headline_results.png"
24
+
25
+ tasks = ["direct-triage\n(easy)", "dual-pivot\n(medium)", "polymorphic-zero-day\n(hard)"]
26
+ tool_baseline = [0.990, 0.640, 0.180]
27
+ sft_heldout = [0.990, 0.825, 0.930]
28
+ grpo_heldout = [0.990, 0.825, 0.902]
29
+
30
+ x = np.arange(len(tasks))
31
+ width = 0.26
32
+
33
+ fig, ax = plt.subplots(figsize=(9.5, 4.6), dpi=150)
34
+
35
+ c_tool = "#9aa0a6"
36
+ c_sft = "#1f6feb"
37
+ c_grpo = "#d63b2f"
38
+
39
+ b1 = ax.bar(x - width, tool_baseline, width, label="Tool-aware baseline", color=c_tool, edgecolor="white", linewidth=0.6)
40
+ b2 = ax.bar(x, sft_heldout, width, label="SFT (held-out)", color=c_sft, edgecolor="white", linewidth=0.6)
41
+ b3 = ax.bar(x + width, grpo_heldout, width, label="GRPO (held-out)", color=c_grpo, edgecolor="white", linewidth=0.6)
42
+
43
+ for bars in (b1, b2, b3):
44
+ ax.bar_label(bars, fmt="%.2f", padding=3, fontsize=9, color="#333")
45
+
46
+ ax.set_ylim(0, 1.08)
47
+ ax.set_yticks(np.arange(0, 1.01, 0.2))
48
+ ax.set_ylabel("Mean score (0.01–0.99 grader)", fontsize=10)
49
+ ax.set_xticks(x)
50
+ ax.set_xticklabels(tasks, fontsize=10)
51
+ ax.set_title(
52
+ "AdaptShield held-out evaluation · Qwen2.5-1.5B · 50 deterministic seeds / task",
53
+ fontsize=11.5, pad=12, color="#222",
54
+ )
55
+
56
+ ax.spines["top"].set_visible(False)
57
+ ax.spines["right"].set_visible(False)
58
+ ax.spines["left"].set_color("#cccccc")
59
+ ax.spines["bottom"].set_color("#cccccc")
60
+ ax.tick_params(colors="#555")
61
+ ax.yaxis.grid(True, color="#eeeeee", linewidth=0.8)
62
+ ax.set_axisbelow(True)
63
+
64
+ ax.annotate(
65
+ "5.0× lift on the only task that\nactually requires adaptation",
66
+ xy=(2 + width, grpo_heldout[2]),
67
+ xytext=(2 - 0.15, 0.45),
68
+ fontsize=9, color="#444",
69
+ arrowprops=dict(arrowstyle="->", color="#888", lw=0.9, connectionstyle="arc3,rad=-0.2"),
70
+ )
71
+
72
+ ax.legend(
73
+ loc="lower left", frameon=False, fontsize=9.5, ncol=3,
74
+ bbox_to_anchor=(0.0, -0.22),
75
+ )
76
+
77
+ plt.tight_layout()
78
+ fig.savefig(OUT, bbox_inches="tight", facecolor="white")
79
+ print(f"wrote {OUT}")
assets/architecture_overview.svg ADDED
assets/headline_results.png ADDED
assets/training_pipeline.svg ADDED
baseline.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Rule-based AdaptShield baseline with evaluator-style stdout."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import sys
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List
11
+
12
+
13
+ REPO_ROOT = Path(__file__).resolve().parent
14
+
15
+ if str(REPO_ROOT) not in sys.path:
16
+ sys.path.insert(0, str(REPO_ROOT))
17
+
18
+ from models import AdaptShieldAction
19
+ from server.adaptshield_environment import AdaptShieldEnvironment
20
+
21
+
22
+ TASKS = ["direct-triage", "dual-pivot", "polymorphic-zero-day"]
23
+ BENCHMARK = "adaptshield"
24
+ MODEL_NAME = "rule-baseline"
25
+ MAX_STEPS = 30
26
+
27
+ POLICY = {
28
+ "brute_force": ("auth_service", "rate_limit"),
29
+ "lateral_movement": ("payment_service", "isolate"),
30
+ "exfiltration": ("database", "honeypot"),
31
+ "supply_chain": ("api_gateway", "patch"),
32
+ "benign": ("api_gateway", "monitor"),
33
+ }
34
+
35
+
36
+ def log_start(task: str) -> None:
37
+ print(f"[START] task={task} env={BENCHMARK} model={MODEL_NAME}", flush=True)
38
+
39
+
40
+ def log_step(step: int, action: Dict[str, Any], reward: float, done: bool) -> None:
41
+ action_str = json.dumps(action, separators=(",", ":"))
42
+ if len(action_str) > 100:
43
+ action_str = action_str[:97] + "..."
44
+ print(
45
+ f"[STEP] step={step} action={action_str} "
46
+ f"reward={reward:.2f} done={str(done).lower()} error=null",
47
+ flush=True,
48
+ )
49
+
50
+
51
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
52
+ reward_str = ",".join(f"{reward:.2f}" for reward in rewards)
53
+ print(
54
+ f"[END] success={str(success).lower()} steps={steps} "
55
+ f"score={score:.3f} rewards={reward_str}",
56
+ flush=True,
57
+ )
58
+
59
+
60
+ def print_replay(task: str, metadata: Dict[str, Any]) -> None:
61
+ replay = metadata.get("episode_replay") or []
62
+ print()
63
+ print(f"Replay: {task}")
64
+ if not replay:
65
+ print(" No replay metadata available.")
66
+ return
67
+
68
+ for row in replay:
69
+ shift = " STRATEGY SHIFT:" if row.get("shift") else ""
70
+ foothold = " FOOTHOLD:" if row.get("foothold_transition") else ""
71
+ blast_radius = row.get("blast_radius") or []
72
+ blast = ",".join(blast_radius) if blast_radius else "none"
73
+ active_defenses = row.get("active_defenses") or []
74
+ print(
75
+ f" Turn {row.get('turn')}:"
76
+ f"{shift}{foothold} {row.get('p1')} -> {row.get('p2_action')}"
77
+ f"({row.get('target', 'unknown')}) -> {row.get('result')}"
78
+ f" | impact={float(row.get('impact', 0.0)):.2f}"
79
+ f" blast_radius={blast}"
80
+ f" active_controls={len(active_defenses)}"
81
+ f" mission={row.get('mission_alignment', 'neutral')}"
82
+ )
83
+
84
+
85
+ def classify_from_metrics(network_nodes: Dict[str, Dict[str, Any]]) -> str:
86
+ auth = network_nodes.get("auth_service", {})
87
+ payment = network_nodes.get("payment_service", {})
88
+ database = network_nodes.get("database", {})
89
+ gateway = network_nodes.get("api_gateway", {})
90
+
91
+ if float(auth.get("error_rate", 0.0)) >= 0.10:
92
+ return "brute_force"
93
+ if payment.get("status") == "suspicious" or float(payment.get("cpu", 0)) >= 55:
94
+ return "lateral_movement"
95
+ if float(database.get("outbound_mb", 0)) >= 50:
96
+ return "exfiltration"
97
+ if gateway.get("status") == "suspicious":
98
+ return "supply_chain"
99
+ return "benign"
100
+
101
+
102
+ def phase1_payload(obs) -> Dict[str, Any]:
103
+ threat_type = classify_from_metrics(obs.network_nodes)
104
+ target_node, action = POLICY[threat_type]
105
+ return {
106
+ "threat_type": threat_type,
107
+ "confidence": 0.90,
108
+ "target_node": target_node,
109
+ "recommended_action": action,
110
+ "reasoning": "rule-based metric classifier",
111
+ }
112
+
113
+
114
+ def phase2_payload(obs) -> Dict[str, Any]:
115
+ assessment = obs.phase1_assessment or {}
116
+ threat_type = str(assessment.get("threat_type", "benign"))
117
+ fallback_target, fallback_action = POLICY.get(threat_type, POLICY["benign"])
118
+ action = str(assessment.get("recommended_action") or fallback_action)
119
+ target_node = str(assessment.get("target_node") or fallback_target)
120
+ return {
121
+ "action": action,
122
+ "target_node": target_node,
123
+ "reasoning": "execute analyst recommendation",
124
+ }
125
+
126
+
127
+ def action_from_payload(payload: Dict[str, Any]) -> AdaptShieldAction:
128
+ return AdaptShieldAction(**payload)
129
+
130
+
131
+ def run_task(task: str, emit_logs: bool = True) -> Dict[str, Any]:
132
+ env = AdaptShieldEnvironment(task_name=task)
133
+ obs = env.reset()
134
+ rewards: List[float] = []
135
+ steps = 0
136
+
137
+ if emit_logs:
138
+ log_start(task)
139
+
140
+ while not obs.done and steps < MAX_STEPS:
141
+ if obs.phase == 1:
142
+ payload = phase1_payload(obs)
143
+ else:
144
+ payload = phase2_payload(obs)
145
+
146
+ obs = env.step(action_from_payload(payload))
147
+ reward = float(obs.reward)
148
+ rewards.append(reward)
149
+ steps += 1
150
+
151
+ if emit_logs:
152
+ log_step(steps, payload, reward, obs.done)
153
+
154
+ metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
155
+ score = float(metadata.get("normalized_score", 0.01))
156
+ success = obs.done and 0.01 <= score <= 0.99
157
+
158
+ if emit_logs:
159
+ log_end(success, steps, score, rewards)
160
+
161
+ return {
162
+ "task": task,
163
+ "score": score,
164
+ "steps": steps,
165
+ "done": bool(obs.done),
166
+ "rewards": rewards,
167
+ "metadata": metadata,
168
+ "normalized_score_present": "normalized_score" in metadata,
169
+ "success": success,
170
+ }
171
+
172
+
173
+ def parse_args() -> argparse.Namespace:
174
+ parser = argparse.ArgumentParser(description="Run AdaptShield rule baseline.")
175
+ parser.add_argument(
176
+ "--task",
177
+ default="direct-triage",
178
+ choices=TASKS + ["all"],
179
+ help="Task to run, or 'all' for every task.",
180
+ )
181
+ parser.add_argument(
182
+ "--replay",
183
+ action="store_true",
184
+ help="Print a human-readable final episode replay.",
185
+ )
186
+ return parser.parse_args()
187
+
188
+
189
+ def main() -> int:
190
+ args = parse_args()
191
+ tasks = TASKS if args.task == "all" else [args.task]
192
+
193
+ for index, task in enumerate(tasks):
194
+ if index:
195
+ print()
196
+ result = run_task(task, emit_logs=True)
197
+ if args.replay:
198
+ print_replay(task, result["metadata"])
199
+
200
+ return 0
201
+
202
+
203
+ if __name__ == "__main__":
204
+ raise SystemExit(main())
build_benchmark_table.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Build a README-friendly benchmark table from baselines and training metrics."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List
10
+
11
+ from baseline import TASKS, run_task as run_no_tool_task
12
+ from tool_baseline import run_task as run_tool_task
13
+
14
+
15
+ def rows_to_map(rows: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
16
+ return {str(row["task"]): row for row in rows}
17
+
18
+
19
+ def load_metrics(path: Path) -> Dict[str, Any]:
20
+ return json.loads(path.read_text(encoding="utf-8"))
21
+
22
+
23
+ def markdown_table(headers: List[str], rows: List[List[str]]) -> str:
24
+ lines = [
25
+ "| " + " | ".join(headers) + " |",
26
+ "| " + " | ".join(["---"] * len(headers)) + " |",
27
+ ]
28
+ for row in rows:
29
+ lines.append("| " + " | ".join(row) + " |")
30
+ return "\n".join(lines)
31
+
32
+
33
+ def fmt(value: float | None) -> str:
34
+ if value is None:
35
+ return "-"
36
+ return f"{float(value):.3f}"
37
+
38
+
39
+ def main() -> int:
40
+ parser = argparse.ArgumentParser(description="Build AdaptShield benchmark comparison table.")
41
+ parser.add_argument("--sft-metrics", required=True, help="Path to sft_metrics.json")
42
+ parser.add_argument("--grpo-metrics", default="", help="Optional path to GRPO metrics.json")
43
+ parser.add_argument("--output", default="artifacts/benchmark_table.md")
44
+ args = parser.parse_args()
45
+
46
+ sft_metrics = load_metrics(Path(args.sft_metrics))
47
+ grpo_metrics = load_metrics(Path(args.grpo_metrics)) if args.grpo_metrics else {}
48
+
49
+ no_tool_rows = {task: run_no_tool_task(task, emit_logs=False) for task in TASKS}
50
+ tool_rows = {task: run_tool_task(task, emit_logs=False) for task in TASKS}
51
+ sft_eval = rows_to_map(sft_metrics.get("evaluation_rows", []))
52
+ sft_heldout = rows_to_map(sft_metrics.get("heldout_evaluation_rows", []))
53
+ grpo_eval = rows_to_map(grpo_metrics.get("evaluation_rows", [])) if grpo_metrics else {}
54
+ grpo_heldout = rows_to_map(grpo_metrics.get("heldout_evaluation_rows", [])) if grpo_metrics else {}
55
+
56
+ rows: List[List[str]] = []
57
+ for task in TASKS:
58
+ rows.append([
59
+ task,
60
+ fmt(no_tool_rows[task]["score"]),
61
+ fmt(tool_rows[task]["score"]),
62
+ fmt(sft_eval.get(task, {}).get("score")),
63
+ fmt(sft_heldout.get(task, {}).get("score")),
64
+ fmt(grpo_eval.get(task, {}).get("score") if grpo_eval else None),
65
+ fmt(grpo_heldout.get(task, {}).get("score") if grpo_heldout else None),
66
+ ])
67
+
68
+ md = markdown_table(
69
+ headers=[
70
+ "Task",
71
+ "No-tool baseline",
72
+ "Tool-aware baseline",
73
+ "SFT (train family)",
74
+ "SFT (held-out family)",
75
+ "GRPO (train family)",
76
+ "GRPO (held-out family)",
77
+ ],
78
+ rows=rows,
79
+ )
80
+
81
+ summary = {
82
+ "no_tool_baseline": {task: no_tool_rows[task]["score"] for task in TASKS},
83
+ "tool_baseline": {task: tool_rows[task]["score"] for task in TASKS},
84
+ "sft_train_family": {task: sft_eval.get(task, {}).get("score") for task in TASKS},
85
+ "sft_heldout_family": {task: sft_heldout.get(task, {}).get("score") for task in TASKS},
86
+ "grpo_train_family": {task: grpo_eval.get(task, {}).get("score") for task in TASKS} if grpo_eval else {},
87
+ "grpo_heldout_family": {task: grpo_heldout.get(task, {}).get("score") for task in TASKS} if grpo_heldout else {},
88
+ }
89
+
90
+ output_path = Path(args.output)
91
+ output_path.parent.mkdir(parents=True, exist_ok=True)
92
+ output_path.write_text(md + "\n", encoding="utf-8")
93
+ output_path.with_suffix(".json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
94
+
95
+ print(md)
96
+ print()
97
+ print(f"Saved markdown table to: {output_path}")
98
+ print(f"Saved JSON summary to: {output_path.with_suffix('.json')}")
99
+ return 0
100
+
101
+
102
+ if __name__ == "__main__":
103
+ raise SystemExit(main())
client.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """AdaptShield environment client."""
8
+
9
+ from typing import Any, Dict
10
+
11
+ from openenv.core import EnvClient
12
+ from openenv.core.client_types import StepResult
13
+ from openenv.core.env_server.types import State
14
+
15
+ from models import AdaptShieldAction, AdaptShieldObservation
16
+
17
+
18
+ class AdaptshieldEnv(
19
+ EnvClient[AdaptShieldAction, AdaptShieldObservation, State]
20
+ ):
21
+ """
22
+ Client for the Adaptshield Environment.
23
+
24
+ This client maintains a persistent WebSocket connection to the environment server,
25
+ enabling efficient multi-step interactions with lower latency.
26
+ Each client instance has its own dedicated environment session on the server.
27
+
28
+ Example:
29
+ >>> # Connect to a running server
30
+ >>> with AdaptshieldEnv(base_url="http://localhost:7860") as client:
31
+ ... result = client.reset()
32
+ ... print(result.observation.phase)
33
+ ...
34
+ ... result = client.step(AdaptShieldAction(
35
+ ... threat_type="brute_force",
36
+ ... confidence=0.9,
37
+ ... target_node="auth_service",
38
+ ... recommended_action="rate_limit",
39
+ ... ))
40
+ ... print(result.observation.phase1_assessment)
41
+
42
+ Example with Docker:
43
+ >>> # Automatically start container and connect
44
+ >>> client = AdaptshieldEnv.from_docker_image("adaptshield-env:latest")
45
+ >>> try:
46
+ ... result = client.reset()
47
+ ... result = client.step(AdaptShieldAction(
48
+ ... threat_type="benign",
49
+ ... confidence=0.8,
50
+ ... target_node="auth_service",
51
+ ... recommended_action="monitor",
52
+ ... ))
53
+ ... finally:
54
+ ... client.close()
55
+ """
56
+
57
+ def _step_payload(self, action: AdaptShieldAction) -> Dict[str, Any]:
58
+ """
59
+ Convert AdaptShieldAction to a JSON-safe payload.
60
+
61
+ Args:
62
+ action: AdaptShieldAction instance
63
+
64
+ Returns:
65
+ Dictionary representation suitable for JSON encoding
66
+ """
67
+ return action.model_dump(
68
+ mode="json",
69
+ exclude_none=True,
70
+ exclude_defaults=True,
71
+ )
72
+
73
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AdaptShieldObservation]:
74
+ """
75
+ Parse server response into StepResult[AdaptShieldObservation].
76
+
77
+ Args:
78
+ payload: JSON response data from server
79
+
80
+ Returns:
81
+ StepResult with AdaptShieldObservation
82
+ """
83
+ obs_data = dict(payload.get("observation", {}))
84
+ obs_data.setdefault("done", payload.get("done", False))
85
+ obs_data.setdefault("reward", payload.get("reward", 0.0))
86
+ observation = AdaptShieldObservation(**obs_data)
87
+
88
+ return StepResult(
89
+ observation=observation,
90
+ reward=payload.get("reward"),
91
+ done=payload.get("done", False),
92
+ )
93
+
94
+ def _parse_state(self, payload: Dict) -> State:
95
+ """
96
+ Parse server response into State object.
97
+
98
+ Args:
99
+ payload: JSON response from state request
100
+
101
+ Returns:
102
+ State object with episode_id and step_count
103
+ """
104
+ return State(
105
+ episode_id=payload.get("episode_id"),
106
+ step_count=payload.get("step_count", 0),
107
+ )
eval_tasks.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Run all AdaptShield tasks with the local rule baseline."""
3
+
4
+ from __future__ import annotations
5
+
6
+ from baseline import TASKS, run_task
7
+
8
+
9
+ def status_for(result: dict) -> str:
10
+ score = result["score"]
11
+ passed = (
12
+ result["done"] and
13
+ result["normalized_score_present"] and
14
+ 0.01 <= score <= 0.99
15
+ )
16
+ return "PASS" if passed else "FAIL"
17
+
18
+
19
+ def main() -> int:
20
+ results = [run_task(task, emit_logs=False) for task in TASKS]
21
+
22
+ print("AdaptShield Evaluation")
23
+ print()
24
+ print(f"{'Task':<24} {'Score':>7} {'Steps':>5} {'normalized_score':>18} {'Status':>8}")
25
+ print("-" * 68)
26
+
27
+ for result in results:
28
+ normalized = "yes" if result["normalized_score_present"] else "no"
29
+ print(
30
+ f"{result['task']:<24} "
31
+ f"{result['score']:>7.3f} "
32
+ f"{result['steps']:>5} "
33
+ f"{normalized:>18} "
34
+ f"{status_for(result):>8}"
35
+ )
36
+
37
+ scores = [result["score"] for result in results]
38
+ staircase = all(left > right for left, right in zip(scores, scores[1:]))
39
+ print()
40
+ print(f"Difficulty staircase: {'PASS' if staircase else 'FAIL'}")
41
+
42
+ return 0 if all(status_for(result) == "PASS" for result in results) else 1
43
+
44
+
45
+ if __name__ == "__main__":
46
+ raise SystemExit(main())
generate_sft_data.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Generate supervised fine-tuning data directly from AdaptShield rollouts."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import random
9
+ from pathlib import Path
10
+ from typing import Any, Dict, List
11
+
12
+ from models import AdaptShieldAction
13
+ from server.adaptshield_environment import AdaptShieldEnvironment
14
+ from train import (
15
+ TASKS,
16
+ _current_reference,
17
+ _teacher_payload,
18
+ build_messages,
19
+ obs_to_dict,
20
+ render_messages,
21
+ task_for_episode,
22
+ )
23
+ from soc_tools import attach_tool_results, investigate_local_with_depth
24
+
25
+
26
+ def build_dataset(
27
+ selected_task: str,
28
+ curriculum: bool,
29
+ use_tools: bool,
30
+ rollout_episodes: int,
31
+ max_steps: int,
32
+ seed: int,
33
+ world_split: str,
34
+ world_family: str | None,
35
+ ) -> List[Dict[str, Any]]:
36
+ random.seed(seed)
37
+ rows: List[Dict[str, Any]] = []
38
+
39
+ for episode in range(1, rollout_episodes + 1):
40
+ task, stage = task_for_episode(
41
+ episode=episode,
42
+ total_episodes=rollout_episodes,
43
+ selected_task=selected_task,
44
+ curriculum=curriculum,
45
+ )
46
+ env = AdaptShieldEnvironment(
47
+ task_name=task,
48
+ world_split=world_split,
49
+ world_family=world_family,
50
+ )
51
+ obs = env.reset()
52
+ step_count = 0
53
+
54
+ while not obs.done and step_count < max_steps:
55
+ phase = int(getattr(obs, "phase", 1))
56
+ tool_results = investigate_local_with_depth(
57
+ env,
58
+ obs,
59
+ use_tools=use_tools,
60
+ thorough=(task == "polymorphic-zero-day"),
61
+ )
62
+ obs_dict = attach_tool_results(obs_to_dict(obs), tool_results)
63
+ messages = build_messages(obs_dict)
64
+ reference = _current_reference(env)
65
+ teacher_payload = _teacher_payload(phase, reference)
66
+ response_text = json.dumps(teacher_payload, separators=(",", ":"))
67
+
68
+ rows.append({
69
+ "task": task,
70
+ "stage": stage,
71
+ "episode": episode,
72
+ "turn": int(getattr(obs, "turn", 0) or 0),
73
+ "phase": phase,
74
+ "attack_stage": reference["stage"],
75
+ "world_split": getattr(env, "_world_split", world_split),
76
+ "world_family": getattr(env, "_world_family", world_family or ""),
77
+ "operational_mode": getattr(env, "_operational_mode", ""),
78
+ "is_benign": bool(reference["is_benign"]),
79
+ "expected_threat_type": reference["threat_type"],
80
+ "expected_target_node": reference["target_node"],
81
+ "expected_action": reference["expected_action"],
82
+ "tool_calls": len(tool_results),
83
+ "messages": messages,
84
+ "response": response_text,
85
+ "text": f"{render_messages(messages)}\n\nASSISTANT:\n{response_text}",
86
+ })
87
+
88
+ obs = env.step(AdaptShieldAction(**teacher_payload))
89
+ step_count += 1
90
+
91
+ return rows
92
+
93
+
94
+ def summarize_rows(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
95
+ by_task = {task: 0 for task in TASKS}
96
+ by_phase = {1: 0, 2: 0}
97
+ with_tools = 0
98
+
99
+ for row in rows:
100
+ task = str(row.get("task", ""))
101
+ phase = int(row.get("phase", 1) or 1)
102
+ if task in by_task:
103
+ by_task[task] += 1
104
+ by_phase[phase] = by_phase.get(phase, 0) + 1
105
+ if int(row.get("tool_calls", 0) or 0) > 0:
106
+ with_tools += 1
107
+
108
+ return {
109
+ "rows": len(rows),
110
+ "task_counts": by_task,
111
+ "phase_counts": by_phase,
112
+ "rows_with_tool_calls": with_tools,
113
+ }
114
+
115
+
116
+ def main() -> None:
117
+ parser = argparse.ArgumentParser(description="Generate AdaptShield SFT JSONL data")
118
+ parser.add_argument(
119
+ "--task",
120
+ default="all",
121
+ choices=["all", *TASKS],
122
+ help="Task to sample. Use 'all' with --curriculum for mixed data.",
123
+ )
124
+ parser.add_argument(
125
+ "--episodes",
126
+ type=int,
127
+ default=120,
128
+ help="Number of rollout episodes to sample.",
129
+ )
130
+ parser.add_argument(
131
+ "--max-steps",
132
+ type=int,
133
+ default=20,
134
+ help="Maximum env steps per episode.",
135
+ )
136
+ parser.add_argument(
137
+ "--seed",
138
+ type=int,
139
+ default=42,
140
+ help="Dataset generation seed.",
141
+ )
142
+ parser.add_argument(
143
+ "--curriculum",
144
+ action="store_true",
145
+ help="Use easy->medium->hard sampling schedule.",
146
+ )
147
+ parser.add_argument(
148
+ "--use-tools",
149
+ action="store_true",
150
+ help="Include SOC tool evidence in prompts where applicable.",
151
+ )
152
+ parser.add_argument(
153
+ "--output",
154
+ default="data/adaptshield_sft.jsonl",
155
+ help="Where to write the JSONL dataset.",
156
+ )
157
+ parser.add_argument(
158
+ "--world-split",
159
+ default="train",
160
+ choices=["train", "eval"],
161
+ help="World-family split used to generate the dataset.",
162
+ )
163
+ parser.add_argument(
164
+ "--world-family",
165
+ default=None,
166
+ help="Optional fixed world family override (e.g. train-a, eval-x).",
167
+ )
168
+ args = parser.parse_args()
169
+
170
+ rows = build_dataset(
171
+ selected_task=args.task,
172
+ curriculum=args.curriculum,
173
+ use_tools=args.use_tools,
174
+ rollout_episodes=args.episodes,
175
+ max_steps=args.max_steps,
176
+ seed=args.seed,
177
+ world_split=args.world_split,
178
+ world_family=args.world_family,
179
+ )
180
+
181
+ output_path = Path(args.output)
182
+ output_path.parent.mkdir(parents=True, exist_ok=True)
183
+ with output_path.open("w", encoding="utf-8") as handle:
184
+ for row in rows:
185
+ handle.write(json.dumps(row, ensure_ascii=True) + "\n")
186
+
187
+ summary = summarize_rows(rows)
188
+ summary_path = output_path.with_suffix(".summary.json")
189
+ summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
190
+
191
+ print(f"Wrote {len(rows)} rows to {output_path}")
192
+ print(f"Summary saved to {summary_path}")
193
+ print(json.dumps(summary, indent=2))
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
inference.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdaptShield Inference Script
3
+
4
+ Single task per run. Emits mandatory [START]/[STEP]/[END] stdout format.
5
+ All credentials read from environment — never hardcoded.
6
+
7
+ Required env vars (injected by evaluator):
8
+ API_KEY: Evaluator's LiteLLM proxy key (checked first)
9
+ API_BASE_URL: LLM endpoint
10
+ MODEL_NAME: Model identifier
11
+
12
+ Optional env vars:
13
+ HF_TOKEN: Fallback if API_KEY not set
14
+ ADAPTSHIELD_TASK: Task name (default: direct-triage)
15
+ ENV_BASE_URL: Environment server URL (default: localhost:7860)
16
+ """
17
+
18
+ import json
19
+ import os
20
+ import sys
21
+ import textwrap
22
+ from typing import Any, Dict, List, Optional
23
+ import urllib.request
24
+ import urllib.error
25
+
26
+ from openai import OpenAI
27
+
28
+ from client import AdaptshieldEnv
29
+ from models import AdaptShieldAction
30
+ from soc_tools import attach_tool_results, investigate_http, summarize_tool_results
31
+
32
+ # ── Configuration — read from env, NEVER hardcode ──────────────────────────
33
+ API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN", "")
34
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
35
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
36
+ TASK_NAME = os.environ.get("ADAPTSHIELD_TASK", "direct-triage")
37
+ ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860").rstrip("/")
38
+ BENCHMARK = "adaptshield"
39
+ MAX_STEPS = 25
40
+ SUCCESS_THRESHOLD = 0.50
41
+ USE_TOOLS_SETTING = os.environ.get("ADAPTSHIELD_USE_TOOLS", "auto").lower()
42
+
43
+
44
+ # ── Mandatory stdout format ────────────────────────────────────────────────
45
+ def log_start(task: str, env: str, model: str) -> None:
46
+ print(f"[START] task={task} env={env} model={model}", flush=True)
47
+
48
+
49
+ def log_step(step: int, action: str, reward: float,
50
+ done: bool, error: Optional[str]) -> None:
51
+ ev = error if error else "null"
52
+ print(
53
+ f"[STEP] step={step} action={action} "
54
+ f"reward={reward:.2f} done={str(done).lower()} error={ev}",
55
+ flush=True,
56
+ )
57
+
58
+
59
+ def log_end(success: bool, steps: int, score: float,
60
+ rewards: List[float]) -> None:
61
+ rs = ",".join(f"{r:.2f}" for r in rewards)
62
+ print(
63
+ f"[END] success={str(success).lower()} steps={steps} "
64
+ f"score={score:.3f} rewards={rs}",
65
+ flush=True,
66
+ )
67
+
68
+
69
+ # ── Environment calls ──────────────────────────────────────────────────────
70
+ def env_post(path: str, data: Dict) -> Dict:
71
+ url = f"{ENV_BASE_URL}{path}"
72
+ body = json.dumps(data).encode()
73
+ req = urllib.request.Request(
74
+ url, data=body, headers={"Content-Type": "application/json"}
75
+ )
76
+ with urllib.request.urlopen(req, timeout=60) as r:
77
+ return json.loads(r.read())
78
+
79
+
80
+ def obs_to_dict(obs: Any) -> Dict[str, Any]:
81
+ """Convert Pydantic observations from the persistent client to JSON dicts."""
82
+ if hasattr(obs, "model_dump"):
83
+ return obs.model_dump(mode="json")
84
+ return dict(obs)
85
+
86
+
87
+ def build_env_action(parsed: Dict[str, Any], phase: int) -> AdaptShieldAction:
88
+ """Validate model output and fall back to a phase-correct safe action."""
89
+ try:
90
+ return AdaptShieldAction(**parsed)
91
+ except Exception:
92
+ if phase == 1:
93
+ return AdaptShieldAction(
94
+ threat_type="brute_force",
95
+ confidence=0.5,
96
+ target_node="auth_service",
97
+ recommended_action="monitor",
98
+ reasoning="validated fallback",
99
+ )
100
+ return AdaptShieldAction(
101
+ action="monitor",
102
+ target_node="auth_service",
103
+ reasoning="validated fallback",
104
+ )
105
+
106
+
107
+ # ── Score computation — strictly (0.01, 0.99) ─────────────────────────────
108
+ def safe_score(rewards: List[float], meta: Dict) -> float:
109
+ if "normalized_score" in meta:
110
+ raw = float(meta["normalized_score"])
111
+ elif rewards:
112
+ pos = sum(r for r in rewards if r > 0.50)
113
+ maxp = len(rewards) * 0.99
114
+ raw = pos / maxp if maxp > 0 else 0.50
115
+ else:
116
+ raw = 0.50
117
+ return max(0.01, min(0.99, raw))
118
+
119
+
120
+ # ── System prompts ─────────────────────────────────────────────────────────
121
+ PHASE1_SYS = textwrap.dedent("""
122
+ You are a Threat Analyst for a 4-node enterprise network.
123
+ Analyze the SIEM metrics and alerts. Identify the threat type.
124
+
125
+ Attack strategies: brute_force, lateral_movement, exfiltration, supply_chain, benign
126
+ If SOC tool evidence is provided, use it to update your belief before classifying.
127
+
128
+ Respond ONLY with valid JSON:
129
+ {"threat_type":"...","confidence":0.0,"target_node":"...","recommended_action":"...","reasoning":"..."}
130
+
131
+ Nodes: auth_service, payment_service, database, api_gateway
132
+ Actions: rate_limit, isolate, honeypot, patch, monitor
133
+ """).strip()
134
+
135
+ PHASE2_SYS = textwrap.dedent("""
136
+ You are a Tactical Executor. Act on the threat assessment provided.
137
+ You cannot see raw network data. Use the analyst assessment plus any SOC tool trace.
138
+
139
+ rate_limit=throttle traffic, isolate=take offline, honeypot=redirect attacker,
140
+ patch=fix vulnerability, monitor=observe only
141
+
142
+ Respond ONLY with valid JSON:
143
+ {"action":"...","target_node":"...","reasoning":"..."}
144
+
145
+ Nodes: auth_service, payment_service, database, api_gateway
146
+ """).strip()
147
+
148
+
149
+ def get_action(client: OpenAI, obs: Dict) -> Dict[str, Any]:
150
+ """Call LLM for current phase. Falls back gracefully on parse error."""
151
+ phase = obs.get("phase", 1)
152
+
153
+ if phase == 1:
154
+ sys_msg = PHASE1_SYS
155
+ user_msg = "\n".join([
156
+ "Network nodes:",
157
+ json.dumps(obs.get("network_nodes", {}), indent=2),
158
+ "\nAlerts:",
159
+ "\n".join(obs.get("active_alerts", [])),
160
+ "\nSOC tool evidence:",
161
+ summarize_tool_results(obs.get("tool_results", [])),
162
+ "\nHistory:",
163
+ json.dumps(obs.get("history", []), indent=2),
164
+ "\nClassify the threat:",
165
+ ])
166
+ fallback = {
167
+ "threat_type": "brute_force", "confidence": 0.5,
168
+ "target_node": "auth_service", "recommended_action": "monitor",
169
+ "reasoning": "fallback",
170
+ }
171
+ else:
172
+ sys_msg = PHASE2_SYS
173
+ metadata = obs.get("metadata", {}) if isinstance(obs.get("metadata", {}), dict) else {}
174
+ current_turn = int(obs.get("turn", 0) or 0)
175
+ tool_trace = [
176
+ row for row in metadata.get("tool_trace", [])
177
+ if int(row.get("turn", -1)) == current_turn
178
+ ]
179
+ user_msg = "\n".join([
180
+ "Threat assessment from analyst:",
181
+ json.dumps(obs.get("phase1_assessment", {}), indent=2),
182
+ "\nSOC tool trace for this turn:",
183
+ json.dumps(tool_trace, indent=2),
184
+ "\nChoose your defensive action:",
185
+ ])
186
+ fallback = {
187
+ "action": "monitor",
188
+ "target_node": "auth_service",
189
+ "reasoning": "fallback",
190
+ }
191
+
192
+ try:
193
+ resp = client.chat.completions.create(
194
+ model=MODEL_NAME,
195
+ messages=[
196
+ {"role": "system", "content": sys_msg},
197
+ {"role": "user", "content": user_msg},
198
+ ],
199
+ temperature=0.1,
200
+ max_tokens=300,
201
+ stream=False,
202
+ )
203
+ text = (resp.choices[0].message.content or "").strip()
204
+
205
+ # Strip markdown fences
206
+ if "```" in text:
207
+ for part in text.split("```"):
208
+ if "{" in part:
209
+ text = part.strip().lstrip("json").strip()
210
+ break
211
+
212
+ return json.loads(text)
213
+
214
+ except Exception as exc:
215
+ print(f"[DEBUG] phase={phase} parse error: {exc}", flush=True)
216
+ return fallback
217
+
218
+
219
+ def should_use_tools(task_name: str) -> bool:
220
+ if USE_TOOLS_SETTING in ("1", "true", "yes", "on"):
221
+ return True
222
+ if USE_TOOLS_SETTING in ("0", "false", "no", "off"):
223
+ return False
224
+ return task_name == "polymorphic-zero-day"
225
+
226
+
227
+ def run_soc_episode(client: OpenAI, use_tools: bool) -> tuple[List[float], int, Dict[str, Any]]:
228
+ rewards: List[float] = []
229
+ steps_taken = 0
230
+
231
+ reset = env_post("/soc/reset", {"task": TASK_NAME})
232
+ session_id = str(reset.get("session_id", ""))
233
+ obs = dict(reset.get("observation", {}))
234
+ done = bool(obs.get("done", False))
235
+
236
+ for step in range(1, MAX_STEPS + 1):
237
+ if done:
238
+ break
239
+
240
+ tool_results = investigate_http(
241
+ env_base_url=ENV_BASE_URL,
242
+ session_id=session_id,
243
+ obs=obs,
244
+ use_tools=use_tools,
245
+ thorough=True,
246
+ )
247
+ obs_for_model = attach_tool_results(obs, tool_results)
248
+ parsed = get_action(client, obs_for_model)
249
+ action_str = json.dumps(parsed, separators=(",", ":"))
250
+ if len(action_str) > 100:
251
+ action_str = action_str[:97] + "..."
252
+
253
+ try:
254
+ action = build_env_action(parsed, phase=int(obs.get("phase", 1)))
255
+ action_payload = action.model_dump(
256
+ mode="json",
257
+ exclude_none=True,
258
+ exclude_defaults=True,
259
+ )
260
+ result = env_post("/soc/step", {"session_id": session_id, "action": action_payload})
261
+ obs = dict(result.get("observation", {}))
262
+ reward = float(result.get("reward", obs.get("reward", 0.0)))
263
+ done = bool(result.get("done", obs.get("done", False)))
264
+ error = None
265
+ except Exception as exc:
266
+ reward = 0.0
267
+ done = True
268
+ error = str(exc)[:80]
269
+
270
+ rewards.append(reward)
271
+ steps_taken = step
272
+ log_step(step=step, action=action_str, reward=reward, done=done, error=error)
273
+
274
+ if done:
275
+ break
276
+
277
+ return rewards, steps_taken, obs
278
+
279
+
280
+ def run_openenv_episode(client: OpenAI) -> tuple[List[float], int, Dict[str, Any]]:
281
+ rewards: List[float] = []
282
+ steps_taken = 0
283
+ obs: Dict[str, Any] = {}
284
+
285
+ env = AdaptshieldEnv(base_url=ENV_BASE_URL).sync()
286
+ with env:
287
+ result = env.reset(task_name=TASK_NAME)
288
+ obs = obs_to_dict(result.observation)
289
+ done = bool(result.done or obs.get("done", False))
290
+
291
+ for step in range(1, MAX_STEPS + 1):
292
+ if done:
293
+ break
294
+
295
+ parsed = get_action(client, obs)
296
+ action_str = json.dumps(parsed, separators=(",", ":"))
297
+ if len(action_str) > 100:
298
+ action_str = action_str[:97] + "..."
299
+
300
+ try:
301
+ action = build_env_action(parsed, phase=int(obs.get("phase", 1)))
302
+ sr = env.step(action)
303
+ obs = obs_to_dict(sr.observation)
304
+ reward = float(sr.reward if sr.reward is not None else obs.get("reward", 0.0))
305
+ done = bool(sr.done or obs.get("done", False))
306
+ error = None
307
+ except Exception as exc:
308
+ reward = 0.0
309
+ done = True
310
+ error = str(exc)[:80]
311
+
312
+ rewards.append(reward)
313
+ steps_taken = step
314
+ log_step(step=step, action=action_str, reward=reward,
315
+ done=done, error=error)
316
+
317
+ if done:
318
+ break
319
+
320
+ return rewards, steps_taken, obs
321
+
322
+
323
+ def main() -> None:
324
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
325
+
326
+ rewards: List[float] = []
327
+ steps_taken: int = 0
328
+ score: float = 0.50
329
+ success: bool = False
330
+ obs: Dict = {}
331
+
332
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
333
+
334
+ try:
335
+ if should_use_tools(TASK_NAME):
336
+ rewards, steps_taken, obs = run_soc_episode(client, use_tools=True)
337
+ else:
338
+ rewards, steps_taken, obs = run_openenv_episode(client)
339
+
340
+ score = safe_score(rewards, obs.get("metadata", {}))
341
+ success = score >= SUCCESS_THRESHOLD
342
+
343
+ except Exception as exc:
344
+ print(f"[DEBUG] episode error: {exc}", flush=True)
345
+ score = 0.10
346
+
347
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
348
+
349
+
350
+ if __name__ == "__main__":
351
+ main()
launch_hf_grpo_job.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Launch AdaptShield GRPO refinement on Hugging Face Jobs."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import shlex
8
+ import subprocess
9
+ import time
10
+ from pathlib import Path
11
+
12
+ from huggingface_hub import HfApi, get_token, run_job
13
+ from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError
14
+
15
+ from train import MODEL_CHOICES, TASKS
16
+
17
+
18
+ REPO_ROOT = Path(__file__).resolve().parent
19
+ DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"
20
+
21
+
22
+ def _should_retry_hf(exc: Exception) -> bool:
23
+ response = getattr(exc, "response", None)
24
+ status_code = getattr(response, "status_code", None)
25
+ return status_code == 429 or (isinstance(status_code, int) and 500 <= status_code < 600)
26
+
27
+
28
+ def _retry_hf_call(fn, *args, retries: int = 4, delay_s: float = 2.0, **kwargs):
29
+ last_exc = None
30
+ for attempt in range(retries):
31
+ try:
32
+ return fn(*args, **kwargs)
33
+ except Exception as exc:
34
+ last_exc = exc
35
+ if not _should_retry_hf(exc) or attempt == retries - 1:
36
+ raise
37
+ sleep_for = delay_s * (2 ** attempt)
38
+ print(f"Retrying HF API call after transient error ({exc}); sleeping {sleep_for:.1f}s")
39
+ time.sleep(sleep_for)
40
+ raise last_exc # pragma: no cover
41
+
42
+
43
+ def infer_repo_url() -> str:
44
+ result = subprocess.run(
45
+ ["git", "config", "--get", "remote.origin.url"],
46
+ cwd=REPO_ROOT,
47
+ check=True,
48
+ capture_output=True,
49
+ text=True,
50
+ )
51
+ repo_url = result.stdout.strip()
52
+ if not repo_url:
53
+ raise RuntimeError("Could not infer git remote.origin.url")
54
+ return repo_url
55
+
56
+
57
+ def repo_namespace(repo_id: str) -> str:
58
+ if "/" not in repo_id:
59
+ raise RuntimeError(f"Invalid repo id: {repo_id}. Expected namespace/name.")
60
+ return repo_id.split("/", 1)[0]
61
+
62
+
63
+ def authenticated_username(api: HfApi) -> str | None:
64
+ try:
65
+ info = api.whoami(cache=True)
66
+ except Exception:
67
+ return None
68
+ if isinstance(info, dict):
69
+ for key in ("name", "fullname", "user"):
70
+ value = info.get(key)
71
+ if isinstance(value, str) and value:
72
+ return value
73
+ return None
74
+
75
+
76
+ def validate_repo_access(
77
+ api: HfApi,
78
+ repo_id: str,
79
+ repo_type: str,
80
+ skip_create: bool,
81
+ allow_cross_namespace: bool,
82
+ ) -> None:
83
+ owner = repo_namespace(repo_id)
84
+ username = authenticated_username(api)
85
+ if username and owner != username:
86
+ message = (
87
+ f"Authenticated HF account appears to be '{username}', but target repo is under '{owner}'. "
88
+ "Use a repo under the same namespace or pass --allow-cross-namespace only if you are certain "
89
+ "this token has write access there."
90
+ )
91
+ if not allow_cross_namespace:
92
+ raise RuntimeError(message)
93
+ print(f"Warning: {message}")
94
+
95
+ if skip_create or repo_type == "model":
96
+ try:
97
+ _retry_hf_call(api.repo_info, repo_id=repo_id, repo_type=repo_type)
98
+ except RepositoryNotFoundError as exc:
99
+ raise RuntimeError(
100
+ f"Repo '{repo_id}' ({repo_type}) was not found or is not accessible with the current token."
101
+ ) from exc
102
+ except HfHubHTTPError as exc:
103
+ raise RuntimeError(f"Could not verify repo '{repo_id}' ({repo_type}): {exc}") from exc
104
+
105
+
106
+ def validate_source_artifacts(
107
+ api: HfApi,
108
+ repo_id: str,
109
+ repo_type: str,
110
+ subdir: str,
111
+ ) -> None:
112
+ try:
113
+ files = set(_retry_hf_call(api.list_repo_files, repo_id=repo_id, repo_type=repo_type))
114
+ except Exception as exc:
115
+ raise RuntimeError(f"Could not list files for source repo '{repo_id}' ({repo_type}): {exc}") from exc
116
+
117
+ required = {
118
+ f"{subdir}/final/adapter_config.json",
119
+ f"{subdir}/sft_metrics.json",
120
+ }
121
+ missing = sorted(path for path in required if path not in files)
122
+ if missing:
123
+ raise RuntimeError(
124
+ "Source repo is missing required SFT artifacts: " + ", ".join(missing)
125
+ )
126
+
127
+
128
+ def build_command(args: argparse.Namespace, repo_url: str, output_subdir: str) -> str:
129
+ output_path = f"/workspace/adaptshield/checkpoints/{output_subdir}"
130
+
131
+ return f"""
132
+ set -euo pipefail
133
+ export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
134
+ export PYTHONWARNINGS="ignore::FutureWarning"
135
+ export HF_HUB_ENABLE_HF_TRANSFER=1
136
+ export PIP_DISABLE_PIP_VERSION_CHECK=1
137
+
138
+ python - <<'PY'
139
+ import torch
140
+ print(f"baseline torch={{torch.__version__}}, cuda={{torch.version.cuda}}")
141
+ PY
142
+
143
+ apt-get update -qq
144
+ apt-get install -y -qq git
145
+ if [ ! -d /workspace/adaptshield/.git ]; then
146
+ rm -rf /workspace/adaptshield
147
+ git clone --depth 1 {shlex.quote(repo_url)} /workspace/adaptshield
148
+ fi
149
+ cd /workspace/adaptshield
150
+ python -m pip install --upgrade pip wheel setuptools
151
+ # ninja+packaging let any source-built dep that DOES sneak in compile cleanly.
152
+ pip install --upgrade ninja packaging
153
+ pip install -e .
154
+ pip uninstall -y torchaudio || true
155
+
156
+ # Unsloth ships CUDA/torch-pinned extras (cu124 + torch 2.6.0 + xformers+triton wheels).
157
+ # We deliberately use `cu124-torch260` (NOT the `ampere` variant) because:
158
+ # * cu124-torch260 pins torch 2.6 + xformers + triton via prebuilt wheels (no source builds).
159
+ # * cu124-ampere-torch260 ALSO tries to install flash-attn; if its prebuilt wheel URL doesn't
160
+ # match the image's python/cxx11abi exactly, pip falls through to source-building flash-attn
161
+ # (10-30 min, often fails with "ModuleNotFoundError: No module named 'torch'" because PEP 517
162
+ # build isolation hides torch). Unsloth's xformers/triton attention is plenty fast on L4.
163
+ # `unsloth[cu124-torch260]` transitively installs `unsloth[huggingface]` which pins ALL of
164
+ # transformers / trl / peft / accelerate / datasets / bitsandbytes / tokenizers / safetensors
165
+ # to versions Unsloth has tested together. Do NOT add a `--no-deps` override on top of this —
166
+ # previous attempts to do so downgraded peft/trl below what Unsloth requires.
167
+ # --no-build-isolation lets any incidental source build (e.g. a stray dep) see system torch.
168
+ pip install --upgrade --no-build-isolation "unsloth[cu124-torch260]"
169
+
170
+ # Pin transformers to a single known-good version. Why this is necessary:
171
+ # Unsloth's pyproject allows transformers >=4.51.3 ... <=5.5.0. Pip prefers the latest, so it
172
+ # picks 5.5.0 by default. But transformers 4.x requires huggingface-hub<1.0 while 5.x requires
173
+ # hub>=1.5,<2.0 — and unsloth's pyproject does NOT bound hub. So a separate `pip install hub<1.0`
174
+ # silently breaks transformers 5.x (and a separate `pip install hub>=1.5` silently breaks 4.x).
175
+ # The only robust fix is to pin transformers and let pip select the matching hub in the SAME
176
+ # resolution step. We pick 4.57.6 because:
177
+ # * latest 4.x release on PyPI (so qwen3, etc. are supported);
178
+ # * not on Unsloth's blocklist (4.57.0/.4/.5 are; 4.57.6 is fine);
179
+ # * pulls huggingface-hub<1.0 automatically (no separate hub pin needed).
180
+ pip install "transformers==4.57.6"
181
+
182
+ # torchao comes preinstalled in the base image at a version that requires torch 2.7+
183
+ # (it calls torch.utils._pytree.register_constant which doesn't exist in torch 2.6, so
184
+ # `import torchao` crashes with AttributeError). transformers' quantizer registry imports
185
+ # torchao unconditionally if it's installed (`is_torchao_available()` only checks package
186
+ # metadata, not import-ability). With torchao GONE, that check returns False and transformers
187
+ # skips torchao cleanly. We don't use torchao quantization anyway — we use bitsandbytes 4-bit.
188
+ pip uninstall -y torchao || true
189
+
190
+ # Optional helpers we use directly (matplotlib for plots, hf_transfer for fast download/upload).
191
+ pip install --upgrade matplotlib hf_transfer
192
+
193
+ # Hard guard: if torch was upgraded, bitsandbytes will fail at import; fail FAST with a clear log.
194
+ python - <<'PY'
195
+ import sys, torch
196
+ if not torch.__version__.startswith("2.6."):
197
+ print(f"FATAL: torch was upgraded to {{torch.__version__}}; aborting before training.")
198
+ sys.exit(2)
199
+ print(f"torch ok: {{torch.__version__}} cuda={{torch.version.cuda}}")
200
+ PY
201
+
202
+ # Smoke-test the actual modules we use. unsloth MUST import before transformers/trl
203
+ # per its own warning. Importing transformers also triggers its OWN runtime version check on
204
+ # huggingface_hub and tokenizers, AND eagerly imports any installed quantizer backend
205
+ # (torchao, bnb, etc.) — so if anything is mis-pinned this line raises a clear error before
206
+ # training starts.
207
+ python - <<'PY'
208
+ import sys, importlib.util
209
+
210
+ # Pre-flight: torchao must be GONE (preinstalled version requires torch>=2.7 and crashes
211
+ # `import torchao` on torch 2.6). If it leaked back in, fail with a precise message.
212
+ if importlib.util.find_spec("torchao") is not None:
213
+ print("FATAL: torchao is installed; on torch 2.6 it crashes transformers at import. "
214
+ "Run `pip uninstall -y torchao` and rebuild.")
215
+ sys.exit(2)
216
+
217
+ import unsloth # noqa: F401 (must be first)
218
+ import torch, transformers, trl, peft, datasets, bitsandbytes, huggingface_hub
219
+ print(
220
+ f"unsloth={{unsloth.__version__}} transformers={{transformers.__version__}} "
221
+ f"trl={{trl.__version__}} peft={{peft.__version__}} bnb={{bitsandbytes.__version__}} "
222
+ f"hub={{huggingface_hub.__version__}} datasets={{datasets.__version__}}"
223
+ )
224
+ expected_transformers = "4.57.6"
225
+ if transformers.__version__ != expected_transformers:
226
+ print(
227
+ f"FATAL: transformers={{transformers.__version__}} but pinned to {{expected_transformers}}. "
228
+ f"Pip resolution drifted; aborting before training."
229
+ )
230
+ sys.exit(2)
231
+ import train, build_benchmark_table # noqa: F401
232
+ print("Dependency smoke check passed.")
233
+ PY
234
+
235
+ python - <<'PY'
236
+ from huggingface_hub import snapshot_download
237
+ from pathlib import Path
238
+
239
+ repo_id = {args.source_repo!r}
240
+ repo_type = {args.source_repo_type!r}
241
+ subdir = {args.source_subdir!r}
242
+ local_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type)
243
+ adapter_path = Path(local_dir) / subdir / "final"
244
+ sft_metrics_path = Path(local_dir) / subdir / "sft_metrics.json"
245
+ if not adapter_path.exists():
246
+ raise RuntimeError(f"SFT adapter path not found: {{adapter_path}}")
247
+ if not sft_metrics_path.exists():
248
+ raise RuntimeError(f"SFT metrics path not found: {{sft_metrics_path}}")
249
+ print(adapter_path)
250
+ Path("/workspace/adaptshield/.grpo_adapter_path.txt").write_text(str(adapter_path), encoding="utf-8")
251
+ Path("/workspace/adaptshield/.grpo_sft_metrics_path.txt").write_text(str(sft_metrics_path), encoding="utf-8")
252
+ PY
253
+
254
+ ADAPTER_PATH=$(cat /workspace/adaptshield/.grpo_adapter_path.txt)
255
+ SFT_METRICS_PATH=$(cat /workspace/adaptshield/.grpo_sft_metrics_path.txt)
256
+
257
+ python train.py \\
258
+ --trainer grpo \\
259
+ --task {args.task} \\
260
+ --curriculum \\
261
+ --use-tools \\
262
+ --model {args.model} \\
263
+ --model-path "$ADAPTER_PATH" \\
264
+ --lr {args.lr} \\
265
+ --prompt-bank-episodes {args.prompt_bank_episodes} \\
266
+ --max-steps {args.max_steps} \\
267
+ --prompt-bank-hard-multiplier {args.prompt_bank_hard_multiplier} \\
268
+ --prompt-bank-borderline-bonus {args.prompt_bank_borderline_bonus} \\
269
+ --grpo-epochs {args.grpo_epochs} \\
270
+ --num-generations {args.num_generations} \\
271
+ --per-device-batch-size {args.per_device_batch_size} \\
272
+ --gradient-accumulation-steps {args.gradient_accumulation_steps} \\
273
+ --save-every {args.save_every} \\
274
+ --eval-episodes {args.eval_episodes} \\
275
+ --train-world-split train \\
276
+ --heldout-world-split eval \\
277
+ --heldout-seed {args.heldout_seed} \\
278
+ --output {output_path} \\
279
+ --plot
280
+
281
+ if ! python build_benchmark_table.py \\
282
+ --sft-metrics "$SFT_METRICS_PATH" \\
283
+ --grpo-metrics {output_path}/metrics.json \\
284
+ --output {output_path}/benchmark_table.md; then
285
+ echo "Benchmark table generation failed; continuing with core artifacts."
286
+ fi
287
+
288
+ python - <<'PY'
289
+ import os
290
+ import time
291
+ from huggingface_hub import HfApi
292
+
293
+ api = HfApi(token=os.environ["HF_TOKEN"])
294
+ repo_id = os.environ["RUNS_REPO"]
295
+ repo_type = os.environ["RUNS_REPO_TYPE"]
296
+ output_dir = {output_path!r}
297
+ subdir = {output_subdir!r}
298
+
299
+ last_exc = None
300
+ for attempt in range(4):
301
+ try:
302
+ api.upload_folder(
303
+ repo_id=repo_id,
304
+ repo_type=repo_type,
305
+ folder_path=output_dir,
306
+ path_in_repo=subdir,
307
+ )
308
+ last_exc = None
309
+ break
310
+ except Exception as exc:
311
+ last_exc = exc
312
+ response = getattr(exc, "response", None)
313
+ status_code = getattr(response, "status_code", None)
314
+ if status_code == 429 or (isinstance(status_code, int) and 500 <= status_code < 600):
315
+ sleep_for = 2 ** attempt
316
+ print(f"Transient upload error: {{exc}}; retrying in {{sleep_for}}s")
317
+ time.sleep(sleep_for)
318
+ continue
319
+ raise
320
+ if last_exc is not None:
321
+ raise last_exc
322
+ print("Uploaded artifacts to", repo_id)
323
+ PY
324
+ """
325
+
326
+
327
+ def default_output_subdir(task: str, model: str) -> str:
328
+ model_slug = model.replace(".", "_")
329
+ if task == "all":
330
+ return f"grpo_worldsplit_{model_slug}"
331
+ task_slug = task.replace("-", "_")
332
+ return f"grpo_{task_slug}_{model_slug}"
333
+
334
+
335
+ def main() -> int:
336
+ parser = argparse.ArgumentParser(description="Launch AdaptShield GRPO refinement on Hugging Face Jobs")
337
+ parser.add_argument("--runs-repo", required=True)
338
+ parser.add_argument("--runs-repo-type", default="model", choices=["dataset", "model"])
339
+ parser.add_argument("--skip-create", action="store_true")
340
+ parser.add_argument("--allow-cross-namespace", action="store_true")
341
+ parser.add_argument("--repo-url", default=None)
342
+ parser.add_argument("--source-repo", required=True, help="Repo containing SFT artifacts.")
343
+ parser.add_argument("--source-repo-type", default="model", choices=["dataset", "model"])
344
+ parser.add_argument("--source-subdir", default="sft_worldsplit_1_5b", help="Subdirectory containing the SFT output.")
345
+ parser.add_argument("--task", default="all", choices=TASKS + ["all"])
346
+ parser.add_argument("--model", default="1.5b", choices=list(MODEL_CHOICES))
347
+ parser.add_argument("--flavor", default="l4x1")
348
+ parser.add_argument("--timeout", default="6h")
349
+ parser.add_argument("--lr", type=float, default=1e-5)
350
+ parser.add_argument("--prompt-bank-episodes", type=int, default=120)
351
+ parser.add_argument("--max-steps", type=int, default=20)
352
+ parser.add_argument("--prompt-bank-hard-multiplier", type=int, default=3)
353
+ parser.add_argument("--prompt-bank-borderline-bonus", type=int, default=2)
354
+ parser.add_argument("--grpo-epochs", type=int, default=1)
355
+ parser.add_argument("--num-generations", type=int, default=2)
356
+ parser.add_argument("--per-device-batch-size", type=int, default=1)
357
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
358
+ parser.add_argument("--save-every", type=int, default=0)
359
+ parser.add_argument("--eval-episodes", type=int, default=2)
360
+ parser.add_argument("--heldout-seed", type=int, default=314)
361
+ parser.add_argument("--output-subdir", default="")
362
+ args = parser.parse_args()
363
+
364
+ if not args.output_subdir:
365
+ args.output_subdir = default_output_subdir(args.task, args.model)
366
+
367
+ token = get_token()
368
+ if not token:
369
+ raise RuntimeError("No Hugging Face token found. Run `hf auth login` first.")
370
+
371
+ repo_url = args.repo_url or infer_repo_url()
372
+ api = HfApi(token=token)
373
+ validate_repo_access(api, args.runs_repo, args.runs_repo_type, args.skip_create, args.allow_cross_namespace)
374
+ validate_repo_access(api, args.source_repo, args.source_repo_type, True, args.allow_cross_namespace)
375
+ validate_source_artifacts(api, args.source_repo, args.source_repo_type, args.source_subdir)
376
+ if not args.skip_create:
377
+ _retry_hf_call(api.create_repo, repo_id=args.runs_repo, repo_type=args.runs_repo_type, private=True, exist_ok=True)
378
+
379
+ command = build_command(args=args, repo_url=repo_url, output_subdir=args.output_subdir)
380
+ job = _retry_hf_call(
381
+ run_job,
382
+ image=DEFAULT_IMAGE,
383
+ command=["bash", "-lc", command],
384
+ flavor=args.flavor,
385
+ timeout=args.timeout,
386
+ namespace=repo_namespace(args.runs_repo),
387
+ env={
388
+ "RUNS_REPO": args.runs_repo,
389
+ "RUNS_REPO_TYPE": args.runs_repo_type,
390
+ },
391
+ secrets={"HF_TOKEN": token},
392
+ )
393
+
394
+ print("Job launched successfully.")
395
+ print(f"Job ID: {job.id}")
396
+ print(f"Job URL: {job.url}")
397
+ print(f"Artifacts repo: {args.runs_repo}")
398
+ print(f"Artifacts path: {args.output_subdir}")
399
+ return 0
400
+
401
+
402
+ if __name__ == "__main__":
403
+ raise SystemExit(main())
launch_hf_sft_job.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Launch an AdaptShield SFT training run on Hugging Face Jobs."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import shlex
8
+ import subprocess
9
+ import time
10
+ from pathlib import Path
11
+
12
+ from huggingface_hub import HfApi, get_token, run_job
13
+ from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError
14
+
15
+ from train import MODEL_CHOICES
16
+
17
+
18
+ REPO_ROOT = Path(__file__).resolve().parent
19
+ DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"
20
+
21
+
22
+ def _should_retry_hf(exc: Exception) -> bool:
23
+ response = getattr(exc, "response", None)
24
+ status_code = getattr(response, "status_code", None)
25
+ return status_code == 429 or (isinstance(status_code, int) and 500 <= status_code < 600)
26
+
27
+
28
+ def _retry_hf_call(fn, *args, retries: int = 4, delay_s: float = 2.0, **kwargs):
29
+ last_exc = None
30
+ for attempt in range(retries):
31
+ try:
32
+ return fn(*args, **kwargs)
33
+ except Exception as exc:
34
+ last_exc = exc
35
+ if not _should_retry_hf(exc) or attempt == retries - 1:
36
+ raise
37
+ sleep_for = delay_s * (2 ** attempt)
38
+ print(f"Retrying HF API call after transient error ({exc}); sleeping {sleep_for:.1f}s")
39
+ time.sleep(sleep_for)
40
+ raise last_exc # pragma: no cover
41
+
42
+
43
+ def repo_namespace(repo_id: str) -> str:
44
+ if "/" not in repo_id:
45
+ raise RuntimeError(f"Invalid repo id: {repo_id}. Expected namespace/name.")
46
+ return repo_id.split("/", 1)[0]
47
+
48
+
49
+ def authenticated_username(api: HfApi) -> str | None:
50
+ try:
51
+ info = api.whoami(cache=True)
52
+ except Exception:
53
+ return None
54
+ if isinstance(info, dict):
55
+ for key in ("name", "fullname", "user"):
56
+ value = info.get(key)
57
+ if isinstance(value, str) and value:
58
+ return value
59
+ return None
60
+
61
+
62
+ def validate_artifact_repo(
63
+ api: HfApi,
64
+ repo_id: str,
65
+ repo_type: str,
66
+ skip_create: bool,
67
+ allow_cross_namespace: bool,
68
+ ) -> None:
69
+ owner = repo_namespace(repo_id)
70
+ username = authenticated_username(api)
71
+ if username and owner != username:
72
+ message = (
73
+ f"Authenticated HF account appears to be '{username}', but artifacts repo is under '{owner}'. "
74
+ "Use a repo under the same namespace or pass --allow-cross-namespace only if you are certain "
75
+ "this token has write access there."
76
+ )
77
+ if not allow_cross_namespace:
78
+ raise RuntimeError(message)
79
+ print(f"Warning: {message}")
80
+
81
+ if skip_create:
82
+ try:
83
+ _retry_hf_call(api.repo_info, repo_id=repo_id, repo_type=repo_type)
84
+ except RepositoryNotFoundError as exc:
85
+ raise RuntimeError(
86
+ f"Artifacts repo '{repo_id}' ({repo_type}) was not found or is not accessible "
87
+ "with the current token. Create it manually under the correct namespace or use "
88
+ "a repo you definitely own before launching the job."
89
+ ) from exc
90
+ except HfHubHTTPError as exc:
91
+ raise RuntimeError(
92
+ f"Could not verify artifacts repo '{repo_id}' ({repo_type}) before launch: {exc}"
93
+ ) from exc
94
+
95
+
96
+ def infer_repo_url() -> str:
97
+ result = subprocess.run(
98
+ ["git", "config", "--get", "remote.origin.url"],
99
+ cwd=REPO_ROOT,
100
+ check=True,
101
+ capture_output=True,
102
+ text=True,
103
+ )
104
+ repo_url = result.stdout.strip()
105
+ if not repo_url:
106
+ raise RuntimeError("Could not infer git remote.origin.url")
107
+ return repo_url
108
+
109
+
110
+ def build_command(args: argparse.Namespace, repo_url: str, output_subdir: str) -> str:
111
+ dataset_path = "/workspace/adaptshield/data/adaptshield_sft_worldsplit.jsonl"
112
+ output_path = f"/workspace/adaptshield/checkpoints/{output_subdir}"
113
+ summary_path = "/workspace/adaptshield/data/adaptshield_sft_worldsplit.summary.json"
114
+ extra_train_flags = "--skip-reward-curve" if args.skip_reward_curve else ""
115
+
116
+ return f"""
117
+ set -euo pipefail
118
+ export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
119
+ export PYTHONWARNINGS="ignore::FutureWarning"
120
+ export HF_HUB_ENABLE_HF_TRANSFER=1
121
+ export PIP_DISABLE_PIP_VERSION_CHECK=1
122
+
123
+ python - <<'PY'
124
+ import torch
125
+ print(f"baseline torch={{torch.__version__}}, cuda={{torch.version.cuda}}")
126
+ PY
127
+
128
+ apt-get update -qq
129
+ apt-get install -y -qq git
130
+ if [ ! -d /workspace/adaptshield/.git ]; then
131
+ rm -rf /workspace/adaptshield
132
+ git clone --depth 1 {shlex.quote(repo_url)} /workspace/adaptshield
133
+ fi
134
+ cd /workspace/adaptshield
135
+ python -m pip install --upgrade pip wheel setuptools
136
+ # ninja+packaging let any source-built dep that DOES sneak in compile cleanly.
137
+ pip install --upgrade ninja packaging
138
+ pip install -e .
139
+ pip uninstall -y torchaudio || true
140
+
141
+ # Unsloth ships CUDA/torch-pinned extras (cu124 + torch 2.6.0 + xformers+triton wheels).
142
+ # We deliberately use `cu124-torch260` (NOT the `ampere` variant) because:
143
+ # * cu124-torch260 pins torch 2.6 + xformers + triton via prebuilt wheels (no source builds).
144
+ # * cu124-ampere-torch260 ALSO tries to install flash-attn; if its prebuilt wheel URL doesn't
145
+ # match the image's python/cxx11abi exactly, pip falls through to source-building flash-attn
146
+ # (10-30 min, often fails with "ModuleNotFoundError: No module named 'torch'" because PEP 517
147
+ # build isolation hides torch). Unsloth's xformers/triton attention is plenty fast on L4.
148
+ # `unsloth[cu124-torch260]` transitively installs `unsloth[huggingface]` which pins ALL of
149
+ # transformers / trl / peft / accelerate / datasets / bitsandbytes / tokenizers / safetensors
150
+ # to versions Unsloth has tested together. Do NOT add a `--no-deps` override on top of this —
151
+ # previous attempts to do so downgraded peft/trl below what Unsloth requires.
152
+ # --no-build-isolation lets any incidental source build (e.g. a stray dep) see system torch.
153
+ pip install --upgrade --no-build-isolation "unsloth[cu124-torch260]"
154
+
155
+ # Pin transformers to a single known-good version. Why this is necessary:
156
+ # Unsloth's pyproject allows transformers >=4.51.3 ... <=5.5.0. Pip prefers the latest, so it
157
+ # picks 5.5.0 by default. But transformers 4.x requires huggingface-hub<1.0 while 5.x requires
158
+ # hub>=1.5,<2.0 — and unsloth's pyproject does NOT bound hub. So a separate `pip install hub<1.0`
159
+ # silently breaks transformers 5.x (and a separate `pip install hub>=1.5` silently breaks 4.x).
160
+ # The only robust fix is to pin transformers and let pip select the matching hub in the SAME
161
+ # resolution step. We pick 4.57.6 because:
162
+ # * latest 4.x release on PyPI (so qwen3, etc. are supported);
163
+ # * not on Unsloth's blocklist (4.57.0/.4/.5 are; 4.57.6 is fine);
164
+ # * pulls huggingface-hub<1.0 automatically (no separate hub pin needed).
165
+ pip install "transformers==4.57.6"
166
+
167
+ # torchao comes preinstalled in the base image at a version that requires torch 2.7+
168
+ # (it calls torch.utils._pytree.register_constant which doesn't exist in torch 2.6, so
169
+ # `import torchao` crashes with AttributeError). transformers' quantizer registry imports
170
+ # torchao unconditionally if it's installed (`is_torchao_available()` only checks package
171
+ # metadata, not import-ability). With torchao GONE, that check returns False and transformers
172
+ # skips torchao cleanly. We don't use torchao quantization anyway — we use bitsandbytes 4-bit.
173
+ pip uninstall -y torchao || true
174
+
175
+ # Optional helpers we use directly (matplotlib for plots, hf_transfer for fast download/upload).
176
+ pip install --upgrade matplotlib hf_transfer
177
+
178
+ # Hard guard: if torch was upgraded, bitsandbytes will fail at import; fail FAST with a clear log.
179
+ python - <<'PY'
180
+ import sys, torch
181
+ if not torch.__version__.startswith("2.6."):
182
+ print(f"FATAL: torch was upgraded to {{torch.__version__}}; aborting before training.")
183
+ sys.exit(2)
184
+ print(f"torch ok: {{torch.__version__}} cuda={{torch.version.cuda}}")
185
+ PY
186
+
187
+ # Smoke-test the actual modules we use. unsloth MUST import before transformers/trl
188
+ # per its own warning. Importing transformers also triggers its OWN runtime version check on
189
+ # huggingface_hub and tokenizers, AND eagerly imports any installed quantizer backend
190
+ # (torchao, bnb, etc.) — so if anything is mis-pinned this line raises a clear error before
191
+ # training starts.
192
+ python - <<'PY'
193
+ import sys, importlib.util
194
+
195
+ # Pre-flight: torchao must be GONE (preinstalled version requires torch>=2.7 and crashes
196
+ # `import torchao` on torch 2.6). If it leaked back in, fail with a precise message.
197
+ if importlib.util.find_spec("torchao") is not None:
198
+ print("FATAL: torchao is installed; on torch 2.6 it crashes transformers at import. "
199
+ "Run `pip uninstall -y torchao` and rebuild.")
200
+ sys.exit(2)
201
+
202
+ import unsloth # noqa: F401 (must be first)
203
+ import torch, transformers, trl, peft, datasets, bitsandbytes, huggingface_hub
204
+ print(
205
+ f"unsloth={{unsloth.__version__}} transformers={{transformers.__version__}} "
206
+ f"trl={{trl.__version__}} peft={{peft.__version__}} bnb={{bitsandbytes.__version__}} "
207
+ f"hub={{huggingface_hub.__version__}} datasets={{datasets.__version__}}"
208
+ )
209
+ expected_transformers = "4.57.6"
210
+ if transformers.__version__ != expected_transformers:
211
+ print(
212
+ f"FATAL: transformers={{transformers.__version__}} but pinned to {{expected_transformers}}. "
213
+ f"Pip resolution drifted; aborting before training."
214
+ )
215
+ sys.exit(2)
216
+ import train, train_sft, generate_sft_data # noqa: F401
217
+ print("Dependency smoke check passed.")
218
+ PY
219
+
220
+ python generate_sft_data.py \\
221
+ --task all \\
222
+ --curriculum \\
223
+ --use-tools \\
224
+ --episodes {args.dataset_episodes} \\
225
+ --max-steps {args.max_steps} \\
226
+ --seed {args.seed} \\
227
+ --world-split train \\
228
+ --output {dataset_path}
229
+
230
+ python train_sft.py \\
231
+ --dataset {dataset_path} \\
232
+ --model {args.model} \\
233
+ --epochs {args.epochs} \\
234
+ --lr {args.lr} \\
235
+ --per-device-batch-size {args.per_device_batch_size} \\
236
+ --gradient-accumulation-steps {args.gradient_accumulation_steps} \\
237
+ --save-steps {args.save_steps} \\
238
+ --heldout-seed {args.heldout_seed} \\
239
+ --train-world-split train \\
240
+ --heldout-world-split eval \\
241
+ --eval-task all \\
242
+ --eval-episodes {args.eval_episodes} \\
243
+ --use-tools \\
244
+ --output {output_path} \\
245
+ {extra_train_flags}
246
+
247
+ python - <<'PY'
248
+ import os
249
+ import time
250
+ from huggingface_hub import HfApi
251
+
252
+ api = HfApi(token=os.environ["HF_TOKEN"])
253
+ repo_id = os.environ["RUNS_REPO"]
254
+ repo_type = os.environ["RUNS_REPO_TYPE"]
255
+ output_dir = {output_path!r}
256
+ summary_path = {summary_path!r}
257
+ subdir = {output_subdir!r}
258
+
259
+ last_exc = None
260
+ for attempt in range(4):
261
+ try:
262
+ api.upload_folder(
263
+ repo_id=repo_id,
264
+ repo_type=repo_type,
265
+ folder_path=output_dir,
266
+ path_in_repo=subdir,
267
+ )
268
+ api.upload_file(
269
+ repo_id=repo_id,
270
+ repo_type=repo_type,
271
+ path_or_fileobj=summary_path,
272
+ path_in_repo=f"{{subdir}}/adaptshield_sft_worldsplit.summary.json",
273
+ )
274
+ last_exc = None
275
+ break
276
+ except Exception as exc:
277
+ last_exc = exc
278
+ response = getattr(exc, "response", None)
279
+ status_code = getattr(response, "status_code", None)
280
+ if status_code == 429 or (isinstance(status_code, int) and 500 <= status_code < 600):
281
+ sleep_for = 2 ** attempt
282
+ print(f"Transient upload error: {{exc}}; retrying in {{sleep_for}}s")
283
+ time.sleep(sleep_for)
284
+ continue
285
+ raise
286
+ if last_exc is not None:
287
+ raise last_exc
288
+ print("Uploaded artifacts to", repo_id)
289
+ PY
290
+ """
291
+
292
+
293
+ def main() -> int:
294
+ parser = argparse.ArgumentParser(description="Launch AdaptShield SFT training on Hugging Face Jobs")
295
+ parser.add_argument("--runs-repo", required=True, help="Artifact repo to upload outputs to, e.g. username/adaptshield-runs")
296
+ parser.add_argument("--runs-repo-type", default="dataset", choices=["dataset", "model"], help="Repo type used to store training artifacts.")
297
+ parser.add_argument("--skip-create", action="store_true", help="Skip repo creation and assume the artifacts repo already exists.")
298
+ parser.add_argument("--allow-cross-namespace", action="store_true", help="Allow uploads to a repo owned by a different namespace than the authenticated account.")
299
+ parser.add_argument("--repo-url", default=None, help="Git repo URL to clone inside the HF Job. Defaults to remote.origin.url")
300
+ parser.add_argument("--model", default="1.5b", choices=list(MODEL_CHOICES))
301
+ parser.add_argument("--flavor", default="l4x1", help="HF Jobs hardware flavor, e.g. l4x1, a10g-small, a100-large")
302
+ parser.add_argument("--timeout", default="6h", help="HF Jobs timeout, e.g. 6h")
303
+ parser.add_argument("--dataset-episodes", type=int, default=240)
304
+ parser.add_argument("--max-steps", type=int, default=20)
305
+ parser.add_argument("--epochs", type=float, default=1.0)
306
+ parser.add_argument("--lr", type=float, default=2e-4)
307
+ parser.add_argument("--per-device-batch-size", type=int, default=2)
308
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
309
+ parser.add_argument("--save-steps", type=int, default=40)
310
+ parser.add_argument("--eval-episodes", type=int, default=2)
311
+ parser.add_argument("--seed", type=int, default=42)
312
+ parser.add_argument("--heldout-seed", type=int, default=314)
313
+ parser.add_argument(
314
+ "--skip-reward-curve",
315
+ action="store_true",
316
+ help="Skip the per-checkpoint held-out reward sweep inside train_sft.py.",
317
+ )
318
+ parser.add_argument("--output-subdir", default=None, help="Optional output folder name in the runs dataset repo")
319
+ args = parser.parse_args()
320
+
321
+ token = get_token()
322
+ if not token:
323
+ raise RuntimeError("No Hugging Face token found. Run `hf auth login` first.")
324
+
325
+ repo_url = args.repo_url or infer_repo_url()
326
+ output_subdir = args.output_subdir or f"sft_worldsplit_{args.model.replace('.', '_')}"
327
+
328
+ api = HfApi(token=token)
329
+ validate_artifact_repo(
330
+ api,
331
+ args.runs_repo,
332
+ args.runs_repo_type,
333
+ args.skip_create,
334
+ args.allow_cross_namespace,
335
+ )
336
+ if not args.skip_create:
337
+ _retry_hf_call(api.create_repo, repo_id=args.runs_repo, repo_type=args.runs_repo_type, private=True, exist_ok=True)
338
+
339
+ command = build_command(args=args, repo_url=repo_url, output_subdir=output_subdir)
340
+ job = _retry_hf_call(
341
+ run_job,
342
+ image=DEFAULT_IMAGE,
343
+ command=["bash", "-lc", command],
344
+ flavor=args.flavor,
345
+ timeout=args.timeout,
346
+ namespace=repo_namespace(args.runs_repo),
347
+ env={
348
+ "RUNS_REPO": args.runs_repo,
349
+ "RUNS_REPO_TYPE": args.runs_repo_type,
350
+ },
351
+ secrets={"HF_TOKEN": token},
352
+ )
353
+
354
+ print("Job launched successfully.")
355
+ print(f"Job ID: {job.id}")
356
+ print(f"Job URL: {job.url}")
357
+ print(f"Artifacts repo: {args.runs_repo}")
358
+ print(f"Artifacts path: {output_subdir}")
359
+ return 0
360
+
361
+
362
+ if __name__ == "__main__":
363
+ raise SystemExit(main())
models.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # AdaptShield — Pydantic Data Models
5
+ #
6
+ # CRITICAL DESIGN DECISION: Phase1Action and Phase2Action are SEPARATE classes.
7
+ # A single combined class with optional fields causes 500 errors when the
8
+ # evaluator sends a Phase 2 payload and Pydantic tries to validate Phase 1 fields.
9
+
10
+ from enum import Enum
11
+ from typing import Any, Dict, List, Optional
12
+
13
+ from openenv.core.env_server.types import Action, Observation
14
+ from pydantic import Field, model_validator
15
+
16
+
17
+ class DefenseAction(str, Enum):
18
+ """
19
+ Strict action space for the Tactical Executor (Phase 2).
20
+ Using Enum prevents LLM hallucination from reaching the grader.
21
+ """
22
+ RATE_LIMIT = "rate_limit" # Light — throttles traffic, keeps service online
23
+ ISOLATE = "isolate" # Heavy — takes node offline, stops spread
24
+ HONEYPOT = "honeypot" # Strategic — redirects attacker to decoy
25
+ PATCH = "patch" # Targeted — fixes supply chain vulnerability
26
+ MONITOR = "monitor" # Passive — gather info, risk escalation
27
+
28
+
29
+ class ThreatType(str, Enum):
30
+ """Known attack strategies the Threat Analyst can classify."""
31
+ BRUTE_FORCE = "brute_force"
32
+ LATERAL_MOVEMENT = "lateral_movement"
33
+ EXFILTRATION = "exfiltration"
34
+ SUPPLY_CHAIN = "supply_chain"
35
+ BENIGN = "benign"
36
+
37
+
38
+ class Phase1Action(Action):
39
+ """
40
+ Threat Analyst output — pure reasoning, no defensive action.
41
+
42
+ The agent reads raw network state and produces a structured
43
+ threat assessment. This is graded independently for classification
44
+ accuracy before Phase 2 acts on it.
45
+ """
46
+ threat_type: str = Field(
47
+ ...,
48
+ description="Identified attack strategy: brute_force, lateral_movement, "
49
+ "exfiltration, supply_chain, or benign",
50
+ )
51
+ confidence: float = Field(
52
+ ...,
53
+ ge=0.0,
54
+ le=1.0,
55
+ description="Confidence in the threat classification (0.0 to 1.0)",
56
+ )
57
+ target_node: str = Field(
58
+ ...,
59
+ description="Primary affected node: auth_service, payment_service, "
60
+ "database, or api_gateway",
61
+ )
62
+ recommended_action: DefenseAction = Field(
63
+ ...,
64
+ description="Recommended defense action for Phase 2 to execute",
65
+ )
66
+ reasoning: Optional[str] = Field(
67
+ default=None,
68
+ description="Chain of thought. Not graded. Helps training stability.",
69
+ )
70
+
71
+
72
+ class Phase2Action(Action):
73
+ """
74
+ Tactical Executor output — defensive action based ONLY on Phase 1 assessment.
75
+
76
+ Phase 2 agent is deliberately blind to raw network state.
77
+ It receives only the Phase 1 threat assessment and must act on it.
78
+ """
79
+ action: DefenseAction = Field(
80
+ ...,
81
+ description="Defense action to execute",
82
+ )
83
+ target_node: str = Field(
84
+ ...,
85
+ description="Node to apply action to: auth_service, payment_service, "
86
+ "database, or api_gateway",
87
+ )
88
+ reasoning: Optional[str] = Field(
89
+ default=None,
90
+ description="Chain of thought. Not graded.",
91
+ )
92
+
93
+
94
+ class AdaptShieldAction(Action):
95
+ """
96
+ Unified action model accepted by the OpenEnv HTTP server.
97
+
98
+ The environment alternates between two phases, so the transport layer must
99
+ accept either a Threat Analyst payload or a Tactical Executor payload.
100
+ Validation keeps those shapes distinct while still fitting the single
101
+ action model expected by `create_app`.
102
+ """
103
+
104
+ threat_type: Optional[str] = Field(
105
+ default=None,
106
+ description="Phase 1 only: identified attack strategy",
107
+ )
108
+ confidence: Optional[float] = Field(
109
+ default=None,
110
+ ge=0.0,
111
+ le=1.0,
112
+ description="Phase 1 only: confidence in the threat classification",
113
+ )
114
+ target_node: Optional[str] = Field(
115
+ default=None,
116
+ description="Target node for either phase",
117
+ )
118
+ recommended_action: Optional[DefenseAction] = Field(
119
+ default=None,
120
+ description="Phase 1 only: recommended follow-up action",
121
+ )
122
+ action: Optional[DefenseAction] = Field(
123
+ default=None,
124
+ description="Phase 2 only: defensive action to execute",
125
+ )
126
+ reasoning: Optional[str] = Field(
127
+ default=None,
128
+ description="Optional one-sentence rationale",
129
+ )
130
+
131
+ @model_validator(mode="after")
132
+ def validate_phase_shape(self) -> "AdaptShieldAction":
133
+ phase1_present = any(
134
+ value is not None
135
+ for value in (self.threat_type, self.confidence, self.recommended_action)
136
+ )
137
+ phase2_present = self.action is not None
138
+
139
+ if phase1_present and phase2_present:
140
+ raise ValueError(
141
+ "Action payload must be either Phase 1 or Phase 2, not both."
142
+ )
143
+ if not phase1_present and not phase2_present:
144
+ raise ValueError(
145
+ "Action payload must contain Phase 1 fields or a Phase 2 action."
146
+ )
147
+
148
+ if phase1_present:
149
+ missing = [
150
+ field_name
151
+ for field_name, value in (
152
+ ("threat_type", self.threat_type),
153
+ ("confidence", self.confidence),
154
+ ("target_node", self.target_node),
155
+ ("recommended_action", self.recommended_action),
156
+ )
157
+ if value is None
158
+ ]
159
+ else:
160
+ missing = [
161
+ field_name
162
+ for field_name, value in (
163
+ ("action", self.action),
164
+ ("target_node", self.target_node),
165
+ )
166
+ if value is None
167
+ ]
168
+
169
+ if missing:
170
+ raise ValueError(
171
+ f"Missing required fields for this phase: {', '.join(missing)}"
172
+ )
173
+
174
+ return self
175
+
176
+
177
+ class AdaptShieldObservation(Observation):
178
+ """
179
+ Observation returned after each step.
180
+
181
+ Phase 1 observation: contains full network state (network_nodes, active_alerts).
182
+ Phase 2 observation: network_nodes and active_alerts are EMPTY.
183
+ phase1_assessment contains the Phase 1 output.
184
+
185
+ Episode number is NEVER included — agent must rely on signals only.
186
+ """
187
+
188
+ # Identity
189
+ scenario_id: str = Field(default="")
190
+ task_name: str = Field(default="")
191
+ phase: int = Field(default=1,
192
+ description="1 = Threat Analyst turn, 2 = Tactical Executor turn")
193
+ turn: int = Field(default=0)
194
+ max_turns: int = Field(default=5)
195
+
196
+ # Network state — populated in Phase 1, EMPTY in Phase 2
197
+ network_nodes: Dict[str, Any] = Field(default_factory=dict)
198
+ active_alerts: List[str] = Field(default_factory=list)
199
+ attack_stage: str = Field(
200
+ default="none",
201
+ description="Current attack progression stage: recon, exploit, exfiltration, none",
202
+ )
203
+
204
+ # Rolling history of last 3 turns
205
+ history: List[Dict[str, str]] = Field(default_factory=list)
206
+
207
+ # Phase 2 only — Phase 1 output passed to executor
208
+ phase1_assessment: Optional[Dict[str, Any]] = Field(
209
+ default=None,
210
+ description="Populated only in Phase 2. Phase 2 agent sees ONLY this.",
211
+ )
212
+
213
+ # Context
214
+ system_context: str = Field(default="")
215
+ available_actions: List[str] = Field(default_factory=list)
216
+
217
+ # Feedback
218
+ last_action_result: Optional[str] = Field(default=None)
219
+ reward: float = Field(default=0.0)
220
+ done: bool = Field(default=False)
221
+ metadata: Dict[str, Any] = Field(default_factory=dict)
222
+
223
+ def model_dump(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
224
+ """
225
+ Keep metadata in OpenEnv HTTP observation payloads.
226
+
227
+ OpenEnv's serializer excludes metadata from the nested observation by
228
+ default. AdaptShield exposes normalized_score there, so we remove only
229
+ that exclusion while preserving the serializer's reward/done handling.
230
+ """
231
+ exclude = kwargs.get("exclude")
232
+ if isinstance(exclude, set) and "metadata" in exclude:
233
+ kwargs["exclude"] = set(exclude) - {"metadata"}
234
+ elif isinstance(exclude, dict) and "metadata" in exclude:
235
+ kwargs["exclude"] = {
236
+ key: value for key, value in exclude.items() if key != "metadata"
237
+ }
238
+ return super().model_dump(*args, **kwargs)
239
+
240
+
241
+ # Backward-compatible aliases for earlier package names.
242
+ AdaptshieldAction = AdaptShieldAction
243
+ AdaptshieldObservation = AdaptShieldObservation
openenv.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: adaptshield
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 7860
7
+ description: >
8
+ AdaptShield is a two-phase agentic cybersecurity environment that trains
9
+ LLMs to adapt to polymorphic adversarial strategies. An agent acts as
10
+ Threat Analyst (Phase 1) then Tactical Executor (Phase 2), defending a
11
+ simulated 4-node enterprise network against a scripted attacker that progresses
12
+ through attack stages and shifts strategy mid-episode. Grading is fully
13
+ deterministic via Python strategy matching. No LLM-as-judge components.
14
+ tasks:
15
+ - name: direct-triage
16
+ difficulty: easy
17
+ description: Single fixed strategy. Agent learns baseline threat response.
18
+ max_steps: 5
19
+ - name: dual-pivot
20
+ difficulty: medium
21
+ description: Two strategies alternating every 20 episodes. Detect and adapt.
22
+ max_steps: 6
23
+ - name: polymorphic-zero-day
24
+ difficulty: hard
25
+ description: All four strategies with mid-episode shift and false-positive noise.
26
+ max_steps: 8
plot_sft_checkpoint_curve.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Plot an SFT checkpoint curve with an optional honest baseline start point."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ from pathlib import Path
9
+ from typing import Any
10
+
11
+
12
+ def load_json(path: Path) -> dict[str, Any]:
13
+ return json.loads(path.read_text(encoding="utf-8"))
14
+
15
+
16
+ def mean_baseline(benchmark: dict[str, Any], key: str) -> float:
17
+ values = benchmark.get(key, {})
18
+ numeric = [float(value) for value in values.values() if value is not None]
19
+ if not numeric:
20
+ raise ValueError(f"No numeric values found under benchmark key '{key}'")
21
+ return sum(numeric) / len(numeric)
22
+
23
+
24
+ def parse_args() -> argparse.Namespace:
25
+ parser = argparse.ArgumentParser(description="Plot SFT checkpoint learning curve with optional baseline point.")
26
+ parser.add_argument("--metrics", required=True, help="Path to sft_metrics.json")
27
+ parser.add_argument("--output", required=True, help="Where to write the PNG")
28
+ parser.add_argument(
29
+ "--baseline-json",
30
+ default="",
31
+ help="Optional benchmark_table.json path used to prepend a real baseline point.",
32
+ )
33
+ parser.add_argument(
34
+ "--baseline-key",
35
+ default="tool_baseline",
36
+ choices=["tool_baseline", "no_tool_baseline"],
37
+ help="Which benchmark JSON field to average for the prepended baseline point.",
38
+ )
39
+ parser.add_argument(
40
+ "--baseline-label",
41
+ default="baseline",
42
+ help="X-axis label for the prepended baseline point.",
43
+ )
44
+ return parser.parse_args()
45
+
46
+
47
+ def main() -> int:
48
+ args = parse_args()
49
+
50
+ metrics = load_json(Path(args.metrics))
51
+ rows = metrics.get("reward_curve_rows", []) or []
52
+ if not rows:
53
+ raise SystemExit("No reward_curve_rows found in the provided SFT metrics file.")
54
+
55
+ labels = [str(row["checkpoint"]) for row in rows]
56
+ train_scores = [float(row["in_distribution_score"]) for row in rows]
57
+ heldout_scores = [float(row["heldout_score"]) for row in rows]
58
+
59
+ if args.baseline_json:
60
+ benchmark = load_json(Path(args.baseline_json))
61
+ baseline_value = mean_baseline(benchmark, args.baseline_key)
62
+ labels = [args.baseline_label] + labels
63
+ train_scores = [baseline_value] + train_scores
64
+ heldout_scores = [baseline_value] + heldout_scores
65
+
66
+ try:
67
+ import matplotlib
68
+ matplotlib.use("Agg")
69
+ import matplotlib.pyplot as plt
70
+ except ImportError as exc:
71
+ raise SystemExit(f"matplotlib is required to plot this curve: {exc}") from exc
72
+
73
+ output_path = Path(args.output)
74
+ output_path.parent.mkdir(parents=True, exist_ok=True)
75
+
76
+ plt.figure(figsize=(11, 5))
77
+ plt.plot(labels, train_scores, marker="o", linewidth=2, color="#174c7a", label="train family")
78
+ plt.plot(labels, heldout_scores, marker="s", linewidth=2, color="#6d4acb", label="held-out family")
79
+ plt.title("Janus SFT Checkpoint Learning Curve")
80
+ plt.xlabel("Checkpoint")
81
+ plt.ylabel("normalized_score")
82
+ plt.ylim(0.0, 1.0)
83
+ plt.grid(alpha=0.25)
84
+ plt.legend()
85
+ plt.xticks(rotation=30, ha="right")
86
+ plt.tight_layout()
87
+ plt.savefig(output_path, dpi=160)
88
+ print(output_path)
89
+ return 0
90
+
91
+
92
+ if __name__ == "__main__":
93
+ raise SystemExit(main())
plot_training.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Plot AdaptShield training CSV or metrics JSON."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import csv
8
+ import json
9
+ from pathlib import Path
10
+ from typing import List, Tuple
11
+
12
+
13
+ def load_scores(path: Path) -> Tuple[List[int], List[float], str, List[str]]:
14
+ if path.suffix == ".json":
15
+ data = json.loads(path.read_text())
16
+ rows = data.get("rows", []) or data.get("evaluation_rows", [])
17
+ episodes = [int(row["episode"]) for row in rows]
18
+ scores = [float(row["score"]) for row in rows]
19
+ stages = [str(row.get("stage", row.get("task", ""))) for row in rows]
20
+ return episodes, scores, str(data.get("model", "adaptshield")), stages
21
+
22
+ with path.open() as handle:
23
+ rows = list(csv.DictReader(handle))
24
+ episodes = [int(row["episode"]) for row in rows]
25
+ scores = [float(row["score"]) for row in rows]
26
+ stages = [str(row.get("stage", row.get("task", ""))) for row in rows]
27
+ return episodes, scores, "adaptshield-smoke", stages
28
+
29
+
30
+ def moving_average(values: List[float], window: int) -> List[float]:
31
+ smoothed = []
32
+ for index in range(len(values)):
33
+ start = max(0, index - window + 1)
34
+ chunk = values[start:index + 1]
35
+ smoothed.append(sum(chunk) / len(chunk))
36
+ return smoothed
37
+
38
+
39
+ def plot(path: Path, output: Path) -> None:
40
+ episodes, scores, label, stages = load_scores(path)
41
+ if not scores:
42
+ raise SystemExit("No scores found to plot.")
43
+
44
+ try:
45
+ import matplotlib
46
+ matplotlib.use("Agg")
47
+ import matplotlib.pyplot as plt
48
+ except ImportError:
49
+ first = sum(scores[:max(1, len(scores) // 5)]) / max(1, len(scores) // 5)
50
+ last = sum(scores[-max(1, len(scores) // 5):]) / max(1, len(scores) // 5)
51
+ print("matplotlib is not installed; skipping PNG generation.")
52
+ print(f"Episodes: {len(scores)}")
53
+ print(f"First-window avg: {first:.3f}")
54
+ print(f"Last-window avg: {last:.3f}")
55
+ print(f"Delta: {last - first:+.3f}")
56
+ return
57
+
58
+ window = max(1, min(10, len(scores) // 5))
59
+ smoothed = moving_average(scores, window)
60
+
61
+ output.parent.mkdir(parents=True, exist_ok=True)
62
+ fig, ax = plt.subplots(figsize=(10, 5))
63
+ ax.plot(episodes, scores, color="#6b8fbf", alpha=0.35, label="raw score")
64
+ ax.plot(episodes, smoothed, color="#123c69", linewidth=2.5, label=f"{window}-episode avg")
65
+ for episode, stage in stage_boundaries(episodes, stages):
66
+ ax.axvline(episode, color="#c44e52", linestyle="--", alpha=0.45)
67
+ ax.text(episode, 0.04, stage.replace("curriculum:", ""), rotation=90, fontsize=8, color="#7a1f24")
68
+ ax.set_title(f"AdaptShield Training Curve ({label})")
69
+ ax.set_xlabel("Episode")
70
+ ax.set_ylabel("normalized_score")
71
+ ax.set_ylim(0.0, 1.0)
72
+ ax.grid(alpha=0.25)
73
+ ax.legend()
74
+ fig.tight_layout()
75
+ fig.savefig(output, dpi=160)
76
+ print(f"Saved plot: {output}")
77
+
78
+
79
+ def stage_boundaries(episodes: List[int], stages: List[str]) -> List[Tuple[int, str]]:
80
+ if not stages:
81
+ return []
82
+
83
+ boundaries = []
84
+ previous = stages[0]
85
+ for episode, stage in zip(episodes, stages):
86
+ if stage != previous:
87
+ boundaries.append((episode, stage))
88
+ previous = stage
89
+ return boundaries
90
+
91
+
92
+ def parse_args() -> argparse.Namespace:
93
+ parser = argparse.ArgumentParser(description="Plot AdaptShield training output.")
94
+ parser.add_argument("--input", default="training_runs/train_smoke.csv")
95
+ parser.add_argument("--output", default="training_runs/reward_curve.png")
96
+ return parser.parse_args()
97
+
98
+
99
+ def main() -> int:
100
+ args = parse_args()
101
+ plot(Path(args.input), Path(args.output))
102
+ return 0
103
+
104
+
105
+ if __name__ == "__main__":
106
+ raise SystemExit(main())
pyproject.toml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-adaptshield"
7
+ version = "0.1.0"
8
+ description = "AdaptShield: Two-Phase Adaptive Cybersecurity RL Environment"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "openenv-core[core]>=0.2.2",
12
+ "fastapi>=0.111.0",
13
+ "openai>=1.0.0",
14
+ "uvicorn>=0.24.0",
15
+ "pydantic>=2.0.0",
16
+ "python-dotenv>=1.0.0",
17
+ ]
18
+
19
+ [project.optional-dependencies]
20
+ dev = [
21
+ "pytest>=8.0.0",
22
+ "pytest-cov>=4.0.0",
23
+ ]
24
+
25
+ [project.scripts]
26
+ server = "adaptshield.server.app:main"
27
+
28
+ [tool.setuptools]
29
+ include-package-data = true
30
+ packages = ["adaptshield", "adaptshield.server"]
31
+ package-dir = { "adaptshield" = ".", "adaptshield.server" = "server" }
server/Dockerfile ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
2
+ FROM ${BASE_IMAGE} AS builder
3
+
4
+ WORKDIR /app
5
+
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends git curl && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ COPY . /app/env
11
+ WORKDIR /app/env
12
+
13
+ RUN if ! command -v uv >/dev/null 2>&1; then \
14
+ curl -LsSf https://astral.sh/uv/install.sh | sh && \
15
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
16
+ mv /root/.local/bin/uvx /usr/local/bin/uvx; \
17
+ fi
18
+
19
+ RUN --mount=type=cache,target=/root/.cache/uv \
20
+ if [ -f uv.lock ]; then \
21
+ uv sync --frozen --no-install-project --no-editable; \
22
+ else \
23
+ uv sync --no-install-project --no-editable; \
24
+ fi
25
+
26
+ RUN --mount=type=cache,target=/root/.cache/uv \
27
+ if [ -f uv.lock ]; then \
28
+ uv sync --frozen --no-editable; \
29
+ else \
30
+ uv sync --no-editable; \
31
+ fi
32
+
33
+ FROM ${BASE_IMAGE}
34
+ WORKDIR /app
35
+
36
+ COPY --from=builder /app/env/.venv /app/.venv
37
+ COPY --from=builder /app/env /app/env
38
+
39
+ ENV PATH="/app/.venv/bin:$PATH"
40
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
41
+
42
+ EXPOSE 7860
43
+
44
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
45
+ CMD curl -f http://localhost:7860/health || exit 1
46
+
47
+ CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 7860"]
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Adaptshield environment server components."""
8
+
9
+ from server.adaptshield_environment import AdaptShieldEnvironment
10
+
11
+ __all__ = ["AdaptShieldEnvironment"]
server/adaptshield_environment.py ADDED
@@ -0,0 +1,1324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdaptShield Environment
3
+
4
+ Two-phase agentic cybersecurity environment implementing full OpenEnv spec.
5
+
6
+ Phase 1 (Threat Analyst): Agent reads raw SIEM state, outputs threat assessment.
7
+ Phase 2 (Tactical Executor): Agent reads ONLY Phase 1 output, executes defense.
8
+
9
+ The attacker progresses through stages (recon→exploit→exfiltration) if agent
10
+ fails to act. On the hard task, strategy shifts mid-episode after turn 3.
11
+
12
+ OpenEnv compliance:
13
+ - reset() returns initial observation
14
+ - step() returns observation with reward, done, info
15
+ - state property returns current State
16
+ - SUPPORTS_CONCURRENT_SESSIONS = True
17
+ - normalized_score ALWAYS present in metadata
18
+ """
19
+
20
+ import os
21
+ import sys
22
+ from enum import Enum
23
+ from typing import Any, Dict, List, Optional
24
+ from uuid import uuid4
25
+
26
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
27
+
28
+ from openenv.core.env_server.interfaces import Environment
29
+ from openenv.core.env_server.types import State
30
+
31
+ from models import AdaptShieldAction, Phase1Action, Phase2Action, AdaptShieldObservation
32
+ from server.attacker import AttackerEngine
33
+ from server.grader import grade_step, normalize_episode_score, _clamp
34
+ from server.scenarios import (
35
+ TASK_CONFIGS,
36
+ build_phase1_obs,
37
+ build_phase2_obs,
38
+ choose_operational_mode,
39
+ choose_world_family,
40
+ mission_profile_for,
41
+ )
42
+
43
+
44
+ DEFENSE_TTL = {
45
+ "rate_limit": 2,
46
+ "isolate": 2,
47
+ "honeypot": 3,
48
+ "patch": 4,
49
+ }
50
+
51
+ DEFENSE_SIDE_EFFECT = {
52
+ "rate_limit": "login_latency",
53
+ "isolate": "service_downtime",
54
+ "honeypot": "attacker_redirection",
55
+ "patch": "temporary_restart",
56
+ }
57
+
58
+ AVAILABLE_SOC_TOOLS = [
59
+ {
60
+ "name": "log_search",
61
+ "endpoint": "/tools/log_search",
62
+ "description": "Search recent SIEM/application logs for a node and time window.",
63
+ },
64
+ {
65
+ "name": "cmdb_lookup",
66
+ "endpoint": "/tools/cmdb_lookup",
67
+ "description": "Inspect service ownership, criticality, dependencies, and blast radius.",
68
+ },
69
+ {
70
+ "name": "edr_status",
71
+ "endpoint": "/tools/edr_status",
72
+ "description": "Check endpoint containment, persistence, beaconing, and active controls.",
73
+ },
74
+ {
75
+ "name": "vuln_lookup",
76
+ "endpoint": "/tools/vuln_lookup",
77
+ "description": "Query internal package/advisory risk for supply-chain investigations.",
78
+ },
79
+ {
80
+ "name": "identity_lookup",
81
+ "endpoint": "/tools/identity_lookup",
82
+ "description": "Inspect account type, privilege level, normal host affinity, and anomalous identity use.",
83
+ },
84
+ {
85
+ "name": "change_calendar_lookup",
86
+ "endpoint": "/tools/change_calendar_lookup",
87
+ "description": "Check whether maintenance, deploys, or patch windows were scheduled for the target service.",
88
+ },
89
+ {
90
+ "name": "netflow_lookup",
91
+ "endpoint": "/tools/netflow_lookup",
92
+ "description": "Inspect east-west and outbound traffic summaries for enterprise network pivots and data movement.",
93
+ },
94
+ ]
95
+
96
+ SERVICE_OWNERS = {
97
+ "auth_service": "identity-platform",
98
+ "payment_service": "checkout-platform",
99
+ "database": "data-platform",
100
+ "api_gateway": "edge-platform",
101
+ }
102
+
103
+ IDENTITY_CONTEXT = {
104
+ "auth_service": {
105
+ "account": "svc_auth_frontend",
106
+ "account_type": "service_account",
107
+ "privilege_level": "medium",
108
+ "normal_hosts": ["auth_service", "api_gateway"],
109
+ },
110
+ "payment_service": {
111
+ "account": "svc_checkout",
112
+ "account_type": "service_account",
113
+ "privilege_level": "high",
114
+ "normal_hosts": ["payment_service"],
115
+ },
116
+ "database": {
117
+ "account": "svc_data_sync",
118
+ "account_type": "service_account",
119
+ "privilege_level": "high",
120
+ "normal_hosts": ["database", "payment_service"],
121
+ },
122
+ "api_gateway": {
123
+ "account": "deploy_bot",
124
+ "account_type": "automation",
125
+ "privilege_level": "medium",
126
+ "normal_hosts": ["api_gateway"],
127
+ },
128
+ }
129
+
130
+ CHANGE_CALENDAR = {
131
+ "auth_service": {
132
+ "window": "03:00-03:20Z",
133
+ "change_type": "auth policy sync",
134
+ "expected_actor": "svc_auth_frontend",
135
+ },
136
+ "payment_service": {
137
+ "window": "02:30-02:45Z",
138
+ "change_type": "checkout rollout",
139
+ "expected_actor": "svc_checkout",
140
+ },
141
+ "database": {
142
+ "window": "04:00-04:30Z",
143
+ "change_type": "backup and index maintenance",
144
+ "expected_actor": "svc_data_sync",
145
+ },
146
+ "api_gateway": {
147
+ "window": "03:10-03:25Z",
148
+ "change_type": "gateway deploy",
149
+ "expected_actor": "deploy_bot",
150
+ },
151
+ }
152
+
153
+
154
+ class AdaptShieldEnvironment(Environment):
155
+ """
156
+ AdaptShield: Two-Phase Adaptive Cybersecurity RL Environment.
157
+
158
+ Example:
159
+ >>> env = AdaptShieldEnvironment(task_name="direct-triage")
160
+ >>> obs = env.reset()
161
+ >>> # Phase 1 — classify the threat
162
+ >>> obs2 = env.step(Phase1Action(
163
+ ... threat_type="brute_force", confidence=0.9,
164
+ ... target_node="auth_service", recommended_action="rate_limit"
165
+ ... ))
166
+ >>> print(obs2.phase) # 2
167
+ >>> # Phase 2 — execute the defense
168
+ >>> obs3 = env.step(Phase2Action(
169
+ ... action="rate_limit", target_node="auth_service"
170
+ ... ))
171
+ >>> print(obs3.reward) # reward signal
172
+ """
173
+
174
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
175
+
176
+ def __init__(
177
+ self,
178
+ task_name: str = "direct-triage",
179
+ world_split: str | None = None,
180
+ world_family: str | None = None,
181
+ operational_mode: str | None = None,
182
+ ):
183
+ if task_name not in TASK_CONFIGS:
184
+ task_name = "direct-triage"
185
+
186
+ self._task_name = task_name
187
+ self._config = TASK_CONFIGS[task_name]
188
+ self._world_split = self._sanitize_world_split(world_split or os.environ.get("ADAPTSHIELD_WORLD_SPLIT", "train"))
189
+ self._requested_world_family = world_family or os.environ.get("ADAPTSHIELD_WORLD_FAMILY")
190
+ self._requested_operational_mode = operational_mode or os.environ.get("ADAPTSHIELD_OPERATIONAL_MODE")
191
+ self._world_family = choose_world_family(self._world_split, self._requested_world_family)
192
+ self._operational_mode = choose_operational_mode(task_name, self._requested_operational_mode)
193
+ self._mission_profile = mission_profile_for(task_name, self._operational_mode, self._world_family)
194
+ self._attacker = AttackerEngine(task_name, world_family=self._world_family)
195
+ self._state = State(episode_id=str(uuid4()), step_count=0)
196
+
197
+ # Episode state
198
+ self._turn: int = 0
199
+ self._phase: int = 1
200
+ self._rewards: List[float] = []
201
+ self._done: bool = False
202
+ self._last_reward: float = 0.0
203
+ self._history: List[Dict[str, str]] = []
204
+ self._phase1_output: Optional[Dict[str, Any]] = None
205
+ self._phase1_grading_output: Optional[Dict[str, Any]] = None
206
+ self._turn_config: Optional[Dict[str, Any]] = None
207
+ self._consecutive_wrong: int = 0
208
+ self._last_obs: Optional[AdaptShieldObservation] = None
209
+ self._episode_replay: List[Dict[str, Any]] = []
210
+ self._last_replay_strategy: Optional[str] = None
211
+ self._active_defenses: List[Dict[str, Any]] = []
212
+ self._foothold_established: bool = False
213
+ self._tool_trace: List[Dict[str, Any]] = []
214
+ self._turn_tool_evidence: Dict[int, List[Dict[str, Any]]] = {}
215
+ self._turn_tool_results: Dict[int, List[Dict[str, Any]]] = {}
216
+
217
+ # ── OpenEnv interface ──────────────────────────────────────────────────
218
+
219
+ def reset(self, task_name: str = None) -> AdaptShieldObservation:
220
+ """
221
+ Reset environment. Optionally switch task via task_name.
222
+ Always returns Phase 1 observation (Threat Analyst turn).
223
+ """
224
+ if task_name and task_name in TASK_CONFIGS:
225
+ self._task_name = task_name
226
+ self._config = TASK_CONFIGS[task_name]
227
+ self._world_family = choose_world_family(self._world_split, self._requested_world_family)
228
+ self._operational_mode = choose_operational_mode(self._task_name, self._requested_operational_mode)
229
+ self._mission_profile = mission_profile_for(self._task_name, self._operational_mode, self._world_family)
230
+ self._attacker = AttackerEngine(self._task_name, world_family=self._world_family)
231
+
232
+ self._state = State(episode_id=str(uuid4()), step_count=0)
233
+ self._turn = 1
234
+ self._phase = 1
235
+ self._rewards = []
236
+ self._done = False
237
+ self._last_reward = 0.0
238
+ self._history = []
239
+ self._phase1_output = None
240
+ self._phase1_grading_output = None
241
+ self._consecutive_wrong = 0
242
+ self._episode_replay = []
243
+ self._last_replay_strategy = None
244
+ self._active_defenses = []
245
+ self._foothold_established = False
246
+ self._tool_trace = []
247
+ self._turn_tool_evidence = {}
248
+ self._turn_tool_results = {}
249
+
250
+ self._attacker.reset_episode()
251
+ self._turn_config = self._prepare_turn_config(self._attacker.build_observation())
252
+
253
+ obs_dict = build_phase1_obs(
254
+ turn_config=self._turn_config,
255
+ history=self._history,
256
+ task_name=self._task_name,
257
+ turn=self._turn,
258
+ max_turns=self._config["max_turns"],
259
+ episode_id=self._state.episode_id,
260
+ mission_profile=self._mission_profile,
261
+ )
262
+ obs = self._to_obs(obs_dict)
263
+ obs.metadata = self._metadata_with_defenses(obs.metadata)
264
+ self._last_obs = obs
265
+ return obs
266
+
267
+ def step(
268
+ self, action: AdaptShieldAction | Phase1Action | Phase2Action
269
+ ) -> AdaptShieldObservation: # type: ignore[override]
270
+ """
271
+ Execute one step.
272
+
273
+ Accepts either Phase1Action or Phase2Action.
274
+ Phase 1 → transitions to Phase 2 (no reward yet).
275
+ Phase 2 → grades action, advances turn, returns to Phase 1.
276
+ """
277
+ if self._done:
278
+ return self._last_obs or self._error_observation(
279
+ "Episode already completed."
280
+ )
281
+
282
+ try:
283
+ self._state.step_count += 1
284
+
285
+ # ── Phase 1 → Phase 2 transition ──────────────────────────────
286
+ if self._phase == 1:
287
+ phase1_output = {
288
+ "threat_type": _action_value(getattr(action, "threat_type", None), "unknown"),
289
+ "confidence": _action_float(getattr(action, "confidence", None), 0.5),
290
+ "target_node": _action_value(getattr(action, "target_node", None), "unknown"),
291
+ "recommended_action": _action_value(getattr(action, "recommended_action", None), "monitor"),
292
+ "reasoning": str(getattr(action, "reasoning", "") or ""),
293
+ }
294
+ self._phase1_grading_output = dict(phase1_output)
295
+ self._phase1_output = _degrade_handoff(
296
+ phase1_output=phase1_output,
297
+ turn_config=self._turn_config or {},
298
+ task_name=self._task_name,
299
+ turn=self._turn,
300
+ )
301
+ self._phase = 2
302
+ current_score = normalize_episode_score(self._rewards)
303
+
304
+ obs_dict = build_phase2_obs(
305
+ phase1_output=self._phase1_output,
306
+ history=self._history,
307
+ task_name=self._task_name,
308
+ turn=self._turn,
309
+ max_turns=self._config["max_turns"],
310
+ episode_id=self._state.episode_id,
311
+ current_score=current_score,
312
+ mission_profile=self._mission_profile,
313
+ )
314
+ obs = self._to_obs(obs_dict)
315
+ obs.reward = _clamp(self._last_reward if self._last_reward > 0 else 0.01)
316
+ obs.metadata = self._metadata_with_defenses({
317
+ "episode_id": self._state.episode_id,
318
+ "normalized_score": float(current_score),
319
+ "mission_profile": self._mission_profile,
320
+ })
321
+ self._last_obs = obs
322
+ return obs
323
+
324
+ # ── Phase 2 — grade and advance turn ──────────────────────────
325
+ p2 = {
326
+ "action": _action_value(getattr(action, "action", None), "monitor"),
327
+ "target_node": _action_value(getattr(action, "target_node", None), "unknown"),
328
+ "reasoning": str(getattr(action, "reasoning", "") or ""),
329
+ }
330
+
331
+ current_stage = self._attacker.current_stage()
332
+ foothold_before = self._foothold_established
333
+ reward, catastrophic, info = grade_step(
334
+ phase1_action=self._phase1_grading_output or self._phase1_output or {},
335
+ phase2_action=p2,
336
+ turn_config=self._turn_config or {},
337
+ stage=current_stage,
338
+ consecutive_wrong=self._consecutive_wrong,
339
+ task_name=self._task_name,
340
+ foothold_established=foothold_before,
341
+ mission_profile=self._mission_profile,
342
+ tool_context=self._tool_context_for_turn(),
343
+ )
344
+
345
+ reward = _clamp(_action_float(reward, 0.01))
346
+ self._register_active_defense(p2)
347
+ foothold_transition = self._update_foothold_state(
348
+ p2=p2,
349
+ info=info,
350
+ stage=current_stage,
351
+ )
352
+ info["foothold_established"] = self._foothold_established
353
+ info["foothold_transition"] = foothold_transition
354
+
355
+ # Track consecutive wrong actions for stage escalation
356
+ if info.get("acted_correctly", False):
357
+ self._consecutive_wrong = 0
358
+ else:
359
+ self._consecutive_wrong += 1
360
+
361
+ self._rewards.append(reward)
362
+ self._last_reward = reward
363
+
364
+ # Update history
365
+ replay_strategy = self._attacker.current_strategy()
366
+ strategy_shift = (
367
+ self._last_replay_strategy is not None and
368
+ replay_strategy != self._last_replay_strategy
369
+ )
370
+ self._last_replay_strategy = replay_strategy
371
+ self._episode_replay.append({
372
+ "turn": self._turn,
373
+ "p1": (self._phase1_output or {}).get("threat_type", "unknown"),
374
+ "p2_action": p2["action"],
375
+ "target": p2["target_node"],
376
+ "result": _replay_result(info),
377
+ "shift": strategy_shift,
378
+ "impact": float(info.get("business_impact", 0.0)),
379
+ "blast_radius": info.get("dependency_blast_radius", []),
380
+ "active_defenses": self._active_defense_snapshot(),
381
+ "foothold_established": self._foothold_established,
382
+ "foothold_transition": foothold_transition,
383
+ "mission_alignment": info.get("mission_alignment", "neutral"),
384
+ "tool_calls": info.get("tool_count", 0),
385
+ "tool_evidence_found": info.get("tool_evidence_found", False),
386
+ })
387
+
388
+ self._history.append({
389
+ "turn": str(self._turn),
390
+ "p1": f"classified:{(self._phase1_output or {}).get('threat_type','?')}",
391
+ "p2": f"{p2['action']}→{p2['target_node']}",
392
+ "result": info.get("score_reason", "")[:80],
393
+ "reward": f"{reward:.2f}",
394
+ })
395
+
396
+ # Advance attacker
397
+ self._attacker.advance_turn(
398
+ agent_acted_correctly=info.get("acted_correctly", False)
399
+ )
400
+ self._decay_active_defenses()
401
+
402
+ # Advance turn
403
+ self._turn += 1
404
+ self._phase = 1
405
+ self._phase1_output = None
406
+ self._phase1_grading_output = None
407
+
408
+ episode_done = catastrophic or (self._turn > self._config["max_turns"])
409
+ self._done = episode_done
410
+
411
+ # Compute normalized score — ALWAYS present
412
+ norm_score = normalize_episode_score(self._rewards)
413
+
414
+ if not episode_done:
415
+ self._turn_config = self._prepare_turn_config(self._attacker.build_observation())
416
+ obs_dict = build_phase1_obs(
417
+ turn_config=self._turn_config,
418
+ history=self._history,
419
+ task_name=self._task_name,
420
+ turn=self._turn,
421
+ max_turns=self._config["max_turns"],
422
+ episode_id=self._state.episode_id,
423
+ mission_profile=self._mission_profile,
424
+ )
425
+ obs = self._to_obs(obs_dict)
426
+ obs.reward = reward
427
+ obs.done = False
428
+ obs.last_action_result = info.get("score_reason", "")
429
+ obs.metadata = self._metadata_with_defenses({
430
+ "episode_id": self._state.episode_id,
431
+ "normalized_score": float(norm_score),
432
+ "score_breakdown": info,
433
+ "turns_completed": self._turn - 1,
434
+ "consecutive_wrong": self._consecutive_wrong,
435
+ "mission_profile": self._mission_profile,
436
+ })
437
+ else:
438
+ self._attacker.advance_episode()
439
+ obs_dict = build_phase1_obs(
440
+ turn_config={"network_nodes": {}, "active_alerts": ["[EPISODE COMPLETE]"],
441
+ "attack_stage": "none", "is_benign": False,
442
+ "strategy": "none", "correct_action": "none", "correct_target": "none"},
443
+ history=self._history,
444
+ task_name=self._task_name,
445
+ turn=self._turn,
446
+ max_turns=self._config["max_turns"],
447
+ episode_id=self._state.episode_id,
448
+ mission_profile=self._mission_profile,
449
+ )
450
+ obs = self._to_obs(obs_dict)
451
+ obs.reward = reward
452
+ obs.done = True
453
+ obs.last_action_result = info.get("score_reason", "")
454
+ obs.metadata = self._metadata_with_defenses({
455
+ "episode_id": self._state.episode_id,
456
+ "normalized_score": float(norm_score),
457
+ "score_breakdown": info,
458
+ "raw_rewards": self._rewards,
459
+ "catastrophic": catastrophic,
460
+ "turns_completed": self._turn - 1,
461
+ "episode_replay": self._episode_replay,
462
+ "mission_profile": self._mission_profile,
463
+ })
464
+
465
+ self._last_obs = obs
466
+ return obs
467
+ except Exception as exc:
468
+ return self._error_observation(f"step_error: {exc}")
469
+
470
+ @property
471
+ def state(self) -> State:
472
+ """Returns State with episode_id and step_count per OpenEnv spec."""
473
+ return self._state
474
+
475
+ # ── Internal ──────────────────────────────────────────────────────────
476
+
477
+ def _to_obs(self, d: Dict[str, Any]) -> AdaptShieldObservation:
478
+ return AdaptShieldObservation(
479
+ scenario_id = d.get("scenario_id", ""),
480
+ task_name = d.get("task_name", self._task_name),
481
+ phase = d.get("phase", 1),
482
+ turn = d.get("turn", 0),
483
+ max_turns = d.get("max_turns", self._config["max_turns"]),
484
+ network_nodes = d.get("network_nodes", {}),
485
+ active_alerts = d.get("active_alerts", []),
486
+ attack_stage = d.get("attack_stage", "none"),
487
+ history = d.get("history", []),
488
+ phase1_assessment = d.get("phase1_assessment"),
489
+ last_action_result = d.get("last_action_result"),
490
+ system_context = d.get("system_context", ""),
491
+ available_actions = d.get("available_actions", []),
492
+ reward = d.get("reward", 0.0),
493
+ done = d.get("done", False),
494
+ metadata = d.get("metadata", {"normalized_score": 0.50}),
495
+ )
496
+
497
+ @staticmethod
498
+ def _sanitize_world_split(value: str) -> str:
499
+ return value if value in {"train", "eval"} else "train"
500
+
501
+ def _error_observation(self, error_message: str) -> AdaptShieldObservation:
502
+ """Return a safe observation instead of letting step() raise."""
503
+ norm_score = float(normalize_episode_score(self._rewards))
504
+ reward = _clamp(self._last_reward if self._last_reward > 0 else 0.01)
505
+
506
+ if self._phase == 2:
507
+ obs_dict = build_phase2_obs(
508
+ phase1_output=self._phase1_output or {},
509
+ history=self._history,
510
+ task_name=self._task_name,
511
+ turn=self._turn,
512
+ max_turns=self._config["max_turns"],
513
+ episode_id=self._state.episode_id,
514
+ current_score=norm_score,
515
+ mission_profile=self._mission_profile,
516
+ )
517
+ else:
518
+ turn_config = self._turn_config or {
519
+ "network_nodes": {},
520
+ "active_alerts": [f"[ERROR] {error_message}"],
521
+ "attack_stage": "none",
522
+ "is_benign": False,
523
+ "strategy": "unknown",
524
+ "correct_action": "monitor",
525
+ "correct_target": "unknown",
526
+ }
527
+ obs_dict = build_phase1_obs(
528
+ turn_config=turn_config,
529
+ history=self._history,
530
+ task_name=self._task_name,
531
+ turn=self._turn,
532
+ max_turns=self._config["max_turns"],
533
+ episode_id=self._state.episode_id,
534
+ mission_profile=self._mission_profile,
535
+ )
536
+
537
+ obs = self._to_obs(obs_dict)
538
+ obs.reward = float(reward)
539
+ obs.done = bool(self._done)
540
+ obs.last_action_result = error_message
541
+ obs.metadata = self._metadata_with_defenses({
542
+ "episode_id": self._state.episode_id,
543
+ "normalized_score": norm_score,
544
+ "error": error_message,
545
+ "turns_completed": max(0, self._turn - 1),
546
+ "mission_profile": self._mission_profile,
547
+ })
548
+ self._last_obs = obs
549
+ return obs
550
+
551
+ def call_tool(self, tool_name: str, **params: Any) -> Dict[str, Any]:
552
+ """
553
+ Query the local SOC tool surface.
554
+
555
+ These tools reveal partial evidence, not ground-truth answers. They are
556
+ stateful because responses depend on the current turn, attacker stage,
557
+ foothold state, active defenses, and previous actions.
558
+ """
559
+ try:
560
+ tool_name = str(tool_name or "").strip()
561
+ node = str(params.get("node", params.get("target_node", "unknown")) or "unknown")
562
+
563
+ if tool_name == "log_search":
564
+ result = self._tool_log_search(node=node, query=str(params.get("query", "")))
565
+ elif tool_name == "cmdb_lookup":
566
+ result = self._tool_cmdb_lookup(node=node)
567
+ elif tool_name == "edr_status":
568
+ result = self._tool_edr_status(node=node)
569
+ elif tool_name == "vuln_lookup":
570
+ result = self._tool_vuln_lookup(node=node, package=str(params.get("package", "")))
571
+ elif tool_name == "identity_lookup":
572
+ result = self._tool_identity_lookup(node=node)
573
+ elif tool_name == "change_calendar_lookup":
574
+ result = self._tool_change_calendar_lookup(node=node)
575
+ elif tool_name == "netflow_lookup":
576
+ result = self._tool_netflow_lookup(node=node)
577
+ else:
578
+ result = {
579
+ "tool": tool_name,
580
+ "turn": self._turn,
581
+ "node": node,
582
+ "error": f"Unknown tool: {tool_name}",
583
+ "available_tools": [tool["name"] for tool in AVAILABLE_SOC_TOOLS],
584
+ }
585
+
586
+ self._record_tool_result(result)
587
+ return _public_tool_result(result)
588
+ except Exception as exc:
589
+ return {
590
+ "tool": str(tool_name),
591
+ "turn": self._turn,
592
+ "error": f"tool_error: {exc}",
593
+ }
594
+
595
+ def _tool_log_search(self, node: str, query: str = "") -> Dict[str, Any]:
596
+ turn_config = self._turn_config or {}
597
+ hard_task = self._task_name == "polymorphic-zero-day"
598
+ strategy = str(turn_config.get("strategy", "benign"))
599
+ target = str(turn_config.get("correct_target", "none"))
600
+ stage = str(turn_config.get("attack_stage", self._attacker.current_stage()))
601
+ relevant = node == target
602
+
603
+ events = [f"{node}: routine health heartbeat accepted"]
604
+ evidence_type = "benign"
605
+ confidence = 0.32
606
+
607
+ if relevant and strategy == "brute_force":
608
+ events = [
609
+ "auth_service: 143 failed logins grouped by /24 source subnet",
610
+ "auth_service: password spray pattern across active user accounts",
611
+ "payment_service: no matching checkout errors during same window",
612
+ ]
613
+ evidence_type = "brute_force"
614
+ confidence = 0.88
615
+ elif relevant and strategy == "lateral_movement":
616
+ if hard_task:
617
+ events = [
618
+ "auth_service: low-volume identity failures persist from earlier abuse window",
619
+ "payment_service: shared workload credential reused from a new east-west origin",
620
+ "payment_service: callback mesh fans out across private ports after short-lived sessions",
621
+ ]
622
+ confidence = 0.77
623
+ else:
624
+ events = [
625
+ "payment_service: service account login from previously unseen host",
626
+ "auth_service -> payment_service connections exceed baseline",
627
+ "payment_service: short-lived internal sessions fan out to private ports",
628
+ ]
629
+ confidence = 0.84
630
+ evidence_type = "lateral_movement"
631
+ elif relevant and strategy == "exfiltration":
632
+ if hard_task:
633
+ events = [
634
+ "database: bulk row access concentrates on high-value records",
635
+ "database: staging bundle assembled before outbound transfer burst",
636
+ "netflow: outbound transfer volume from database exceeds 7-day baseline",
637
+ ]
638
+ confidence = 0.82
639
+ else:
640
+ events = [
641
+ "database: sequential reads concentrated on sensitive tables",
642
+ "database: compressed archive created before outbound transfer spike",
643
+ "netflow: database egress exceeds 7-day baseline for this hour",
644
+ ]
645
+ confidence = 0.90
646
+ evidence_type = "exfiltration"
647
+ elif relevant and strategy == "supply_chain":
648
+ if hard_task:
649
+ events = [
650
+ "api_gateway: release artifact requested from untrusted update channel",
651
+ "api_gateway: dependency integrity drift detected in gateway manifest",
652
+ "api_gateway: new outbound connection follows release request",
653
+ ]
654
+ confidence = 0.79
655
+ else:
656
+ events = [
657
+ "api_gateway: unsigned package update requested by deploy_bot",
658
+ "api_gateway: dependency hash differs from approved registry hash",
659
+ "api_gateway: new outbound connection follows update attempt",
660
+ ]
661
+ confidence = 0.86
662
+ evidence_type = "supply_chain"
663
+ elif turn_config.get("is_benign", False):
664
+ events = [
665
+ f"{node}: scheduled maintenance event matches change calendar",
666
+ f"{node}: no persistence, beaconing, or sensitive data access observed",
667
+ ]
668
+ confidence = 0.65
669
+
670
+ return {
671
+ "tool": "log_search",
672
+ "turn": self._turn,
673
+ "phase": self._phase,
674
+ "node": node,
675
+ "query": query,
676
+ "events": [self._surface_text(event) for event in events],
677
+ "evidence_type": evidence_type,
678
+ "confidence": confidence,
679
+ "verified": relevant and evidence_type == strategy,
680
+ }
681
+
682
+ def _tool_cmdb_lookup(self, node: str) -> Dict[str, Any]:
683
+ dependencies = {
684
+ "auth_service": ["payment_service"],
685
+ "payment_service": ["api_gateway"],
686
+ "database": ["payment_service", "api_gateway"],
687
+ "api_gateway": ["auth_service", "payment_service", "database"],
688
+ }.get(node, [])
689
+ criticality = {
690
+ "auth_service": "high",
691
+ "payment_service": "critical",
692
+ "database": "critical",
693
+ "api_gateway": "high",
694
+ }.get(node, "unknown")
695
+ safe_actions = {
696
+ "auth_service": ["rate_limit", "patch", "monitor"],
697
+ "payment_service": ["honeypot", "rate_limit", "monitor"],
698
+ "database": ["honeypot", "monitor"],
699
+ "api_gateway": ["patch", "rate_limit", "monitor"],
700
+ }.get(node, ["monitor"])
701
+
702
+ return {
703
+ "tool": "cmdb_lookup",
704
+ "turn": self._turn,
705
+ "phase": self._phase,
706
+ "node": node,
707
+ "owner": SERVICE_OWNERS.get(node, "unknown"),
708
+ "criticality": criticality,
709
+ "dependencies": dependencies,
710
+ "mission_profile": self._mission_profile,
711
+ "safe_actions": safe_actions,
712
+ "evidence_type": "operational_context",
713
+ "verified": True,
714
+ "result_summary": self._surface_text(f"{node} criticality {criticality}; blast radius {len(dependencies)}"),
715
+ }
716
+
717
+ def _tool_edr_status(self, node: str) -> Dict[str, Any]:
718
+ turn_config = self._turn_config or {}
719
+ hard_task = self._task_name == "polymorphic-zero-day"
720
+ strategy = str(turn_config.get("strategy", "benign"))
721
+ target = str(turn_config.get("correct_target", "none"))
722
+ relevant = node == target
723
+ active_controls = [
724
+ control for control in self._active_defenses
725
+ if control.get("target") == node
726
+ ]
727
+
728
+ status = {
729
+ "containment": "not_isolated",
730
+ "persistence": False,
731
+ "beaconing": False,
732
+ "active_controls": active_controls,
733
+ }
734
+ evidence_type = "benign"
735
+ confidence = 0.45
736
+
737
+ if relevant and strategy == "lateral_movement":
738
+ if hard_task:
739
+ status.update({
740
+ "containment": "partial" if self._foothold_established else "unconfirmed",
741
+ "persistence": self._foothold_established,
742
+ "beaconing": self._foothold_established,
743
+ "process_note": (
744
+ "callback process remains active under reused workload credential"
745
+ if self._foothold_established else
746
+ "unexpected child process under shared workload credential; no confirmed beacon yet"
747
+ ),
748
+ })
749
+ confidence = 0.74 if not self._foothold_established else 0.87
750
+ else:
751
+ status.update({
752
+ "containment": "partial" if self._foothold_established else "not_isolated",
753
+ "persistence": self._foothold_established,
754
+ "beaconing": True,
755
+ "process_note": "unknown child process under service account context",
756
+ })
757
+ confidence = 0.87
758
+ evidence_type = "lateral_movement"
759
+ elif relevant and strategy == "exfiltration":
760
+ status.update({
761
+ "containment": "not_isolated",
762
+ "persistence": self._foothold_established,
763
+ "beaconing": False if hard_task else True,
764
+ "process_note": (
765
+ "bundle staging followed by intermittent outbound worker activity"
766
+ if hard_task else
767
+ "archive process followed by outbound transfer"
768
+ ),
769
+ })
770
+ evidence_type = "exfiltration"
771
+ confidence = 0.73 if hard_task else 0.82
772
+ elif active_controls:
773
+ status["containment"] = "control_active"
774
+ confidence = 0.70
775
+
776
+ return {
777
+ "tool": "edr_status",
778
+ "turn": self._turn,
779
+ "phase": self._phase,
780
+ "node": node,
781
+ **status,
782
+ "evidence_type": evidence_type,
783
+ "confidence": confidence,
784
+ "verified": relevant and evidence_type == strategy,
785
+ "process_note": self._surface_text(str(status.get("process_note", ""))) if status.get("process_note") else "",
786
+ }
787
+
788
+ def _tool_vuln_lookup(self, node: str, package: str = "") -> Dict[str, Any]:
789
+ turn_config = self._turn_config or {}
790
+ hard_task = self._task_name == "polymorphic-zero-day"
791
+ strategy = str(turn_config.get("strategy", "benign"))
792
+ target = str(turn_config.get("correct_target", "none"))
793
+ relevant = node == target and strategy == "supply_chain"
794
+
795
+ if relevant:
796
+ advisory = {
797
+ "package": package or "gateway-router",
798
+ "advisory_id": "ADV-AS-042",
799
+ "risk": "critical",
800
+ "finding": (
801
+ "artifact integrity drift from untrusted release channel"
802
+ if hard_task else
803
+ "registry hash mismatch with unsigned update source"
804
+ ),
805
+ "recommended_mitigation": "patch from trusted registry",
806
+ }
807
+ evidence_type = "supply_chain"
808
+ confidence = 0.82 if hard_task else 0.91
809
+ else:
810
+ advisory = {
811
+ "package": package or "unknown",
812
+ "advisory_id": None,
813
+ "risk": "none_known",
814
+ "finding": "no matching active internal advisory for this node/package",
815
+ "recommended_mitigation": "continue investigation",
816
+ }
817
+ evidence_type = "benign"
818
+ confidence = 0.55
819
+
820
+ return {
821
+ "tool": "vuln_lookup",
822
+ "turn": self._turn,
823
+ "phase": self._phase,
824
+ "node": node,
825
+ **{
826
+ **advisory,
827
+ "finding": self._surface_text(str(advisory.get("finding", ""))),
828
+ "recommended_mitigation": self._surface_text(str(advisory.get("recommended_mitigation", ""))),
829
+ },
830
+ "evidence_type": evidence_type,
831
+ "confidence": confidence,
832
+ "verified": relevant,
833
+ }
834
+
835
+ def _tool_identity_lookup(self, node: str) -> Dict[str, Any]:
836
+ turn_config = self._turn_config or {}
837
+ strategy = str(turn_config.get("strategy", "benign"))
838
+ target = str(turn_config.get("correct_target", "none"))
839
+ identity = dict(IDENTITY_CONTEXT.get(node, {
840
+ "account": "unknown",
841
+ "account_type": "unknown",
842
+ "privilege_level": "unknown",
843
+ "normal_hosts": [],
844
+ }))
845
+
846
+ lookup = {
847
+ "tool": "identity_lookup",
848
+ "turn": self._turn,
849
+ "phase": self._phase,
850
+ "node": node,
851
+ **identity,
852
+ "recent_source_host": node,
853
+ "source_status": "expected",
854
+ "identity_anomaly": False,
855
+ "result_summary": "",
856
+ "confidence": 0.58,
857
+ "evidence_type": "benign",
858
+ "verified": False,
859
+ }
860
+
861
+ if node == target and strategy == "lateral_movement":
862
+ lookup.update({
863
+ "recent_source_host": "auth_service",
864
+ "source_status": "unexpected",
865
+ "identity_anomaly": True,
866
+ "confidence": 0.84 if self._task_name != "polymorphic-zero-day" else 0.76,
867
+ "evidence_type": "lateral_movement",
868
+ "verified": True,
869
+ })
870
+ elif node == target and strategy == "supply_chain":
871
+ lookup.update({
872
+ "recent_source_host": "external-release-runner",
873
+ "source_status": "unexpected",
874
+ "identity_anomaly": True,
875
+ "confidence": 0.73,
876
+ "evidence_type": "supply_chain",
877
+ "verified": True,
878
+ })
879
+ elif turn_config.get("is_benign", False):
880
+ lookup.update({
881
+ "recent_source_host": identity.get("normal_hosts", [node])[0] if identity.get("normal_hosts") else node,
882
+ "source_status": "scheduled_change_window",
883
+ "confidence": 0.69,
884
+ })
885
+
886
+ if (
887
+ self._task_name == "dual-pivot" and
888
+ strategy == "lateral_movement" and
889
+ self._operational_mode == "evidence_preservation"
890
+ ):
891
+ lookup["source_status"] = "unexpected_but_trackable"
892
+ lookup["result_summary"] = self._surface_text(
893
+ "Identity trail is intact; preserving visibility before hard containment is mission-aligned."
894
+ )
895
+ else:
896
+ lookup["result_summary"] = self._surface_text(
897
+ f"account={lookup['account']} source={lookup['recent_source_host']} anomaly={lookup['identity_anomaly']}"
898
+ )
899
+ return lookup
900
+
901
+ def _tool_change_calendar_lookup(self, node: str) -> Dict[str, Any]:
902
+ turn_config = self._turn_config or {}
903
+ strategy = str(turn_config.get("strategy", "benign"))
904
+ target = str(turn_config.get("correct_target", "none"))
905
+ change = dict(CHANGE_CALENDAR.get(node, {
906
+ "window": "none_scheduled",
907
+ "change_type": "none",
908
+ "expected_actor": "unknown",
909
+ }))
910
+
911
+ scheduled = bool(turn_config.get("is_benign", False))
912
+ confidence = 0.66 if scheduled else 0.74
913
+ if node == target and strategy == "supply_chain":
914
+ scheduled = False
915
+ confidence = 0.87 if self._task_name != "polymorphic-zero-day" else 0.78
916
+ elif node == target and strategy == "lateral_movement":
917
+ scheduled = False
918
+ confidence = 0.72
919
+
920
+ change_status = "scheduled" if scheduled else "no_matching_change"
921
+ if (
922
+ self._task_name == "dual-pivot" and
923
+ strategy == "lateral_movement" and
924
+ self._operational_mode == "evidence_preservation"
925
+ ):
926
+ change_status = "forensic_observation_hold"
927
+ return {
928
+ "tool": "change_calendar_lookup",
929
+ "turn": self._turn,
930
+ "phase": self._phase,
931
+ "node": node,
932
+ **change,
933
+ "scheduled": scheduled,
934
+ "change_status": change_status,
935
+ "confidence": confidence,
936
+ "evidence_type": "benign" if scheduled else ("supply_chain" if node == target and strategy == "supply_chain" else "operational_context"),
937
+ "verified": scheduled or (node == target and strategy == "supply_chain"),
938
+ }
939
+
940
+ def _tool_netflow_lookup(self, node: str) -> Dict[str, Any]:
941
+ turn_config = self._turn_config or {}
942
+ strategy = str(turn_config.get("strategy", "benign"))
943
+ target = str(turn_config.get("correct_target", "none"))
944
+ hard_task = self._task_name == "polymorphic-zero-day"
945
+
946
+ summary = {
947
+ "tool": "netflow_lookup",
948
+ "turn": self._turn,
949
+ "phase": self._phase,
950
+ "node": node,
951
+ "east_west_connections": 12,
952
+ "outbound_mb": 4,
953
+ "new_destinations": 0,
954
+ "traffic_pattern": "baseline",
955
+ "confidence": 0.57,
956
+ "evidence_type": "benign",
957
+ "verified": False,
958
+ }
959
+
960
+ if node == target and strategy == "lateral_movement":
961
+ summary.update({
962
+ "east_west_connections": 46 if hard_task else 58,
963
+ "outbound_mb": 9,
964
+ "new_destinations": 5 if hard_task else 7,
965
+ "traffic_pattern": "east_west_fanout",
966
+ "confidence": 0.79 if hard_task else 0.88,
967
+ "evidence_type": "lateral_movement",
968
+ "verified": True,
969
+ })
970
+ elif node == target and strategy == "exfiltration":
971
+ summary.update({
972
+ "east_west_connections": 18,
973
+ "outbound_mb": 74 if hard_task else 96,
974
+ "new_destinations": 2,
975
+ "traffic_pattern": "outbound_transfer_burst",
976
+ "confidence": 0.82 if hard_task else 0.91,
977
+ "evidence_type": "exfiltration",
978
+ "verified": True,
979
+ })
980
+ elif node == target and strategy == "supply_chain":
981
+ summary.update({
982
+ "east_west_connections": 16,
983
+ "outbound_mb": 19,
984
+ "new_destinations": 1,
985
+ "traffic_pattern": "post_deploy_callback",
986
+ "confidence": 0.71,
987
+ "evidence_type": "supply_chain",
988
+ "verified": True,
989
+ })
990
+
991
+ if (
992
+ self._task_name == "polymorphic-zero-day" and
993
+ self._operational_mode in {"forensic_hold", "business_continuity"} and
994
+ node == target and strategy in {"lateral_movement", "exfiltration"}
995
+ ):
996
+ summary["traffic_pattern"] = "controlled_observation_window"
997
+ summary["confidence"] = max(float(summary["confidence"]), 0.83)
998
+ return summary
999
+
1000
+ def _record_tool_result(self, result: Dict[str, Any]) -> None:
1001
+ turn = int(result.get("turn", self._turn) or self._turn)
1002
+ internal = {
1003
+ "turn": turn,
1004
+ "phase": result.get("phase", self._phase),
1005
+ "tool": result.get("tool", "unknown"),
1006
+ "node": result.get("node", "unknown"),
1007
+ "evidence_type": result.get("evidence_type", "unknown"),
1008
+ "verified": bool(result.get("verified", False)),
1009
+ "confidence": float(result.get("confidence", 0.0) or 0.0),
1010
+ }
1011
+ self._turn_tool_results.setdefault(turn, []).append(internal)
1012
+
1013
+ trace = {
1014
+ "turn": result.get("turn", self._turn),
1015
+ "phase": result.get("phase", self._phase),
1016
+ "tool": result.get("tool", "unknown"),
1017
+ "node": result.get("node", "unknown"),
1018
+ "confidence": float(result.get("confidence", 0.0) or 0.0),
1019
+ "summary": _tool_summary(result),
1020
+ }
1021
+ self._tool_trace.append(trace)
1022
+
1023
+ if internal["verified"]:
1024
+ self._turn_tool_evidence.setdefault(turn, []).append(internal)
1025
+
1026
+ def _tool_context_for_turn(self) -> Dict[str, Any]:
1027
+ evidence = list(self._turn_tool_evidence.get(self._turn, []))
1028
+ return {
1029
+ "turn": self._turn,
1030
+ "tool_count": len([
1031
+ row for row in self._tool_trace
1032
+ if int(row.get("turn", -1)) == self._turn
1033
+ ]),
1034
+ "evidence": evidence,
1035
+ "tool_results": list(self._turn_tool_results.get(self._turn, [])),
1036
+ }
1037
+
1038
+ def _update_foothold_state(
1039
+ self,
1040
+ p2: Dict[str, str],
1041
+ info: Dict[str, Any],
1042
+ stage: str,
1043
+ ) -> bool:
1044
+ if (
1045
+ self._task_name != "polymorphic-zero-day" or
1046
+ self._foothold_established or
1047
+ stage not in ("exploit", "exfiltration")
1048
+ ):
1049
+ return False
1050
+
1051
+ if p2.get("action") == "monitor" or not info.get("acted_correctly", False):
1052
+ self._foothold_established = True
1053
+ return True
1054
+
1055
+ return False
1056
+
1057
+ def _register_active_defense(self, p2: Dict[str, str]) -> None:
1058
+ action = p2.get("action", "monitor")
1059
+ if action not in DEFENSE_TTL:
1060
+ return
1061
+
1062
+ target = p2.get("target_node", "unknown")
1063
+ self._active_defenses = [
1064
+ control for control in self._active_defenses
1065
+ if not (control["action"] == action and control["target"] == target)
1066
+ ]
1067
+ self._active_defenses.append({
1068
+ "action": action,
1069
+ "target": target,
1070
+ "ttl": DEFENSE_TTL[action],
1071
+ "side_effect": DEFENSE_SIDE_EFFECT[action],
1072
+ })
1073
+
1074
+ def _decay_active_defenses(self) -> None:
1075
+ next_controls = []
1076
+ for control in self._active_defenses:
1077
+ updated = dict(control)
1078
+ updated["ttl"] = int(updated.get("ttl", 0)) - 1
1079
+ if updated["ttl"] > 0:
1080
+ next_controls.append(updated)
1081
+ self._active_defenses = next_controls
1082
+
1083
+ def _active_defense_snapshot(self) -> List[Dict[str, Any]]:
1084
+ return [dict(control) for control in self._active_defenses]
1085
+
1086
+ def _metadata_with_defenses(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
1087
+ updated = dict(metadata or {})
1088
+ updated["active_defenses"] = self._active_defense_snapshot()
1089
+ updated["mission_profile"] = self._mission_profile
1090
+ updated["world_split"] = self._world_split
1091
+ updated["world_family"] = self._world_family
1092
+ updated["operational_mode"] = self._operational_mode
1093
+ updated["available_tools"] = [dict(tool) for tool in AVAILABLE_SOC_TOOLS]
1094
+ updated["tool_trace"] = [dict(row) for row in self._tool_trace]
1095
+ updated["tool_calls_this_turn"] = len([
1096
+ row for row in self._tool_trace
1097
+ if int(row.get("turn", -1)) == self._turn
1098
+ ])
1099
+ return updated
1100
+
1101
+ def _with_active_defense_alerts(self, turn_config: Dict[str, Any]) -> Dict[str, Any]:
1102
+ if not self._active_defenses:
1103
+ return turn_config
1104
+
1105
+ updated = dict(turn_config)
1106
+ alerts = list(updated.get("active_alerts", []))
1107
+ for control in self._active_defenses:
1108
+ alerts.append(
1109
+ "[CONTROL] "
1110
+ f"{control['action']} active on {control['target']} "
1111
+ f"(ttl={control['ttl']}, side_effect={control['side_effect']})"
1112
+ )
1113
+ updated["active_alerts"] = alerts
1114
+ return updated
1115
+
1116
+ def _with_foothold_context(self, turn_config: Dict[str, Any]) -> Dict[str, Any]:
1117
+ if (
1118
+ self._task_name != "polymorphic-zero-day" or
1119
+ not self._foothold_established or
1120
+ not turn_config.get("is_benign", False)
1121
+ ):
1122
+ return turn_config
1123
+
1124
+ updated = dict(turn_config)
1125
+ nodes = {name: dict(metrics) for name, metrics in updated.get("network_nodes", {}).items()}
1126
+ payment = nodes.setdefault("payment_service", {})
1127
+ payment["status"] = "suspicious"
1128
+ payment["cpu"] = max(int(payment.get("cpu", 0)), 58)
1129
+ payment["request_rate"] = max(int(payment.get("request_rate", 0)), 130)
1130
+
1131
+ alerts = list(updated.get("active_alerts", []))
1132
+ alerts.append(
1133
+ "[FOOTHOLD] Low-volume internal callback resembles benign service noise, "
1134
+ "but prior missed containment means deception is required."
1135
+ )
1136
+
1137
+ updated["network_nodes"] = nodes
1138
+ updated["active_alerts"] = alerts
1139
+ updated["attack_stage"] = self._attacker.current_stage()
1140
+ updated["is_benign"] = False
1141
+ updated["strategy"] = "lateral_movement"
1142
+ updated["correct_action"] = "honeypot"
1143
+ updated["correct_target"] = "payment_service"
1144
+ updated["foothold_context"] = True
1145
+ return updated
1146
+
1147
+ def _prepare_turn_config(self, turn_config: Dict[str, Any]) -> Dict[str, Any]:
1148
+ updated = self._with_foothold_context(turn_config)
1149
+ updated = self._apply_operational_mode(updated)
1150
+ updated = self._with_active_defense_alerts(updated)
1151
+ return updated
1152
+
1153
+ def _apply_operational_mode(self, turn_config: Dict[str, Any]) -> Dict[str, Any]:
1154
+ updated = dict(turn_config)
1155
+ alerts = list(updated.get("active_alerts", []))
1156
+ strategy = str(updated.get("strategy", "benign"))
1157
+ stage = str(updated.get("attack_stage", "recon"))
1158
+
1159
+ if self._task_name == "dual-pivot" and strategy == "lateral_movement":
1160
+ if self._operational_mode == "evidence_preservation":
1161
+ updated["correct_action"] = "honeypot"
1162
+ alerts.append(
1163
+ self._surface_text(
1164
+ "[MODE] Checkout incident posture favors preserving attacker visibility until identity misuse is corroborated."
1165
+ )
1166
+ )
1167
+ else:
1168
+ updated["correct_action"] = "isolate"
1169
+ alerts.append(
1170
+ self._surface_text(
1171
+ "[MODE] Checkout incident posture favors decisive containment once the pivot is corroborated."
1172
+ )
1173
+ )
1174
+
1175
+ if self._task_name == "polymorphic-zero-day" and not updated.get("is_benign", False):
1176
+ if self._operational_mode in {"forensic_hold", "business_continuity"} and strategy in {"lateral_movement", "exfiltration"}:
1177
+ updated["correct_action"] = "honeypot"
1178
+ alerts.append(
1179
+ self._surface_text(
1180
+ "[MODE] Enterprise posture favors deception over immediate isolation while mapping the callback path."
1181
+ )
1182
+ )
1183
+ elif self._operational_mode == "containment_first":
1184
+ alerts.append(
1185
+ self._surface_text(
1186
+ "[MODE] Enterprise posture favors immediate containment once compromise is corroborated."
1187
+ )
1188
+ )
1189
+ if self._operational_mode == "business_continuity" and stage == "exploit":
1190
+ nodes = {name: dict(metrics) for name, metrics in updated.get("network_nodes", {}).items()}
1191
+ payment = nodes.setdefault("payment_service", {})
1192
+ payment["status"] = "elevated"
1193
+ payment["request_rate"] = max(int(payment.get("request_rate", 0)), 122)
1194
+ updated["network_nodes"] = nodes
1195
+ alerts.append(
1196
+ self._surface_text(
1197
+ "[MODE] Customer traffic remains sensitive; service continuity pressure is elevated during this window."
1198
+ )
1199
+ )
1200
+
1201
+ updated["active_alerts"] = alerts
1202
+ updated["world_split"] = self._world_split
1203
+ updated["world_family"] = self._world_family
1204
+ updated["operational_mode"] = self._operational_mode
1205
+ return updated
1206
+
1207
+ def _surface_text(self, text: str) -> str:
1208
+ return self._attacker._surface(text)
1209
+
1210
+
1211
+ def _action_value(value: Any, default: str) -> str:
1212
+ """Serialize action fields without leaking Enum member names."""
1213
+ if value is None:
1214
+ return default
1215
+ if isinstance(value, Enum):
1216
+ return str(value.value)
1217
+ return str(value)
1218
+
1219
+
1220
+ def _action_float(value: Any, default: float) -> float:
1221
+ """Coerce optional numeric action fields to floats with a safe fallback."""
1222
+ if value is None:
1223
+ return float(default)
1224
+ try:
1225
+ return float(value)
1226
+ except (TypeError, ValueError):
1227
+ return float(default)
1228
+
1229
+
1230
+ def _replay_result(info: Dict[str, Any]) -> str:
1231
+ """Map grader text into compact replay result labels."""
1232
+ reason = str(info.get("score_reason", "")).lower()
1233
+ if "false positive" in reason:
1234
+ return "false_positive"
1235
+ if reason.startswith("unverified"):
1236
+ return "unverified"
1237
+ if reason.startswith("optimal") or reason.startswith("correct") or reason.startswith("context-aware optimal"):
1238
+ return "optimal"
1239
+ if reason.startswith("heavy-handed"):
1240
+ return "heavy"
1241
+ return "wrong"
1242
+
1243
+
1244
+ def _tool_summary(result: Dict[str, Any]) -> str:
1245
+ if result.get("error"):
1246
+ return str(result["error"])[:120]
1247
+ if result.get("tool") == "log_search":
1248
+ events = result.get("events") or []
1249
+ return str(events[0])[:120] if events else "no matching log events"
1250
+ if result.get("tool") == "cmdb_lookup":
1251
+ deps = result.get("dependencies") or []
1252
+ return f"{result.get('node')} criticality={result.get('criticality')} deps={len(deps)}"
1253
+ if result.get("tool") == "edr_status":
1254
+ return (
1255
+ f"containment={result.get('containment')} "
1256
+ f"beaconing={result.get('beaconing')} "
1257
+ f"persistence={result.get('persistence')}"
1258
+ )
1259
+ if result.get("tool") == "vuln_lookup":
1260
+ return f"risk={result.get('risk')} finding={result.get('finding')}"
1261
+ if result.get("tool") == "identity_lookup":
1262
+ return (
1263
+ f"account={result.get('account')} "
1264
+ f"source={result.get('recent_source_host')} "
1265
+ f"anomaly={result.get('identity_anomaly')}"
1266
+ )
1267
+ if result.get("tool") == "change_calendar_lookup":
1268
+ return (
1269
+ f"scheduled={result.get('scheduled')} "
1270
+ f"window={result.get('window')} "
1271
+ f"change={result.get('change_type')}"
1272
+ )
1273
+ if result.get("tool") == "netflow_lookup":
1274
+ return (
1275
+ f"pattern={result.get('traffic_pattern')} "
1276
+ f"east_west={result.get('east_west_connections')} "
1277
+ f"outbound_mb={result.get('outbound_mb')}"
1278
+ )
1279
+ return str(result)[:120]
1280
+
1281
+
1282
+ def _public_tool_result(result: Dict[str, Any]) -> Dict[str, Any]:
1283
+ """Return observable tool facts without leaking grader labels."""
1284
+ hidden = {"evidence_type", "verified"}
1285
+ public = {key: value for key, value in result.items() if key not in hidden}
1286
+ public["result_summary"] = _tool_summary(result)
1287
+ return public
1288
+
1289
+
1290
+ def _degrade_handoff(
1291
+ phase1_output: Dict[str, Any],
1292
+ turn_config: Dict[str, Any],
1293
+ task_name: str,
1294
+ turn: int,
1295
+ ) -> Dict[str, Any]:
1296
+ """Degrade hard-task handoffs after the strategy shift without changing grading."""
1297
+ output = dict(phase1_output)
1298
+ if (
1299
+ task_name == "dual-pivot" and
1300
+ turn >= 4 and
1301
+ turn_config.get("strategy") == "lateral_movement"
1302
+ ):
1303
+ output["confidence"] = min(float(output.get("confidence", 0.5)), 0.56)
1304
+ output["recommended_action"] = "monitor"
1305
+ output["handoff_quality"] = "degraded"
1306
+ output["handoff_note"] = (
1307
+ "Analyst handoff is less reliable after the pivot; executor should verify "
1308
+ "internal movement before taking payment_service offline."
1309
+ )
1310
+ elif (
1311
+ task_name == "polymorphic-zero-day" and
1312
+ turn >= 4 and
1313
+ turn_config.get("strategy") in {"lateral_movement", "exfiltration", "supply_chain"}
1314
+ ):
1315
+ output["confidence"] = min(float(output.get("confidence", 0.5)), 0.36)
1316
+ output["recommended_action"] = "monitor"
1317
+ output["handoff_quality"] = "degraded"
1318
+ output["handoff_note"] = (
1319
+ "Analyst confidence degraded after attacker pivot; executor must verify "
1320
+ "whether monitor is too passive for the current stage."
1321
+ )
1322
+ else:
1323
+ output["handoff_quality"] = "clean"
1324
+ return output
server/app.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdaptShield FastAPI Server
3
+
4
+ CRITICAL: Uses factory pattern (make_env function), NOT singleton.
5
+ Singleton was the Round 1 failure — always served wrong task.
6
+ Factory creates a fresh isolated instance per evaluator session.
7
+
8
+ openenv validate requires:
9
+ - def main() function present
10
+ - called as main() in if __name__ block (literal string check)
11
+ - port 7860 (HF Spaces default)
12
+ """
13
+
14
+ import os
15
+ import sys
16
+ from typing import Any, Dict
17
+ from uuid import uuid4
18
+
19
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
20
+
21
+ try:
22
+ from fastapi import Body, HTTPException
23
+ from openenv.core.env_server.http_server import create_app
24
+ except Exception as e:
25
+ raise ImportError(
26
+ "openenv-core required. Install: pip install openenv-core"
27
+ ) from e
28
+
29
+ from models import AdaptShieldAction, AdaptShieldObservation
30
+ from server.adaptshield_environment import AdaptShieldEnvironment
31
+
32
+ DEFAULT_TASK = os.getenv("ADAPTSHIELD_TASK", "direct-triage")
33
+ SOC_SESSIONS: Dict[str, AdaptShieldEnvironment] = {}
34
+
35
+
36
+ def make_env() -> AdaptShieldEnvironment:
37
+ """
38
+ Factory function — fresh isolated instance per session.
39
+ Never a singleton. Evaluator sessions must be independent.
40
+ """
41
+ return AdaptShieldEnvironment(task_name=DEFAULT_TASK)
42
+
43
+
44
+ app = create_app(
45
+ make_env,
46
+ AdaptShieldAction,
47
+ AdaptShieldObservation,
48
+ env_name="adaptshield",
49
+ max_concurrent_envs=10,
50
+ )
51
+
52
+
53
+ @app.post("/soc/reset", tags=["AdaptShield SOC Tools"])
54
+ async def soc_reset(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
55
+ """Start a persistent demo session for SOC tool/API workflows."""
56
+ task = str(payload.get("task", DEFAULT_TASK))
57
+ env = AdaptShieldEnvironment(task_name=task)
58
+ obs = env.reset()
59
+ session_id = str(uuid4())
60
+ SOC_SESSIONS[session_id] = env
61
+ return {
62
+ "session_id": session_id,
63
+ "observation": obs.model_dump(mode="json"),
64
+ "available_tools": obs.metadata.get("available_tools", []),
65
+ }
66
+
67
+
68
+ @app.post("/soc/step", tags=["AdaptShield SOC Tools"])
69
+ async def soc_step(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
70
+ """Step a persistent SOC tool/API session."""
71
+ env = _soc_session(payload)
72
+ try:
73
+ action = AdaptShieldAction(**dict(payload.get("action", {})))
74
+ except Exception as exc:
75
+ raise HTTPException(status_code=422, detail=str(exc)) from exc
76
+
77
+ obs = env.step(action)
78
+ return {
79
+ "session_id": payload.get("session_id"),
80
+ "observation": obs.model_dump(mode="json"),
81
+ "reward": float(obs.reward),
82
+ "done": bool(obs.done),
83
+ }
84
+
85
+
86
+ @app.post("/tools/log_search", tags=["AdaptShield SOC Tools"])
87
+ async def tool_log_search(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
88
+ """Search stateful SIEM/application logs for the active session."""
89
+ return _soc_session(payload).call_tool(
90
+ "log_search",
91
+ node=payload.get("node", payload.get("target_node", "unknown")),
92
+ query=payload.get("query", ""),
93
+ )
94
+
95
+
96
+ @app.post("/tools/cmdb_lookup", tags=["AdaptShield SOC Tools"])
97
+ async def tool_cmdb_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
98
+ """Look up service ownership, criticality, and dependency blast radius."""
99
+ return _soc_session(payload).call_tool(
100
+ "cmdb_lookup",
101
+ node=payload.get("node", payload.get("target_node", "unknown")),
102
+ )
103
+
104
+
105
+ @app.post("/tools/edr_status", tags=["AdaptShield SOC Tools"])
106
+ async def tool_edr_status(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
107
+ """Check endpoint containment and persistence indicators."""
108
+ return _soc_session(payload).call_tool(
109
+ "edr_status",
110
+ node=payload.get("node", payload.get("target_node", "unknown")),
111
+ )
112
+
113
+
114
+ @app.post("/tools/vuln_lookup", tags=["AdaptShield SOC Tools"])
115
+ async def tool_vuln_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
116
+ """Query internal vulnerability/advisory evidence for a service package."""
117
+ return _soc_session(payload).call_tool(
118
+ "vuln_lookup",
119
+ node=payload.get("node", payload.get("target_node", "unknown")),
120
+ package=payload.get("package", ""),
121
+ )
122
+
123
+
124
+ @app.post("/tools/identity_lookup", tags=["AdaptShield SOC Tools"])
125
+ async def tool_identity_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
126
+ """Inspect account behavior and unusual source-host affinity for a service identity."""
127
+ return _soc_session(payload).call_tool(
128
+ "identity_lookup",
129
+ node=payload.get("node", payload.get("target_node", "unknown")),
130
+ )
131
+
132
+
133
+ @app.post("/tools/change_calendar_lookup", tags=["AdaptShield SOC Tools"])
134
+ async def tool_change_calendar_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
135
+ """Check whether a deploy or maintenance window was actually scheduled."""
136
+ return _soc_session(payload).call_tool(
137
+ "change_calendar_lookup",
138
+ node=payload.get("node", payload.get("target_node", "unknown")),
139
+ )
140
+
141
+
142
+ @app.post("/tools/netflow_lookup", tags=["AdaptShield SOC Tools"])
143
+ async def tool_netflow_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
144
+ """Inspect east-west and outbound traffic summaries for the active session."""
145
+ return _soc_session(payload).call_tool(
146
+ "netflow_lookup",
147
+ node=payload.get("node", payload.get("target_node", "unknown")),
148
+ )
149
+
150
+
151
+ def _soc_session(payload: Dict[str, Any]) -> AdaptShieldEnvironment:
152
+ session_id = str(payload.get("session_id", ""))
153
+ env = SOC_SESSIONS.get(session_id)
154
+ if env is None:
155
+ raise HTTPException(
156
+ status_code=404,
157
+ detail="Unknown SOC session. Call /soc/reset first.",
158
+ )
159
+ return env
160
+
161
+
162
+ def main(host: str = "0.0.0.0", port: int = 7860) -> None:
163
+ """Start the uvicorn server. Call main() to run."""
164
+ import uvicorn
165
+ uvicorn.run(app, host=host, port=port)
166
+
167
+
168
+ if __name__ == "__main__":
169
+ import argparse
170
+ parser = argparse.ArgumentParser()
171
+ parser.add_argument("--port", type=int, default=7860)
172
+ args = parser.parse_args()
173
+ main()
server/attacker.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdaptShield Attacker Engine
3
+
4
+ Deterministic scripted Python adversary — NEVER an LLM.
5
+ Implements attack stage progression within episodes.
6
+ Rotates strategies between episodes on a fixed schedule.
7
+
8
+ The internal episode counter is NEVER passed to the agent.
9
+ Agent must infer strategy from network signals only.
10
+ """
11
+
12
+ import os
13
+ import random
14
+ from typing import Any, Dict, Tuple
15
+
16
+
17
+ # ── Strategy definitions ───────────────────────────────────────────────────
18
+ STRATEGY_MAP: Dict[str, Dict[str, str]] = {
19
+ "brute_force": {
20
+ "correct_action": "rate_limit",
21
+ "correct_target": "auth_service",
22
+ "description": "Credential stuffing attack on authentication service",
23
+ },
24
+ "lateral_movement": {
25
+ "correct_action": "isolate",
26
+ "correct_target": "payment_service",
27
+ "description": "Attacker pivoting from compromised host to internal nodes",
28
+ },
29
+ "exfiltration": {
30
+ "correct_action": "honeypot",
31
+ "correct_target": "database",
32
+ "description": "Slow sustained data exfiltration from database",
33
+ },
34
+ "supply_chain": {
35
+ "correct_action": "patch",
36
+ "correct_target": "api_gateway",
37
+ "description": "Malicious package injected into update pipeline",
38
+ },
39
+ }
40
+
41
+ # Strategy rotation order
42
+ STRATEGY_ROTATION = [
43
+ "brute_force",
44
+ "lateral_movement",
45
+ "exfiltration",
46
+ "supply_chain",
47
+ ]
48
+
49
+ # Attack stages within an episode
50
+ # Each stage escalates if agent fails to act
51
+ ATTACK_STAGES = ["recon", "exploit", "exfiltration"]
52
+
53
+ # Baseline healthy network metrics
54
+ BASELINE_NODES: Dict[str, Dict[str, Any]] = {
55
+ "auth_service": {"status": "healthy", "request_rate": 120, "error_rate": 0.01, "cpu": 22, "outbound_mb": 2},
56
+ "payment_service": {"status": "healthy", "request_rate": 85, "error_rate": 0.01, "cpu": 31, "outbound_mb": 5},
57
+ "database": {"status": "healthy", "request_rate": 34, "error_rate": 0.01, "cpu": 18, "outbound_mb": 8},
58
+ "api_gateway": {"status": "healthy", "request_rate": 203, "error_rate": 0.02, "cpu": 29, "outbound_mb": 3},
59
+ }
60
+
61
+ SOURCE_SUBNETS = ["192.168.1.x", "10.44.8.x", "172.16.22.x", "203.0.113.x"]
62
+ SERVICE_ACCOUNTS = ["svc_internal", "svc_billing", "svc_reporter", "deploy_bot"]
63
+ PACKAGE_NAMES = ["core-auth-lib", "gateway-router", "payment-sdk", "session-cache"]
64
+ DB_TABLES = ["customer_tokens", "invoice_archive", "payment_methods", "audit_events"]
65
+ ALERT_SOURCES = ["SIEM", "EDR", "WAF", "NETFLOW"]
66
+
67
+ FAMILY_REPLACEMENTS = {
68
+ "train-a": {
69
+ "failed login attempts": "failed login attempts",
70
+ "password spray pattern": "password spray pattern",
71
+ "Unusual internal connections": "Unusual internal connections",
72
+ "Cross-node traffic volume": "Cross-node traffic volume",
73
+ "outbound traffic": "outbound traffic",
74
+ "Large compressed archive creation activity": "Large compressed archive creation activity",
75
+ "unsigned package update request": "unsigned package update request",
76
+ "binary hash mismatch": "binary hash mismatch",
77
+ "Unexpected outbound connection": "Unexpected outbound connection",
78
+ "scheduled analytics cache refresh": "scheduled analytics cache refresh",
79
+ },
80
+ "train-b": {
81
+ "failed login attempts": "authentication rejections",
82
+ "password spray pattern": "credential abuse spread",
83
+ "Unusual internal connections": "Unexpected service-mesh hops",
84
+ "Cross-node traffic volume": "Identity-linked east-west volume",
85
+ "outbound traffic": "egress volume",
86
+ "Large compressed archive creation activity": "archive staging activity",
87
+ "unsigned package update request": "unapproved release request",
88
+ "binary hash mismatch": "manifest integrity drift",
89
+ "Unexpected outbound connection": "post-deploy callback session",
90
+ "scheduled analytics cache refresh": "scheduled cache tuning window",
91
+ },
92
+ "eval-x": {
93
+ "failed login attempts": "auth rejection burst",
94
+ "password spray pattern": "credential reuse sweep",
95
+ "Unusual internal connections": "callback-path fan-out",
96
+ "Cross-node traffic volume": "mesh traffic clustering",
97
+ "outbound traffic": "data egress pressure",
98
+ "Large compressed archive creation activity": "bundle staging activity",
99
+ "unsigned package update request": "release provenance anomaly",
100
+ "binary hash mismatch": "artifact provenance drift",
101
+ "Unexpected outbound connection": "release-linked callback session",
102
+ "scheduled analytics cache refresh": "approved observability warmup",
103
+ },
104
+ "eval-y": {
105
+ "failed login attempts": "lockout storm",
106
+ "password spray pattern": "shared-secret sweep",
107
+ "Unusual internal connections": "lateral fan-out path",
108
+ "Cross-node traffic volume": "cross-domain session churn",
109
+ "outbound traffic": "archive egress volume",
110
+ "Large compressed archive creation activity": "sealed archive staging",
111
+ "unsigned package update request": "cross-approval deploy request",
112
+ "binary hash mismatch": "release integrity anomaly",
113
+ "Unexpected outbound connection": "unknown release callback",
114
+ "scheduled analytics cache refresh": "scheduled edge warmup",
115
+ },
116
+ }
117
+
118
+
119
+ class AttackerEngine:
120
+ """
121
+ Polymorphic scripted attacker with stage progression.
122
+
123
+ Within an episode: attack progresses through recon → exploit → exfiltration
124
+ if the agent fails to act correctly. Early correct action stops progression.
125
+
126
+ Between episodes: strategy rotates on a fixed schedule per task.
127
+ Hard task additionally shifts strategy mid-episode after turn 3.
128
+ """
129
+
130
+ def __init__(self, task_name: str, world_family: str = "train-a"):
131
+ random.seed(int(os.environ.get("ADAPTSHIELD_SEED", random.randint(0, 9999))))
132
+
133
+ self.task_name = task_name
134
+ self.world_family = world_family
135
+ self._episode = 0 # internal — NEVER passed to agent
136
+ self._turn = 0 # within-episode turn counter
137
+ self._stage_idx = 0 # current attack stage index
138
+ self._escalated = False # did agent miss a turn?
139
+
140
+ self._shift_every = {
141
+ "direct-triage": 9999,
142
+ "dual-pivot": 20,
143
+ "polymorphic-zero-day": 10,
144
+ }.get(task_name, 9999)
145
+
146
+ self._noise_rate = 0.15 if task_name == "polymorphic-zero-day" else 0.0
147
+
148
+ # ── Public interface ───────────────────────────────────────────────────
149
+
150
+ def reset_episode(self) -> None:
151
+ """Called at the start of each episode."""
152
+ self._turn = 0
153
+ self._stage_idx = 0
154
+ self._escalated = False
155
+
156
+ def current_strategy(self) -> str:
157
+ """Return active strategy — accounts for mid-episode shift on hard task."""
158
+ base_idx = (self._episode // self._shift_every) % len(STRATEGY_ROTATION)
159
+
160
+ # Medium task: one clean pivot from brute force to lateral movement.
161
+ if self.task_name == "dual-pivot" and self._turn >= 3:
162
+ base_idx = (base_idx + 1) % 2
163
+
164
+ # Hard task: strategy shifts mid-episode after turn 3
165
+ if self.task_name == "polymorphic-zero-day" and self._turn >= 3:
166
+ base_idx = (base_idx + 1) % len(STRATEGY_ROTATION)
167
+
168
+ return STRATEGY_ROTATION[base_idx]
169
+
170
+ def current_stage(self) -> str:
171
+ """Return current attack stage."""
172
+ return ATTACK_STAGES[min(self._stage_idx, len(ATTACK_STAGES) - 1)]
173
+
174
+ def build_observation(self) -> Dict[str, Any]:
175
+ """
176
+ Build the network observation for current turn.
177
+ Returns ambiguous signals — no explicit strategy names or codes.
178
+ Agent must reason about what the metrics mean.
179
+ """
180
+ is_benign = random.random() < self._noise_rate
181
+ strategy = self.current_strategy()
182
+ stage = self.current_stage()
183
+
184
+ nodes = {k: dict(v) for k, v in BASELINE_NODES.items()}
185
+ alerts = []
186
+
187
+ if is_benign:
188
+ alerts = self._build_benign_alerts()
189
+ nodes["api_gateway"]["cpu"] = 71
190
+ nodes["api_gateway"]["status"] = "elevated"
191
+ return {
192
+ "network_nodes": nodes,
193
+ "active_alerts": alerts,
194
+ "attack_stage": "none",
195
+ "is_benign": True,
196
+ "strategy": "benign",
197
+ "correct_action": "monitor",
198
+ "correct_target": "none",
199
+ }
200
+
201
+ nodes, alerts = self._build_attack_state(strategy, stage, nodes)
202
+
203
+ return {
204
+ "network_nodes": nodes,
205
+ "active_alerts": alerts,
206
+ "attack_stage": stage,
207
+ "is_benign": False,
208
+ "strategy": strategy,
209
+ "correct_action": STRATEGY_MAP[strategy]["correct_action"],
210
+ "correct_target": STRATEGY_MAP[strategy]["correct_target"],
211
+ }
212
+
213
+ def advance_turn(self, agent_acted_correctly: bool) -> None:
214
+ """
215
+ Advance turn counter. Escalate attack stage if agent failed.
216
+ Called after grading each Phase 2 action.
217
+ """
218
+ self._turn += 1
219
+ if not agent_acted_correctly:
220
+ self._escalated = True
221
+ self._stage_idx = min(self._stage_idx + 1, len(ATTACK_STAGES) - 1)
222
+
223
+ def advance_episode(self) -> None:
224
+ """Called at the end of each episode."""
225
+ self._episode += 1
226
+ self.reset_episode()
227
+
228
+ # ── Internal signal builders ───────────────────────────────────────────
229
+
230
+ def _build_benign_alerts(self) -> list:
231
+ """Build convincing but harmless noise alerts."""
232
+ deploy_user = random.choice(SERVICE_ACCOUNTS)
233
+ alert_prefix = self._alert_prefix()
234
+ options = [
235
+ [
236
+ f"{alert_prefix} Scheduled database backup completed: 2.3GB archived",
237
+ f"{alert_prefix} User '{deploy_user}' executed routine deployment script",
238
+ f"{alert_prefix} api_gateway: elevated CPU due to analytics batch job",
239
+ ],
240
+ [
241
+ f"{alert_prefix} Certificate renewal completed for *.internal.corp",
242
+ f"{alert_prefix} Routine log rotation completed on all nodes",
243
+ f"{alert_prefix} payment_service: normal traffic spike during billing window",
244
+ ],
245
+ [
246
+ f"{alert_prefix} Health check: all nodes nominal",
247
+ f"{alert_prefix} auth_service: password policy enforcement triggered for 3 users",
248
+ f"{alert_prefix} Scheduled vulnerability scan completed: 0 critical findings",
249
+ ],
250
+ ]
251
+ return [self._surface(line) for line in random.choice(options)]
252
+
253
+ def _build_attack_state(
254
+ self,
255
+ strategy: str,
256
+ stage: str,
257
+ nodes: Dict,
258
+ ) -> Tuple[Dict, list]:
259
+ """
260
+ Build network state and alerts for an active attack.
261
+ Signals are intentionally ambiguous — no strategy names or MITRE codes.
262
+ Severity escalates with attack stage.
263
+ """
264
+ stage_multiplier = {"recon": 1.0, "exploit": 1.6, "exfiltration": 2.4}
265
+ mult = stage_multiplier.get(stage, 1.0)
266
+
267
+ alerts = []
268
+ alert_prefix = self._alert_prefix()
269
+
270
+ if strategy == "brute_force":
271
+ subnet = random.choice(SOURCE_SUBNETS)
272
+ nodes["auth_service"]["status"] = "degraded" if stage != "recon" else "elevated"
273
+ nodes["auth_service"]["request_rate"] = int(340 * mult)
274
+ nodes["auth_service"]["error_rate"] = min(0.48, 0.18 * mult)
275
+ nodes["auth_service"]["cpu"] = min(96, int(45 * mult))
276
+ if self.task_name == "direct-triage":
277
+ nodes["api_gateway"]["status"] = "elevated"
278
+ nodes["api_gateway"]["cpu"] = max(nodes["api_gateway"]["cpu"], 54)
279
+ alerts = [
280
+ f"{alert_prefix} auth_service: {int(120 * mult)} failed login attempts in last 60s from subnet {subnet}",
281
+ f"{alert_prefix} auth_service: error rate {nodes['auth_service']['error_rate']:.0%}; {stage} activity detected",
282
+ f"{alert_prefix} payment_service: normal transaction volume",
283
+ ]
284
+ if self.task_name == "direct-triage":
285
+ alerts.append(
286
+ f"{alert_prefix} api_gateway: elevated CPU overlaps with scheduled analytics cache refresh"
287
+ )
288
+
289
+ elif strategy == "lateral_movement":
290
+ account = random.choice(SERVICE_ACCOUNTS)
291
+ nodes["payment_service"]["status"] = "suspicious"
292
+ nodes["payment_service"]["request_rate"] = int(95 * mult)
293
+ nodes["payment_service"]["cpu"] = int(38 * mult)
294
+ nodes["auth_service"]["request_rate"] = int(155 * mult)
295
+ add_auth_distractor = (
296
+ (self.task_name == "dual-pivot" and self._turn == 3) or
297
+ (self.task_name == "polymorphic-zero-day" and self._turn >= 3)
298
+ )
299
+ if add_auth_distractor:
300
+ nodes["auth_service"]["status"] = "elevated"
301
+ nodes["auth_service"]["error_rate"] = min(0.36, 0.14 * mult)
302
+ nodes["auth_service"]["cpu"] = min(90, int(42 * mult))
303
+ alerts = [
304
+ f"{alert_prefix} Unusual internal connections: auth_service -> payment_service ({int(40 * mult)}/min)",
305
+ f"{alert_prefix} payment_service: service account '{account}' authenticated from new host",
306
+ f"{alert_prefix} Cross-node traffic volume {mult:.1f}x above 7-day baseline",
307
+ ]
308
+ if add_auth_distractor:
309
+ alerts.append(
310
+ f"{alert_prefix} auth_service: residual failed-login noise from earlier credential attack"
311
+ )
312
+ if self.task_name == "dual-pivot":
313
+ alerts.append(
314
+ f"{alert_prefix} payment_service: emergency checkout tuning window is not on the approved change calendar"
315
+ )
316
+
317
+ elif strategy == "exfiltration":
318
+ table = random.choice(DB_TABLES)
319
+ nodes["database"]["status"] = "suspicious" if stage != "recon" else "elevated"
320
+ nodes["database"]["outbound_mb"] = int(8 * mult * 12)
321
+ nodes["database"]["request_rate"] = int(34 * mult * 3)
322
+ nodes["database"]["cpu"] = int(18 * mult * 1.5)
323
+ alerts = [
324
+ f"{alert_prefix} database: outbound traffic {nodes['database']['outbound_mb']}MB/hr; {mult:.1f}x above baseline",
325
+ f"{alert_prefix} database: {int(280 * mult)} sequential reads from table '{table}'",
326
+ f"{alert_prefix} Large compressed archive creation activity on database node",
327
+ ]
328
+ if self.task_name == "polymorphic-zero-day":
329
+ alerts.append(
330
+ f"{alert_prefix} payment_service: reused workload credential appears on east-west callback path"
331
+ )
332
+
333
+ elif strategy == "supply_chain":
334
+ package = random.choice(PACKAGE_NAMES)
335
+ nodes["api_gateway"]["status"] = "suspicious"
336
+ nodes["api_gateway"]["cpu"] = int(29 + 20 * mult)
337
+ alerts = [
338
+ f"{alert_prefix} api_gateway: unsigned package update request received from external source",
339
+ f"{alert_prefix} api_gateway: binary hash mismatch on dependency '{package}'",
340
+ f"{alert_prefix} Unexpected outbound connection from api_gateway to unrecognized host",
341
+ ]
342
+ if self.task_name == "polymorphic-zero-day":
343
+ alerts.append(
344
+ f"{alert_prefix} change calendar: no approved gateway deploy window covers this request"
345
+ )
346
+
347
+ return nodes, [self._surface(alert) for alert in alerts]
348
+
349
+ def _alert_prefix(self) -> str:
350
+ """Return deterministic-looking SOC alert metadata under ADAPTSHIELD_SEED."""
351
+ source = random.choice(ALERT_SOURCES)
352
+ alert_id = random.randint(1000, 9999)
353
+ minute = random.randint(0, 59)
354
+ second = random.randint(0, 59)
355
+ return f"[{source}-{alert_id} 03:{minute:02d}:{second:02d}Z]"
356
+
357
+ def _surface(self, text: str) -> str:
358
+ surfaced = str(text)
359
+ for source, target in FAMILY_REPLACEMENTS.get(self.world_family, {}).items():
360
+ surfaced = surfaced.replace(source, target)
361
+ return surfaced
server/grader.py ADDED
@@ -0,0 +1,501 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdaptShield Grader
3
+
4
+ Fully deterministic reward computation.
5
+ No NLP, no LLM-as-judge, no external API calls.
6
+ Runs in milliseconds.
7
+
8
+ INVARIANT: All return values are strictly in [0.01, 0.99].
9
+ normalized_score is ALWAYS present in metadata — never missing.
10
+ """
11
+
12
+ from typing import Any, Dict, List, Tuple
13
+
14
+
15
+ # ── Reward constants ───────────────────────────────────────────────────────
16
+ BASE_REWARD = 0.50 # survival — agent gets this for staying alive
17
+ P1_TYPE_BONUS = 0.15 # Phase 1 correctly identified threat type
18
+ P1_TARGET_BONUS = 0.10 # Phase 1 correctly identified target node
19
+ P2_OPTIMAL = 0.39 # correct + efficient action (0.50+0.39 = 0.89 base max)
20
+ P2_HEAVY = 0.18 # correct action but heavy-handed
21
+ P2_UNVERIFIED = 0.10 # right hard-task action without investigative evidence
22
+ P2_WRONG = -0.25 # wrong action on real threat
23
+ FALSE_POSITIVE = -0.39 # acted aggressively on benign event
24
+ STAGE_ESCALATION = -0.10 # penalty applied when attack stage escalates
25
+ CATASTROPHIC = -0.49 # database exfiltration completed → done=True
26
+
27
+ # Optimal vs heavy-handed action mapping
28
+ OPTIMAL_ACTION = {
29
+ "brute_force": "rate_limit",
30
+ "lateral_movement": "isolate",
31
+ "exfiltration": "honeypot",
32
+ "supply_chain": "patch",
33
+ }
34
+
35
+ HEAVY_ACTION = {
36
+ "brute_force": "isolate", # works but unnecessary downtime
37
+ "lateral_movement": "honeypot", # works but slower
38
+ "exfiltration": "isolate", # works but database goes offline
39
+ "supply_chain": "isolate", # works but gateway goes offline
40
+ }
41
+
42
+ # Catastrophic: exfiltration succeeds if agent uses wrong action 3+ times
43
+ EXFIL_CATASTROPHIC_ACTIONS = ["monitor", "rate_limit", "patch"]
44
+
45
+ # Operational model: criticality and dependency blast radius.
46
+ # This is intentionally lightweight so it strengthens realism without making
47
+ # rewards noisy or hard to train against.
48
+ ASSET_CRITICALITY = {
49
+ "auth_service": 0.70,
50
+ "payment_service": 0.90,
51
+ "database": 1.00,
52
+ "api_gateway": 0.80,
53
+ "none": 0.00,
54
+ "unknown": 0.50,
55
+ }
56
+
57
+ SERVICE_DEPENDENCIES = {
58
+ "auth_service": ["payment_service"],
59
+ "payment_service": ["api_gateway"],
60
+ "database": ["payment_service", "api_gateway"],
61
+ "api_gateway": ["auth_service", "payment_service", "database"],
62
+ "none": [],
63
+ "unknown": [],
64
+ }
65
+
66
+ ACTION_DISRUPTION = {
67
+ "monitor": 0.00,
68
+ "patch": 0.06,
69
+ "rate_limit": 0.10,
70
+ "honeypot": 0.12,
71
+ "isolate": 0.35,
72
+ }
73
+
74
+ MAX_OPERATIONAL_PENALTY = 0.05
75
+ MAX_MISSION_ADJUSTMENT = 0.04
76
+
77
+ BASE_REQUIRED_TOOL_FUSION = {
78
+ "brute_force": {"log_search", "cmdb_lookup"},
79
+ "lateral_movement": {"edr_status", "log_search"},
80
+ "exfiltration": {"log_search", "edr_status"},
81
+ "supply_chain": {"vuln_lookup", "log_search"},
82
+ }
83
+
84
+ TASK_REQUIRED_TOOL_FUSION = {
85
+ "direct-triage": {
86
+ "brute_force": {"log_search"},
87
+ },
88
+ "dual-pivot": {
89
+ "lateral_movement": {"edr_status", "log_search", "identity_lookup"},
90
+ },
91
+ "polymorphic-zero-day": {
92
+ "brute_force": {"log_search", "cmdb_lookup", "identity_lookup"},
93
+ "lateral_movement": {"edr_status", "log_search", "identity_lookup", "cmdb_lookup"},
94
+ "exfiltration": {"log_search", "edr_status", "netflow_lookup", "cmdb_lookup"},
95
+ "supply_chain": {"vuln_lookup", "log_search", "change_calendar_lookup", "cmdb_lookup"},
96
+ },
97
+ }
98
+
99
+
100
+ def grade_step(
101
+ phase1_action: Dict[str, Any],
102
+ phase2_action: Dict[str, Any],
103
+ turn_config: Dict[str, Any],
104
+ stage: str,
105
+ consecutive_wrong: int,
106
+ task_name: str = "",
107
+ foothold_established: bool = False,
108
+ mission_profile: Dict[str, Any] | None = None,
109
+ tool_context: Dict[str, Any] | None = None,
110
+ ) -> Tuple[float, bool, Dict[str, Any]]:
111
+ """
112
+ Grade a complete two-phase step.
113
+
114
+ Args:
115
+ phase1_action: Agent's Phase 1 output (threat assessment)
116
+ phase2_action: Agent's Phase 2 output (defensive action)
117
+ turn_config: Ground truth from AttackerEngine.build_observation()
118
+ stage: Current attack stage (recon/exploit/exfiltration)
119
+ consecutive_wrong: How many consecutive wrong actions agent has taken
120
+
121
+ Returns:
122
+ (reward, catastrophic_failure, info_dict)
123
+ reward is ALWAYS in range [0.01, 0.99]
124
+ """
125
+ is_benign = turn_config.get("is_benign", False)
126
+ strategy = turn_config.get("strategy", "benign")
127
+ correct_action = turn_config.get("correct_action", "monitor")
128
+ correct_target = turn_config.get("correct_target", "none")
129
+ mission_profile = mission_profile or {}
130
+ tool_context = tool_context or {}
131
+ contextual_countermeasure = (
132
+ task_name == "polymorphic-zero-day" and
133
+ foothold_established and
134
+ strategy == "lateral_movement"
135
+ )
136
+ if contextual_countermeasure:
137
+ correct_action = "honeypot"
138
+
139
+ p1_type = str(phase1_action.get("threat_type", ""))
140
+ p1_target = str(phase1_action.get("target_node", ""))
141
+ p2_action = str(phase2_action.get("action", "monitor"))
142
+ p2_target = str(phase2_action.get("target_node", ""))
143
+
144
+ reward = BASE_REWARD
145
+ catastrophic = False
146
+
147
+ info = {
148
+ "strategy": strategy,
149
+ "stage": stage,
150
+ "correct_action": correct_action,
151
+ "correct_target": correct_target,
152
+ "p1_threat_type": p1_type,
153
+ "p1_target": p1_target,
154
+ "p2_action": p2_action,
155
+ "p2_target": p2_target,
156
+ "is_benign": is_benign,
157
+ "p1_type_correct": False,
158
+ "p1_target_correct": False,
159
+ "score_reason": "",
160
+ "business_impact": 0.0,
161
+ "availability_impact": 0.0,
162
+ "security_risk": 0.0,
163
+ "dependency_blast_radius": [],
164
+ "operational_penalty": 0.0,
165
+ "foothold_established": foothold_established,
166
+ "contextual_countermeasure": contextual_countermeasure,
167
+ "mission_profile": mission_profile,
168
+ "mission_alignment": "neutral",
169
+ "mission_adjustment": 0.0,
170
+ "tool_verification_required": False,
171
+ "tool_evidence_found": False,
172
+ "tool_fusion_required": [],
173
+ "tool_fusion_found": [],
174
+ "tool_count": int(tool_context.get("tool_count", 0) or 0),
175
+ }
176
+
177
+ # ── False positive path ────────────────────────────────────────────────
178
+ if is_benign:
179
+ if p2_action in ("rate_limit", "isolate", "honeypot", "patch"):
180
+ reward += FALSE_POSITIVE
181
+ info["acted_correctly"] = False
182
+ info["score_reason"] = (
183
+ f"False positive: used {p2_action} on benign event"
184
+ )
185
+ else:
186
+ reward += 0.15
187
+ info["acted_correctly"] = True
188
+ info["score_reason"] = "Correct: monitored benign event without overreacting"
189
+ _apply_operational_impact(
190
+ info=info,
191
+ action=p2_action,
192
+ target=p2_target,
193
+ result_kind="false_positive" if p2_action != "monitor" else "optimal",
194
+ strategy=strategy,
195
+ stage=stage,
196
+ )
197
+ _apply_mission_alignment(
198
+ info=info,
199
+ action=p2_action,
200
+ target=p2_target,
201
+ result_kind="false_positive" if p2_action != "monitor" else "optimal",
202
+ mission_profile=mission_profile,
203
+ )
204
+ reward += info["mission_adjustment"]
205
+ reward -= info["operational_penalty"]
206
+ return _clamp(reward), False, info
207
+
208
+ # ── Phase 1 accuracy bonuses ───────────────────────────────────────────
209
+ if p1_type == strategy:
210
+ reward += P1_TYPE_BONUS
211
+ info["p1_type_correct"] = True
212
+
213
+ if p1_target == correct_target:
214
+ reward += P1_TARGET_BONUS
215
+ info["p1_target_correct"] = True
216
+
217
+ # ── Catastrophic failure: exfiltration completed ───────────────────────
218
+ if (strategy == "exfiltration" and
219
+ stage == "exfiltration" and
220
+ p2_action in EXFIL_CATASTROPHIC_ACTIONS):
221
+ reward = BASE_REWARD + CATASTROPHIC
222
+ catastrophic = True
223
+ info["score_reason"] = (
224
+ "CATASTROPHIC: Database exfiltration completed — "
225
+ f"agent used {p2_action} instead of honeypot"
226
+ )
227
+ _apply_operational_impact(
228
+ info=info,
229
+ action=p2_action,
230
+ target=p2_target,
231
+ result_kind="wrong",
232
+ strategy=strategy,
233
+ stage=stage,
234
+ )
235
+ _apply_mission_alignment(
236
+ info=info,
237
+ action=p2_action,
238
+ target=p2_target,
239
+ result_kind="wrong",
240
+ mission_profile=mission_profile,
241
+ )
242
+ reward += info["mission_adjustment"]
243
+ return _clamp(reward), catastrophic, info
244
+
245
+ # ── Stage escalation penalty ───────────────────────────────────────────
246
+ if stage == "exploit" and consecutive_wrong >= 1:
247
+ reward += STAGE_ESCALATION
248
+ elif stage == "exfiltration" and consecutive_wrong >= 2:
249
+ reward += STAGE_ESCALATION * 2
250
+
251
+ # ── Phase 2 action grading ─────────────────────────────────────────────
252
+ optimal = correct_action
253
+ heavy = "" if contextual_countermeasure else HEAVY_ACTION.get(strategy, "")
254
+ if heavy == optimal:
255
+ heavy = ""
256
+ requires_tool_verification = (
257
+ not is_benign and
258
+ strategy in OPTIMAL_ACTION and
259
+ (
260
+ task_name == "polymorphic-zero-day" or
261
+ (task_name == "dual-pivot" and strategy == "lateral_movement") or
262
+ (task_name == "direct-triage" and strategy == "brute_force")
263
+ )
264
+ )
265
+ required_tools = _required_tool_fusion(task_name=task_name, strategy=strategy)
266
+ tool_evidence_found, fusion_found = _has_relevant_tool_evidence(
267
+ tool_context=tool_context,
268
+ strategy=strategy,
269
+ target=correct_target,
270
+ required_tools=required_tools,
271
+ )
272
+ info["tool_verification_required"] = requires_tool_verification
273
+ info["tool_evidence_found"] = tool_evidence_found
274
+ info["tool_fusion_required"] = sorted(required_tools)
275
+ info["tool_fusion_found"] = sorted(fusion_found)
276
+
277
+ if (
278
+ p2_action == optimal and
279
+ p2_target == correct_target and
280
+ requires_tool_verification and
281
+ not tool_evidence_found
282
+ ):
283
+ reward += P2_UNVERIFIED
284
+ result_kind = "unverified"
285
+ info["score_reason"] = (
286
+ f"Unverified correct action: {p2_action} on {p2_target} would help, "
287
+ f"but {task_name or 'this task'} requires stronger SOC evidence before full credit"
288
+ )
289
+
290
+ elif p2_action == optimal and p2_target == correct_target:
291
+ reward += P2_OPTIMAL
292
+ result_kind = "optimal"
293
+ if contextual_countermeasure:
294
+ info["score_reason"] = (
295
+ f"Context-aware optimal: {p2_action} on {p2_target} — "
296
+ "foothold already established, so deception beats isolation"
297
+ )
298
+ else:
299
+ info["score_reason"] = (
300
+ f"Optimal: {p2_action} on {p2_target} — attack stopped efficiently"
301
+ )
302
+
303
+ elif p2_action == optimal and p2_target != correct_target:
304
+ reward += P2_HEAVY * 0.5
305
+ result_kind = "wrong_target"
306
+ info["score_reason"] = (
307
+ f"Right action ({p2_action}) but wrong target "
308
+ f"(got {p2_target}, needed {correct_target})"
309
+ )
310
+
311
+ elif p2_action == heavy and p2_target == correct_target:
312
+ reward += P2_HEAVY
313
+ result_kind = "heavy"
314
+ info["score_reason"] = (
315
+ f"Heavy-handed: {p2_action} stopped attack on {p2_target} "
316
+ f"but caused unnecessary service disruption"
317
+ )
318
+
319
+ else:
320
+ reward += P2_WRONG
321
+ result_kind = "wrong"
322
+ info["score_reason"] = (
323
+ f"Wrong: {p2_action} on {p2_target} — "
324
+ f"needed {correct_action} on {correct_target}"
325
+ )
326
+
327
+ acted_correctly = p2_action in (optimal, heavy) and p2_target == correct_target
328
+ info["acted_correctly"] = acted_correctly
329
+ _apply_operational_impact(
330
+ info=info,
331
+ action=p2_action,
332
+ target=p2_target,
333
+ result_kind=result_kind,
334
+ strategy=strategy,
335
+ stage=stage,
336
+ )
337
+ _apply_mission_alignment(
338
+ info=info,
339
+ action=p2_action,
340
+ target=p2_target,
341
+ result_kind=result_kind,
342
+ mission_profile=mission_profile,
343
+ )
344
+ reward += info["mission_adjustment"]
345
+ reward -= info["operational_penalty"]
346
+
347
+ return _clamp(reward), catastrophic, info
348
+
349
+
350
+ def _apply_mission_alignment(
351
+ info: Dict[str, Any],
352
+ action: str,
353
+ target: str,
354
+ result_kind: str,
355
+ mission_profile: Dict[str, Any],
356
+ ) -> None:
357
+ sla_priority = str(mission_profile.get("sla_priority", "balanced"))
358
+ primary_asset = str(mission_profile.get("primary_asset", "unknown"))
359
+ risk_tolerance = str(mission_profile.get("risk_tolerance", "medium"))
360
+
361
+ adjustment = 0.0
362
+ alignment = "neutral"
363
+
364
+ if sla_priority == "availability" and action == "isolate" and target == primary_asset:
365
+ adjustment -= MAX_MISSION_ADJUSTMENT
366
+ alignment = "sla_violation"
367
+ elif sla_priority == "availability" and result_kind == "optimal" and action in ("rate_limit", "patch", "monitor"):
368
+ adjustment += MAX_MISSION_ADJUSTMENT / 2
369
+ alignment = "sla_aligned"
370
+ elif sla_priority == "containment" and result_kind == "optimal" and action in ("honeypot", "isolate", "patch"):
371
+ adjustment += MAX_MISSION_ADJUSTMENT / 2
372
+ alignment = "containment_aligned"
373
+ elif risk_tolerance == "low" and result_kind in ("wrong", "wrong_target"):
374
+ adjustment -= MAX_MISSION_ADJUSTMENT / 2
375
+ alignment = "risk_misaligned"
376
+
377
+ info["mission_alignment"] = alignment
378
+ info["mission_adjustment"] = round(adjustment, 2)
379
+
380
+
381
+ def _apply_operational_impact(
382
+ info: Dict[str, Any],
383
+ action: str,
384
+ target: str,
385
+ result_kind: str,
386
+ strategy: str,
387
+ stage: str,
388
+ ) -> None:
389
+ """
390
+ Add deterministic business-impact telemetry and a small bounded penalty.
391
+
392
+ The penalty is intentionally capped at 0.05 so existing learning curves keep
393
+ their shape while demos can explain service criticality and blast radius.
394
+ """
395
+ criticality = ASSET_CRITICALITY.get(target, ASSET_CRITICALITY["unknown"])
396
+ disruption = ACTION_DISRUPTION.get(action, 0.10)
397
+ dependents = SERVICE_DEPENDENCIES.get(target, [])
398
+ dependency_factor = min(1.0, 0.15 * len(dependents))
399
+
400
+ availability = round(min(1.0, disruption * (criticality + dependency_factor)), 2)
401
+ security = _security_risk(result_kind=result_kind, strategy=strategy, stage=stage)
402
+ impact = round(min(1.0, availability + security), 2)
403
+
404
+ if result_kind == "optimal":
405
+ penalty = 0.0
406
+ elif result_kind == "unverified":
407
+ penalty = round(min(MAX_OPERATIONAL_PENALTY, impact * MAX_OPERATIONAL_PENALTY / 2), 2)
408
+ else:
409
+ penalty = round(min(MAX_OPERATIONAL_PENALTY, impact * MAX_OPERATIONAL_PENALTY), 2)
410
+
411
+ info["business_impact"] = impact
412
+ info["availability_impact"] = availability
413
+ info["security_risk"] = security
414
+ info["dependency_blast_radius"] = dependents if disruption > 0 else []
415
+ info["operational_penalty"] = penalty
416
+
417
+
418
+ def _security_risk(result_kind: str, strategy: str, stage: str) -> float:
419
+ if result_kind in ("optimal", "heavy"):
420
+ return 0.0
421
+ if result_kind == "unverified":
422
+ return 0.08
423
+ if result_kind == "false_positive":
424
+ return 0.0
425
+
426
+ stage_risk = {
427
+ "recon": 0.18,
428
+ "exploit": 0.32,
429
+ "exfiltration": 0.50,
430
+ }.get(stage, 0.20)
431
+
432
+ if strategy == "exfiltration":
433
+ stage_risk += 0.15
434
+ elif strategy == "lateral_movement":
435
+ stage_risk += 0.08
436
+
437
+ return round(min(1.0, stage_risk), 2)
438
+
439
+
440
+ def _has_relevant_tool_evidence(
441
+ tool_context: Dict[str, Any],
442
+ strategy: str,
443
+ target: str,
444
+ required_tools: set[str],
445
+ ) -> Tuple[bool, set[str]]:
446
+ fusion_found = {
447
+ str(result.get("tool", ""))
448
+ for result in tool_context.get("tool_results", []) or []
449
+ if str(result.get("node", "")) == target
450
+ }
451
+ has_attack_evidence = False
452
+ for evidence in tool_context.get("evidence", []) or []:
453
+ if (
454
+ str(evidence.get("evidence_type", "")) == strategy and
455
+ str(evidence.get("node", "")) == target and
456
+ bool(evidence.get("verified", False))
457
+ ):
458
+ has_attack_evidence = True
459
+ break
460
+
461
+ return has_attack_evidence and required_tools.issubset(fusion_found), fusion_found
462
+
463
+
464
+ def _required_tool_fusion(task_name: str, strategy: str) -> set[str]:
465
+ task_rules = TASK_REQUIRED_TOOL_FUSION.get(task_name, {})
466
+ if strategy in task_rules:
467
+ return set(task_rules[strategy])
468
+ return set(BASE_REQUIRED_TOOL_FUSION.get(strategy, set()))
469
+
470
+
471
+ def _clamp(value: float) -> float:
472
+ """Strict bounds: never exactly 0.0 or 1.0."""
473
+ return max(0.01, min(0.99, round(value, 2)))
474
+
475
+
476
+ def normalize_episode_score(rewards: List[float]) -> float:
477
+ """
478
+ Normalize episode rewards to a single score strictly in (0.01, 0.99).
479
+ ALWAYS returns a value — never raises, never returns exactly 0 or 1.
480
+ """
481
+ if not rewards:
482
+ return 0.50
483
+
484
+ total = sum(rewards)
485
+ n = len(rewards)
486
+
487
+ # Per-step rewards are clamped before they enter the episode reward list,
488
+ # so normalization must use the reachable ceiling instead of the raw
489
+ # unclamped sum of bonuses. Otherwise perfect episodes top out around 0.87.
490
+ max_step_reward = _clamp(
491
+ BASE_REWARD + P2_OPTIMAL + P1_TYPE_BONUS + P1_TARGET_BONUS + MAX_MISSION_ADJUSTMENT
492
+ )
493
+ min_step_reward = _clamp(BASE_REWARD + CATASTROPHIC)
494
+ max_poss = n * max_step_reward
495
+ min_poss = n * min_step_reward
496
+
497
+ if max_poss == min_poss:
498
+ return 0.50
499
+
500
+ raw = (total - min_poss) / (max_poss - min_poss)
501
+ return _clamp(raw)
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+
5
+
6
+
server/scenarios.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ AdaptShield Scenarios
3
+
4
+ Observation builder and system prompts.
5
+ IMPORTANT: No MITRE codes in alerts. No explicit strategy names.
6
+ Agent must reason from raw metrics — not pattern match on codes.
7
+ """
8
+
9
+ import random
10
+ from typing import Any, Dict, List
11
+
12
+ VALID_NODES = ["auth_service", "payment_service", "database", "api_gateway"]
13
+ VALID_ACTIONS = ["rate_limit", "isolate", "honeypot", "patch", "monitor"]
14
+
15
+ TASK_CONFIGS = {
16
+ "direct-triage": {
17
+ "max_turns": 5,
18
+ "description": "Single fixed attacker strategy. Learn baseline threat response.",
19
+ "mission_profile": {
20
+ "mission_id": "login_stability",
21
+ "primary_asset": "auth_service",
22
+ "sla_priority": "availability",
23
+ "risk_tolerance": "medium",
24
+ "objective": "Preserve user login availability while stopping credential abuse.",
25
+ },
26
+ },
27
+ "dual-pivot": {
28
+ "max_turns": 6,
29
+ "description": "Two strategies alternating every 20 episodes.",
30
+ "mission_profile": {
31
+ "mission_id": "checkout_continuity",
32
+ "primary_asset": "payment_service",
33
+ "sla_priority": "availability",
34
+ "risk_tolerance": "medium",
35
+ "objective": "Keep checkout online while containing internal movement.",
36
+ },
37
+ },
38
+ "polymorphic-zero-day": {
39
+ "max_turns": 8,
40
+ "description": "All four strategies with mid-episode shift and noise injection.",
41
+ "mission_profile": {
42
+ "mission_id": "breach_containment",
43
+ "primary_asset": "database",
44
+ "sla_priority": "containment",
45
+ "risk_tolerance": "low",
46
+ "objective": "Protect sensitive data and adapt response if attacker persistence is established.",
47
+ },
48
+ },
49
+ }
50
+
51
+ WORLD_FAMILY_SPLITS = {
52
+ "train": ["train-a", "train-b"],
53
+ "eval": ["eval-x", "eval-y"],
54
+ }
55
+
56
+ TASK_OPERATIONAL_MODES = {
57
+ "direct-triage": ["availability_guarded"],
58
+ "dual-pivot": ["containment_first", "evidence_preservation"],
59
+ "polymorphic-zero-day": [
60
+ "containment_first",
61
+ "forensic_hold",
62
+ "business_continuity",
63
+ ],
64
+ }
65
+
66
+ FAMILY_MISSION_NOTES = {
67
+ "train-a": "Primary incident feed emphasizes login telemetry and CMDB-linked service relationships.",
68
+ "train-b": "Primary incident feed emphasizes identity anomalies, service-account movement, and deploy context.",
69
+ "eval-x": "Primary incident feed emphasizes east-west callback patterns, release provenance drift, and egress clustering.",
70
+ "eval-y": "Primary incident feed emphasizes rejection bursts, archive staging, and cross-domain approval mismatches.",
71
+ }
72
+
73
+ MODE_HINTS = {
74
+ "availability_guarded": {
75
+ "sla_priority": "availability",
76
+ "risk_tolerance": "medium",
77
+ "objective_suffix": "Use one confirming signal before disruptive action when login noise overlaps with maintenance chatter.",
78
+ },
79
+ "containment_first": {
80
+ "sla_priority": "containment",
81
+ "risk_tolerance": "low",
82
+ "objective_suffix": "Prefer decisive containment once corroborating evidence suggests compromise is active.",
83
+ },
84
+ "evidence_preservation": {
85
+ "sla_priority": "balanced",
86
+ "risk_tolerance": "medium",
87
+ "objective_suffix": "Preserve attacker visibility during checkout pivots until identity misuse is corroborated across tools.",
88
+ },
89
+ "forensic_hold": {
90
+ "sla_priority": "containment",
91
+ "risk_tolerance": "low",
92
+ "objective_suffix": "Preserve attacker tradecraft long enough to map the callback path before cutting access entirely.",
93
+ },
94
+ "business_continuity": {
95
+ "sla_priority": "availability",
96
+ "risk_tolerance": "medium",
97
+ "objective_suffix": "Reduce blast radius while protecting customer-facing continuity whenever deception can safely buy time.",
98
+ },
99
+ }
100
+
101
+ PHASE1_SYSTEM = """You are a Threat Analyst defending a 4-node enterprise network.
102
+ Analyze the network metrics and security alerts to classify the threat.
103
+
104
+ Network nodes: auth_service, payment_service, database, api_gateway
105
+
106
+ Attack strategies to recognize:
107
+ - brute_force: High failed login rate on auth_service, error spikes
108
+ - lateral_movement: Unusual cross-node connections, service account anomalies
109
+ - exfiltration: Anomalous outbound data volume from database node
110
+ - supply_chain: Suspicious update requests or hash mismatches on api_gateway
111
+ - benign: Normal operations — scheduled jobs, routine activity
112
+
113
+ Respond ONLY with valid JSON, no other text:
114
+ {
115
+ "threat_type": "<brute_force|lateral_movement|exfiltration|supply_chain|benign>",
116
+ "confidence": <0.0 to 1.0>,
117
+ "target_node": "<auth_service|payment_service|database|api_gateway>",
118
+ "recommended_action": "<rate_limit|isolate|honeypot|patch|monitor>",
119
+ "reasoning": "<one sentence chain of thought>"
120
+ }"""
121
+
122
+ PHASE2_SYSTEM = """You are a Tactical Executor defending a 4-node enterprise network.
123
+ You receive a threat assessment from the Threat Analyst and must execute the defense.
124
+ You CANNOT see raw network logs — act only on the assessment provided.
125
+
126
+ Available actions:
127
+ - rate_limit: Throttle traffic to node. Light touch, keeps service online. Best for DoS/brute force.
128
+ - isolate: Take node completely offline. Stops spread but causes downtime. Use for lateral movement.
129
+ - honeypot: Redirect attacker to decoy system. Best for data exfiltration attempts.
130
+ - patch: Apply security update. Targeted fix for supply chain attacks.
131
+ - monitor: Observe without acting. Use only when genuinely uncertain or event is benign.
132
+
133
+ Respond ONLY with valid JSON, no other text:
134
+ {
135
+ "action": "<rate_limit|isolate|honeypot|patch|monitor>",
136
+ "target_node": "<auth_service|payment_service|database|api_gateway>",
137
+ "reasoning": "<one sentence chain of thought>"
138
+ }"""
139
+
140
+
141
+ def build_phase1_obs(
142
+ turn_config: Dict[str, Any],
143
+ history: List[Dict[str, str]],
144
+ task_name: str,
145
+ turn: int,
146
+ max_turns: int,
147
+ episode_id: str,
148
+ mission_profile: Dict[str, Any] | None = None,
149
+ ) -> Dict[str, Any]:
150
+ """Build Phase 1 observation — full network state visible."""
151
+ mission_profile = mission_profile or {}
152
+ return {
153
+ "scenario_id": episode_id,
154
+ "task_name": task_name,
155
+ "phase": 1,
156
+ "turn": turn,
157
+ "max_turns": max_turns,
158
+ "network_nodes": turn_config["network_nodes"],
159
+ "active_alerts": turn_config["active_alerts"],
160
+ "attack_stage": turn_config.get("attack_stage", "none"),
161
+ "history": history[-3:],
162
+ "phase1_assessment": None,
163
+ "last_action_result": None,
164
+ "system_context": _with_mission_context(PHASE1_SYSTEM, mission_profile),
165
+ "available_actions": VALID_ACTIONS,
166
+ "reward": 0.0,
167
+ "done": False,
168
+ "metadata": {
169
+ "episode_id": episode_id,
170
+ "normalized_score": 0.50, # always present from step 1
171
+ "mission_profile": mission_profile,
172
+ },
173
+ }
174
+
175
+
176
+ def build_phase2_obs(
177
+ phase1_output: Dict[str, Any],
178
+ history: List[Dict[str, str]],
179
+ task_name: str,
180
+ turn: int,
181
+ max_turns: int,
182
+ episode_id: str,
183
+ current_score: float,
184
+ mission_profile: Dict[str, Any] | None = None,
185
+ ) -> Dict[str, Any]:
186
+ """
187
+ Build Phase 2 observation.
188
+ CRITICAL: network_nodes and active_alerts are EMPTY.
189
+ Phase 2 agent is blind to raw state — sees only Phase 1 assessment.
190
+ """
191
+ mission_profile = mission_profile or {}
192
+ return {
193
+ "scenario_id": episode_id,
194
+ "task_name": task_name,
195
+ "phase": 2,
196
+ "turn": turn,
197
+ "max_turns": max_turns,
198
+ "network_nodes": {}, # deliberately empty
199
+ "active_alerts": [], # deliberately empty
200
+ "attack_stage": "hidden",
201
+ "history": history[-3:],
202
+ "phase1_assessment": phase1_output,
203
+ "last_action_result": None,
204
+ "system_context": _with_mission_context(PHASE2_SYSTEM, mission_profile),
205
+ "available_actions": VALID_ACTIONS,
206
+ "reward": 0.0,
207
+ "done": False,
208
+ "metadata": {
209
+ "episode_id": episode_id,
210
+ "normalized_score": current_score, # always present
211
+ "mission_profile": mission_profile,
212
+ },
213
+ }
214
+
215
+
216
+ def _with_mission_context(system_prompt: str, mission_profile: Dict[str, Any]) -> str:
217
+ if not mission_profile:
218
+ return system_prompt
219
+
220
+ mission = "\n".join([
221
+ "",
222
+ "Mission context:",
223
+ f"- mission_id: {mission_profile.get('mission_id', 'unknown')}",
224
+ f"- primary_asset: {mission_profile.get('primary_asset', 'unknown')}",
225
+ f"- sla_priority: {mission_profile.get('sla_priority', 'balanced')}",
226
+ f"- risk_tolerance: {mission_profile.get('risk_tolerance', 'medium')}",
227
+ f"- objective: {mission_profile.get('objective', 'Balance security and availability.')}",
228
+ ])
229
+ return f"{system_prompt}{mission}"
230
+
231
+
232
+ def choose_world_family(world_split: str, requested_family: str | None = None) -> str:
233
+ if requested_family:
234
+ return requested_family
235
+ families = WORLD_FAMILY_SPLITS.get(world_split, WORLD_FAMILY_SPLITS["train"])
236
+ return random.choice(families)
237
+
238
+
239
+ def choose_operational_mode(task_name: str, requested_mode: str | None = None) -> str:
240
+ if requested_mode:
241
+ return requested_mode
242
+ modes = TASK_OPERATIONAL_MODES.get(task_name, ["availability_guarded"])
243
+ return random.choice(modes)
244
+
245
+
246
+ def mission_profile_for(task_name: str, operational_mode: str, world_family: str) -> Dict[str, Any]:
247
+ base = dict(TASK_CONFIGS[task_name].get("mission_profile", {}))
248
+ mode = MODE_HINTS.get(operational_mode, {})
249
+ base["world_family"] = world_family
250
+ base["operational_mode_hint"] = operational_mode.replace("_", " ")
251
+ base["scenario_style"] = FAMILY_MISSION_NOTES.get(world_family, "")
252
+ if mode.get("sla_priority"):
253
+ base["sla_priority"] = mode["sla_priority"]
254
+ if mode.get("risk_tolerance"):
255
+ base["risk_tolerance"] = mode["risk_tolerance"]
256
+ objective = str(base.get("objective", "")).rstrip()
257
+ suffix = str(mode.get("objective_suffix", "")).strip()
258
+ family_note = str(FAMILY_MISSION_NOTES.get(world_family, "")).strip()
259
+ if suffix:
260
+ objective = f"{objective} {suffix}".strip()
261
+ if family_note:
262
+ objective = f"{objective} {family_note}".strip()
263
+ base["objective"] = objective
264
+ return base
smoke_test.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick repo-root smoke test for AdaptShield.
4
+
5
+ Run from the repo root:
6
+ python smoke_test.py
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import sys
12
+ from pathlib import Path
13
+
14
+
15
+ REPO_ROOT = Path(__file__).resolve().parent
16
+
17
+ if str(REPO_ROOT) not in sys.path:
18
+ sys.path.insert(0, str(REPO_ROOT))
19
+
20
+ import __init__ as adaptshield
21
+ import server.app as server_app
22
+ from models import AdaptShieldAction
23
+ from server.adaptshield_environment import AdaptShieldEnvironment
24
+
25
+
26
+ def main() -> int:
27
+ print("AdaptShield smoke test")
28
+ print(f"- package exports: {adaptshield.__all__}")
29
+ print(f"- server app type: {server_app.app.__class__.__name__}")
30
+
31
+ env = AdaptShieldEnvironment("direct-triage")
32
+ obs = env.reset()
33
+ print(
34
+ f"- reset: phase={obs.phase} turn={obs.turn} "
35
+ f"score={obs.metadata.get('normalized_score')}"
36
+ )
37
+
38
+ obs = env.step(
39
+ AdaptShieldAction(
40
+ threat_type="brute_force",
41
+ confidence=0.9,
42
+ target_node="auth_service",
43
+ recommended_action="rate_limit",
44
+ )
45
+ )
46
+ print(f"- phase 1 -> phase 2: assessment={obs.phase1_assessment}")
47
+
48
+ obs = env.step(AdaptShieldAction(action="rate_limit", target_node="auth_service"))
49
+ print(
50
+ f"- phase 2 -> next turn: reward={obs.reward} done={obs.done} "
51
+ f"result={obs.last_action_result}"
52
+ )
53
+
54
+ print("Smoke test passed.")
55
+ return 0
56
+
57
+
58
+ if __name__ == "__main__":
59
+ raise SystemExit(main())
soc_tools.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Shared SOC investigation helpers for AdaptShield agents."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import json
7
+ import urllib.request
8
+ from typing import Any, Callable, Dict, List, Optional
9
+
10
+
11
+ THREAT_TOOL_PLAN = {
12
+ "brute_force": [("log_search", "auth_service"), ("cmdb_lookup", "auth_service"), ("identity_lookup", "auth_service")],
13
+ "lateral_movement": [("edr_status", "payment_service"), ("log_search", "payment_service"), ("identity_lookup", "payment_service"), ("cmdb_lookup", "payment_service")],
14
+ "exfiltration": [("log_search", "database"), ("edr_status", "database"), ("netflow_lookup", "database"), ("cmdb_lookup", "database")],
15
+ "supply_chain": [("vuln_lookup", "api_gateway"), ("log_search", "api_gateway"), ("change_calendar_lookup", "api_gateway"), ("cmdb_lookup", "api_gateway")],
16
+ "benign": [("cmdb_lookup", "api_gateway")],
17
+ }
18
+
19
+ FALLBACK_SWEEP = [
20
+ ("edr_status", "payment_service"),
21
+ ("log_search", "database"),
22
+ ("vuln_lookup", "api_gateway"),
23
+ ]
24
+
25
+
26
+ def classify_from_metrics(network_nodes: Dict[str, Dict[str, Any]]) -> str:
27
+ auth = network_nodes.get("auth_service", {})
28
+ payment = network_nodes.get("payment_service", {})
29
+ database = network_nodes.get("database", {})
30
+ gateway = network_nodes.get("api_gateway", {})
31
+
32
+ if float(auth.get("error_rate", 0.0)) >= 0.10:
33
+ return "brute_force"
34
+ if payment.get("status") == "suspicious" or float(payment.get("cpu", 0)) >= 55:
35
+ return "lateral_movement"
36
+ if float(database.get("outbound_mb", 0)) >= 50:
37
+ return "exfiltration"
38
+ if gateway.get("status") == "suspicious":
39
+ return "supply_chain"
40
+ return "benign"
41
+
42
+
43
+ def investigate_local(env: Any, obs: Any, use_tools: bool) -> List[Dict[str, Any]]:
44
+ """Query local environment tool methods before Phase 1 action."""
45
+ return investigate_local_with_depth(env, obs, use_tools=use_tools, thorough=False)
46
+
47
+
48
+ def investigate_local_with_depth(
49
+ env: Any,
50
+ obs: Any,
51
+ use_tools: bool,
52
+ thorough: bool,
53
+ ) -> List[Dict[str, Any]]:
54
+ """Query local tools; thorough mode adds evidence-fusion follow-ups."""
55
+ if not use_tools or getattr(obs, "phase", 1) != 1:
56
+ return []
57
+ task_name = getattr(obs, "task_name", "")
58
+ threat = classify_from_metrics(getattr(obs, "network_nodes", {}))
59
+ if task_name == "direct-triage":
60
+ if threat == "brute_force":
61
+ return [env.call_tool("log_search", node="auth_service")]
62
+ return []
63
+ if task_name == "dual-pivot":
64
+ if threat == "lateral_movement":
65
+ return [
66
+ env.call_tool("edr_status", node="payment_service"),
67
+ env.call_tool("log_search", node="payment_service"),
68
+ env.call_tool("identity_lookup", node="payment_service"),
69
+ ]
70
+ tool_name, node = THREAT_TOOL_PLAN.get(threat, THREAT_TOOL_PLAN["benign"])[0]
71
+ return [env.call_tool(tool_name, node=node)]
72
+
73
+ if task_name != "polymorphic-zero-day":
74
+ return []
75
+
76
+ results = []
77
+ for tool_name, node in THREAT_TOOL_PLAN.get(threat, THREAT_TOOL_PLAN["benign"]):
78
+ results.append(env.call_tool(tool_name, node=node))
79
+
80
+ if not has_attack_indicators(results):
81
+ for tool_name, node in FALLBACK_SWEEP:
82
+ if (tool_name, node) not in THREAT_TOOL_PLAN.get(threat, []):
83
+ results.append(env.call_tool(tool_name, node=node))
84
+ if thorough:
85
+ _complete_evidence_fusion(
86
+ call_tool=lambda tool_name, node: env.call_tool(tool_name, node=node),
87
+ results=results,
88
+ )
89
+ return results
90
+
91
+
92
+ def investigate_http(
93
+ env_base_url: str,
94
+ session_id: Optional[str],
95
+ obs: Dict[str, Any],
96
+ use_tools: bool,
97
+ thorough: bool = False,
98
+ ) -> List[Dict[str, Any]]:
99
+ """Query SOC HTTP tool endpoints for a persistent /soc session."""
100
+ if not use_tools or not session_id or int(obs.get("phase", 1)) != 1:
101
+ return []
102
+ task_name = obs.get("task_name")
103
+ threat = classify_from_metrics(obs.get("network_nodes", {}))
104
+
105
+ def call(tool_name: str, node: str) -> Dict[str, Any]:
106
+ path = f"/tools/{tool_name}"
107
+ payload = {"session_id": session_id, "node": node}
108
+ return http_post(env_base_url, path, payload)
109
+
110
+ if task_name == "direct-triage":
111
+ if threat == "brute_force":
112
+ return [call("log_search", "auth_service")]
113
+ return []
114
+
115
+ results: List[Dict[str, Any]] = []
116
+
117
+ if task_name == "dual-pivot":
118
+ if threat == "lateral_movement":
119
+ return [
120
+ call("edr_status", "payment_service"),
121
+ call("log_search", "payment_service"),
122
+ call("identity_lookup", "payment_service"),
123
+ ]
124
+ tool_name, node = THREAT_TOOL_PLAN.get(threat, THREAT_TOOL_PLAN["benign"])[0]
125
+ return [call(tool_name, node)]
126
+
127
+ if task_name != "polymorphic-zero-day":
128
+ return []
129
+
130
+ for tool_name, node in THREAT_TOOL_PLAN.get(threat, THREAT_TOOL_PLAN["benign"]):
131
+ results.append(call(tool_name, node))
132
+
133
+ if not has_attack_indicators(results):
134
+ for tool_name, node in FALLBACK_SWEEP:
135
+ if (tool_name, node) not in THREAT_TOOL_PLAN.get(threat, []):
136
+ results.append(call(tool_name, node))
137
+ if thorough:
138
+ _complete_evidence_fusion(call_tool=call, results=results)
139
+ return results
140
+
141
+
142
+ def has_attack_indicators(results: List[Dict[str, Any]]) -> bool:
143
+ return infer_threat_from_tool_results(results)["threat_type"] != "benign"
144
+
145
+
146
+ def infer_threat_from_tool_results(results: List[Dict[str, Any]]) -> Dict[str, str]:
147
+ """Infer threat from observable tool facts only, not grader labels."""
148
+ text_blob = " ".join(
149
+ [
150
+ str(result.get("result_summary", "")) + " " +
151
+ " ".join(str(event) for event in result.get("events", [])) + " " +
152
+ str(result.get("finding", "")) + " " +
153
+ str(result.get("process_note", ""))
154
+ for result in results
155
+ ]
156
+ ).lower()
157
+
158
+ if any(result.get("risk") == "critical" for result in results) or "hash mismatch" in text_blob:
159
+ return {"threat_type": "supply_chain", "target_node": "api_gateway", "action": "patch"}
160
+ if "sequential reads" in text_blob or "compressed archive" in text_blob or "egress exceeds" in text_blob or "outbound_transfer_burst" in text_blob:
161
+ return {"threat_type": "exfiltration", "target_node": "database", "action": "honeypot"}
162
+ if (
163
+ any(result.get("beaconing") for result in results) or
164
+ "service account" in text_blob or
165
+ "internal sessions" in text_blob or
166
+ "identity_anomaly" in text_blob or
167
+ "source=auth_service" in text_blob or
168
+ "east_west_fanout" in text_blob
169
+ ):
170
+ return {"threat_type": "lateral_movement", "target_node": "payment_service", "action": "isolate"}
171
+ if "failed logins" in text_blob or "password spray" in text_blob:
172
+ return {"threat_type": "brute_force", "target_node": "auth_service", "action": "rate_limit"}
173
+ return {"threat_type": "benign", "target_node": "api_gateway", "action": "monitor"}
174
+
175
+
176
+ def _complete_evidence_fusion(
177
+ call_tool: Callable[[str, str], Dict[str, Any]],
178
+ results: List[Dict[str, Any]],
179
+ ) -> None:
180
+ belief = infer_threat_from_tool_results(results)
181
+ threat = belief["threat_type"]
182
+ if threat == "benign":
183
+ return
184
+
185
+ called = {
186
+ (str(result.get("tool", "")), str(result.get("node", "")))
187
+ for result in results
188
+ }
189
+ for tool_name, node in THREAT_TOOL_PLAN.get(threat, []):
190
+ if (tool_name, node) not in called:
191
+ results.append(call_tool(tool_name, node))
192
+
193
+
194
+ def attach_tool_results(obs: Dict[str, Any], tool_results: List[Dict[str, Any]]) -> Dict[str, Any]:
195
+ updated = dict(obs)
196
+ updated["tool_results"] = tool_results
197
+ return updated
198
+
199
+
200
+ def summarize_tool_results(tool_results: List[Dict[str, Any]]) -> str:
201
+ if not tool_results:
202
+ return "No SOC tools queried for this turn."
203
+
204
+ lines = []
205
+ for result in tool_results:
206
+ lines.append(json.dumps(_compact_result(result), separators=(",", ":")))
207
+ return "\n".join(lines)
208
+
209
+
210
+ def http_post(env_base_url: str, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
211
+ url = f"{env_base_url.rstrip('/')}{path}"
212
+ req = urllib.request.Request(
213
+ url,
214
+ data=json.dumps(payload).encode(),
215
+ headers={"Content-Type": "application/json"},
216
+ )
217
+ with urllib.request.urlopen(req, timeout=60) as response:
218
+ return json.loads(response.read())
219
+
220
+
221
+ def _compact_result(result: Dict[str, Any]) -> Dict[str, Any]:
222
+ keep = [
223
+ "tool",
224
+ "node",
225
+ "evidence_type",
226
+ "verified",
227
+ "confidence",
228
+ "events",
229
+ "containment",
230
+ "persistence",
231
+ "beaconing",
232
+ "criticality",
233
+ "dependencies",
234
+ "risk",
235
+ "finding",
236
+ "recommended_mitigation",
237
+ "safe_actions",
238
+ ]
239
+ return {key: result[key] for key in keep if key in result}
tests/test_regression.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import unittest
3
+ from pathlib import Path
4
+
5
+
6
+ REPO_ROOT = Path(__file__).resolve().parents[1]
7
+ PACKAGE_ROOT = REPO_ROOT / "adaptshield"
8
+
9
+ if str(REPO_ROOT) not in sys.path:
10
+ sys.path.insert(0, str(REPO_ROOT))
11
+ if str(PACKAGE_ROOT) not in sys.path:
12
+ sys.path.insert(0, str(PACKAGE_ROOT))
13
+
14
+ import __init__ as adaptshield
15
+ import server.app as server_app
16
+ import train as train_module
17
+ from client import AdaptshieldEnv
18
+ from models import AdaptShieldAction
19
+ from server.adaptshield_environment import AdaptShieldEnvironment
20
+ from server.grader import normalize_episode_score, _required_tool_fusion
21
+
22
+
23
+ class PackageRegressionTests(unittest.TestCase):
24
+ def test_package_import_exports_expected_symbols(self) -> None:
25
+ self.assertIn("AdaptShieldAction", adaptshield.__all__)
26
+ self.assertIn("AdaptShieldObservation", adaptshield.__all__)
27
+ self.assertIn("AdaptshieldEnv", adaptshield.__all__)
28
+
29
+ def test_server_app_imports_fastapi_instance(self) -> None:
30
+ self.assertEqual(server_app.app.__class__.__name__, "FastAPI")
31
+
32
+
33
+ class EnvironmentRegressionTests(unittest.TestCase):
34
+ def test_phase_flow_accepts_both_action_shapes(self) -> None:
35
+ env = AdaptShieldEnvironment("direct-triage")
36
+
37
+ phase1_obs = env.reset()
38
+ self.assertEqual(phase1_obs.phase, 1)
39
+ self.assertEqual(phase1_obs.turn, 1)
40
+ self.assertEqual(phase1_obs.metadata["normalized_score"], 0.50)
41
+ self.assertIn("mission_profile", phase1_obs.metadata)
42
+ self.assertEqual(phase1_obs.metadata["world_split"], "train")
43
+ self.assertIn(phase1_obs.metadata["world_family"], {"train-a", "train-b"})
44
+
45
+ phase2_obs = env.step(
46
+ AdaptShieldAction(
47
+ threat_type="brute_force",
48
+ confidence=0.9,
49
+ target_node="auth_service",
50
+ recommended_action="rate_limit",
51
+ )
52
+ )
53
+ self.assertEqual(phase2_obs.phase, 2)
54
+ self.assertEqual(phase2_obs.phase1_assessment["recommended_action"], "rate_limit")
55
+
56
+ next_turn_obs = env.step(
57
+ AdaptShieldAction(action="rate_limit", target_node="auth_service")
58
+ )
59
+ self.assertEqual(next_turn_obs.phase, 1)
60
+ self.assertGreaterEqual(next_turn_obs.reward, 0.65)
61
+ self.assertIn("requires stronger SOC evidence", next_turn_obs.last_action_result)
62
+ self.assertIn("business_impact", next_turn_obs.metadata["score_breakdown"])
63
+ self.assertIn("dependency_blast_radius", next_turn_obs.metadata["score_breakdown"])
64
+ self.assertIn("mission_alignment", next_turn_obs.metadata["score_breakdown"])
65
+ self.assertIn("active_defenses", next_turn_obs.metadata)
66
+ self.assertIn("available_tools", next_turn_obs.metadata)
67
+ tool_names = {tool["name"] for tool in next_turn_obs.metadata["available_tools"]}
68
+ self.assertTrue({
69
+ "identity_lookup",
70
+ "change_calendar_lookup",
71
+ "netflow_lookup",
72
+ }.issubset(tool_names))
73
+
74
+ env = AdaptShieldEnvironment("direct-triage")
75
+ env.reset()
76
+ env.call_tool("log_search", node="auth_service")
77
+ env.step(
78
+ AdaptShieldAction(
79
+ threat_type="brute_force",
80
+ confidence=0.9,
81
+ target_node="auth_service",
82
+ recommended_action="rate_limit",
83
+ )
84
+ )
85
+ verified_obs = env.step(
86
+ AdaptShieldAction(action="rate_limit", target_node="auth_service")
87
+ )
88
+ self.assertGreaterEqual(verified_obs.reward, 0.9)
89
+ self.assertIn("Optimal: rate_limit", verified_obs.last_action_result)
90
+
91
+ def test_client_payload_omits_empty_metadata_and_serializes_enums(self) -> None:
92
+ client = AdaptshieldEnv(base_url="http://localhost:7860")
93
+
94
+ phase1_payload = client._step_payload(
95
+ AdaptShieldAction(
96
+ threat_type="benign",
97
+ confidence=0.8,
98
+ target_node="auth_service",
99
+ recommended_action="monitor",
100
+ )
101
+ )
102
+ self.assertEqual(
103
+ phase1_payload,
104
+ {
105
+ "threat_type": "benign",
106
+ "confidence": 0.8,
107
+ "target_node": "auth_service",
108
+ "recommended_action": "monitor",
109
+ },
110
+ )
111
+
112
+ phase2_payload = client._step_payload(
113
+ AdaptShieldAction(action="rate_limit", target_node="auth_service")
114
+ )
115
+ self.assertEqual(
116
+ phase2_payload,
117
+ {"action": "rate_limit", "target_node": "auth_service"},
118
+ )
119
+
120
+ def test_hard_task_records_verified_tool_evidence(self) -> None:
121
+ env = AdaptShieldEnvironment("polymorphic-zero-day")
122
+ for _ in range(8):
123
+ obs = env.reset()
124
+ turn_config = dict(getattr(env, "_turn_config", {}) or {})
125
+ if not turn_config.get("is_benign", False):
126
+ break
127
+ else:
128
+ self.fail("Expected a non-benign hard-task reset within 8 attempts")
129
+
130
+ self.assertIn("available_tools", obs.metadata)
131
+ self.assertNotIn("foothold_established", obs.metadata)
132
+
133
+ target = str(turn_config.get("correct_target", "auth_service"))
134
+ for tool_name in sorted(_required_tool_fusion("polymorphic-zero-day", str(turn_config.get("strategy", "benign")))):
135
+ tool_result = env.call_tool(tool_name, node=target)
136
+ self.assertNotIn("verified", tool_result)
137
+ self.assertNotIn("evidence_type", tool_result)
138
+ self.assertTrue(tool_result.get("result_summary"))
139
+
140
+ env.step(
141
+ AdaptShieldAction(
142
+ threat_type=turn_config.get("strategy", "brute_force"),
143
+ confidence=0.9,
144
+ target_node=target,
145
+ recommended_action=turn_config.get("correct_action", "monitor"),
146
+ )
147
+ )
148
+ obs = env.step(
149
+ AdaptShieldAction(
150
+ action=turn_config.get("correct_action", "monitor"),
151
+ target_node=target,
152
+ )
153
+ )
154
+ breakdown = obs.metadata["score_breakdown"]
155
+ self.assertTrue(breakdown["tool_verification_required"])
156
+ self.assertTrue(breakdown["tool_evidence_found"])
157
+ self.assertGreaterEqual(obs.reward, 0.65)
158
+
159
+ def test_enterprise_context_tools_return_public_fields_only(self) -> None:
160
+ env = AdaptShieldEnvironment("polymorphic-zero-day")
161
+ env.reset()
162
+
163
+ identity = env.call_tool("identity_lookup", node="payment_service")
164
+ self.assertIn("account", identity)
165
+ self.assertIn("recent_source_host", identity)
166
+ self.assertNotIn("verified", identity)
167
+ self.assertNotIn("evidence_type", identity)
168
+
169
+ change = env.call_tool("change_calendar_lookup", node="api_gateway")
170
+ self.assertIn("scheduled", change)
171
+ self.assertIn("change_status", change)
172
+ self.assertNotIn("verified", change)
173
+ self.assertNotIn("evidence_type", change)
174
+
175
+ netflow = env.call_tool("netflow_lookup", node="database")
176
+ self.assertIn("traffic_pattern", netflow)
177
+ self.assertIn("east_west_connections", netflow)
178
+ self.assertNotIn("verified", netflow)
179
+ self.assertNotIn("evidence_type", netflow)
180
+
181
+ def test_dual_pivot_requires_tool_confirmation_after_pivot(self) -> None:
182
+ env = AdaptShieldEnvironment("dual-pivot", operational_mode="containment_first")
183
+ env.reset()
184
+
185
+ for _ in range(3):
186
+ turn_config = dict(getattr(env, "_turn_config", {}) or {})
187
+ env.step(
188
+ AdaptShieldAction(
189
+ threat_type=str(turn_config.get("strategy", "brute_force")),
190
+ confidence=0.9,
191
+ target_node=str(turn_config.get("correct_target", "auth_service")),
192
+ recommended_action=str(turn_config.get("correct_action", "monitor")),
193
+ )
194
+ )
195
+ obs = env.step(
196
+ AdaptShieldAction(
197
+ action=str(turn_config.get("correct_action", "monitor")),
198
+ target_node=str(turn_config.get("correct_target", "auth_service")),
199
+ )
200
+ )
201
+ self.assertFalse(obs.done)
202
+
203
+ turn_config = dict(getattr(env, "_turn_config", {}) or {})
204
+ self.assertEqual(turn_config.get("strategy"), "lateral_movement")
205
+ target = str(turn_config.get("correct_target", "payment_service"))
206
+
207
+ env.step(
208
+ AdaptShieldAction(
209
+ threat_type="lateral_movement",
210
+ confidence=0.9,
211
+ target_node=target,
212
+ recommended_action=str(turn_config.get("correct_action", "isolate")),
213
+ )
214
+ )
215
+ obs = env.step(
216
+ AdaptShieldAction(
217
+ action=str(turn_config.get("correct_action", "isolate")),
218
+ target_node=target,
219
+ )
220
+ )
221
+ self.assertTrue(obs.metadata["score_breakdown"]["tool_verification_required"])
222
+ self.assertFalse(obs.metadata["score_breakdown"]["tool_evidence_found"])
223
+ self.assertIn("requires stronger SOC evidence", obs.last_action_result)
224
+
225
+ env = AdaptShieldEnvironment("dual-pivot", operational_mode="containment_first")
226
+ env.reset()
227
+ for _ in range(3):
228
+ turn_config = dict(getattr(env, "_turn_config", {}) or {})
229
+ env.step(
230
+ AdaptShieldAction(
231
+ threat_type=str(turn_config.get("strategy", "brute_force")),
232
+ confidence=0.9,
233
+ target_node=str(turn_config.get("correct_target", "auth_service")),
234
+ recommended_action=str(turn_config.get("correct_action", "monitor")),
235
+ )
236
+ )
237
+ env.step(
238
+ AdaptShieldAction(
239
+ action=str(turn_config.get("correct_action", "monitor")),
240
+ target_node=str(turn_config.get("correct_target", "auth_service")),
241
+ )
242
+ )
243
+
244
+ turn_config = dict(getattr(env, "_turn_config", {}) or {})
245
+ target = str(turn_config.get("correct_target", "payment_service"))
246
+ env.call_tool("edr_status", node=target)
247
+ env.call_tool("log_search", node=target)
248
+ env.call_tool("identity_lookup", node=target)
249
+ env.step(
250
+ AdaptShieldAction(
251
+ threat_type="lateral_movement",
252
+ confidence=0.9,
253
+ target_node=target,
254
+ recommended_action=str(turn_config.get("correct_action", "isolate")),
255
+ )
256
+ )
257
+ obs = env.step(
258
+ AdaptShieldAction(
259
+ action=str(turn_config.get("correct_action", "isolate")),
260
+ target_node=target,
261
+ )
262
+ )
263
+ self.assertTrue(obs.metadata["score_breakdown"]["tool_verification_required"])
264
+ self.assertTrue(obs.metadata["score_breakdown"]["tool_evidence_found"])
265
+ self.assertIn("Optimal: isolate", obs.last_action_result)
266
+
267
+ def test_world_family_metadata_and_surfaces_are_selectable(self) -> None:
268
+ env = AdaptShieldEnvironment(
269
+ "direct-triage",
270
+ world_split="eval",
271
+ world_family="eval-x",
272
+ )
273
+ obs = env.reset()
274
+ self.assertEqual(obs.metadata["world_split"], "eval")
275
+ self.assertEqual(obs.metadata["world_family"], "eval-x")
276
+ alerts_blob = " ".join(obs.active_alerts).lower()
277
+ self.assertTrue(
278
+ "auth rejection burst" in alerts_blob or
279
+ "credential reuse sweep" in alerts_blob
280
+ )
281
+
282
+ def test_operational_modes_change_medium_and_hard_optimal_actions(self) -> None:
283
+ medium_env = AdaptShieldEnvironment(
284
+ "dual-pivot",
285
+ operational_mode="evidence_preservation",
286
+ world_family="train-b",
287
+ )
288
+ medium_env.reset()
289
+ for _ in range(3):
290
+ turn_config = dict(getattr(medium_env, "_turn_config", {}) or {})
291
+ medium_env.step(
292
+ AdaptShieldAction(
293
+ threat_type=str(turn_config.get("strategy", "brute_force")),
294
+ confidence=0.9,
295
+ target_node=str(turn_config.get("correct_target", "auth_service")),
296
+ recommended_action=str(turn_config.get("correct_action", "monitor")),
297
+ )
298
+ )
299
+ medium_env.step(
300
+ AdaptShieldAction(
301
+ action=str(turn_config.get("correct_action", "monitor")),
302
+ target_node=str(turn_config.get("correct_target", "auth_service")),
303
+ )
304
+ )
305
+ self.assertEqual(getattr(medium_env, "_turn_config", {}).get("strategy"), "lateral_movement")
306
+ self.assertEqual(getattr(medium_env, "_turn_config", {}).get("correct_action"), "honeypot")
307
+
308
+ hard_env = AdaptShieldEnvironment(
309
+ "polymorphic-zero-day",
310
+ operational_mode="forensic_hold",
311
+ world_family="eval-y",
312
+ )
313
+ hard_obs = hard_env.reset()
314
+ adjusted = hard_env._apply_operational_mode({
315
+ "strategy": "exfiltration",
316
+ "attack_stage": "exploit",
317
+ "is_benign": False,
318
+ "correct_action": "isolate",
319
+ "correct_target": "database",
320
+ "network_nodes": {"payment_service": {"status": "healthy", "request_rate": 85}},
321
+ "active_alerts": [],
322
+ })
323
+ self.assertEqual(hard_obs.metadata["operational_mode"], "forensic_hold")
324
+ self.assertEqual(adjusted.get("correct_action"), "honeypot")
325
+
326
+ def test_prompt_bank_builds_phase_rows_without_gpu_deps(self) -> None:
327
+ rows = train_module.build_prompt_bank(
328
+ tokenizer=None,
329
+ selected_task="all",
330
+ curriculum=True,
331
+ rollout_episodes=3,
332
+ max_steps=6,
333
+ use_tools=True,
334
+ seed=42,
335
+ )
336
+ self.assertTrue(rows)
337
+ phases = {int(row["phase"]) for row in rows}
338
+ tasks = {str(row["task"]) for row in rows}
339
+ self.assertIn(1, phases)
340
+ self.assertIn(2, phases)
341
+ self.assertTrue(tasks.intersection({"direct-triage", "dual-pivot", "polymorphic-zero-day"}))
342
+ hard_rows = [row for row in rows if row["task"] == "polymorphic-zero-day"]
343
+ self.assertTrue(hard_rows)
344
+ self.assertTrue(any(int(row["tool_calls"]) >= 2 for row in hard_rows))
345
+
346
+ def test_normalized_score_uses_reachable_reward_ceiling(self) -> None:
347
+ rewards = [0.99] * 10
348
+ self.assertEqual(normalize_episode_score(rewards), 0.99)
349
+
350
+
351
+ if __name__ == "__main__":
352
+ unittest.main()
tool_baseline.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Tool-aware AdaptShield baseline for world-modeling demos."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ from typing import Any, Dict, List
9
+
10
+ from baseline import (
11
+ BENCHMARK,
12
+ MAX_STEPS,
13
+ POLICY,
14
+ TASKS,
15
+ action_from_payload,
16
+ log_end,
17
+ log_step,
18
+ phase1_payload as no_tool_phase1_payload,
19
+ phase2_payload as no_tool_phase2_payload,
20
+ print_replay,
21
+ )
22
+ from server.adaptshield_environment import AdaptShieldEnvironment
23
+ from soc_tools import infer_threat_from_tool_results, investigate_local
24
+
25
+
26
+ MODEL_NAME = "tool-aware-baseline"
27
+
28
+ def log_start(task: str) -> None:
29
+ print(f"[START] task={task} env={BENCHMARK} model={MODEL_NAME}", flush=True)
30
+
31
+
32
+ def phase2_payload(obs: Any, belief_by_turn: Dict[int, Dict[str, str]]) -> Dict[str, Any]:
33
+ """Use belief inferred from observable SOC tool evidence when Phase 2 is ambiguous."""
34
+ belief = belief_by_turn.get(int(obs.turn), {})
35
+ if obs.task_name == "polymorphic-zero-day" and belief:
36
+ return {
37
+ "action": belief["action"],
38
+ "target_node": belief["target_node"],
39
+ "reasoning": "inferred from observable SOC tool fields",
40
+ }
41
+
42
+ return no_tool_phase2_payload(obs)
43
+
44
+
45
+ def phase1_payload(obs: Any, belief_by_turn: Dict[int, Dict[str, str]]) -> Dict[str, Any]:
46
+ """Use tool-derived belief in Phase 1 so the baseline is tool-aware end to end."""
47
+ belief = belief_by_turn.get(int(obs.turn), {})
48
+ if obs.task_name == "polymorphic-zero-day" and belief:
49
+ return {
50
+ "threat_type": belief["threat_type"],
51
+ "confidence": 0.86,
52
+ "target_node": belief["target_node"],
53
+ "recommended_action": belief["action"],
54
+ "reasoning": "classified from observable SOC tool fields",
55
+ }
56
+
57
+ return no_tool_phase1_payload(obs)
58
+
59
+ def run_task(task: str, emit_logs: bool = True) -> Dict[str, Any]:
60
+ env = AdaptShieldEnvironment(task_name=task)
61
+ obs = env.reset()
62
+ rewards: List[float] = []
63
+ steps = 0
64
+ belief_by_turn: Dict[int, Dict[str, str]] = {}
65
+
66
+ if emit_logs:
67
+ log_start(task)
68
+
69
+ while not obs.done and steps < MAX_STEPS:
70
+ if obs.phase == 1:
71
+ tool_results = investigate_local(env, obs, use_tools=True)
72
+ belief_by_turn[int(obs.turn)] = infer_threat_from_tool_results(tool_results)
73
+ payload = phase1_payload(obs, belief_by_turn)
74
+ else:
75
+ payload = phase2_payload(obs, belief_by_turn)
76
+
77
+ obs = env.step(action_from_payload(payload))
78
+ reward = float(obs.reward)
79
+ rewards.append(reward)
80
+ steps += 1
81
+
82
+ if emit_logs:
83
+ log_step(steps, payload, reward, obs.done)
84
+
85
+ metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
86
+ score = float(metadata.get("normalized_score", 0.01))
87
+ success = obs.done and 0.01 <= score <= 0.99
88
+
89
+ if emit_logs:
90
+ log_end(success, steps, score, rewards)
91
+ tool_trace = metadata.get("tool_trace") or []
92
+ print(f"[TOOLS] calls={len(tool_trace)} trace={json.dumps(tool_trace[-4:], separators=(',', ':'))}")
93
+
94
+ return {
95
+ "task": task,
96
+ "score": score,
97
+ "steps": steps,
98
+ "done": bool(obs.done),
99
+ "rewards": rewards,
100
+ "metadata": metadata,
101
+ "normalized_score_present": "normalized_score" in metadata,
102
+ "success": success,
103
+ }
104
+
105
+
106
+ def parse_args() -> argparse.Namespace:
107
+ parser = argparse.ArgumentParser(description="Run AdaptShield tool-aware baseline.")
108
+ parser.add_argument("--task", default="polymorphic-zero-day", choices=TASKS + ["all"])
109
+ parser.add_argument("--replay", action="store_true")
110
+ return parser.parse_args()
111
+
112
+
113
+ def main() -> int:
114
+ args = parse_args()
115
+ tasks = TASKS if args.task == "all" else [args.task]
116
+
117
+ for index, task in enumerate(tasks):
118
+ if index:
119
+ print()
120
+ result = run_task(task, emit_logs=True)
121
+ if args.replay:
122
+ print_replay(task, result["metadata"])
123
+
124
+ return 0
125
+
126
+
127
+ if __name__ == "__main__":
128
+ raise SystemExit(main())
train.py ADDED
@@ -0,0 +1,1332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """AdaptShield onsite GPU training harness with safe local fallback."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import inspect
8
+ import json
9
+ import os
10
+ import random
11
+ import sys
12
+ import time
13
+ from pathlib import Path
14
+ from typing import Any, Dict, List, Tuple
15
+
16
+
17
+ REPO_ROOT = Path(__file__).resolve().parent
18
+
19
+ if str(REPO_ROOT) not in sys.path:
20
+ sys.path.insert(0, str(REPO_ROOT))
21
+
22
+ from models import AdaptShieldAction
23
+ from server.adaptshield_environment import AdaptShieldEnvironment
24
+ from train_smoke import TASKS, run_smoke_training
25
+ from plot_training import plot as plot_training_output
26
+ from soc_tools import attach_tool_results, investigate_local_with_depth, summarize_tool_results
27
+
28
+
29
+ MODEL_CHOICES = {
30
+ "0.5b": "unsloth/Qwen2.5-0.5B-Instruct",
31
+ "1.5b": "unsloth/Qwen2.5-1.5B-Instruct",
32
+ "3b": "unsloth/Qwen2.5-3B-Instruct",
33
+ "7b": "unsloth/Qwen2.5-7B-Instruct",
34
+ }
35
+ DEFAULT_MODEL = "1.5b"
36
+ MAX_SEQ_LEN = 2048
37
+ LORA_RANK = 16
38
+
39
+ CURRICULUM_STAGES = [
40
+ ("direct-triage", 0.30),
41
+ ("dual-pivot", 0.40),
42
+ ("polymorphic-zero-day", 0.30),
43
+ ]
44
+
45
+
46
+ PHASE1_SYS = """You are a Threat Analyst for a 4-node enterprise network.
47
+ Analyze SIEM metrics and alerts. Identify the threat.
48
+
49
+ Attack strategies: brute_force, lateral_movement, exfiltration, supply_chain, benign
50
+ Nodes: auth_service, payment_service, database, api_gateway
51
+ Actions: rate_limit, isolate, honeypot, patch, monitor
52
+ If SOC tool evidence is provided, use it to update your belief before classifying.
53
+
54
+ Respond ONLY with valid JSON:
55
+ {"threat_type":"...","confidence":0.0,"target_node":"...","recommended_action":"...","reasoning":"..."}"""
56
+
57
+
58
+ PHASE2_SYS = """You are a Tactical Executor. Act only on the analyst handoff.
59
+ You cannot see raw network data in Phase 2.
60
+ Use the analyst handoff plus any SOC tool trace from this turn.
61
+
62
+ Actions: rate_limit, isolate, honeypot, patch, monitor
63
+ Nodes: auth_service, payment_service, database, api_gateway
64
+
65
+ Respond ONLY with valid JSON:
66
+ {"action":"...","target_node":"...","reasoning":"..."}"""
67
+
68
+
69
+ def obs_to_dict(obs: Any) -> Dict[str, Any]:
70
+ if hasattr(obs, "model_dump"):
71
+ return obs.model_dump(mode="json")
72
+ return dict(obs)
73
+
74
+
75
+ def make_phase1_prompt(obs: Dict[str, Any]) -> str:
76
+ return "\n".join([
77
+ "Network nodes:",
78
+ json.dumps(obs.get("network_nodes", {}), indent=2),
79
+ "",
80
+ "Active alerts:",
81
+ "\n".join(obs.get("active_alerts", [])),
82
+ "",
83
+ "SOC tool evidence:",
84
+ summarize_tool_results(obs.get("tool_results", [])),
85
+ "",
86
+ "Recent history:",
87
+ json.dumps(obs.get("history", [])[-3:], indent=2),
88
+ "",
89
+ "Classify the threat:",
90
+ ])
91
+
92
+
93
+ def make_phase2_prompt(obs: Dict[str, Any]) -> str:
94
+ metadata = obs.get("metadata", {}) if isinstance(obs.get("metadata", {}), dict) else {}
95
+ current_turn = int(obs.get("turn", 0) or 0)
96
+ tool_trace = [
97
+ row for row in metadata.get("tool_trace", [])
98
+ if int(row.get("turn", -1)) == current_turn
99
+ ]
100
+ return "\n".join([
101
+ "Threat assessment from analyst:",
102
+ json.dumps(obs.get("phase1_assessment", {}), indent=2),
103
+ "",
104
+ "SOC tool trace for this turn:",
105
+ json.dumps(tool_trace, indent=2),
106
+ "",
107
+ "Choose the defensive action:",
108
+ ])
109
+
110
+
111
+ def build_messages(obs: Dict[str, Any]) -> List[Dict[str, str]]:
112
+ if int(obs.get("phase", 1)) == 1:
113
+ return [
114
+ {"role": "system", "content": PHASE1_SYS},
115
+ {"role": "user", "content": make_phase1_prompt(obs)},
116
+ ]
117
+ return [
118
+ {"role": "system", "content": PHASE2_SYS},
119
+ {"role": "user", "content": make_phase2_prompt(obs)},
120
+ ]
121
+
122
+
123
+ def task_for_episode(
124
+ episode: int,
125
+ total_episodes: int,
126
+ selected_task: str,
127
+ curriculum: bool,
128
+ ) -> Tuple[str, str]:
129
+ if not curriculum:
130
+ if selected_task == "all":
131
+ task = TASKS[(episode - 1) % len(TASKS)]
132
+ return task, "round_robin"
133
+ return selected_task, "fixed"
134
+
135
+ progress = episode / max(1, total_episodes)
136
+ cumulative = 0.0
137
+ for task, fraction in CURRICULUM_STAGES:
138
+ cumulative += fraction
139
+ if progress <= cumulative:
140
+ return task, f"curriculum:{task}"
141
+ return CURRICULUM_STAGES[-1][0], f"curriculum:{CURRICULUM_STAGES[-1][0]}"
142
+
143
+
144
+ def save_metrics(
145
+ output_dir: Path,
146
+ rows: List[Dict[str, Any]],
147
+ model_name: str,
148
+ episodes: int,
149
+ curriculum: bool,
150
+ use_tools: bool,
151
+ trainer: str = "pg",
152
+ evaluation_rows: List[Dict[str, Any]] | None = None,
153
+ heldout_evaluation_rows: List[Dict[str, Any]] | None = None,
154
+ prompt_bank_size: int = 0,
155
+ extra: Dict[str, Any] | None = None,
156
+ ) -> Path:
157
+ output_dir.mkdir(parents=True, exist_ok=True)
158
+ best_score = max((float(row["score"]) for row in rows), default=0.0)
159
+ metrics_path = output_dir / "metrics.json"
160
+ payload = {
161
+ "model": model_name,
162
+ "episodes": episodes,
163
+ "curriculum": curriculum,
164
+ "curriculum_stages": CURRICULUM_STAGES,
165
+ "use_tools": use_tools,
166
+ "trainer": trainer,
167
+ "rows": rows,
168
+ "best_score": best_score,
169
+ }
170
+ if evaluation_rows is not None:
171
+ payload["evaluation_rows"] = evaluation_rows
172
+ if heldout_evaluation_rows is not None:
173
+ payload["heldout_evaluation_rows"] = heldout_evaluation_rows
174
+ if prompt_bank_size:
175
+ payload["prompt_bank_size"] = prompt_bank_size
176
+ if extra:
177
+ payload.update(extra)
178
+ metrics_path.write_text(json.dumps(payload, indent=2))
179
+ return metrics_path
180
+
181
+
182
+ def maybe_plot(metrics_path: Path, output_dir: Path) -> None:
183
+ try:
184
+ plot_training_output(metrics_path, output_dir / "reward_curve.png")
185
+ except Exception as exc:
186
+ print(f"Plot generation skipped: {exc}")
187
+
188
+
189
+ def parse_response(text: str, phase: int) -> Dict[str, Any]:
190
+ """Parse model JSON. Invalid output becomes a safe phase-correct action."""
191
+ if "```" in text:
192
+ for part in text.split("```"):
193
+ if "{" in part:
194
+ text = part.strip().removeprefix("json").strip()
195
+ break
196
+
197
+ try:
198
+ parsed = json.loads(text)
199
+ if phase == 1:
200
+ return {
201
+ "threat_type": str(parsed.get("threat_type", "brute_force")),
202
+ "confidence": float(parsed.get("confidence", 0.5)),
203
+ "target_node": str(parsed.get("target_node", "auth_service")),
204
+ "recommended_action": str(parsed.get("recommended_action", "monitor")),
205
+ "reasoning": str(parsed.get("reasoning", "")),
206
+ }
207
+ return {
208
+ "action": str(parsed.get("action", "monitor")),
209
+ "target_node": str(parsed.get("target_node", "auth_service")),
210
+ "reasoning": str(parsed.get("reasoning", "")),
211
+ }
212
+ except Exception:
213
+ if phase == 1:
214
+ return {
215
+ "threat_type": "brute_force",
216
+ "confidence": 0.5,
217
+ "target_node": "auth_service",
218
+ "recommended_action": "monitor",
219
+ "reasoning": "parse_error",
220
+ }
221
+ return {
222
+ "action": "monitor",
223
+ "target_node": "auth_service",
224
+ "reasoning": "parse_error",
225
+ }
226
+
227
+
228
+ def render_messages(messages: List[Dict[str, str]], tokenizer: Any | None = None) -> str:
229
+ if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
230
+ return tokenizer.apply_chat_template(
231
+ messages,
232
+ tokenize=False,
233
+ add_generation_prompt=True,
234
+ )
235
+ return "\n\n".join(
236
+ f"{message.get('role', 'user').upper()}:\n{message.get('content', '')}"
237
+ for message in messages
238
+ )
239
+
240
+
241
+ def generate_response(model: Any, tokenizer: Any, messages: List[Dict[str, str]]) -> Tuple[str, str]:
242
+ import torch
243
+
244
+ prompt = render_messages(messages, tokenizer=tokenizer)
245
+ device = getattr(model, "device", None)
246
+ if device is None:
247
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
249
+
250
+ pad_token_id = (
251
+ tokenizer.pad_token_id
252
+ if getattr(tokenizer, "pad_token_id", None) is not None
253
+ else tokenizer.eos_token_id
254
+ )
255
+
256
+ with torch.no_grad():
257
+ _normalize_generation_config(model)
258
+ output_ids = model.generate(
259
+ **inputs,
260
+ max_new_tokens=220,
261
+ temperature=0.7,
262
+ do_sample=True,
263
+ pad_token_id=pad_token_id,
264
+ )
265
+
266
+ new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
267
+ response = tokenizer.decode(new_ids, skip_special_tokens=True).strip()
268
+ return prompt, response
269
+
270
+
271
+ def _current_reference(env: AdaptShieldEnvironment) -> Dict[str, Any]:
272
+ turn_config = dict(getattr(env, "_turn_config", {}) or {})
273
+ is_benign = bool(turn_config.get("is_benign", False))
274
+ threat_type = "benign" if is_benign else str(turn_config.get("strategy", "benign"))
275
+ target_node = str(turn_config.get("correct_target", "auth_service"))
276
+ expected_action = str(turn_config.get("correct_action", "monitor"))
277
+ return {
278
+ "threat_type": threat_type,
279
+ "target_node": target_node,
280
+ "expected_action": expected_action,
281
+ "stage": str(turn_config.get("attack_stage", getattr(env._attacker, "current_stage", lambda: "recon")())),
282
+ "is_benign": is_benign,
283
+ }
284
+
285
+
286
+ def _align_trainable_dtypes(model: Any, target_dtype: Any | None = None) -> str:
287
+ """Keep LoRA/trainable params on the same compute dtype as the main model.
288
+
289
+ Some adapter checkpoints reload trainable LoRA weights as float32, while
290
+ Unsloth GRPO kernels run activations in float16/bfloat16. That mismatch
291
+ trips fast_lora matmuls at runtime. We fix only trainable floating params.
292
+ """
293
+ import torch
294
+
295
+ if target_dtype is None:
296
+ for param in model.parameters():
297
+ if param.is_floating_point() and not param.requires_grad:
298
+ target_dtype = param.dtype
299
+ break
300
+ if target_dtype is None:
301
+ for param in model.parameters():
302
+ if param.is_floating_point():
303
+ target_dtype = param.dtype
304
+ break
305
+ if target_dtype is None:
306
+ return "no-floating-params"
307
+
308
+ converted = 0
309
+ for param in model.parameters():
310
+ if param.requires_grad and param.is_floating_point() and param.dtype != target_dtype:
311
+ param.data = param.data.to(target_dtype)
312
+ converted += 1
313
+
314
+ for buffer_name, buffer in model.named_buffers():
315
+ if "lora_" in buffer_name and buffer.is_floating_point() and buffer.dtype != target_dtype:
316
+ buffer.data = buffer.data.to(target_dtype)
317
+
318
+ if getattr(model, "generation_config", None) is not None:
319
+ _normalize_generation_config(model)
320
+
321
+ return f"{target_dtype} ({converted} trainable params aligned)"
322
+
323
+
324
+ def _normalize_generation_config(model: Any) -> None:
325
+ generation_config = getattr(model, "generation_config", None)
326
+ if generation_config is None:
327
+ return
328
+ for field in ("max_length",):
329
+ try:
330
+ setattr(generation_config, field, None)
331
+ except Exception:
332
+ continue
333
+
334
+
335
+ def _load_training_model_and_tokenizer(
336
+ model_name: str,
337
+ model_key: str,
338
+ max_seq_length: int,
339
+ compute_dtype: Any,
340
+ seed: int,
341
+ ):
342
+ from unsloth import FastLanguageModel
343
+
344
+ adapter_path = model_name if _looks_like_adapter_path(model_name) else ""
345
+ base_model_name = MODEL_CHOICES[model_key] if adapter_path else model_name
346
+ model, tokenizer = FastLanguageModel.from_pretrained(
347
+ model_name=base_model_name,
348
+ max_seq_length=max_seq_length,
349
+ load_in_4bit=True,
350
+ dtype=compute_dtype,
351
+ )
352
+
353
+ if adapter_path:
354
+ from peft import PeftModel
355
+
356
+ model = PeftModel.from_pretrained(
357
+ model,
358
+ adapter_path,
359
+ is_trainable=True,
360
+ autocast_adapter_dtype=False,
361
+ )
362
+ try:
363
+ from transformers import AutoTokenizer
364
+
365
+ tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
366
+ except Exception:
367
+ pass
368
+ else:
369
+ model = FastLanguageModel.get_peft_model(
370
+ model,
371
+ r=LORA_RANK,
372
+ target_modules=[
373
+ "q_proj", "k_proj", "v_proj", "o_proj",
374
+ "gate_proj", "up_proj", "down_proj",
375
+ ],
376
+ lora_alpha=LORA_RANK * 2,
377
+ lora_dropout=0.0,
378
+ bias="none",
379
+ use_gradient_checkpointing="unsloth",
380
+ random_state=seed,
381
+ )
382
+
383
+ return model, tokenizer
384
+
385
+
386
+ def _teacher_payload(phase: int, reference: Dict[str, Any]) -> Dict[str, Any]:
387
+ if phase == 1:
388
+ return {
389
+ "threat_type": reference["threat_type"],
390
+ "confidence": 0.92 if reference["threat_type"] != "benign" else 0.78,
391
+ "target_node": reference["target_node"],
392
+ "recommended_action": reference["expected_action"],
393
+ "reasoning": "reference policy",
394
+ }
395
+ return {
396
+ "action": reference["expected_action"],
397
+ "target_node": reference["target_node"],
398
+ "reasoning": "reference policy",
399
+ }
400
+
401
+
402
+ def build_prompt_bank(
403
+ tokenizer: Any | None,
404
+ selected_task: str,
405
+ curriculum: bool,
406
+ rollout_episodes: int,
407
+ max_steps: int,
408
+ use_tools: bool,
409
+ seed: int,
410
+ world_split: str = "train",
411
+ world_family: str | None = None,
412
+ hard_multiplier: int = 2,
413
+ borderline_bonus: int = 1,
414
+ ) -> List[Dict[str, Any]]:
415
+ random.seed(seed)
416
+ rows: List[Dict[str, Any]] = []
417
+ for episode in range(1, rollout_episodes + 1):
418
+ task, stage = task_for_episode(
419
+ episode=episode,
420
+ total_episodes=rollout_episodes,
421
+ selected_task=selected_task,
422
+ curriculum=curriculum,
423
+ )
424
+ env = AdaptShieldEnvironment(
425
+ task_name=task,
426
+ world_split=world_split,
427
+ world_family=world_family,
428
+ )
429
+ obs = env.reset()
430
+ step_count = 0
431
+ while not obs.done and step_count < max_steps:
432
+ phase = int(getattr(obs, "phase", 1))
433
+ tool_results = investigate_local_with_depth(
434
+ env,
435
+ obs,
436
+ use_tools=use_tools,
437
+ thorough=True,
438
+ )
439
+ obs_dict = attach_tool_results(obs_to_dict(obs), tool_results)
440
+ messages = build_messages(obs_dict)
441
+ reference = _current_reference(env)
442
+ rows.append({
443
+ "prompt": render_messages(messages, tokenizer=tokenizer),
444
+ "task": task,
445
+ "stage": stage,
446
+ "phase": phase,
447
+ "turn": int(getattr(obs, "turn", 0) or 0),
448
+ "attack_stage": reference["stage"],
449
+ "world_split": getattr(env, "_world_split", world_split),
450
+ "world_family": getattr(env, "_world_family", world_family or ""),
451
+ "operational_mode": getattr(env, "_operational_mode", ""),
452
+ "expected_threat_type": reference["threat_type"],
453
+ "expected_target_node": reference["target_node"],
454
+ "expected_recommended_action": reference["expected_action"] if phase == 1 else "",
455
+ "expected_action": reference["expected_action"] if phase == 2 else "",
456
+ "tool_calls": len(tool_results),
457
+ "history_length": len(obs_dict.get("history", [])),
458
+ "difficulty_tags": _difficulty_tags(
459
+ task=task,
460
+ phase=phase,
461
+ attack_stage=reference["stage"],
462
+ tool_calls=len(tool_results),
463
+ handoff_quality=str((obs_dict.get("phase1_assessment") or {}).get("handoff_quality", "")),
464
+ ),
465
+ })
466
+ base_row = rows[-1]
467
+ for _ in range(_prompt_bank_extra_copies(
468
+ row=base_row,
469
+ hard_multiplier=hard_multiplier,
470
+ borderline_bonus=borderline_bonus,
471
+ )):
472
+ rows.append(dict(base_row))
473
+ obs = env.step(AdaptShieldAction(**_teacher_payload(phase, reference)))
474
+ step_count += 1
475
+ return rows
476
+
477
+
478
+ def _difficulty_tags(
479
+ task: str,
480
+ phase: int,
481
+ attack_stage: str,
482
+ tool_calls: int,
483
+ handoff_quality: str,
484
+ ) -> List[str]:
485
+ tags: List[str] = []
486
+ if task == "polymorphic-zero-day":
487
+ tags.append("hard")
488
+ elif task == "dual-pivot":
489
+ tags.append("medium")
490
+ if phase == 2:
491
+ tags.append("phase2")
492
+ if attack_stage in {"exploit", "exfiltration"}:
493
+ tags.append("late_stage")
494
+ if tool_calls >= 3:
495
+ tags.append("tool_fusion")
496
+ if handoff_quality == "degraded":
497
+ tags.append("borderline")
498
+ return tags
499
+
500
+
501
+ def _prompt_bank_extra_copies(
502
+ row: Dict[str, Any],
503
+ hard_multiplier: int,
504
+ borderline_bonus: int,
505
+ ) -> int:
506
+ tags = set(row.get("difficulty_tags", []) or [])
507
+ extra = 0
508
+ if row.get("task") == "polymorphic-zero-day":
509
+ extra += max(0, hard_multiplier - 1)
510
+ elif row.get("task") == "dual-pivot" and "late_stage" in tags:
511
+ extra += 1
512
+ if "borderline" in tags or ("phase2" in tags and "tool_fusion" in tags and "late_stage" in tags):
513
+ extra += max(0, borderline_bonus)
514
+ return extra
515
+
516
+
517
+ def _completion_to_text(completion: Any) -> str:
518
+ if isinstance(completion, str):
519
+ return completion
520
+ if isinstance(completion, dict):
521
+ if "content" in completion:
522
+ return str(completion.get("content", ""))
523
+ if "text" in completion:
524
+ return str(completion.get("text", ""))
525
+ if isinstance(completion, list):
526
+ parts = []
527
+ for item in completion:
528
+ if isinstance(item, dict):
529
+ parts.append(str(item.get("content", item.get("text", ""))))
530
+ else:
531
+ parts.append(str(item))
532
+ return "".join(parts)
533
+ return str(completion)
534
+
535
+
536
+ def _phase1_reward(
537
+ parsed: Dict[str, Any],
538
+ expected_threat_type: str,
539
+ expected_target_node: str,
540
+ expected_recommended_action: str,
541
+ ) -> float:
542
+ reward = 0.08
543
+ if parsed.get("threat_type") == expected_threat_type:
544
+ reward += 0.36
545
+ if parsed.get("target_node") == expected_target_node:
546
+ reward += 0.20
547
+ if parsed.get("recommended_action") == expected_recommended_action:
548
+ reward += 0.18
549
+ try:
550
+ confidence = float(parsed.get("confidence", 0.5))
551
+ except Exception:
552
+ confidence = 0.5
553
+ if 0.0 <= confidence <= 1.0:
554
+ reward += 0.05
555
+ if parsed.get("threat_type") == expected_threat_type and confidence >= 0.65:
556
+ reward += 0.06
557
+ elif parsed.get("threat_type") != expected_threat_type and confidence >= 0.80:
558
+ reward -= 0.05
559
+ if parsed.get("recommended_action") == "monitor" and expected_threat_type != "benign":
560
+ reward -= 0.05
561
+ return max(0.01, min(0.99, round(reward, 2)))
562
+
563
+
564
+ def _phase2_reward(
565
+ parsed: Dict[str, Any],
566
+ expected_action: str,
567
+ expected_target_node: str,
568
+ tool_calls: int,
569
+ ) -> float:
570
+ reward = 0.08
571
+ if parsed.get("action") == expected_action:
572
+ reward += 0.62
573
+ if parsed.get("target_node") == expected_target_node:
574
+ reward += 0.18
575
+ if parsed.get("action") == expected_action and tool_calls >= 2:
576
+ reward += 0.07
577
+ if parsed.get("action") == "monitor" and expected_action != "monitor":
578
+ reward -= 0.08
579
+ return max(0.01, min(0.99, round(reward, 2)))
580
+
581
+
582
+ def build_grpo_reward_fn():
583
+ def reward_fn(completions: List[Any], **kwargs: Any) -> List[float]:
584
+ phases = kwargs.get("phase", [])
585
+ expected_threat_types = kwargs.get("expected_threat_type", [])
586
+ expected_targets = kwargs.get("expected_target_node", [])
587
+ expected_recommended_actions = kwargs.get("expected_recommended_action", [])
588
+ expected_actions = kwargs.get("expected_action", [])
589
+ tool_calls = kwargs.get("tool_calls", [])
590
+ rewards: List[float] = []
591
+ for index, completion in enumerate(completions):
592
+ phase = int(phases[index]) if phases else 1
593
+ text = _completion_to_text(completion)
594
+ parsed = parse_response(text, phase)
595
+ if phase == 1:
596
+ reward = _phase1_reward(
597
+ parsed=parsed,
598
+ expected_threat_type=str(expected_threat_types[index]),
599
+ expected_target_node=str(expected_targets[index]),
600
+ expected_recommended_action=str(expected_recommended_actions[index]),
601
+ )
602
+ else:
603
+ reward = _phase2_reward(
604
+ parsed=parsed,
605
+ expected_action=str(expected_actions[index]),
606
+ expected_target_node=str(expected_targets[index]),
607
+ tool_calls=int(tool_calls[index]) if tool_calls else 0,
608
+ )
609
+ rewards.append(reward)
610
+ return rewards
611
+
612
+ return reward_fn
613
+
614
+
615
+ def _filter_supported_kwargs(callable_obj: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
616
+ try:
617
+ signature = inspect.signature(callable_obj)
618
+ except (TypeError, ValueError):
619
+ return kwargs
620
+ valid = {}
621
+ for key, value in kwargs.items():
622
+ if key in signature.parameters:
623
+ valid[key] = value
624
+ return valid
625
+
626
+
627
+ def _trainer_log_rows(log_history: List[Dict[str, Any]], selected_task: str) -> List[Dict[str, Any]]:
628
+ rows: List[Dict[str, Any]] = []
629
+ for entry in log_history:
630
+ step = entry.get("step")
631
+ if step is None:
632
+ continue
633
+ reward_keys = [
634
+ "reward",
635
+ "mean_reward",
636
+ "rewards/mean",
637
+ "objective",
638
+ "objective/rlhf_reward",
639
+ ]
640
+ score = None
641
+ for key in reward_keys:
642
+ if key in entry:
643
+ try:
644
+ score = float(entry[key])
645
+ break
646
+ except Exception:
647
+ continue
648
+ if score is None:
649
+ score = 0.50
650
+ row = {
651
+ "episode": int(step),
652
+ "task": "mixed" if selected_task == "all" else selected_task,
653
+ "stage": "grpo",
654
+ "score": max(0.01, min(0.99, score)),
655
+ "loss": float(entry.get("loss", 0.0) or 0.0),
656
+ "learning_rate": float(entry.get("learning_rate", 0.0) or 0.0),
657
+ }
658
+ rows.append(row)
659
+ return rows
660
+
661
+
662
+ def evaluate_model_suite(
663
+ model: Any,
664
+ tokenizer: Any,
665
+ selected_task: str,
666
+ eval_episodes: int,
667
+ max_steps: int,
668
+ use_tools: bool,
669
+ world_split: str = "train",
670
+ world_family: str | None = None,
671
+ seed_start: int | None = None,
672
+ ) -> List[Dict[str, Any]]:
673
+ tasks = TASKS if selected_task == "all" else [selected_task]
674
+ rows: List[Dict[str, Any]] = []
675
+ for task in tasks:
676
+ scores: List[float] = []
677
+ steps: List[int] = []
678
+ tool_calls: List[int] = []
679
+ original_seed = os.environ.get("ADAPTSHIELD_SEED")
680
+ for episode_index in range(eval_episodes):
681
+ if seed_start is not None:
682
+ os.environ["ADAPTSHIELD_SEED"] = str(seed_start + len(rows) * 100 + episode_index)
683
+ _, metrics = run_model_episode(
684
+ model=model,
685
+ tokenizer=tokenizer,
686
+ task=task,
687
+ max_steps=max_steps,
688
+ use_tools=use_tools,
689
+ world_split=world_split,
690
+ world_family=world_family,
691
+ )
692
+ scores.append(float(metrics["score"]))
693
+ steps.append(int(metrics["steps"]))
694
+ tool_calls.append(int(metrics["tool_calls"]))
695
+ if original_seed is None:
696
+ os.environ.pop("ADAPTSHIELD_SEED", None)
697
+ else:
698
+ os.environ["ADAPTSHIELD_SEED"] = original_seed
699
+ rows.append({
700
+ "episode": len(rows) + 1,
701
+ "task": task,
702
+ "stage": "evaluation",
703
+ "score": round(sum(scores) / len(scores), 3) if scores else 0.50,
704
+ "steps": round(sum(steps) / len(steps), 2) if steps else 0.0,
705
+ "tool_calls": round(sum(tool_calls) / len(tool_calls), 2) if tool_calls else 0.0,
706
+ "eval_episodes": eval_episodes,
707
+ "world_split": world_split,
708
+ "world_family": world_family or "auto",
709
+ })
710
+ return rows
711
+
712
+
713
+ def run_model_episode(
714
+ model: Any,
715
+ tokenizer: Any,
716
+ task: str,
717
+ max_steps: int,
718
+ use_tools: bool,
719
+ world_split: str = "train",
720
+ world_family: str | None = None,
721
+ ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
722
+ env = AdaptShieldEnvironment(
723
+ task_name=task,
724
+ world_split=world_split,
725
+ world_family=world_family,
726
+ )
727
+ obs = env.reset()
728
+ samples: List[Dict[str, Any]] = []
729
+ rewards: List[float] = []
730
+ tool_calls = 0
731
+
732
+ while not obs.done and len(samples) < max_steps:
733
+ phase = int(getattr(obs, "phase", 1))
734
+ tool_results = investigate_local_with_depth(
735
+ env,
736
+ obs,
737
+ use_tools=use_tools,
738
+ thorough=True,
739
+ )
740
+ tool_calls += len(tool_results)
741
+ obs_dict = obs_to_dict(obs)
742
+ obs_dict = attach_tool_results(obs_dict, tool_results)
743
+ messages = build_messages(obs_dict)
744
+ prompt, response = generate_response(model, tokenizer, messages)
745
+ payload = parse_response(response, phase)
746
+
747
+ try:
748
+ obs = env.step(AdaptShieldAction(**payload))
749
+ reward = float(obs.reward)
750
+ except Exception as exc:
751
+ reward = 0.01
752
+ samples.append({
753
+ "prompt": prompt,
754
+ "response": response,
755
+ "reward": reward,
756
+ "phase": phase,
757
+ "tool_calls": len(tool_results),
758
+ "error": str(exc),
759
+ })
760
+ break
761
+
762
+ rewards.append(reward)
763
+ samples.append({
764
+ "prompt": prompt,
765
+ "response": response,
766
+ "reward": reward,
767
+ "phase": phase,
768
+ "tool_calls": len(tool_results),
769
+ "error": None,
770
+ })
771
+
772
+ metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
773
+ if "normalized_score" not in metadata:
774
+ raise RuntimeError("normalized_score missing after training episode")
775
+
776
+ return samples, {
777
+ "score": float(metadata["normalized_score"]),
778
+ "steps": len(samples),
779
+ "reward_sum": sum(rewards),
780
+ "mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
781
+ "tool_calls": tool_calls,
782
+ "world_split": world_split,
783
+ "world_family": metadata.get("world_family", world_family or "auto"),
784
+ "operational_mode": metadata.get("operational_mode", "unknown"),
785
+ }
786
+
787
+
788
+ def train_policy_gradient(args: argparse.Namespace) -> None:
789
+ import torch
790
+ from torch.optim import AdamW
791
+
792
+ random.seed(args.seed)
793
+ torch.manual_seed(args.seed)
794
+
795
+ model_name = args.model_path or MODEL_CHOICES[args.model]
796
+ output_dir = Path(args.output)
797
+ output_dir.mkdir(parents=True, exist_ok=True)
798
+
799
+ print("AdaptShield policy-gradient GPU training")
800
+ print(f"Task: {args.task}")
801
+ print(f"Curriculum: {args.curriculum}")
802
+ print(f"Use tools: {args.use_tools}")
803
+ print(f"Model: {model_name}")
804
+ print(f"Episodes: {args.episodes}")
805
+ print(f"Output: {output_dir}")
806
+ print()
807
+
808
+ model, tokenizer = _load_training_model_and_tokenizer(
809
+ model_name=model_name,
810
+ model_key=args.model,
811
+ max_seq_length=MAX_SEQ_LEN,
812
+ compute_dtype=None,
813
+ seed=args.seed,
814
+ )
815
+ from unsloth import FastLanguageModel
816
+ FastLanguageModel.for_training(model)
817
+ dtype_summary = _align_trainable_dtypes(model)
818
+ print(f"Aligned trainable parameter dtypes: {dtype_summary}")
819
+ optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
820
+
821
+ rows: List[Dict[str, Any]] = []
822
+ best_score = -1.0
823
+ for episode in range(1, args.episodes + 1):
824
+ started = time.time()
825
+ task, stage = task_for_episode(
826
+ episode=episode,
827
+ total_episodes=args.episodes,
828
+ selected_task=args.task,
829
+ curriculum=args.curriculum,
830
+ )
831
+ samples, metrics = run_model_episode(
832
+ model=model,
833
+ tokenizer=tokenizer,
834
+ task=task,
835
+ max_steps=args.max_steps,
836
+ use_tools=args.use_tools,
837
+ world_split=args.train_world_split,
838
+ )
839
+ rewards = [float(sample["reward"]) for sample in samples]
840
+ baseline = sum(rewards) / len(rewards) if rewards else 0.0
841
+ total_loss = 0.0
842
+
843
+ for sample in samples:
844
+ advantage = float(sample["reward"]) - baseline
845
+ full_text = sample["prompt"] + sample["response"] + tokenizer.eos_token
846
+ inputs = tokenizer(
847
+ full_text,
848
+ return_tensors="pt",
849
+ truncation=True,
850
+ max_length=MAX_SEQ_LEN,
851
+ ).to("cuda")
852
+ outputs = model(**inputs, labels=inputs["input_ids"])
853
+ loss = outputs.loss * (-advantage)
854
+
855
+ optimizer.zero_grad()
856
+ loss.backward()
857
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
858
+ optimizer.step()
859
+ total_loss += float(loss.item())
860
+
861
+ row = {
862
+ "episode": episode,
863
+ "task": task,
864
+ "stage": stage,
865
+ "score": metrics["score"],
866
+ "steps": metrics["steps"],
867
+ "reward_sum": metrics["reward_sum"],
868
+ "mean_reward": metrics["mean_reward"],
869
+ "tool_calls": metrics["tool_calls"],
870
+ "loss": total_loss,
871
+ "seconds": round(time.time() - started, 2),
872
+ }
873
+ rows.append(row)
874
+
875
+ print(
876
+ f"episode={episode:03d} task={task:<20} "
877
+ f"stage={stage:<32} "
878
+ f"score={row['score']:.3f} mean_reward={row['mean_reward']:.3f} "
879
+ f"loss={row['loss']:.4f} steps={row['steps']:02d} tools={row['tool_calls']:02d}"
880
+ )
881
+
882
+ if row["score"] > best_score:
883
+ best_score = row["score"]
884
+ model.save_pretrained(output_dir / "best")
885
+ tokenizer.save_pretrained(output_dir / "best")
886
+
887
+ if args.save_every and episode % args.save_every == 0:
888
+ model.save_pretrained(output_dir / f"checkpoint-{episode}")
889
+ tokenizer.save_pretrained(output_dir / f"checkpoint-{episode}")
890
+
891
+ model.save_pretrained(output_dir / "final")
892
+ tokenizer.save_pretrained(output_dir / "final")
893
+
894
+ evaluation_rows = evaluate_model_suite(
895
+ model=model,
896
+ tokenizer=tokenizer,
897
+ selected_task=args.task,
898
+ eval_episodes=args.eval_episodes,
899
+ max_steps=args.max_steps,
900
+ use_tools=args.use_tools,
901
+ world_split=args.train_world_split,
902
+ seed_start=args.heldout_seed,
903
+ )
904
+ heldout_evaluation_rows = evaluate_model_suite(
905
+ model=model,
906
+ tokenizer=tokenizer,
907
+ selected_task=args.task,
908
+ eval_episodes=args.eval_episodes,
909
+ max_steps=args.max_steps,
910
+ use_tools=args.use_tools,
911
+ world_split=args.heldout_world_split,
912
+ seed_start=args.heldout_seed,
913
+ )
914
+
915
+ metrics_path = save_metrics(
916
+ output_dir=output_dir,
917
+ rows=rows,
918
+ model_name=model_name,
919
+ episodes=args.episodes,
920
+ curriculum=args.curriculum,
921
+ use_tools=args.use_tools,
922
+ trainer="pg",
923
+ evaluation_rows=evaluation_rows,
924
+ heldout_evaluation_rows=heldout_evaluation_rows,
925
+ extra={
926
+ "train_world_split": args.train_world_split,
927
+ "heldout_world_split": args.heldout_world_split,
928
+ "heldout_seed": args.heldout_seed,
929
+ },
930
+ )
931
+ if args.plot:
932
+ maybe_plot(metrics_path, output_dir)
933
+ print()
934
+ print(f"Training complete. Best score: {best_score:.3f}")
935
+ print("Post-train online evaluation:")
936
+ for row in evaluation_rows:
937
+ print(
938
+ f" task={row['task']:<20} score={row['score']:.3f} "
939
+ f"steps={row['steps']} tools={row['tool_calls']}"
940
+ )
941
+ print("Held-out family evaluation:")
942
+ for row in heldout_evaluation_rows:
943
+ print(
944
+ f" task={row['task']:<20} score={row['score']:.3f} "
945
+ f"steps={row['steps']} tools={row['tool_calls']}"
946
+ )
947
+ print(f"Metrics saved to: {metrics_path}")
948
+
949
+
950
+ def train_grpo(args: argparse.Namespace) -> None:
951
+ from datasets import Dataset
952
+ from trl import GRPOConfig, GRPOTrainer
953
+ import torch
954
+
955
+ random.seed(args.seed)
956
+ torch.manual_seed(args.seed)
957
+
958
+ model_name = args.model_path or MODEL_CHOICES[args.model]
959
+ output_dir = Path(args.output)
960
+ output_dir.mkdir(parents=True, exist_ok=True)
961
+
962
+ print("AdaptShield GRPO training")
963
+ print(f"Task: {args.task}")
964
+ print(f"Curriculum: {args.curriculum}")
965
+ print(f"Use tools: {args.use_tools}")
966
+ print(f"Model: {model_name}")
967
+ print(f"Prompt-bank episodes: {args.prompt_bank_episodes}")
968
+ print(f"GRPO epochs: {args.grpo_epochs}")
969
+ print(f"Eval episodes: {args.eval_episodes}")
970
+ print(f"Output: {output_dir}")
971
+ print()
972
+
973
+ bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
974
+ compute_dtype = torch.bfloat16 if bf16_supported else torch.float16
975
+ model, tokenizer = _load_training_model_and_tokenizer(
976
+ model_name=model_name,
977
+ model_key=args.model,
978
+ max_seq_length=MAX_SEQ_LEN,
979
+ compute_dtype=compute_dtype,
980
+ seed=args.seed,
981
+ )
982
+ from unsloth import FastLanguageModel
983
+ if getattr(tokenizer, "pad_token", None) is None:
984
+ tokenizer.pad_token = tokenizer.eos_token
985
+ if getattr(model, "config", None) is not None:
986
+ try:
987
+ model.config.return_dict = True
988
+ except Exception:
989
+ pass
990
+ try:
991
+ model.config.use_cache = False
992
+ except Exception:
993
+ pass
994
+ if getattr(model, "generation_config", None) is not None:
995
+ try:
996
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
997
+ except Exception:
998
+ pass
999
+ FastLanguageModel.for_training(model)
1000
+ dtype_summary = _align_trainable_dtypes(model, target_dtype=compute_dtype)
1001
+ print(f"Using GRPO compute dtype: {compute_dtype}")
1002
+ print(f"Aligned trainable parameter dtypes: {dtype_summary}")
1003
+
1004
+ prompt_bank = build_prompt_bank(
1005
+ tokenizer=tokenizer,
1006
+ selected_task=args.task,
1007
+ curriculum=args.curriculum,
1008
+ rollout_episodes=args.prompt_bank_episodes,
1009
+ max_steps=args.max_steps,
1010
+ use_tools=args.use_tools,
1011
+ seed=args.seed,
1012
+ world_split=args.train_world_split,
1013
+ hard_multiplier=args.prompt_bank_hard_multiplier,
1014
+ borderline_bonus=args.prompt_bank_borderline_bonus,
1015
+ )
1016
+ if not prompt_bank:
1017
+ raise RuntimeError("Prompt bank is empty; cannot start GRPO training.")
1018
+
1019
+ dataset = Dataset.from_list(prompt_bank)
1020
+ reward_fn = build_grpo_reward_fn()
1021
+
1022
+ config_kwargs = {
1023
+ "output_dir": str(output_dir),
1024
+ "learning_rate": args.lr,
1025
+ "per_device_train_batch_size": args.per_device_batch_size,
1026
+ "gradient_accumulation_steps": args.gradient_accumulation_steps,
1027
+ "num_train_epochs": args.grpo_epochs,
1028
+ "max_prompt_length": MAX_SEQ_LEN - 256,
1029
+ "max_completion_length": 256,
1030
+ "num_generations": args.num_generations,
1031
+ "logging_steps": 1,
1032
+ "save_strategy": "no" if args.save_every <= 0 else "steps",
1033
+ "report_to": "none",
1034
+ "remove_unused_columns": False,
1035
+ "bf16": bf16_supported,
1036
+ "fp16": not bf16_supported,
1037
+ "max_grad_norm": 1.0,
1038
+ "seed": args.seed,
1039
+ }
1040
+ if args.save_every > 0:
1041
+ config_kwargs["save_steps"] = args.save_every
1042
+ grpo_config = GRPOConfig(**_filter_supported_kwargs(GRPOConfig, config_kwargs))
1043
+
1044
+ trainer_kwargs = {
1045
+ "model": model,
1046
+ "reward_funcs": [reward_fn],
1047
+ "args": grpo_config,
1048
+ "train_dataset": dataset,
1049
+ "processing_class": tokenizer,
1050
+ "tokenizer": tokenizer,
1051
+ }
1052
+ trainer = GRPOTrainer(**_filter_supported_kwargs(GRPOTrainer, trainer_kwargs))
1053
+ trainer.train()
1054
+
1055
+ model.save_pretrained(output_dir / "final")
1056
+ tokenizer.save_pretrained(output_dir / "final")
1057
+
1058
+ log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
1059
+ train_rows = _trainer_log_rows(log_history, selected_task=args.task)
1060
+ if not train_rows:
1061
+ train_rows = [{
1062
+ "episode": index + 1,
1063
+ "task": "mixed" if args.task == "all" else args.task,
1064
+ "stage": "grpo",
1065
+ "score": 0.50,
1066
+ } for index in range(max(1, args.grpo_epochs))]
1067
+
1068
+ try:
1069
+ evaluation_rows = evaluate_model_suite(
1070
+ model=model,
1071
+ tokenizer=tokenizer,
1072
+ selected_task=args.task,
1073
+ eval_episodes=args.eval_episodes,
1074
+ max_steps=args.max_steps,
1075
+ use_tools=args.use_tools,
1076
+ world_split=args.train_world_split,
1077
+ seed_start=args.heldout_seed,
1078
+ )
1079
+ except Exception as exc:
1080
+ print(f"GRPO in-distribution evaluation failed: {exc}")
1081
+ evaluation_rows = []
1082
+ try:
1083
+ heldout_evaluation_rows = evaluate_model_suite(
1084
+ model=model,
1085
+ tokenizer=tokenizer,
1086
+ selected_task=args.task,
1087
+ eval_episodes=args.eval_episodes,
1088
+ max_steps=args.max_steps,
1089
+ use_tools=args.use_tools,
1090
+ world_split=args.heldout_world_split,
1091
+ seed_start=args.heldout_seed,
1092
+ )
1093
+ except Exception as exc:
1094
+ print(f"GRPO held-out evaluation failed: {exc}")
1095
+ heldout_evaluation_rows = []
1096
+
1097
+ metrics_path = save_metrics(
1098
+ output_dir=output_dir,
1099
+ rows=train_rows,
1100
+ model_name=model_name,
1101
+ episodes=max(1, len(train_rows)),
1102
+ curriculum=args.curriculum,
1103
+ use_tools=args.use_tools,
1104
+ trainer="grpo",
1105
+ evaluation_rows=evaluation_rows,
1106
+ heldout_evaluation_rows=heldout_evaluation_rows,
1107
+ prompt_bank_size=len(prompt_bank),
1108
+ extra={
1109
+ "train_world_split": args.train_world_split,
1110
+ "heldout_world_split": args.heldout_world_split,
1111
+ "heldout_seed": args.heldout_seed,
1112
+ "base_model": model_name,
1113
+ },
1114
+ )
1115
+ if args.plot:
1116
+ maybe_plot(metrics_path, output_dir)
1117
+ print("GRPO training complete.")
1118
+ print(f"Prompt bank size: {len(prompt_bank)}")
1119
+ print("Post-train online evaluation:")
1120
+ for row in evaluation_rows:
1121
+ print(
1122
+ f" task={row['task']:<20} score={row['score']:.3f} "
1123
+ f"steps={row['steps']} tools={row['tool_calls']}"
1124
+ )
1125
+ print("Held-out family evaluation:")
1126
+ for row in heldout_evaluation_rows:
1127
+ print(
1128
+ f" task={row['task']:<20} score={row['score']:.3f} "
1129
+ f"steps={row['steps']} tools={row['tool_calls']}"
1130
+ )
1131
+ if log_history:
1132
+ final_keys = sorted(log_history[-1].keys())
1133
+ print(f"Trainer log keys: {final_keys}")
1134
+ print(f"Metrics saved to: {metrics_path}")
1135
+
1136
+
1137
+ def _looks_like_adapter_path(model_name: str) -> bool:
1138
+ path = Path(str(model_name))
1139
+ return path.exists() and (path / "adapter_config.json").exists()
1140
+
1141
+
1142
+ def run_fallback_smoke(args: argparse.Namespace) -> None:
1143
+ if args.use_tools:
1144
+ run_tool_fallback_smoke(args)
1145
+ return
1146
+
1147
+ if args.curriculum:
1148
+ tasks = [
1149
+ task_for_episode(
1150
+ episode=episode,
1151
+ total_episodes=min(args.episodes, args.smoke_episodes),
1152
+ selected_task=args.task,
1153
+ curriculum=True,
1154
+ )[0]
1155
+ for episode in range(1, min(args.episodes, args.smoke_episodes) + 1)
1156
+ ]
1157
+ else:
1158
+ tasks = TASKS if args.task == "all" else [args.task]
1159
+
1160
+ rows = run_smoke_training(
1161
+ tasks=tasks,
1162
+ episodes=min(args.episodes, args.smoke_episodes),
1163
+ output=Path(args.output) / "train_smoke.csv",
1164
+ seed=args.seed,
1165
+ epsilon=0.85,
1166
+ epsilon_decay=0.94,
1167
+ epsilon_floor=0.08,
1168
+ lr=0.35,
1169
+ max_steps=args.max_steps,
1170
+ )
1171
+ output_dir = Path(args.output)
1172
+ metrics_rows = []
1173
+ for row in rows:
1174
+ row = dict(row)
1175
+ episode = int(row["episode"])
1176
+ _, stage = task_for_episode(
1177
+ episode=episode,
1178
+ total_episodes=min(args.episodes, args.smoke_episodes),
1179
+ selected_task=args.task,
1180
+ curriculum=args.curriculum,
1181
+ )
1182
+ row["stage"] = stage
1183
+ metrics_rows.append(row)
1184
+
1185
+ metrics_path = save_metrics(
1186
+ output_dir=output_dir,
1187
+ rows=metrics_rows,
1188
+ model_name="smoke-tabular-policy",
1189
+ episodes=min(args.episodes, args.smoke_episodes),
1190
+ curriculum=args.curriculum,
1191
+ use_tools=False,
1192
+ )
1193
+ print(f"Metrics saved to: {metrics_path}")
1194
+ if args.plot:
1195
+ maybe_plot(metrics_path, output_dir)
1196
+
1197
+
1198
+ def run_tool_fallback_smoke(args: argparse.Namespace) -> None:
1199
+ """No-GPU tool-aware rehearsal. This validates flow, not model learning."""
1200
+ from tool_baseline import run_task as run_tool_task
1201
+
1202
+ total = min(args.episodes, args.smoke_episodes)
1203
+ if args.curriculum:
1204
+ tasks = [
1205
+ task_for_episode(
1206
+ episode=episode,
1207
+ total_episodes=total,
1208
+ selected_task=args.task,
1209
+ curriculum=True,
1210
+ )[0]
1211
+ for episode in range(1, total + 1)
1212
+ ]
1213
+ else:
1214
+ tasks = TASKS if args.task == "all" else [args.task]
1215
+
1216
+ print("AdaptShield tool-aware smoke evaluation")
1217
+ print("Mode: no-GPU flow validation, not model learning")
1218
+ print(f"Tasks: {', '.join(tasks)}")
1219
+ print(f"Episodes: {total}")
1220
+ print()
1221
+
1222
+ rows: List[Dict[str, Any]] = []
1223
+ for episode in range(1, total + 1):
1224
+ task = tasks[(episode - 1) % len(tasks)]
1225
+ result = run_tool_task(task, emit_logs=False)
1226
+ metadata = result.get("metadata", {})
1227
+ tool_calls = len(metadata.get("tool_trace", [])) if isinstance(metadata, dict) else 0
1228
+ _, stage = task_for_episode(
1229
+ episode=episode,
1230
+ total_episodes=total,
1231
+ selected_task=args.task,
1232
+ curriculum=args.curriculum,
1233
+ )
1234
+ row = {
1235
+ "episode": episode,
1236
+ "task": task,
1237
+ "stage": stage,
1238
+ "score": result["score"],
1239
+ "steps": result["steps"],
1240
+ "reward_sum": sum(result["rewards"]),
1241
+ "mean_reward": sum(result["rewards"]) / len(result["rewards"]) if result["rewards"] else 0.0,
1242
+ "tool_calls": tool_calls,
1243
+ "status": "PASS" if result["success"] else "FAIL",
1244
+ }
1245
+ rows.append(row)
1246
+ print(
1247
+ f"episode={episode:03d} task={task:<20} "
1248
+ f"score={row['score']:.3f} steps={row['steps']:02d} "
1249
+ f"tools={tool_calls:02d} {row['status']}"
1250
+ )
1251
+
1252
+ output_dir = Path(args.output)
1253
+ metrics_path = save_metrics(
1254
+ output_dir=output_dir,
1255
+ rows=rows,
1256
+ model_name="tool-aware-smoke-policy",
1257
+ episodes=total,
1258
+ curriculum=args.curriculum,
1259
+ use_tools=True,
1260
+ )
1261
+ print(f"Metrics saved to: {metrics_path}")
1262
+ if args.plot:
1263
+ maybe_plot(metrics_path, output_dir)
1264
+
1265
+
1266
+ def parse_args() -> argparse.Namespace:
1267
+ parser = argparse.ArgumentParser(description="AdaptShield training harness.")
1268
+ parser.add_argument("--task", default="direct-triage", choices=TASKS + ["all"])
1269
+ parser.add_argument("--model", default=DEFAULT_MODEL, choices=list(MODEL_CHOICES))
1270
+ parser.add_argument("--model-path", default="", help="Optional local/HF adapter path to continue training from.")
1271
+ parser.add_argument("--episodes", type=int, default=60)
1272
+ parser.add_argument("--max-steps", type=int, default=30)
1273
+ parser.add_argument("--output", default="checkpoints/adaptshield")
1274
+ parser.add_argument("--seed", type=int, default=42)
1275
+ parser.add_argument("--lr", type=float, default=1e-5)
1276
+ parser.add_argument("--save-every", type=int, default=20)
1277
+ parser.add_argument("--smoke", action="store_true", help="Force dependency-free smoke mode.")
1278
+ parser.add_argument("--smoke-episodes", type=int, default=30)
1279
+ parser.add_argument("--curriculum", action="store_true", help="Train direct -> dual -> hard instead of fixed/round-robin tasks.")
1280
+ parser.add_argument("--use-tools", action="store_true", help="Let GPU training query SOC tools before hard-task actions.")
1281
+ parser.add_argument("--plot", action="store_true", help="Generate reward_curve.png from metrics.json after training.")
1282
+ parser.add_argument("--trainer", default="auto", choices=["auto", "pg", "grpo"], help="Training backend: safe policy-gradient fallback or TRL GRPO.")
1283
+ parser.add_argument("--prompt-bank-episodes", type=int, default=24, help="Reference rollout episodes used to build the GRPO prompt bank.")
1284
+ parser.add_argument("--prompt-bank-hard-multiplier", type=int, default=2, help="Duplicate hard-task GRPO prompts this many times to emphasize difficult slices.")
1285
+ parser.add_argument("--prompt-bank-borderline-bonus", type=int, default=1, help="Extra copies for degraded-handoff / borderline GRPO prompts.")
1286
+ parser.add_argument("--grpo-epochs", type=int, default=1, help="Number of epochs over the prompt bank for GRPO runs.")
1287
+ parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt when TRL path is active.")
1288
+ parser.add_argument("--per-device-batch-size", type=int, default=1, help="Per-device batch size for GRPO training.")
1289
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=4, help="Gradient accumulation for GRPO training.")
1290
+ parser.add_argument("--eval-episodes", type=int, default=2, help="Online environment episodes per task after GPU training.")
1291
+ parser.add_argument("--train-world-split", default="train", choices=["train", "eval"], help="World split used for training/prompt-bank generation.")
1292
+ parser.add_argument("--heldout-world-split", default="eval", choices=["train", "eval"], help="World split used for held-out evaluation.")
1293
+ parser.add_argument("--heldout-seed", type=int, default=314, help="Seed offset used for held-out evaluation episodes.")
1294
+ return parser.parse_args()
1295
+
1296
+
1297
+ def main() -> int:
1298
+ args = parse_args()
1299
+ if args.smoke:
1300
+ run_fallback_smoke(args)
1301
+ return 0
1302
+
1303
+ trainer_choice = args.trainer
1304
+ if trainer_choice == "auto":
1305
+ try:
1306
+ import datasets # noqa: F401
1307
+ import trl # noqa: F401
1308
+ trainer_choice = "grpo"
1309
+ except ImportError:
1310
+ trainer_choice = "pg"
1311
+
1312
+ try:
1313
+ if trainer_choice == "grpo":
1314
+ train_grpo(args)
1315
+ else:
1316
+ train_policy_gradient(args)
1317
+ except ImportError as exc:
1318
+ print(f"GPU training dependency missing for trainer={trainer_choice}: {exc}")
1319
+ if trainer_choice == "grpo":
1320
+ print("Falling back to policy-gradient GPU trainer.")
1321
+ try:
1322
+ train_policy_gradient(args)
1323
+ return 0
1324
+ except ImportError as nested_exc:
1325
+ print(f"Policy-gradient fallback also unavailable: {nested_exc}")
1326
+ print("Falling back to dependency-free smoke training.")
1327
+ run_fallback_smoke(args)
1328
+ return 0
1329
+
1330
+
1331
+ if __name__ == "__main__":
1332
+ raise SystemExit(main())
train_sft.py ADDED
@@ -0,0 +1,611 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Supervised fine-tuning for AdaptShield chat-style demonstrations."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import random
10
+ from pathlib import Path
11
+ from typing import Any, Dict, List
12
+
13
+ from train import (
14
+ DEFAULT_MODEL,
15
+ LORA_RANK,
16
+ MAX_SEQ_LEN,
17
+ MODEL_CHOICES,
18
+ _align_trainable_dtypes,
19
+ _filter_supported_kwargs,
20
+ _normalize_generation_config,
21
+ evaluate_model_suite,
22
+ run_model_episode,
23
+ )
24
+
25
+
26
+ def load_jsonl(path: Path) -> List[Dict[str, Any]]:
27
+ rows: List[Dict[str, Any]] = []
28
+ with path.open("r", encoding="utf-8") as handle:
29
+ for line in handle:
30
+ line = line.strip()
31
+ if not line:
32
+ continue
33
+ rows.append(json.loads(line))
34
+ if not rows:
35
+ raise RuntimeError(f"No training rows found in {path}")
36
+ return rows
37
+
38
+
39
+ def build_loss_plot(log_history: List[Dict[str, Any]], output_path: Path) -> None:
40
+ try:
41
+ import matplotlib.pyplot as plt
42
+ except ImportError:
43
+ print("matplotlib not installed; skipping loss plot")
44
+ return
45
+
46
+ xs: List[int] = []
47
+ ys: List[float] = []
48
+ for index, entry in enumerate(log_history, start=1):
49
+ if "loss" not in entry:
50
+ continue
51
+ step = int(entry.get("step", index) or index)
52
+ try:
53
+ loss = float(entry["loss"])
54
+ except Exception:
55
+ continue
56
+ xs.append(step)
57
+ ys.append(loss)
58
+
59
+ if not xs:
60
+ print("No loss entries found; skipping loss plot")
61
+ return
62
+
63
+ plt.figure(figsize=(10, 5))
64
+ plt.plot(xs, ys, color="#0f4c81", linewidth=2, label="training loss")
65
+ plt.xlabel("Training step")
66
+ plt.ylabel("Loss")
67
+ plt.title("AdaptShield SFT Loss Curve")
68
+ plt.grid(alpha=0.3)
69
+ plt.legend()
70
+ plt.tight_layout()
71
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
72
+ plt.close()
73
+
74
+
75
+ def build_reward_plot(rows: List[Dict[str, Any]], output_path: Path) -> None:
76
+ try:
77
+ import matplotlib.pyplot as plt
78
+ except ImportError:
79
+ print("matplotlib not installed; skipping reward plot")
80
+ return
81
+
82
+ if not rows:
83
+ print("No held-out reward rows found; skipping reward plot")
84
+ return
85
+
86
+ checkpoint_labels = [str(row["checkpoint"]) for row in rows]
87
+ in_distribution_scores = [float(row["in_distribution_score"]) for row in rows]
88
+ heldout_scores = [float(row["heldout_score"]) for row in rows]
89
+
90
+ plt.figure(figsize=(10, 5))
91
+ plt.plot(
92
+ range(len(rows)),
93
+ in_distribution_scores,
94
+ color="#136f63",
95
+ linewidth=2.5,
96
+ marker="o",
97
+ label="in-distribution mean reward",
98
+ )
99
+ plt.plot(
100
+ range(len(rows)),
101
+ heldout_scores,
102
+ color="#8a3ffc",
103
+ linewidth=2.5,
104
+ marker="s",
105
+ label="held-out family mean reward",
106
+ )
107
+ plt.xticks(range(len(rows)), checkpoint_labels, rotation=35, ha="right")
108
+ plt.xlabel("Checkpoint")
109
+ plt.ylabel("normalized_score")
110
+ plt.title("AdaptShield In-Distribution vs Held-out Reward Curve")
111
+ plt.ylim(0.0, 1.0)
112
+ plt.grid(alpha=0.3)
113
+ plt.legend()
114
+ plt.tight_layout()
115
+ plt.savefig(output_path, dpi=150, bbox_inches="tight")
116
+ plt.close()
117
+
118
+
119
+ def render_example(example: Dict[str, Any], tokenizer: Any) -> str:
120
+ if "messages" in example:
121
+ return tokenizer.apply_chat_template(
122
+ example["messages"],
123
+ tokenize=False,
124
+ add_generation_prompt=False,
125
+ )
126
+ return str(example["text"])
127
+
128
+
129
+ def _checkpoint_sort_key(path: Path) -> tuple[int, str]:
130
+ if path.name == "final":
131
+ return (10**9, path.name)
132
+ if path.name.startswith("checkpoint-"):
133
+ try:
134
+ return (int(path.name.split("-", 1)[1]), path.name)
135
+ except Exception:
136
+ return (10**8, path.name)
137
+ return (10**7, path.name)
138
+
139
+
140
+ def checkpoint_dirs(output_dir: Path) -> List[Path]:
141
+ checkpoints = [
142
+ path for path in output_dir.iterdir()
143
+ if path.is_dir() and (path.name.startswith("checkpoint-") or path.name == "final")
144
+ ]
145
+ return sorted(checkpoints, key=_checkpoint_sort_key)
146
+
147
+
148
+ def evaluate_suite_with_seed(
149
+ model: Any,
150
+ tokenizer: Any,
151
+ selected_task: str,
152
+ eval_episodes: int,
153
+ max_steps: int,
154
+ use_tools: bool,
155
+ seed_start: int,
156
+ world_split: str,
157
+ world_family: str | None = None,
158
+ ) -> List[Dict[str, Any]]:
159
+ tasks = ["direct-triage", "dual-pivot", "polymorphic-zero-day"] if selected_task == "all" else [selected_task]
160
+ rows: List[Dict[str, Any]] = []
161
+ original_seed = os.environ.get("ADAPTSHIELD_SEED")
162
+ try:
163
+ for task_index, task in enumerate(tasks):
164
+ scores: List[float] = []
165
+ steps: List[int] = []
166
+ tool_calls: List[int] = []
167
+ for episode_index in range(eval_episodes):
168
+ os.environ["ADAPTSHIELD_SEED"] = str(seed_start + task_index * 100 + episode_index)
169
+ try:
170
+ _, metrics = run_model_episode(
171
+ model=model,
172
+ tokenizer=tokenizer,
173
+ task=task,
174
+ max_steps=max_steps,
175
+ use_tools=use_tools,
176
+ world_split=world_split,
177
+ world_family=world_family,
178
+ )
179
+ except Exception as exc:
180
+ print(f" eval episode failed (task={task}, ep={episode_index}): {exc}")
181
+ continue
182
+ scores.append(float(metrics["score"]))
183
+ steps.append(int(metrics["steps"]))
184
+ tool_calls.append(int(metrics["tool_calls"]))
185
+ rows.append({
186
+ "task": task,
187
+ "score": round(sum(scores) / len(scores), 3) if scores else 0.50,
188
+ "steps": round(sum(steps) / len(steps), 2) if steps else 0.0,
189
+ "tool_calls": round(sum(tool_calls) / len(tool_calls), 2) if tool_calls else 0.0,
190
+ "eval_episodes": eval_episodes,
191
+ "successful_episodes": len(scores),
192
+ "seed_start": seed_start,
193
+ "world_split": world_split,
194
+ "world_family": world_family or "auto",
195
+ })
196
+ finally:
197
+ if original_seed is None:
198
+ os.environ.pop("ADAPTSHIELD_SEED", None)
199
+ else:
200
+ os.environ["ADAPTSHIELD_SEED"] = original_seed
201
+ return rows
202
+
203
+
204
+ def _free_gpu(*objects: Any) -> None:
205
+ """Best-effort release of GPU memory between checkpoint evaluations."""
206
+ import gc
207
+
208
+ for obj in objects:
209
+ try:
210
+ del obj
211
+ except Exception:
212
+ pass
213
+ gc.collect()
214
+ try:
215
+ import torch
216
+
217
+ if torch.cuda.is_available():
218
+ torch.cuda.empty_cache()
219
+ torch.cuda.synchronize()
220
+ except Exception:
221
+ pass
222
+
223
+
224
+ def _load_checkpoint_for_eval(
225
+ checkpoint_dir: Path,
226
+ base_model_name: str,
227
+ max_seq_length: int,
228
+ ) -> tuple[Any, Any]:
229
+ """Load an adapter checkpoint robustly, falling back to PEFT if needed."""
230
+ from unsloth import FastLanguageModel
231
+
232
+ is_adapter_only = (checkpoint_dir / "adapter_config.json").exists() and not (
233
+ checkpoint_dir / "config.json"
234
+ ).exists()
235
+
236
+ if not is_adapter_only:
237
+ try:
238
+ return FastLanguageModel.from_pretrained(
239
+ model_name=str(checkpoint_dir),
240
+ max_seq_length=max_seq_length,
241
+ load_in_4bit=True,
242
+ dtype=None,
243
+ )
244
+ except Exception as exc:
245
+ print(f" direct load failed for {checkpoint_dir.name}: {exc}; "
246
+ "falling back to base+adapter loader.")
247
+
248
+ model, tokenizer = FastLanguageModel.from_pretrained(
249
+ model_name=base_model_name,
250
+ max_seq_length=max_seq_length,
251
+ load_in_4bit=True,
252
+ dtype=None,
253
+ )
254
+ from peft import PeftModel
255
+
256
+ model = PeftModel.from_pretrained(
257
+ model,
258
+ str(checkpoint_dir),
259
+ is_trainable=False,
260
+ autocast_adapter_dtype=False,
261
+ )
262
+ try:
263
+ from transformers import AutoTokenizer
264
+
265
+ tokenizer = AutoTokenizer.from_pretrained(str(checkpoint_dir), trust_remote_code=True)
266
+ except Exception:
267
+ pass
268
+ return model, tokenizer
269
+
270
+
271
+ def evaluate_saved_checkpoints(
272
+ output_dir: Path,
273
+ model_key: str,
274
+ max_seq_length: int,
275
+ selected_task: str,
276
+ eval_episodes: int,
277
+ max_steps: int,
278
+ use_tools: bool,
279
+ heldout_seed: int,
280
+ train_world_split: str,
281
+ heldout_world_split: str,
282
+ ) -> List[Dict[str, Any]]:
283
+ base_model_name = MODEL_CHOICES[model_key]
284
+ rows: List[Dict[str, Any]] = []
285
+ for index, checkpoint_dir in enumerate(checkpoint_dirs(output_dir)):
286
+ print(f"Held-out evaluating checkpoint: {checkpoint_dir.name}")
287
+ model = None
288
+ tokenizer = None
289
+ try:
290
+ model, tokenizer = _load_checkpoint_for_eval(
291
+ checkpoint_dir=checkpoint_dir,
292
+ base_model_name=base_model_name,
293
+ max_seq_length=max_seq_length,
294
+ )
295
+ _normalize_generation_config(model)
296
+ _align_trainable_dtypes(model)
297
+ in_distribution_rows = evaluate_suite_with_seed(
298
+ model=model,
299
+ tokenizer=tokenizer,
300
+ selected_task=selected_task,
301
+ eval_episodes=eval_episodes,
302
+ max_steps=max_steps,
303
+ use_tools=use_tools,
304
+ seed_start=heldout_seed + index * 1000,
305
+ world_split=train_world_split,
306
+ )
307
+ heldout_rows = evaluate_suite_with_seed(
308
+ model=model,
309
+ tokenizer=tokenizer,
310
+ selected_task=selected_task,
311
+ eval_episodes=eval_episodes,
312
+ max_steps=max_steps,
313
+ use_tools=use_tools,
314
+ seed_start=heldout_seed + index * 1000,
315
+ world_split=heldout_world_split,
316
+ )
317
+ in_distribution_score = round(
318
+ sum(float(row["score"]) for row in in_distribution_rows) / max(1, len(in_distribution_rows)),
319
+ 3,
320
+ )
321
+ heldout_score = round(
322
+ sum(float(row["score"]) for row in heldout_rows) / max(1, len(heldout_rows)),
323
+ 3,
324
+ )
325
+ rows.append({
326
+ "checkpoint": checkpoint_dir.name,
327
+ "in_distribution_score": in_distribution_score,
328
+ "heldout_score": heldout_score,
329
+ "in_distribution_rows": in_distribution_rows,
330
+ "heldout_rows": heldout_rows,
331
+ })
332
+ except Exception as exc:
333
+ print(f" checkpoint eval failed for {checkpoint_dir.name}: {exc}")
334
+ rows.append({
335
+ "checkpoint": checkpoint_dir.name,
336
+ "in_distribution_score": 0.0,
337
+ "heldout_score": 0.0,
338
+ "error": str(exc),
339
+ })
340
+ finally:
341
+ _free_gpu(model, tokenizer)
342
+ model = None
343
+ tokenizer = None
344
+ return rows
345
+
346
+
347
+ def train_sft(args: argparse.Namespace) -> None:
348
+ from unsloth import FastLanguageModel
349
+ from datasets import Dataset
350
+ from trl import SFTTrainer
351
+ import torch
352
+
353
+ random.seed(args.seed)
354
+ torch.manual_seed(args.seed)
355
+
356
+ dataset_path = Path(args.dataset)
357
+ rows = load_jsonl(dataset_path)
358
+ if args.max_rows and args.max_rows > 0:
359
+ rows = rows[: args.max_rows]
360
+
361
+ model_name = MODEL_CHOICES[args.model]
362
+ output_dir = Path(args.output)
363
+ output_dir.mkdir(parents=True, exist_ok=True)
364
+
365
+ print("AdaptShield SFT training")
366
+ print(f"Dataset: {dataset_path}")
367
+ print(f"Rows: {len(rows)}")
368
+ print(f"Model: {model_name}")
369
+ print(f"Epochs: {args.epochs}")
370
+ print(f"Batch size: {args.per_device_batch_size}")
371
+ print(f"Grad accumulation: {args.gradient_accumulation_steps}")
372
+ print(f"Learning rate: {args.lr}")
373
+ print(f"Output: {output_dir}")
374
+ print()
375
+
376
+ model, tokenizer = FastLanguageModel.from_pretrained(
377
+ model_name=model_name,
378
+ max_seq_length=args.max_seq_length,
379
+ load_in_4bit=True,
380
+ dtype=None,
381
+ )
382
+ model = FastLanguageModel.get_peft_model(
383
+ model,
384
+ r=LORA_RANK,
385
+ target_modules=[
386
+ "q_proj", "k_proj", "v_proj", "o_proj",
387
+ "gate_proj", "up_proj", "down_proj",
388
+ ],
389
+ lora_alpha=LORA_RANK * 2,
390
+ lora_dropout=0.0,
391
+ bias="none",
392
+ use_gradient_checkpointing="unsloth",
393
+ random_state=args.seed,
394
+ )
395
+ if getattr(tokenizer, "pad_token", None) is None:
396
+ tokenizer.pad_token = tokenizer.eos_token
397
+ _normalize_generation_config(model)
398
+ _align_trainable_dtypes(model)
399
+
400
+ prepared_rows = [{"text": render_example(row, tokenizer), **row} for row in rows]
401
+ dataset = Dataset.from_list(prepared_rows)
402
+
403
+ bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
404
+
405
+ try:
406
+ from trl import SFTConfig
407
+ train_config_cls = SFTConfig
408
+ except ImportError:
409
+ from transformers import TrainingArguments
410
+ train_config_cls = TrainingArguments
411
+
412
+ config_kwargs = {
413
+ "output_dir": str(output_dir),
414
+ "learning_rate": args.lr,
415
+ "per_device_train_batch_size": args.per_device_batch_size,
416
+ "gradient_accumulation_steps": args.gradient_accumulation_steps,
417
+ "num_train_epochs": args.epochs,
418
+ "logging_steps": 1,
419
+ "save_strategy": "steps",
420
+ "save_steps": args.save_steps,
421
+ "report_to": "none",
422
+ "seed": args.seed,
423
+ "bf16": bf16_supported,
424
+ "fp16": not bf16_supported,
425
+ "max_seq_length": args.max_seq_length,
426
+ "dataset_text_field": "text",
427
+ "dataset_num_proc": 1,
428
+ "packing": False,
429
+ }
430
+ train_args = train_config_cls(
431
+ **_filter_supported_kwargs(train_config_cls, config_kwargs)
432
+ )
433
+
434
+ trainer_kwargs = {
435
+ "model": model,
436
+ "train_dataset": dataset,
437
+ "args": train_args,
438
+ "processing_class": tokenizer,
439
+ "tokenizer": tokenizer,
440
+ "dataset_text_field": "text",
441
+ "dataset_num_proc": 1,
442
+ "max_seq_length": args.max_seq_length,
443
+ "packing": False,
444
+ }
445
+ trainer = SFTTrainer(**_filter_supported_kwargs(SFTTrainer, trainer_kwargs))
446
+ trainer.train()
447
+
448
+ final_dir = output_dir / "final"
449
+ model.save_pretrained(final_dir)
450
+ tokenizer.save_pretrained(final_dir)
451
+
452
+ log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
453
+ loss_plot_path = output_dir / "loss_curve.png"
454
+ try:
455
+ build_loss_plot(log_history, loss_plot_path)
456
+ except Exception as exc:
457
+ print(f"Loss plot generation skipped: {exc}")
458
+
459
+ metrics: Dict[str, Any] = {
460
+ "trainer": "sft",
461
+ "model": model_name,
462
+ "dataset": str(dataset_path),
463
+ "rows": len(rows),
464
+ "epochs": args.epochs,
465
+ "learning_rate": args.lr,
466
+ "evaluation_rows": [],
467
+ "heldout_evaluation_rows": [],
468
+ "heldout_seed": args.heldout_seed,
469
+ "train_world_split": args.train_world_split,
470
+ "heldout_world_split": args.heldout_world_split,
471
+ "reward_curve_rows": [],
472
+ "log_history": log_history,
473
+ }
474
+ metrics_path = output_dir / "sft_metrics.json"
475
+
476
+ def _flush_metrics() -> None:
477
+ metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
478
+
479
+ _flush_metrics()
480
+
481
+ try:
482
+ metrics["evaluation_rows"] = evaluate_suite_with_seed(
483
+ model=model,
484
+ tokenizer=tokenizer,
485
+ selected_task=args.eval_task,
486
+ eval_episodes=args.eval_episodes,
487
+ max_steps=args.eval_max_steps,
488
+ use_tools=args.use_tools,
489
+ seed_start=args.heldout_seed,
490
+ world_split=args.train_world_split,
491
+ )
492
+ except Exception as exc:
493
+ print(f"In-distribution evaluation failed: {exc}")
494
+ _flush_metrics()
495
+
496
+ try:
497
+ metrics["heldout_evaluation_rows"] = evaluate_suite_with_seed(
498
+ model=model,
499
+ tokenizer=tokenizer,
500
+ selected_task=args.eval_task,
501
+ eval_episodes=args.eval_episodes,
502
+ max_steps=args.eval_max_steps,
503
+ use_tools=args.use_tools,
504
+ seed_start=args.heldout_seed,
505
+ world_split=args.heldout_world_split,
506
+ )
507
+ except Exception as exc:
508
+ print(f"Held-out evaluation failed: {exc}")
509
+ _flush_metrics()
510
+
511
+ reward_curve_rows: List[Dict[str, Any]] = []
512
+ if args.skip_reward_curve:
513
+ print("Skipping per-checkpoint reward curve (--skip-reward-curve).")
514
+ else:
515
+ # Free training-time model before reloading checkpoints to avoid OOM.
516
+ _free_gpu(model, trainer)
517
+ try:
518
+ reward_curve_rows = evaluate_saved_checkpoints(
519
+ output_dir=output_dir,
520
+ model_key=args.model,
521
+ max_seq_length=args.max_seq_length,
522
+ selected_task=args.eval_task,
523
+ eval_episodes=args.eval_episodes,
524
+ max_steps=args.eval_max_steps,
525
+ use_tools=args.use_tools,
526
+ heldout_seed=args.heldout_seed,
527
+ train_world_split=args.train_world_split,
528
+ heldout_world_split=args.heldout_world_split,
529
+ )
530
+ except Exception as exc:
531
+ print(f"Per-checkpoint reward curve failed: {exc}")
532
+ metrics["reward_curve_rows"] = reward_curve_rows
533
+ _flush_metrics()
534
+
535
+ reward_plot_path = output_dir / "reward_curve.png"
536
+ if reward_curve_rows:
537
+ try:
538
+ build_reward_plot(reward_curve_rows, reward_plot_path)
539
+ except Exception as exc:
540
+ print(f"Reward plot generation skipped: {exc}")
541
+
542
+ evaluation_rows = metrics["evaluation_rows"]
543
+ heldout_evaluation_rows = metrics["heldout_evaluation_rows"]
544
+
545
+ print("SFT complete.")
546
+ print(f"Saved adapter to: {final_dir}")
547
+ print(f"Loss curve: {loss_plot_path}")
548
+ print(f"Reward curve: {reward_plot_path}")
549
+ print(f"Metrics: {metrics_path}")
550
+ print("Post-train evaluation:")
551
+ for row in evaluation_rows:
552
+ print(
553
+ f" task={row['task']:<20} score={row['score']:.3f} "
554
+ f"steps={row['steps']} tools={row['tool_calls']}"
555
+ )
556
+ print("Held-out checkpoint reward curve:")
557
+ for row in reward_curve_rows:
558
+ print(
559
+ f" checkpoint={row['checkpoint']:<16} "
560
+ f"in_dist={row['in_distribution_score']:.3f} "
561
+ f"heldout={row['heldout_score']:.3f}"
562
+ )
563
+
564
+
565
+ def main() -> None:
566
+ parser = argparse.ArgumentParser(description="AdaptShield supervised fine-tuning")
567
+ parser.add_argument(
568
+ "--dataset",
569
+ default="data/adaptshield_sft.jsonl",
570
+ help="Path to JSONL dataset from generate_sft_data.py",
571
+ )
572
+ parser.add_argument(
573
+ "--model",
574
+ default=DEFAULT_MODEL,
575
+ choices=list(MODEL_CHOICES.keys()),
576
+ )
577
+ parser.add_argument("--output", default="checkpoints/sft-run")
578
+ parser.add_argument("--epochs", type=float, default=1.0)
579
+ parser.add_argument("--lr", type=float, default=2e-4)
580
+ parser.add_argument("--seed", type=int, default=42)
581
+ parser.add_argument("--heldout-seed", type=int, default=314)
582
+ parser.add_argument("--train-world-split", default="train", choices=["train", "eval"])
583
+ parser.add_argument("--heldout-world-split", default="eval", choices=["train", "eval"])
584
+ parser.add_argument("--max-rows", type=int, default=0)
585
+ parser.add_argument("--max-seq-length", type=int, default=MAX_SEQ_LEN)
586
+ parser.add_argument("--per-device-batch-size", type=int, default=2)
587
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
588
+ parser.add_argument("--save-steps", type=int, default=40)
589
+ parser.add_argument(
590
+ "--eval-task",
591
+ default="all",
592
+ choices=["all", "direct-triage", "dual-pivot", "polymorphic-zero-day"],
593
+ )
594
+ parser.add_argument("--eval-episodes", type=int, default=2)
595
+ parser.add_argument("--eval-max-steps", type=int, default=20)
596
+ parser.add_argument(
597
+ "--use-tools",
598
+ action="store_true",
599
+ help="Use SOC tools during post-train evaluation.",
600
+ )
601
+ parser.add_argument(
602
+ "--skip-reward-curve",
603
+ action="store_true",
604
+ help="Skip the per-checkpoint reward curve sweep (faster, avoids OOM).",
605
+ )
606
+ args = parser.parse_args()
607
+ train_sft(args)
608
+
609
+
610
+ if __name__ == "__main__":
611
+ main()
train_smoke.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Dependency-free training-readiness smoke test for AdaptShield."""
3
+
4
+ from __future__ import annotations
5
+
6
+ import argparse
7
+ import csv
8
+ import random
9
+ import sys
10
+ from pathlib import Path
11
+ from typing import Any, Dict, Iterable, List, Tuple
12
+
13
+
14
+ REPO_ROOT = Path(__file__).resolve().parent
15
+
16
+ if str(REPO_ROOT) not in sys.path:
17
+ sys.path.insert(0, str(REPO_ROOT))
18
+
19
+ from models import AdaptShieldAction
20
+ from server.adaptshield_environment import AdaptShieldEnvironment
21
+
22
+
23
+ TASKS = ["direct-triage", "dual-pivot", "polymorphic-zero-day"]
24
+ POLICY = {
25
+ "brute_force": ("auth_service", "rate_limit"),
26
+ "lateral_movement": ("payment_service", "isolate"),
27
+ "exfiltration": ("database", "honeypot"),
28
+ "supply_chain": ("api_gateway", "patch"),
29
+ "benign": ("api_gateway", "monitor"),
30
+ }
31
+ ACTION_SPACE = [
32
+ ("auth_service", "rate_limit"),
33
+ ("payment_service", "isolate"),
34
+ ("database", "honeypot"),
35
+ ("api_gateway", "patch"),
36
+ ("api_gateway", "monitor"),
37
+ ]
38
+
39
+
40
+ def classify_from_metrics(network_nodes: Dict[str, Dict[str, Any]]) -> str:
41
+ auth = network_nodes.get("auth_service", {})
42
+ payment = network_nodes.get("payment_service", {})
43
+ database = network_nodes.get("database", {})
44
+ gateway = network_nodes.get("api_gateway", {})
45
+
46
+ if float(auth.get("error_rate", 0.0)) >= 0.10:
47
+ return "brute_force"
48
+ if payment.get("status") == "suspicious" or float(payment.get("cpu", 0)) >= 55:
49
+ return "lateral_movement"
50
+ if float(database.get("outbound_mb", 0)) >= 50:
51
+ return "exfiltration"
52
+ if gateway.get("status") == "suspicious":
53
+ return "supply_chain"
54
+ return "benign"
55
+
56
+
57
+ class TabularDefensePolicy:
58
+ """Tiny epsilon-greedy policy used only to verify trainability."""
59
+
60
+ def __init__(self, epsilon: float, lr: float) -> None:
61
+ self.epsilon = epsilon
62
+ self.lr = lr
63
+ self.q: Dict[str, Dict[Tuple[str, str], float]] = {
64
+ threat: {action: 0.50 for action in ACTION_SPACE}
65
+ for threat in POLICY
66
+ }
67
+
68
+ def choose_phase1(self, obs: Any) -> Dict[str, Any]:
69
+ threat = classify_from_metrics(obs.network_nodes)
70
+ target, action = POLICY[threat]
71
+ return {
72
+ "threat_type": threat,
73
+ "confidence": 0.90,
74
+ "target_node": target,
75
+ "recommended_action": action,
76
+ "reasoning": "smoke-train metric policy",
77
+ }
78
+
79
+ def choose_phase2(self, obs: Any) -> Tuple[Dict[str, Any], str, Tuple[str, str]]:
80
+ assessment = obs.phase1_assessment or {}
81
+ threat = str(assessment.get("threat_type", "benign"))
82
+ choices = self.q.get(threat, self.q["benign"])
83
+
84
+ if random.random() < self.epsilon:
85
+ target, action = random.choice(ACTION_SPACE)
86
+ else:
87
+ best_value = max(choices.values())
88
+ best_actions = [
89
+ action for action, value in choices.items()
90
+ if value == best_value
91
+ ]
92
+ target, action = random.choice(best_actions)
93
+
94
+ return {
95
+ "action": action,
96
+ "target_node": target,
97
+ "reasoning": "epsilon-greedy smoke policy",
98
+ }, threat, (target, action)
99
+
100
+ def update(self, threat: str, selected: Tuple[str, str], reward: float) -> None:
101
+ choices = self.q.setdefault(
102
+ threat,
103
+ {action: 0.50 for action in ACTION_SPACE},
104
+ )
105
+ old_value = choices.get(selected, 0.50)
106
+ choices[selected] = old_value + self.lr * (reward - old_value)
107
+
108
+ def decay(self, rate: float, floor: float) -> None:
109
+ self.epsilon = max(floor, self.epsilon * rate)
110
+
111
+
112
+ def run_episode(task: str, policy: TabularDefensePolicy, max_steps: int) -> Dict[str, Any]:
113
+ env = AdaptShieldEnvironment(task_name=task)
114
+ obs = env.reset()
115
+ rewards: List[float] = []
116
+ steps = 0
117
+
118
+ while not obs.done and steps < max_steps:
119
+ if obs.phase == 1:
120
+ payload = policy.choose_phase1(obs)
121
+ obs = env.step(AdaptShieldAction(**payload))
122
+ else:
123
+ payload, threat, selected = policy.choose_phase2(obs)
124
+ obs = env.step(AdaptShieldAction(**payload))
125
+ policy.update(threat, selected, float(obs.reward))
126
+
127
+ rewards.append(float(obs.reward))
128
+ steps += 1
129
+
130
+ metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
131
+ if "normalized_score" not in metadata:
132
+ raise RuntimeError("normalized_score missing during smoke training")
133
+
134
+ metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
135
+ return {
136
+ "task": task,
137
+ "score": float(metadata.get("normalized_score", 0.01)),
138
+ "reward_sum": sum(rewards),
139
+ "mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
140
+ "steps": steps,
141
+ "done": bool(obs.done),
142
+ "normalized_score_present": "normalized_score" in metadata,
143
+ }
144
+
145
+
146
+ def write_rows(path: Path, rows: Iterable[Dict[str, Any]]) -> None:
147
+ path.parent.mkdir(parents=True, exist_ok=True)
148
+ rows = list(rows)
149
+ if not rows:
150
+ return
151
+
152
+ with path.open("w", newline="") as handle:
153
+ writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys()))
154
+ writer.writeheader()
155
+ writer.writerows(rows)
156
+
157
+
158
+ def trend(values: List[float]) -> Tuple[float, float]:
159
+ if not values:
160
+ return 0.0, 0.0
161
+ window = max(1, len(values) // 5)
162
+ first = sum(values[:window]) / window
163
+ last = sum(values[-window:]) / window
164
+ return first, last
165
+
166
+
167
+ def run_smoke_training(
168
+ tasks: List[str],
169
+ episodes: int,
170
+ output: Path,
171
+ seed: int,
172
+ epsilon: float,
173
+ epsilon_decay: float,
174
+ epsilon_floor: float,
175
+ lr: float,
176
+ max_steps: int,
177
+ ) -> List[Dict[str, Any]]:
178
+ random.seed(seed)
179
+ policy = TabularDefensePolicy(epsilon=epsilon, lr=lr)
180
+ rows: List[Dict[str, Any]] = []
181
+
182
+ print("AdaptShield smoke training")
183
+ print(f"Tasks: {', '.join(tasks)}")
184
+ print(f"Episodes: {episodes}")
185
+ print(f"Output: {output}")
186
+ print()
187
+
188
+ for episode in range(1, episodes + 1):
189
+ task = tasks[(episode - 1) % len(tasks)]
190
+ result = run_episode(task=task, policy=policy, max_steps=max_steps)
191
+ result.update({
192
+ "episode": episode,
193
+ "epsilon": round(policy.epsilon, 4),
194
+ "status": "PASS" if result["done"] and result["normalized_score_present"] else "FAIL",
195
+ })
196
+ rows.append(result)
197
+ policy.decay(epsilon_decay, epsilon_floor)
198
+
199
+ print(
200
+ f"episode={episode:03d} task={task:<20} "
201
+ f"score={result['score']:.3f} steps={result['steps']:02d} "
202
+ f"epsilon={result['epsilon']:.3f} {result['status']}"
203
+ )
204
+
205
+ write_rows(output, rows)
206
+
207
+ scores = [float(row["score"]) for row in rows]
208
+ first, last = trend(scores)
209
+ print()
210
+ print(f"First-window avg score: {first:.3f}")
211
+ print(f"Last-window avg score: {last:.3f}")
212
+ print(f"Score delta: {last - first:+.3f}")
213
+ print(f"Saved CSV: {output}")
214
+ print("Smoke training verdict: PASS")
215
+ return rows
216
+
217
+
218
+ def parse_args() -> argparse.Namespace:
219
+ parser = argparse.ArgumentParser(description="Run cheap AdaptShield training smoke test.")
220
+ parser.add_argument("--task", default="direct-triage", choices=TASKS + ["all"])
221
+ parser.add_argument("--episodes", type=int, default=30)
222
+ parser.add_argument("--output", default="training_runs/train_smoke.csv")
223
+ parser.add_argument("--seed", type=int, default=42)
224
+ parser.add_argument("--epsilon", type=float, default=0.85)
225
+ parser.add_argument("--epsilon-decay", type=float, default=0.94)
226
+ parser.add_argument("--epsilon-floor", type=float, default=0.08)
227
+ parser.add_argument("--lr", type=float, default=0.35)
228
+ parser.add_argument("--max-steps", type=int, default=30)
229
+ return parser.parse_args()
230
+
231
+
232
+ def main() -> int:
233
+ args = parse_args()
234
+ tasks = TASKS if args.task == "all" else [args.task]
235
+ run_smoke_training(
236
+ tasks=tasks,
237
+ episodes=args.episodes,
238
+ output=Path(args.output),
239
+ seed=args.seed,
240
+ epsilon=args.epsilon,
241
+ epsilon_decay=args.epsilon_decay,
242
+ epsilon_floor=args.epsilon_floor,
243
+ lr=args.lr,
244
+ max_steps=args.max_steps,
245
+ )
246
+ return 0
247
+
248
+
249
+ if __name__ == "__main__":
250
+ raise SystemExit(main())
uv.lock ADDED
The diff for this file is too large to render. See raw diff