Spaces:
Running
Running
Initial deploy of AdaptShield two-phase cybersecurity environment
Browse files- .gitignore +28 -0
- .python-version +1 -0
- Dockerfile +47 -0
- README.md +428 -5
- __init__.py +23 -0
- assets/_make_headline_chart.py +79 -0
- assets/architecture_overview.svg +95 -0
- assets/headline_results.png +0 -0
- assets/training_pipeline.svg +87 -0
- baseline.py +204 -0
- build_benchmark_table.py +103 -0
- client.py +107 -0
- eval_tasks.py +46 -0
- generate_sft_data.py +197 -0
- inference.py +351 -0
- launch_hf_grpo_job.py +403 -0
- launch_hf_sft_job.py +363 -0
- models.py +243 -0
- openenv.yaml +26 -0
- plot_sft_checkpoint_curve.py +93 -0
- plot_training.py +106 -0
- pyproject.toml +31 -0
- server/Dockerfile +47 -0
- server/__init__.py +11 -0
- server/adaptshield_environment.py +1324 -0
- server/app.py +173 -0
- server/attacker.py +361 -0
- server/grader.py +501 -0
- server/requirements.txt +6 -0
- server/scenarios.py +264 -0
- smoke_test.py +59 -0
- soc_tools.py +239 -0
- tests/test_regression.py +352 -0
- tool_baseline.py +128 -0
- train.py +1332 -0
- train_sft.py +611 -0
- train_smoke.py +250 -0
- uv.lock +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
*.so
|
| 5 |
+
.Python
|
| 6 |
+
build/
|
| 7 |
+
develop-eggs/
|
| 8 |
+
dist/
|
| 9 |
+
downloads/
|
| 10 |
+
eggs/
|
| 11 |
+
.eggs/
|
| 12 |
+
lib/
|
| 13 |
+
lib64/
|
| 14 |
+
parts/
|
| 15 |
+
sdist/
|
| 16 |
+
var/
|
| 17 |
+
wheels/
|
| 18 |
+
share/python-wheels/
|
| 19 |
+
*.egg-info/
|
| 20 |
+
.installed.cfg
|
| 21 |
+
*.egg
|
| 22 |
+
MANIFEST
|
| 23 |
+
.venv
|
| 24 |
+
venv/
|
| 25 |
+
ENV/
|
| 26 |
+
env.bak/
|
| 27 |
+
venv.bak/
|
| 28 |
+
.DS_Store
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.11.9
|
Dockerfile
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
+
FROM ${BASE_IMAGE} AS builder
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
RUN apt-get update && \
|
| 7 |
+
apt-get install -y --no-install-recommends git curl && \
|
| 8 |
+
rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
COPY . /app/env
|
| 11 |
+
WORKDIR /app/env
|
| 12 |
+
|
| 13 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 14 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 15 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 16 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 20 |
+
if [ -f uv.lock ]; then \
|
| 21 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 22 |
+
else \
|
| 23 |
+
uv sync --no-install-project --no-editable; \
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 27 |
+
if [ -f uv.lock ]; then \
|
| 28 |
+
uv sync --frozen --no-editable; \
|
| 29 |
+
else \
|
| 30 |
+
uv sync --no-editable; \
|
| 31 |
+
fi
|
| 32 |
+
|
| 33 |
+
FROM ${BASE_IMAGE}
|
| 34 |
+
WORKDIR /app
|
| 35 |
+
|
| 36 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 37 |
+
COPY --from=builder /app/env /app/env
|
| 38 |
+
|
| 39 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 40 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 41 |
+
|
| 42 |
+
EXPOSE 7860
|
| 43 |
+
|
| 44 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
|
| 45 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 46 |
+
|
| 47 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 7860"]
|
README.md
CHANGED
|
@@ -1,10 +1,433 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Janus (AdaptShield)
|
| 3 |
+
emoji: 🛡️
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
tags:
|
| 10 |
+
- openenv
|
| 11 |
+
- security
|
| 12 |
+
- reinforcement-learning
|
| 13 |
+
- cybersecurity
|
| 14 |
+
short_description: Two-phase adaptive cybersecurity benchmark for LLMs
|
| 15 |
---
|
| 16 |
|
| 17 |
+
# Janus (AdaptShield) — Two-Phase Adaptive Cybersecurity Benchmark
|
| 18 |
+
|
| 19 |
+
**AdaptShield** is the environment: a two-phase agentic cybersecurity
|
| 20 |
+
simulator where an LLM defends a 4-node enterprise network against an
|
| 21 |
+
adversary that shifts strategy mid-episode. **Janus** is the model we
|
| 22 |
+
trained on it — a Qwen2.5-1.5B LoRA, supervised then refined with GRPO.
|
| 23 |
+
On the hardest task Janus scores 0.90 on a held-out world family it
|
| 24 |
+
never saw during training; a tool-aware heuristic baseline scores 0.18
|
| 25 |
+
on the same task.
|
| 26 |
+
|
| 27 |
+
The skill being tested is narrow on purpose. Not threat classification.
|
| 28 |
+
Not generic tool calling. The benchmark targets one thing: real-time
|
| 29 |
+
adaptation when the attacker's playbook changes mid-incident. Section
|
| 30 |
+
[Why this matters](#why-this-matters) explains why we think that's the
|
| 31 |
+
gap, and the [Results](#results) section is where the gap closes.
|
| 32 |
+
|
| 33 |
+
## Project Links
|
| 34 |
+
|
| 35 |
+
- **HF Space (live env):** `TODO`
|
| 36 |
+
- **Colab notebook (SFT + GRPO reproducer, free T4):** `TODO`
|
| 37 |
+
- **Artifacts / model repo:** [`SaiManish123/Janus`](https://huggingface.co/SaiManish123/Janus)
|
| 38 |
+
- **Demo video:** `TODO`
|
| 39 |
+
- **Blog / writeup:** `TODO`
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## Why this matters
|
| 44 |
+
|
| 45 |
+
Most cyber-agent demos test threat classification or generic tool
|
| 46 |
+
calling. Real production breaches don't look like that. They look like
|
| 47 |
+
this:
|
| 48 |
+
|
| 49 |
+
In April 2026 attackers compromised Context.ai, used its OAuth
|
| 50 |
+
integration into a Vercel employee's Google Workspace, and pivoted from
|
| 51 |
+
shadow AI through identity into Vercel's internal systems, where they
|
| 52 |
+
enumerated and decrypted customer environment variables. The same week,
|
| 53 |
+
a Broken Object Level Authorization flaw in Lovable.dev let any
|
| 54 |
+
free-tier account read source code, Supabase credentials, Stripe keys
|
| 55 |
+
and AI chat histories from other tenants — including projects built by
|
| 56 |
+
AI itself. Eight months earlier, the Tea dating app left a Firebase
|
| 57 |
+
bucket open and 72,000 verification selfies and driver's licenses of
|
| 58 |
+
women on a safety app were scraped to 4chan within hours.
|
| 59 |
+
|
| 60 |
+
Three different failure modes — identity hijack via shadow AI, broken
|
| 61 |
+
authz in vibe-coded apps, classic cloud misconfig — but the same
|
| 62 |
+
underlying problem for the defender's agent. The environment is shifting
|
| 63 |
+
faster than any static training distribution can keep up with, and the
|
| 64 |
+
real attacker doesn't sit still while you classify them.
|
| 65 |
+
|
| 66 |
+
AdaptShield is built around that pressure. The environment forces the
|
| 67 |
+
agent to (1) act on partial evidence, (2) hand judgment across two
|
| 68 |
+
roles with an information bottleneck between them, (3) trade security
|
| 69 |
+
correctness against operational blast radius, and (4) re-plan when the
|
| 70 |
+
adversary's playbook changes mid-episode. Each of those is a separate
|
| 71 |
+
failure mode in production SOC tooling, and the benchmark scores all
|
| 72 |
+
four at once.
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
## Results
|
| 77 |
+
|
| 78 |
+
Numbers below come from the production run on Hugging Face L4 Jobs,
|
| 79 |
+
training Qwen2.5-1.5B-Instruct with a LoRA adapter. Eval is 50
|
| 80 |
+
deterministic seeds per task, evaluated on a held-out world family
|
| 81 |
+
the policy never saw during training.
|
| 82 |
+
|
| 83 |
+

|
| 84 |
+
|
| 85 |
+
On the hard task (`polymorphic-zero-day`) the tool-aware heuristic
|
| 86 |
+
baseline scores 0.18 and Janus holds 0.90 on the held-out family. On
|
| 87 |
+
the easier tasks the lift is smaller because the rule baseline is
|
| 88 |
+
already near the ceiling; the benchmark is shaped so adaptation only
|
| 89 |
+
matters where it should.
|
| 90 |
+
|
| 91 |
+
### Benchmark comparison (full table)
|
| 92 |
+
|
| 93 |
+
| Task | No-tool baseline | Tool-aware baseline | SFT (train family) | SFT (held-out) | GRPO (train) | GRPO (held-out) |
|
| 94 |
+
|------|-----------------:|-------------------:|-------------------:|---------------:|-------------:|----------------:|
|
| 95 |
+
| `direct-triage` | 0.860 | 0.990 | 0.990 | 0.990 | 0.990 | 0.990 |
|
| 96 |
+
| `dual-pivot` | 0.650 | 0.640 | 0.825 | 0.825 | 0.825 | 0.825 |
|
| 97 |
+
| `polymorphic-zero-day` | 0.380 | 0.180 | 0.960 | 0.930 | **0.883** | **0.902** |
|
| 98 |
+
|
| 99 |
+
Two things in this table are worth flagging.
|
| 100 |
+
|
| 101 |
+
The tool-aware baseline scores 0.18 on the hard task — worse than the
|
| 102 |
+
no-tool baseline at 0.38. That isn't a bug in the baseline; it's that
|
| 103 |
+
bolting tools onto a heuristic without learning when to trust them
|
| 104 |
+
makes the agent over-trigger on injected false positives. You see the
|
| 105 |
+
same pattern in production with rule-based SOAR playbooks against
|
| 106 |
+
adaptive adversaries.
|
| 107 |
+
|
| 108 |
+
Held-out GRPO (0.902) actually edges out train-family GRPO (0.883). That
|
| 109 |
+
is evidence the policy is generalizing across world templates rather
|
| 110 |
+
than memorizing them. Without splitting the eval by world family this
|
| 111 |
+
finding wouldn't be visible — same-seed evaluation would have credited
|
| 112 |
+
the model for memorization it didn't do.
|
| 113 |
+
|
| 114 |
+
### SFT — loss and held-out reward
|
| 115 |
+
|
| 116 |
+

|
| 117 |
+
|
| 118 |
+

|
| 119 |
+
|
| 120 |
+
### GRPO — refinement on the polymorphic adversary
|
| 121 |
+
|
| 122 |
+

|
| 123 |
+
|
| 124 |
+
### Training runs
|
| 125 |
+
|
| 126 |
+
Three production runs on Hugging Face Jobs produced the artifacts in this
|
| 127 |
+
README. Stdout logs are public and the per-step / per-episode metrics
|
| 128 |
+
files are next to the adapters.
|
| 129 |
+
|
| 130 |
+
| Run | Trainer | GPU | Steps / Episodes | Train wall-clock | Logs | Metrics |
|
| 131 |
+
|-----|---------|-----|------------------|------------------|------|---------|
|
| 132 |
+
| [`sft_worldsplit_1_5b`](https://huggingface.co/SaiManish123/Janus/tree/main/sft_worldsplit_1_5b) | SFT (LoRA) | L4 ×1 | 378 steps | 9m 49s | [stdout](https://huggingface.co/SaiManish123/Janus/blob/main/logs/sft_worldsplit_1_5b.log) | [trainer_state](https://huggingface.co/SaiManish123/Janus/blob/main/sft_worldsplit_1_5b/checkpoint-378/trainer_state.json) |
|
| 133 |
+
| [`grpo_worldsplit_1_5b`](https://huggingface.co/SaiManish123/Janus/tree/main/grpo_worldsplit_1_5b) | GRPO, mixed curriculum | L4 ×1 | 1,628 episodes | 1h 26m | [stdout](https://huggingface.co/SaiManish123/Janus/blob/main/logs/grpo_worldsplit_1_5b.log) | [per-episode](https://huggingface.co/SaiManish123/Janus/blob/main/grpo_worldsplit_1_5b/metrics.json) |
|
| 134 |
+
| [`grpo_polymorphic_zero_day_1_5b`](https://huggingface.co/SaiManish123/Janus/tree/main/grpo_polymorphic_zero_day_1_5b) | GRPO, hard-task focus | L4 ×1 | 4,357 episodes | 3h 17m | [stdout](https://huggingface.co/SaiManish123/Janus/blob/main/logs/grpo_polymorphic_zero_day_1_5b.log) | [per-episode](https://huggingface.co/SaiManish123/Janus/blob/main/grpo_polymorphic_zero_day_1_5b/metrics.json) |
|
| 135 |
+
|
| 136 |
+
The curriculum run mixes all three tasks (weights `direct-triage: 0.3 /
|
| 137 |
+
dual-pivot: 0.4 / polymorphic-zero-day: 0.3`). The polymorphic run
|
| 138 |
+
trains exclusively on the hard task to push hard-task performance
|
| 139 |
+
without distraction from saturated tiers. Per-episode reward in both
|
| 140 |
+
runs stabilizes within the first ~500 episodes and stays there for the
|
| 141 |
+
rest of the schedule.
|
| 142 |
+
|
| 143 |
+
---
|
| 144 |
+
|
| 145 |
+
## Architecture
|
| 146 |
+
|
| 147 |
+

|
| 148 |
+
|
| 149 |
+
Each episode runs against a sampled mission profile, world-family
|
| 150 |
+
template, and latent operational mode. The Threat Analyst investigates
|
| 151 |
+
raw enterprise evidence through SOC tools and emits a structured
|
| 152 |
+
handoff. The Tactical Executor sees only that handoff (not the raw
|
| 153 |
+
state) and chooses the mitigation. A deterministic Python grader scores
|
| 154 |
+
security correctness, business impact, dependency blast radius, and
|
| 155 |
+
mission alignment. There is no LLM-as-judge anywhere in the loop.
|
| 156 |
+
|
| 157 |
+
## Training Pipeline
|
| 158 |
+
|
| 159 |
+

|
| 160 |
+
|
| 161 |
+
Five steps, each reproducible from the repo:
|
| 162 |
+
|
| 163 |
+
1. Generate SFT demonstrations by rolling AdaptShield episodes with a
|
| 164 |
+
rule-based Phase 1 expert and a tool-aware Phase 2 expert.
|
| 165 |
+
2. Train a LoRA adapter on Qwen2.5-1.5B (or 0.5B for the Colab
|
| 166 |
+
reproducer) with supervised fine-tuning on those demos.
|
| 167 |
+
3. Evaluate on both train-family and held-out-family worlds. The split
|
| 168 |
+
is by world template, not by seed, so memorizing a template doesn't
|
| 169 |
+
transfer across the split.
|
| 170 |
+
4. Refine the SFT adapter with GRPO on a curriculum weighted toward
|
| 171 |
+
`polymorphic-zero-day`. The deterministic grader is the reward.
|
| 172 |
+
5. Publish adapters, curves, metrics, and benchmark tables to
|
| 173 |
+
[`SaiManish123/Janus`](https://huggingface.co/SaiManish123/Janus).
|
| 174 |
+
|
| 175 |
+
A free-tier Colab notebook reproduces steps 1–4 end-to-end on a T4 in
|
| 176 |
+
roughly 35 minutes using Qwen2.5-0.5B and reduced episode budgets. The
|
| 177 |
+
numbers in this README come from the 1.5B run on a Hugging Face L4 Job.
|
| 178 |
+
|
| 179 |
+
---
|
| 180 |
+
|
| 181 |
+
## Environment Description
|
| 182 |
+
|
| 183 |
+
The agent defends a 4-node enterprise network (`auth_service`,
|
| 184 |
+
`payment_service`, `database`, `api_gateway`). Each turn has two phases:
|
| 185 |
+
|
| 186 |
+
**Phase 1 — Threat Analyst.** Agent reads SIEM metrics, can call SOC tools
|
| 187 |
+
(log search, network telemetry, threat intel lookup), and emits a
|
| 188 |
+
structured `Phase1Action` with threat type, target node, confidence and a
|
| 189 |
+
recommended action.
|
| 190 |
+
|
| 191 |
+
**Phase 2 — Tactical Executor.** Agent receives only the Phase 1
|
| 192 |
+
assessment (blind to raw state) and emits a `Phase2Action`. The analyst
|
| 193 |
+
has to communicate clearly because the executor cannot double-check the
|
| 194 |
+
network.
|
| 195 |
+
|
| 196 |
+
The attacker escalates through `recon → exploit → exfiltration` if the
|
| 197 |
+
agent fails to respond correctly. On the hard task, the attacker shifts
|
| 198 |
+
strategy mid-episode and seeds false-positive noise that looks like a
|
| 199 |
+
real attack but isn't — punishing reflexive isolation.
|
| 200 |
+
|
| 201 |
+
### Observation Space
|
| 202 |
+
|
| 203 |
+
```json
|
| 204 |
+
{
|
| 205 |
+
"phase": "1 or 2",
|
| 206 |
+
"network_nodes": {
|
| 207 |
+
"auth_service": {"status": "...", "request_rate": 0, "error_rate": 0.0, "cpu": 0}
|
| 208 |
+
},
|
| 209 |
+
"active_alerts": ["raw metric alert strings — no MITRE codes"],
|
| 210 |
+
"attack_stage": "recon | exploit | exfiltration | none",
|
| 211 |
+
"history": [{"turn": "1", "p1": "classified:brute_force", "p2": "rate_limit→auth_service"}],
|
| 212 |
+
"phase1_assessment": {"threat_type": "...", "confidence": 0.9, "target_node": "..."},
|
| 213 |
+
"metadata": {"normalized_score": 0.72}
|
| 214 |
+
}
|
| 215 |
+
```
|
| 216 |
+
|
| 217 |
+
Phase 2 observations have empty `network_nodes` and `active_alerts` — the
|
| 218 |
+
executor only sees the analyst's handoff.
|
| 219 |
+
|
| 220 |
+
### Action Space
|
| 221 |
+
|
| 222 |
+
**Phase 1 (`Phase1Action`):**
|
| 223 |
+
```json
|
| 224 |
+
{"threat_type": "brute_force", "confidence": 0.9, "target_node": "auth_service", "recommended_action": "rate_limit", "reasoning": "..."}
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
**Phase 2 (`Phase2Action`):**
|
| 228 |
+
```json
|
| 229 |
+
{"action": "rate_limit", "target_node": "auth_service", "reasoning": "..."}
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
Valid actions: `rate_limit`, `isolate`, `honeypot`, `patch`, `monitor`.
|
| 233 |
+
|
| 234 |
+
### Tasks
|
| 235 |
+
|
| 236 |
+
| Task | Difficulty | Description | Rule baseline |
|
| 237 |
+
|------|-----------|-------------|--------------:|
|
| 238 |
+
| `direct-triage` | Easy | Single fixed strategy | ~0.87 |
|
| 239 |
+
| `dual-pivot` | Medium | Two alternating strategies | ~0.76 |
|
| 240 |
+
| `polymorphic-zero-day` | Hard | All four + mid-episode shift + noise | ~0.52 |
|
| 241 |
+
|
| 242 |
+
### Reward Function
|
| 243 |
+
|
| 244 |
+
| Outcome | Reward |
|
| 245 |
+
|---------|-------:|
|
| 246 |
+
| Phase 1 threat type correct | +0.15 |
|
| 247 |
+
| Phase 1 target node correct | +0.10 |
|
| 248 |
+
| Phase 2 optimal action + correct target | +0.39 |
|
| 249 |
+
| Phase 2 heavy-handed but effective | +0.18 |
|
| 250 |
+
| Phase 2 wrong action | -0.25 |
|
| 251 |
+
| False positive on benign event | -0.39 |
|
| 252 |
+
| Catastrophic: database exfiltrated | -0.49, `done=True` |
|
| 253 |
+
|
| 254 |
+
Scores are clipped to the open interval `(0.01, 0.99)` — the grader never
|
| 255 |
+
emits exactly 0 or 1, which keeps GRPO advantages well-defined.
|
| 256 |
+
|
| 257 |
+
### Operational Impact Layer
|
| 258 |
+
|
| 259 |
+
AdaptShield also scores business impact, so the agent is rewarded for
|
| 260 |
+
stopping the attack without ignoring operational blast radius. Each
|
| 261 |
+
service has a criticality weight and a dependency fan-out:
|
| 262 |
+
|
| 263 |
+
| Service | Criticality | Downstream dependency risk |
|
| 264 |
+
|---------|------------:|----------------------------|
|
| 265 |
+
| `auth_service` | 0.70 | `payment_service` |
|
| 266 |
+
| `payment_service` | 0.90 | `api_gateway` |
|
| 267 |
+
| `database` | 1.00 | `payment_service`, `api_gateway` |
|
| 268 |
+
| `api_gateway` | 0.80 | `auth_service`, `payment_service`, `database` |
|
| 269 |
+
|
| 270 |
+
Actions have bounded disruption costs (`monitor` = none, `isolate` =
|
| 271 |
+
highest). The grader emits `business_impact`, `availability_impact`,
|
| 272 |
+
`security_risk`, `dependency_blast_radius`, and `operational_penalty`
|
| 273 |
+
inside `score_breakdown`. The reward adjustment is capped at `±0.05` per
|
| 274 |
+
turn, which keeps the training signal stable while leaving the replay
|
| 275 |
+
detailed enough to explain whether the agent stopped the attack cleanly
|
| 276 |
+
or caused unnecessary business disruption getting there.
|
| 277 |
+
|
| 278 |
+
### Mission-Aware Objectives
|
| 279 |
+
|
| 280 |
+
Each task carries a mission profile, visible in observation metadata and
|
| 281 |
+
appended to the system prompt:
|
| 282 |
+
|
| 283 |
+
| Task | Mission | Primary Asset | SLA Priority | Risk Tolerance |
|
| 284 |
+
|------|---------|---------------|--------------|----------------|
|
| 285 |
+
| `direct-triage` | `login_stability` | `auth_service` | availability | medium |
|
| 286 |
+
| `dual-pivot` | `checkout_continuity` | `payment_service` | availability | medium |
|
| 287 |
+
| `polymorphic-zero-day` | `breach_containment` | `database` | containment | low |
|
| 288 |
+
|
| 289 |
+
The grader emits `mission_alignment` and `mission_adjustment`, capped at
|
| 290 |
+
`±0.04` per turn. This makes the agent optimize for the operational
|
| 291 |
+
mission, not just the threat label. Availability-priority missions
|
| 292 |
+
discourage unnecessary isolation of the primary asset; containment
|
| 293 |
+
missions reward decisive correct containment of the crown-jewel
|
| 294 |
+
database.
|
| 295 |
+
|
| 296 |
+
### Design choices that aren't obvious
|
| 297 |
+
|
| 298 |
+
A few decisions in the environment that look like details but matter
|
| 299 |
+
for what the benchmark actually measures:
|
| 300 |
+
|
| 301 |
+
- **Information bottleneck between phases.** Phase 2's observation has
|
| 302 |
+
empty `network_nodes` and `active_alerts`. The executor only sees
|
| 303 |
+
Phase 1's structured handoff. If Phase 1 can't communicate clearly,
|
| 304 |
+
Phase 2 fails — and you see it in the score, not in a separate metric.
|
| 305 |
+
This is what makes the env actually test cross-role coordination
|
| 306 |
+
rather than just two independent policies stitched together.
|
| 307 |
+
- **Train/eval split by world family, not by seed.** The world templates
|
| 308 |
+
used for training are disjoint from the ones used for held-out
|
| 309 |
+
evaluation. A model that overfits to a specific service-name pattern
|
| 310 |
+
or a specific alert distribution will pass train evals and fail
|
| 311 |
+
held-out. Same-seed evaluation would have hidden this.
|
| 312 |
+
- **Open scoring interval `(0.01, 0.99)`.** The grader never emits
|
| 313 |
+
exactly 0 or 1. This keeps GRPO advantage estimates well-defined —
|
| 314 |
+
saturating rewards collapse the variance the algorithm needs.
|
| 315 |
+
- **Bounded auxiliary signals.** Operational impact is capped at `±0.05`
|
| 316 |
+
per turn and mission alignment at `±0.04`. They steer the policy
|
| 317 |
+
without dominating the security signal, so the training curve doesn't
|
| 318 |
+
get hijacked by a single side-objective.
|
| 319 |
+
- **Deterministic Python grader, no LLM-as-judge.** Rewards come from
|
| 320 |
+
strategy matching against a fixed ground-truth attacker, not from a
|
| 321 |
+
judge model. The benchmark cannot be gamed by a more eloquent policy.
|
| 322 |
+
- **Phase-1 alerts are raw metric strings, not MITRE codes.** The agent
|
| 323 |
+
has to do the classification, not match a label to a label. This is
|
| 324 |
+
what makes the soc-tool baseline collapse on the hard task: heuristic
|
| 325 |
+
classification doesn't survive injected noise.
|
| 326 |
+
|
| 327 |
+
---
|
| 328 |
+
|
| 329 |
+
## Reproduce it
|
| 330 |
+
|
| 331 |
+
### Free-tier Colab (recommended for judges)
|
| 332 |
+
|
| 333 |
+
Open the Colab notebook linked above and run top-to-bottom. It will:
|
| 334 |
+
|
| 335 |
+
- install the exact pinned dependency stack used in the HF Job
|
| 336 |
+
- generate SFT demos from the environment
|
| 337 |
+
- train an SFT LoRA on Qwen2.5-0.5B (T4-friendly)
|
| 338 |
+
- run GRPO refinement on top of that SFT adapter
|
| 339 |
+
- print the benchmark table and inline the production training curves
|
| 340 |
+
from `SaiManish123/Janus` so you can compare scaled-down vs. full runs
|
| 341 |
+
|
| 342 |
+
End-to-end runtime on a Colab T4 is roughly 35 minutes.
|
| 343 |
+
|
| 344 |
+
### Local setup
|
| 345 |
+
|
| 346 |
+
```bash
|
| 347 |
+
pip install openenv-core
|
| 348 |
+
git clone https://github.com/SaiManish123/adaptshield
|
| 349 |
+
cd adaptshield
|
| 350 |
+
python -m adaptshield.server.app
|
| 351 |
+
```
|
| 352 |
+
|
| 353 |
+
### Run inference against the live environment
|
| 354 |
+
|
| 355 |
+
```bash
|
| 356 |
+
export HF_TOKEN=your_token
|
| 357 |
+
export ADAPTSHIELD_TASK=direct-triage # or dual-pivot / polymorphic-zero-day
|
| 358 |
+
export ENV_BASE_URL=http://localhost:7860
|
| 359 |
+
python inference.py # run from the repo root
|
| 360 |
+
```
|
| 361 |
+
|
| 362 |
+
`inference.py` honors the evaluator contract: `[START]`, `[STEP]`, `[END]`
|
| 363 |
+
stdout markers and credentials read only from environment variables.
|
| 364 |
+
|
| 365 |
+
### Smoke test
|
| 366 |
+
|
| 367 |
+
```bash
|
| 368 |
+
python smoke_test.py
|
| 369 |
+
```
|
| 370 |
+
|
| 371 |
+
Spins the env up in-process and walks one episode of each task with a
|
| 372 |
+
deterministic policy. Should finish in <10 seconds.
|
| 373 |
+
|
| 374 |
+
### Regression tests
|
| 375 |
+
|
| 376 |
+
```bash
|
| 377 |
+
adaptshield/.venv/bin/python -m unittest tests.test_regression -v
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
### Baseline scores
|
| 381 |
+
|
| 382 |
+
With `ADAPTSHIELD_SEED=42`, the deterministic rule baseline produces:
|
| 383 |
+
|
| 384 |
+
| Task | Score | Steps | Status |
|
| 385 |
+
|------|------:|------:|--------|
|
| 386 |
+
| `direct-triage` | 0.870 | 10 | PASS |
|
| 387 |
+
| `dual-pivot` | 0.760 | 12 | PASS |
|
| 388 |
+
| `polymorphic-zero-day` | 0.520 | 16 | PASS |
|
| 389 |
+
|
| 390 |
+
Difficulty staircase: **PASS**.
|
| 391 |
+
|
| 392 |
+
---
|
| 393 |
+
|
| 394 |
+
## Repository layout
|
| 395 |
+
|
| 396 |
+
```
|
| 397 |
+
adaptshield/
|
| 398 |
+
├── server/ # FastAPI server (OpenEnv-compatible)
|
| 399 |
+
├── client.py # OpenEnv client (no server-internal imports)
|
| 400 |
+
├── models.py # Phase1Action / Phase2Action schemas
|
| 401 |
+
├── soc_tools.py # SIEM, log search, threat intel SOC tools
|
| 402 |
+
├── eval_tasks.py # task definitions + difficulty staircase
|
| 403 |
+
├── baseline.py # deterministic rule baseline
|
| 404 |
+
├── tool_baseline.py # tool-aware heuristic baseline
|
| 405 |
+
├── generate_sft_data.py # rolls episodes → SFT JSONL
|
| 406 |
+
├── train_sft.py # LoRA SFT trainer (Unsloth + TRL)
|
| 407 |
+
├── train.py # GRPO trainer (Unsloth + TRL)
|
| 408 |
+
├── plot_training.py # reward / loss curve plotting
|
| 409 |
+
├── build_benchmark_table.py # eval matrix builder
|
| 410 |
+
├── inference.py # judge-facing entry point
|
| 411 |
+
├── smoke_test.py # one-shot in-process smoke test
|
| 412 |
+
├── tests/test_regression.py # determinism + reward regression tests
|
| 413 |
+
├── openenv.yaml # OpenEnv manifest
|
| 414 |
+
└── Dockerfile # HF Space container
|
| 415 |
+
```
|
| 416 |
+
|
| 417 |
+
## Engineering notes
|
| 418 |
+
|
| 419 |
+
`AdaptShieldEnvironment` extends OpenEnv's `Environment` base class and
|
| 420 |
+
follows the Gym-style API (`reset`, `step`, `state`). The client in
|
| 421 |
+
`client.py` talks to the server only through HTTP — no shared imports,
|
| 422 |
+
no leaking of server internals. None of the SOC tools are named
|
| 423 |
+
`reset`, `step`, `state`, or `close`, so they don't collide with the
|
| 424 |
+
reserved MCP tool names. Grading is deterministic Python; the reward
|
| 425 |
+
signal and the benchmark scores both come from strategy matching
|
| 426 |
+
against a fixed ground-truth attacker, never from an LLM judge.
|
| 427 |
+
|
| 428 |
+
All adapters, curves, metrics, and benchmark tables for the 1.5B run
|
| 429 |
+
are public on [`SaiManish123/Janus`](https://huggingface.co/SaiManish123/Janus).
|
| 430 |
+
|
| 431 |
+
## License
|
| 432 |
+
|
| 433 |
+
MIT.
|
__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""AdaptShield environment package."""
|
| 8 |
+
|
| 9 |
+
from client import AdaptshieldEnv
|
| 10 |
+
from models import (
|
| 11 |
+
AdaptShieldAction,
|
| 12 |
+
AdaptShieldObservation,
|
| 13 |
+
AdaptshieldAction,
|
| 14 |
+
AdaptshieldObservation,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
"AdaptShieldAction",
|
| 19 |
+
"AdaptShieldObservation",
|
| 20 |
+
"AdaptshieldAction",
|
| 21 |
+
"AdaptshieldObservation",
|
| 22 |
+
"AdaptshieldEnv",
|
| 23 |
+
]
|
assets/_make_headline_chart.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Render the headline benchmark chart for README.
|
| 2 |
+
|
| 3 |
+
Produces a clean grouped bar chart of held-out evaluation scores
|
| 4 |
+
(tool-aware baseline / SFT / GRPO) across the three difficulty tiers.
|
| 5 |
+
|
| 6 |
+
Numbers are pulled directly from
|
| 7 |
+
https://huggingface.co/SaiManish123/Janus benchmark tables and are
|
| 8 |
+
identical to the values in README.md so the figure stays in sync.
|
| 9 |
+
|
| 10 |
+
Run: python assets/_make_headline_chart.py
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
os.environ.setdefault("MPLCONFIGDIR", "/tmp/mpl-adaptshield")
|
| 19 |
+
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
OUT = Path(__file__).parent / "headline_results.png"
|
| 24 |
+
|
| 25 |
+
tasks = ["direct-triage\n(easy)", "dual-pivot\n(medium)", "polymorphic-zero-day\n(hard)"]
|
| 26 |
+
tool_baseline = [0.990, 0.640, 0.180]
|
| 27 |
+
sft_heldout = [0.990, 0.825, 0.930]
|
| 28 |
+
grpo_heldout = [0.990, 0.825, 0.902]
|
| 29 |
+
|
| 30 |
+
x = np.arange(len(tasks))
|
| 31 |
+
width = 0.26
|
| 32 |
+
|
| 33 |
+
fig, ax = plt.subplots(figsize=(9.5, 4.6), dpi=150)
|
| 34 |
+
|
| 35 |
+
c_tool = "#9aa0a6"
|
| 36 |
+
c_sft = "#1f6feb"
|
| 37 |
+
c_grpo = "#d63b2f"
|
| 38 |
+
|
| 39 |
+
b1 = ax.bar(x - width, tool_baseline, width, label="Tool-aware baseline", color=c_tool, edgecolor="white", linewidth=0.6)
|
| 40 |
+
b2 = ax.bar(x, sft_heldout, width, label="SFT (held-out)", color=c_sft, edgecolor="white", linewidth=0.6)
|
| 41 |
+
b3 = ax.bar(x + width, grpo_heldout, width, label="GRPO (held-out)", color=c_grpo, edgecolor="white", linewidth=0.6)
|
| 42 |
+
|
| 43 |
+
for bars in (b1, b2, b3):
|
| 44 |
+
ax.bar_label(bars, fmt="%.2f", padding=3, fontsize=9, color="#333")
|
| 45 |
+
|
| 46 |
+
ax.set_ylim(0, 1.08)
|
| 47 |
+
ax.set_yticks(np.arange(0, 1.01, 0.2))
|
| 48 |
+
ax.set_ylabel("Mean score (0.01–0.99 grader)", fontsize=10)
|
| 49 |
+
ax.set_xticks(x)
|
| 50 |
+
ax.set_xticklabels(tasks, fontsize=10)
|
| 51 |
+
ax.set_title(
|
| 52 |
+
"AdaptShield held-out evaluation · Qwen2.5-1.5B · 50 deterministic seeds / task",
|
| 53 |
+
fontsize=11.5, pad=12, color="#222",
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
ax.spines["top"].set_visible(False)
|
| 57 |
+
ax.spines["right"].set_visible(False)
|
| 58 |
+
ax.spines["left"].set_color("#cccccc")
|
| 59 |
+
ax.spines["bottom"].set_color("#cccccc")
|
| 60 |
+
ax.tick_params(colors="#555")
|
| 61 |
+
ax.yaxis.grid(True, color="#eeeeee", linewidth=0.8)
|
| 62 |
+
ax.set_axisbelow(True)
|
| 63 |
+
|
| 64 |
+
ax.annotate(
|
| 65 |
+
"5.0× lift on the only task that\nactually requires adaptation",
|
| 66 |
+
xy=(2 + width, grpo_heldout[2]),
|
| 67 |
+
xytext=(2 - 0.15, 0.45),
|
| 68 |
+
fontsize=9, color="#444",
|
| 69 |
+
arrowprops=dict(arrowstyle="->", color="#888", lw=0.9, connectionstyle="arc3,rad=-0.2"),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
ax.legend(
|
| 73 |
+
loc="lower left", frameon=False, fontsize=9.5, ncol=3,
|
| 74 |
+
bbox_to_anchor=(0.0, -0.22),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
plt.tight_layout()
|
| 78 |
+
fig.savefig(OUT, bbox_inches="tight", facecolor="white")
|
| 79 |
+
print(f"wrote {OUT}")
|
assets/architecture_overview.svg
ADDED
|
|
assets/headline_results.png
ADDED
|
assets/training_pipeline.svg
ADDED
|
|
baseline.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Rule-based AdaptShield baseline with evaluator-style stdout."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import sys
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 14 |
+
|
| 15 |
+
if str(REPO_ROOT) not in sys.path:
|
| 16 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 17 |
+
|
| 18 |
+
from models import AdaptShieldAction
|
| 19 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
TASKS = ["direct-triage", "dual-pivot", "polymorphic-zero-day"]
|
| 23 |
+
BENCHMARK = "adaptshield"
|
| 24 |
+
MODEL_NAME = "rule-baseline"
|
| 25 |
+
MAX_STEPS = 30
|
| 26 |
+
|
| 27 |
+
POLICY = {
|
| 28 |
+
"brute_force": ("auth_service", "rate_limit"),
|
| 29 |
+
"lateral_movement": ("payment_service", "isolate"),
|
| 30 |
+
"exfiltration": ("database", "honeypot"),
|
| 31 |
+
"supply_chain": ("api_gateway", "patch"),
|
| 32 |
+
"benign": ("api_gateway", "monitor"),
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def log_start(task: str) -> None:
|
| 37 |
+
print(f"[START] task={task} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def log_step(step: int, action: Dict[str, Any], reward: float, done: bool) -> None:
|
| 41 |
+
action_str = json.dumps(action, separators=(",", ":"))
|
| 42 |
+
if len(action_str) > 100:
|
| 43 |
+
action_str = action_str[:97] + "..."
|
| 44 |
+
print(
|
| 45 |
+
f"[STEP] step={step} action={action_str} "
|
| 46 |
+
f"reward={reward:.2f} done={str(done).lower()} error=null",
|
| 47 |
+
flush=True,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 52 |
+
reward_str = ",".join(f"{reward:.2f}" for reward in rewards)
|
| 53 |
+
print(
|
| 54 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 55 |
+
f"score={score:.3f} rewards={reward_str}",
|
| 56 |
+
flush=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def print_replay(task: str, metadata: Dict[str, Any]) -> None:
|
| 61 |
+
replay = metadata.get("episode_replay") or []
|
| 62 |
+
print()
|
| 63 |
+
print(f"Replay: {task}")
|
| 64 |
+
if not replay:
|
| 65 |
+
print(" No replay metadata available.")
|
| 66 |
+
return
|
| 67 |
+
|
| 68 |
+
for row in replay:
|
| 69 |
+
shift = " STRATEGY SHIFT:" if row.get("shift") else ""
|
| 70 |
+
foothold = " FOOTHOLD:" if row.get("foothold_transition") else ""
|
| 71 |
+
blast_radius = row.get("blast_radius") or []
|
| 72 |
+
blast = ",".join(blast_radius) if blast_radius else "none"
|
| 73 |
+
active_defenses = row.get("active_defenses") or []
|
| 74 |
+
print(
|
| 75 |
+
f" Turn {row.get('turn')}:"
|
| 76 |
+
f"{shift}{foothold} {row.get('p1')} -> {row.get('p2_action')}"
|
| 77 |
+
f"({row.get('target', 'unknown')}) -> {row.get('result')}"
|
| 78 |
+
f" | impact={float(row.get('impact', 0.0)):.2f}"
|
| 79 |
+
f" blast_radius={blast}"
|
| 80 |
+
f" active_controls={len(active_defenses)}"
|
| 81 |
+
f" mission={row.get('mission_alignment', 'neutral')}"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def classify_from_metrics(network_nodes: Dict[str, Dict[str, Any]]) -> str:
|
| 86 |
+
auth = network_nodes.get("auth_service", {})
|
| 87 |
+
payment = network_nodes.get("payment_service", {})
|
| 88 |
+
database = network_nodes.get("database", {})
|
| 89 |
+
gateway = network_nodes.get("api_gateway", {})
|
| 90 |
+
|
| 91 |
+
if float(auth.get("error_rate", 0.0)) >= 0.10:
|
| 92 |
+
return "brute_force"
|
| 93 |
+
if payment.get("status") == "suspicious" or float(payment.get("cpu", 0)) >= 55:
|
| 94 |
+
return "lateral_movement"
|
| 95 |
+
if float(database.get("outbound_mb", 0)) >= 50:
|
| 96 |
+
return "exfiltration"
|
| 97 |
+
if gateway.get("status") == "suspicious":
|
| 98 |
+
return "supply_chain"
|
| 99 |
+
return "benign"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def phase1_payload(obs) -> Dict[str, Any]:
|
| 103 |
+
threat_type = classify_from_metrics(obs.network_nodes)
|
| 104 |
+
target_node, action = POLICY[threat_type]
|
| 105 |
+
return {
|
| 106 |
+
"threat_type": threat_type,
|
| 107 |
+
"confidence": 0.90,
|
| 108 |
+
"target_node": target_node,
|
| 109 |
+
"recommended_action": action,
|
| 110 |
+
"reasoning": "rule-based metric classifier",
|
| 111 |
+
}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def phase2_payload(obs) -> Dict[str, Any]:
|
| 115 |
+
assessment = obs.phase1_assessment or {}
|
| 116 |
+
threat_type = str(assessment.get("threat_type", "benign"))
|
| 117 |
+
fallback_target, fallback_action = POLICY.get(threat_type, POLICY["benign"])
|
| 118 |
+
action = str(assessment.get("recommended_action") or fallback_action)
|
| 119 |
+
target_node = str(assessment.get("target_node") or fallback_target)
|
| 120 |
+
return {
|
| 121 |
+
"action": action,
|
| 122 |
+
"target_node": target_node,
|
| 123 |
+
"reasoning": "execute analyst recommendation",
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def action_from_payload(payload: Dict[str, Any]) -> AdaptShieldAction:
|
| 128 |
+
return AdaptShieldAction(**payload)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def run_task(task: str, emit_logs: bool = True) -> Dict[str, Any]:
|
| 132 |
+
env = AdaptShieldEnvironment(task_name=task)
|
| 133 |
+
obs = env.reset()
|
| 134 |
+
rewards: List[float] = []
|
| 135 |
+
steps = 0
|
| 136 |
+
|
| 137 |
+
if emit_logs:
|
| 138 |
+
log_start(task)
|
| 139 |
+
|
| 140 |
+
while not obs.done and steps < MAX_STEPS:
|
| 141 |
+
if obs.phase == 1:
|
| 142 |
+
payload = phase1_payload(obs)
|
| 143 |
+
else:
|
| 144 |
+
payload = phase2_payload(obs)
|
| 145 |
+
|
| 146 |
+
obs = env.step(action_from_payload(payload))
|
| 147 |
+
reward = float(obs.reward)
|
| 148 |
+
rewards.append(reward)
|
| 149 |
+
steps += 1
|
| 150 |
+
|
| 151 |
+
if emit_logs:
|
| 152 |
+
log_step(steps, payload, reward, obs.done)
|
| 153 |
+
|
| 154 |
+
metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
|
| 155 |
+
score = float(metadata.get("normalized_score", 0.01))
|
| 156 |
+
success = obs.done and 0.01 <= score <= 0.99
|
| 157 |
+
|
| 158 |
+
if emit_logs:
|
| 159 |
+
log_end(success, steps, score, rewards)
|
| 160 |
+
|
| 161 |
+
return {
|
| 162 |
+
"task": task,
|
| 163 |
+
"score": score,
|
| 164 |
+
"steps": steps,
|
| 165 |
+
"done": bool(obs.done),
|
| 166 |
+
"rewards": rewards,
|
| 167 |
+
"metadata": metadata,
|
| 168 |
+
"normalized_score_present": "normalized_score" in metadata,
|
| 169 |
+
"success": success,
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def parse_args() -> argparse.Namespace:
|
| 174 |
+
parser = argparse.ArgumentParser(description="Run AdaptShield rule baseline.")
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--task",
|
| 177 |
+
default="direct-triage",
|
| 178 |
+
choices=TASKS + ["all"],
|
| 179 |
+
help="Task to run, or 'all' for every task.",
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--replay",
|
| 183 |
+
action="store_true",
|
| 184 |
+
help="Print a human-readable final episode replay.",
|
| 185 |
+
)
|
| 186 |
+
return parser.parse_args()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def main() -> int:
|
| 190 |
+
args = parse_args()
|
| 191 |
+
tasks = TASKS if args.task == "all" else [args.task]
|
| 192 |
+
|
| 193 |
+
for index, task in enumerate(tasks):
|
| 194 |
+
if index:
|
| 195 |
+
print()
|
| 196 |
+
result = run_task(task, emit_logs=True)
|
| 197 |
+
if args.replay:
|
| 198 |
+
print_replay(task, result["metadata"])
|
| 199 |
+
|
| 200 |
+
return 0
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
raise SystemExit(main())
|
build_benchmark_table.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Build a README-friendly benchmark table from baselines and training metrics."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List
|
| 10 |
+
|
| 11 |
+
from baseline import TASKS, run_task as run_no_tool_task
|
| 12 |
+
from tool_baseline import run_task as run_tool_task
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def rows_to_map(rows: List[Dict[str, Any]]) -> Dict[str, Dict[str, Any]]:
|
| 16 |
+
return {str(row["task"]): row for row in rows}
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def load_metrics(path: Path) -> Dict[str, Any]:
|
| 20 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def markdown_table(headers: List[str], rows: List[List[str]]) -> str:
|
| 24 |
+
lines = [
|
| 25 |
+
"| " + " | ".join(headers) + " |",
|
| 26 |
+
"| " + " | ".join(["---"] * len(headers)) + " |",
|
| 27 |
+
]
|
| 28 |
+
for row in rows:
|
| 29 |
+
lines.append("| " + " | ".join(row) + " |")
|
| 30 |
+
return "\n".join(lines)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def fmt(value: float | None) -> str:
|
| 34 |
+
if value is None:
|
| 35 |
+
return "-"
|
| 36 |
+
return f"{float(value):.3f}"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main() -> int:
|
| 40 |
+
parser = argparse.ArgumentParser(description="Build AdaptShield benchmark comparison table.")
|
| 41 |
+
parser.add_argument("--sft-metrics", required=True, help="Path to sft_metrics.json")
|
| 42 |
+
parser.add_argument("--grpo-metrics", default="", help="Optional path to GRPO metrics.json")
|
| 43 |
+
parser.add_argument("--output", default="artifacts/benchmark_table.md")
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
|
| 46 |
+
sft_metrics = load_metrics(Path(args.sft_metrics))
|
| 47 |
+
grpo_metrics = load_metrics(Path(args.grpo_metrics)) if args.grpo_metrics else {}
|
| 48 |
+
|
| 49 |
+
no_tool_rows = {task: run_no_tool_task(task, emit_logs=False) for task in TASKS}
|
| 50 |
+
tool_rows = {task: run_tool_task(task, emit_logs=False) for task in TASKS}
|
| 51 |
+
sft_eval = rows_to_map(sft_metrics.get("evaluation_rows", []))
|
| 52 |
+
sft_heldout = rows_to_map(sft_metrics.get("heldout_evaluation_rows", []))
|
| 53 |
+
grpo_eval = rows_to_map(grpo_metrics.get("evaluation_rows", [])) if grpo_metrics else {}
|
| 54 |
+
grpo_heldout = rows_to_map(grpo_metrics.get("heldout_evaluation_rows", [])) if grpo_metrics else {}
|
| 55 |
+
|
| 56 |
+
rows: List[List[str]] = []
|
| 57 |
+
for task in TASKS:
|
| 58 |
+
rows.append([
|
| 59 |
+
task,
|
| 60 |
+
fmt(no_tool_rows[task]["score"]),
|
| 61 |
+
fmt(tool_rows[task]["score"]),
|
| 62 |
+
fmt(sft_eval.get(task, {}).get("score")),
|
| 63 |
+
fmt(sft_heldout.get(task, {}).get("score")),
|
| 64 |
+
fmt(grpo_eval.get(task, {}).get("score") if grpo_eval else None),
|
| 65 |
+
fmt(grpo_heldout.get(task, {}).get("score") if grpo_heldout else None),
|
| 66 |
+
])
|
| 67 |
+
|
| 68 |
+
md = markdown_table(
|
| 69 |
+
headers=[
|
| 70 |
+
"Task",
|
| 71 |
+
"No-tool baseline",
|
| 72 |
+
"Tool-aware baseline",
|
| 73 |
+
"SFT (train family)",
|
| 74 |
+
"SFT (held-out family)",
|
| 75 |
+
"GRPO (train family)",
|
| 76 |
+
"GRPO (held-out family)",
|
| 77 |
+
],
|
| 78 |
+
rows=rows,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
summary = {
|
| 82 |
+
"no_tool_baseline": {task: no_tool_rows[task]["score"] for task in TASKS},
|
| 83 |
+
"tool_baseline": {task: tool_rows[task]["score"] for task in TASKS},
|
| 84 |
+
"sft_train_family": {task: sft_eval.get(task, {}).get("score") for task in TASKS},
|
| 85 |
+
"sft_heldout_family": {task: sft_heldout.get(task, {}).get("score") for task in TASKS},
|
| 86 |
+
"grpo_train_family": {task: grpo_eval.get(task, {}).get("score") for task in TASKS} if grpo_eval else {},
|
| 87 |
+
"grpo_heldout_family": {task: grpo_heldout.get(task, {}).get("score") for task in TASKS} if grpo_heldout else {},
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
output_path = Path(args.output)
|
| 91 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
output_path.write_text(md + "\n", encoding="utf-8")
|
| 93 |
+
output_path.with_suffix(".json").write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 94 |
+
|
| 95 |
+
print(md)
|
| 96 |
+
print()
|
| 97 |
+
print(f"Saved markdown table to: {output_path}")
|
| 98 |
+
print(f"Saved JSON summary to: {output_path.with_suffix('.json')}")
|
| 99 |
+
return 0
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
raise SystemExit(main())
|
client.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""AdaptShield environment client."""
|
| 8 |
+
|
| 9 |
+
from typing import Any, Dict
|
| 10 |
+
|
| 11 |
+
from openenv.core import EnvClient
|
| 12 |
+
from openenv.core.client_types import StepResult
|
| 13 |
+
from openenv.core.env_server.types import State
|
| 14 |
+
|
| 15 |
+
from models import AdaptShieldAction, AdaptShieldObservation
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class AdaptshieldEnv(
|
| 19 |
+
EnvClient[AdaptShieldAction, AdaptShieldObservation, State]
|
| 20 |
+
):
|
| 21 |
+
"""
|
| 22 |
+
Client for the Adaptshield Environment.
|
| 23 |
+
|
| 24 |
+
This client maintains a persistent WebSocket connection to the environment server,
|
| 25 |
+
enabling efficient multi-step interactions with lower latency.
|
| 26 |
+
Each client instance has its own dedicated environment session on the server.
|
| 27 |
+
|
| 28 |
+
Example:
|
| 29 |
+
>>> # Connect to a running server
|
| 30 |
+
>>> with AdaptshieldEnv(base_url="http://localhost:7860") as client:
|
| 31 |
+
... result = client.reset()
|
| 32 |
+
... print(result.observation.phase)
|
| 33 |
+
...
|
| 34 |
+
... result = client.step(AdaptShieldAction(
|
| 35 |
+
... threat_type="brute_force",
|
| 36 |
+
... confidence=0.9,
|
| 37 |
+
... target_node="auth_service",
|
| 38 |
+
... recommended_action="rate_limit",
|
| 39 |
+
... ))
|
| 40 |
+
... print(result.observation.phase1_assessment)
|
| 41 |
+
|
| 42 |
+
Example with Docker:
|
| 43 |
+
>>> # Automatically start container and connect
|
| 44 |
+
>>> client = AdaptshieldEnv.from_docker_image("adaptshield-env:latest")
|
| 45 |
+
>>> try:
|
| 46 |
+
... result = client.reset()
|
| 47 |
+
... result = client.step(AdaptShieldAction(
|
| 48 |
+
... threat_type="benign",
|
| 49 |
+
... confidence=0.8,
|
| 50 |
+
... target_node="auth_service",
|
| 51 |
+
... recommended_action="monitor",
|
| 52 |
+
... ))
|
| 53 |
+
... finally:
|
| 54 |
+
... client.close()
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def _step_payload(self, action: AdaptShieldAction) -> Dict[str, Any]:
|
| 58 |
+
"""
|
| 59 |
+
Convert AdaptShieldAction to a JSON-safe payload.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
action: AdaptShieldAction instance
|
| 63 |
+
|
| 64 |
+
Returns:
|
| 65 |
+
Dictionary representation suitable for JSON encoding
|
| 66 |
+
"""
|
| 67 |
+
return action.model_dump(
|
| 68 |
+
mode="json",
|
| 69 |
+
exclude_none=True,
|
| 70 |
+
exclude_defaults=True,
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult[AdaptShieldObservation]:
|
| 74 |
+
"""
|
| 75 |
+
Parse server response into StepResult[AdaptShieldObservation].
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
payload: JSON response data from server
|
| 79 |
+
|
| 80 |
+
Returns:
|
| 81 |
+
StepResult with AdaptShieldObservation
|
| 82 |
+
"""
|
| 83 |
+
obs_data = dict(payload.get("observation", {}))
|
| 84 |
+
obs_data.setdefault("done", payload.get("done", False))
|
| 85 |
+
obs_data.setdefault("reward", payload.get("reward", 0.0))
|
| 86 |
+
observation = AdaptShieldObservation(**obs_data)
|
| 87 |
+
|
| 88 |
+
return StepResult(
|
| 89 |
+
observation=observation,
|
| 90 |
+
reward=payload.get("reward"),
|
| 91 |
+
done=payload.get("done", False),
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _parse_state(self, payload: Dict) -> State:
|
| 95 |
+
"""
|
| 96 |
+
Parse server response into State object.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
payload: JSON response from state request
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
State object with episode_id and step_count
|
| 103 |
+
"""
|
| 104 |
+
return State(
|
| 105 |
+
episode_id=payload.get("episode_id"),
|
| 106 |
+
step_count=payload.get("step_count", 0),
|
| 107 |
+
)
|
eval_tasks.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Run all AdaptShield tasks with the local rule baseline."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from baseline import TASKS, run_task
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def status_for(result: dict) -> str:
|
| 10 |
+
score = result["score"]
|
| 11 |
+
passed = (
|
| 12 |
+
result["done"] and
|
| 13 |
+
result["normalized_score_present"] and
|
| 14 |
+
0.01 <= score <= 0.99
|
| 15 |
+
)
|
| 16 |
+
return "PASS" if passed else "FAIL"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def main() -> int:
|
| 20 |
+
results = [run_task(task, emit_logs=False) for task in TASKS]
|
| 21 |
+
|
| 22 |
+
print("AdaptShield Evaluation")
|
| 23 |
+
print()
|
| 24 |
+
print(f"{'Task':<24} {'Score':>7} {'Steps':>5} {'normalized_score':>18} {'Status':>8}")
|
| 25 |
+
print("-" * 68)
|
| 26 |
+
|
| 27 |
+
for result in results:
|
| 28 |
+
normalized = "yes" if result["normalized_score_present"] else "no"
|
| 29 |
+
print(
|
| 30 |
+
f"{result['task']:<24} "
|
| 31 |
+
f"{result['score']:>7.3f} "
|
| 32 |
+
f"{result['steps']:>5} "
|
| 33 |
+
f"{normalized:>18} "
|
| 34 |
+
f"{status_for(result):>8}"
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
scores = [result["score"] for result in results]
|
| 38 |
+
staircase = all(left > right for left, right in zip(scores, scores[1:]))
|
| 39 |
+
print()
|
| 40 |
+
print(f"Difficulty staircase: {'PASS' if staircase else 'FAIL'}")
|
| 41 |
+
|
| 42 |
+
return 0 if all(status_for(result) == "PASS" for result in results) else 1
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
raise SystemExit(main())
|
generate_sft_data.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Generate supervised fine-tuning data directly from AdaptShield rollouts."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import random
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
from models import AdaptShieldAction
|
| 13 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 14 |
+
from train import (
|
| 15 |
+
TASKS,
|
| 16 |
+
_current_reference,
|
| 17 |
+
_teacher_payload,
|
| 18 |
+
build_messages,
|
| 19 |
+
obs_to_dict,
|
| 20 |
+
render_messages,
|
| 21 |
+
task_for_episode,
|
| 22 |
+
)
|
| 23 |
+
from soc_tools import attach_tool_results, investigate_local_with_depth
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def build_dataset(
|
| 27 |
+
selected_task: str,
|
| 28 |
+
curriculum: bool,
|
| 29 |
+
use_tools: bool,
|
| 30 |
+
rollout_episodes: int,
|
| 31 |
+
max_steps: int,
|
| 32 |
+
seed: int,
|
| 33 |
+
world_split: str,
|
| 34 |
+
world_family: str | None,
|
| 35 |
+
) -> List[Dict[str, Any]]:
|
| 36 |
+
random.seed(seed)
|
| 37 |
+
rows: List[Dict[str, Any]] = []
|
| 38 |
+
|
| 39 |
+
for episode in range(1, rollout_episodes + 1):
|
| 40 |
+
task, stage = task_for_episode(
|
| 41 |
+
episode=episode,
|
| 42 |
+
total_episodes=rollout_episodes,
|
| 43 |
+
selected_task=selected_task,
|
| 44 |
+
curriculum=curriculum,
|
| 45 |
+
)
|
| 46 |
+
env = AdaptShieldEnvironment(
|
| 47 |
+
task_name=task,
|
| 48 |
+
world_split=world_split,
|
| 49 |
+
world_family=world_family,
|
| 50 |
+
)
|
| 51 |
+
obs = env.reset()
|
| 52 |
+
step_count = 0
|
| 53 |
+
|
| 54 |
+
while not obs.done and step_count < max_steps:
|
| 55 |
+
phase = int(getattr(obs, "phase", 1))
|
| 56 |
+
tool_results = investigate_local_with_depth(
|
| 57 |
+
env,
|
| 58 |
+
obs,
|
| 59 |
+
use_tools=use_tools,
|
| 60 |
+
thorough=(task == "polymorphic-zero-day"),
|
| 61 |
+
)
|
| 62 |
+
obs_dict = attach_tool_results(obs_to_dict(obs), tool_results)
|
| 63 |
+
messages = build_messages(obs_dict)
|
| 64 |
+
reference = _current_reference(env)
|
| 65 |
+
teacher_payload = _teacher_payload(phase, reference)
|
| 66 |
+
response_text = json.dumps(teacher_payload, separators=(",", ":"))
|
| 67 |
+
|
| 68 |
+
rows.append({
|
| 69 |
+
"task": task,
|
| 70 |
+
"stage": stage,
|
| 71 |
+
"episode": episode,
|
| 72 |
+
"turn": int(getattr(obs, "turn", 0) or 0),
|
| 73 |
+
"phase": phase,
|
| 74 |
+
"attack_stage": reference["stage"],
|
| 75 |
+
"world_split": getattr(env, "_world_split", world_split),
|
| 76 |
+
"world_family": getattr(env, "_world_family", world_family or ""),
|
| 77 |
+
"operational_mode": getattr(env, "_operational_mode", ""),
|
| 78 |
+
"is_benign": bool(reference["is_benign"]),
|
| 79 |
+
"expected_threat_type": reference["threat_type"],
|
| 80 |
+
"expected_target_node": reference["target_node"],
|
| 81 |
+
"expected_action": reference["expected_action"],
|
| 82 |
+
"tool_calls": len(tool_results),
|
| 83 |
+
"messages": messages,
|
| 84 |
+
"response": response_text,
|
| 85 |
+
"text": f"{render_messages(messages)}\n\nASSISTANT:\n{response_text}",
|
| 86 |
+
})
|
| 87 |
+
|
| 88 |
+
obs = env.step(AdaptShieldAction(**teacher_payload))
|
| 89 |
+
step_count += 1
|
| 90 |
+
|
| 91 |
+
return rows
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def summarize_rows(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 95 |
+
by_task = {task: 0 for task in TASKS}
|
| 96 |
+
by_phase = {1: 0, 2: 0}
|
| 97 |
+
with_tools = 0
|
| 98 |
+
|
| 99 |
+
for row in rows:
|
| 100 |
+
task = str(row.get("task", ""))
|
| 101 |
+
phase = int(row.get("phase", 1) or 1)
|
| 102 |
+
if task in by_task:
|
| 103 |
+
by_task[task] += 1
|
| 104 |
+
by_phase[phase] = by_phase.get(phase, 0) + 1
|
| 105 |
+
if int(row.get("tool_calls", 0) or 0) > 0:
|
| 106 |
+
with_tools += 1
|
| 107 |
+
|
| 108 |
+
return {
|
| 109 |
+
"rows": len(rows),
|
| 110 |
+
"task_counts": by_task,
|
| 111 |
+
"phase_counts": by_phase,
|
| 112 |
+
"rows_with_tool_calls": with_tools,
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def main() -> None:
|
| 117 |
+
parser = argparse.ArgumentParser(description="Generate AdaptShield SFT JSONL data")
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--task",
|
| 120 |
+
default="all",
|
| 121 |
+
choices=["all", *TASKS],
|
| 122 |
+
help="Task to sample. Use 'all' with --curriculum for mixed data.",
|
| 123 |
+
)
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--episodes",
|
| 126 |
+
type=int,
|
| 127 |
+
default=120,
|
| 128 |
+
help="Number of rollout episodes to sample.",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--max-steps",
|
| 132 |
+
type=int,
|
| 133 |
+
default=20,
|
| 134 |
+
help="Maximum env steps per episode.",
|
| 135 |
+
)
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
"--seed",
|
| 138 |
+
type=int,
|
| 139 |
+
default=42,
|
| 140 |
+
help="Dataset generation seed.",
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
"--curriculum",
|
| 144 |
+
action="store_true",
|
| 145 |
+
help="Use easy->medium->hard sampling schedule.",
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
"--use-tools",
|
| 149 |
+
action="store_true",
|
| 150 |
+
help="Include SOC tool evidence in prompts where applicable.",
|
| 151 |
+
)
|
| 152 |
+
parser.add_argument(
|
| 153 |
+
"--output",
|
| 154 |
+
default="data/adaptshield_sft.jsonl",
|
| 155 |
+
help="Where to write the JSONL dataset.",
|
| 156 |
+
)
|
| 157 |
+
parser.add_argument(
|
| 158 |
+
"--world-split",
|
| 159 |
+
default="train",
|
| 160 |
+
choices=["train", "eval"],
|
| 161 |
+
help="World-family split used to generate the dataset.",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--world-family",
|
| 165 |
+
default=None,
|
| 166 |
+
help="Optional fixed world family override (e.g. train-a, eval-x).",
|
| 167 |
+
)
|
| 168 |
+
args = parser.parse_args()
|
| 169 |
+
|
| 170 |
+
rows = build_dataset(
|
| 171 |
+
selected_task=args.task,
|
| 172 |
+
curriculum=args.curriculum,
|
| 173 |
+
use_tools=args.use_tools,
|
| 174 |
+
rollout_episodes=args.episodes,
|
| 175 |
+
max_steps=args.max_steps,
|
| 176 |
+
seed=args.seed,
|
| 177 |
+
world_split=args.world_split,
|
| 178 |
+
world_family=args.world_family,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
output_path = Path(args.output)
|
| 182 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 183 |
+
with output_path.open("w", encoding="utf-8") as handle:
|
| 184 |
+
for row in rows:
|
| 185 |
+
handle.write(json.dumps(row, ensure_ascii=True) + "\n")
|
| 186 |
+
|
| 187 |
+
summary = summarize_rows(rows)
|
| 188 |
+
summary_path = output_path.with_suffix(".summary.json")
|
| 189 |
+
summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 190 |
+
|
| 191 |
+
print(f"Wrote {len(rows)} rows to {output_path}")
|
| 192 |
+
print(f"Summary saved to {summary_path}")
|
| 193 |
+
print(json.dumps(summary, indent=2))
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
main()
|
inference.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdaptShield Inference Script
|
| 3 |
+
|
| 4 |
+
Single task per run. Emits mandatory [START]/[STEP]/[END] stdout format.
|
| 5 |
+
All credentials read from environment — never hardcoded.
|
| 6 |
+
|
| 7 |
+
Required env vars (injected by evaluator):
|
| 8 |
+
API_KEY: Evaluator's LiteLLM proxy key (checked first)
|
| 9 |
+
API_BASE_URL: LLM endpoint
|
| 10 |
+
MODEL_NAME: Model identifier
|
| 11 |
+
|
| 12 |
+
Optional env vars:
|
| 13 |
+
HF_TOKEN: Fallback if API_KEY not set
|
| 14 |
+
ADAPTSHIELD_TASK: Task name (default: direct-triage)
|
| 15 |
+
ENV_BASE_URL: Environment server URL (default: localhost:7860)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
import sys
|
| 21 |
+
import textwrap
|
| 22 |
+
from typing import Any, Dict, List, Optional
|
| 23 |
+
import urllib.request
|
| 24 |
+
import urllib.error
|
| 25 |
+
|
| 26 |
+
from openai import OpenAI
|
| 27 |
+
|
| 28 |
+
from client import AdaptshieldEnv
|
| 29 |
+
from models import AdaptShieldAction
|
| 30 |
+
from soc_tools import attach_tool_results, investigate_http, summarize_tool_results
|
| 31 |
+
|
| 32 |
+
# ── Configuration — read from env, NEVER hardcode ──────────────────────────
|
| 33 |
+
API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN", "")
|
| 34 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 35 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 36 |
+
TASK_NAME = os.environ.get("ADAPTSHIELD_TASK", "direct-triage")
|
| 37 |
+
ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860").rstrip("/")
|
| 38 |
+
BENCHMARK = "adaptshield"
|
| 39 |
+
MAX_STEPS = 25
|
| 40 |
+
SUCCESS_THRESHOLD = 0.50
|
| 41 |
+
USE_TOOLS_SETTING = os.environ.get("ADAPTSHIELD_USE_TOOLS", "auto").lower()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ── Mandatory stdout format ────────────────────────────────────────────────
|
| 45 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 46 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def log_step(step: int, action: str, reward: float,
|
| 50 |
+
done: bool, error: Optional[str]) -> None:
|
| 51 |
+
ev = error if error else "null"
|
| 52 |
+
print(
|
| 53 |
+
f"[STEP] step={step} action={action} "
|
| 54 |
+
f"reward={reward:.2f} done={str(done).lower()} error={ev}",
|
| 55 |
+
flush=True,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def log_end(success: bool, steps: int, score: float,
|
| 60 |
+
rewards: List[float]) -> None:
|
| 61 |
+
rs = ",".join(f"{r:.2f}" for r in rewards)
|
| 62 |
+
print(
|
| 63 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 64 |
+
f"score={score:.3f} rewards={rs}",
|
| 65 |
+
flush=True,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ── Environment calls ──────────────────────────────────────────────────────
|
| 70 |
+
def env_post(path: str, data: Dict) -> Dict:
|
| 71 |
+
url = f"{ENV_BASE_URL}{path}"
|
| 72 |
+
body = json.dumps(data).encode()
|
| 73 |
+
req = urllib.request.Request(
|
| 74 |
+
url, data=body, headers={"Content-Type": "application/json"}
|
| 75 |
+
)
|
| 76 |
+
with urllib.request.urlopen(req, timeout=60) as r:
|
| 77 |
+
return json.loads(r.read())
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def obs_to_dict(obs: Any) -> Dict[str, Any]:
|
| 81 |
+
"""Convert Pydantic observations from the persistent client to JSON dicts."""
|
| 82 |
+
if hasattr(obs, "model_dump"):
|
| 83 |
+
return obs.model_dump(mode="json")
|
| 84 |
+
return dict(obs)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def build_env_action(parsed: Dict[str, Any], phase: int) -> AdaptShieldAction:
|
| 88 |
+
"""Validate model output and fall back to a phase-correct safe action."""
|
| 89 |
+
try:
|
| 90 |
+
return AdaptShieldAction(**parsed)
|
| 91 |
+
except Exception:
|
| 92 |
+
if phase == 1:
|
| 93 |
+
return AdaptShieldAction(
|
| 94 |
+
threat_type="brute_force",
|
| 95 |
+
confidence=0.5,
|
| 96 |
+
target_node="auth_service",
|
| 97 |
+
recommended_action="monitor",
|
| 98 |
+
reasoning="validated fallback",
|
| 99 |
+
)
|
| 100 |
+
return AdaptShieldAction(
|
| 101 |
+
action="monitor",
|
| 102 |
+
target_node="auth_service",
|
| 103 |
+
reasoning="validated fallback",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ── Score computation — strictly (0.01, 0.99) ─────────────────────────────
|
| 108 |
+
def safe_score(rewards: List[float], meta: Dict) -> float:
|
| 109 |
+
if "normalized_score" in meta:
|
| 110 |
+
raw = float(meta["normalized_score"])
|
| 111 |
+
elif rewards:
|
| 112 |
+
pos = sum(r for r in rewards if r > 0.50)
|
| 113 |
+
maxp = len(rewards) * 0.99
|
| 114 |
+
raw = pos / maxp if maxp > 0 else 0.50
|
| 115 |
+
else:
|
| 116 |
+
raw = 0.50
|
| 117 |
+
return max(0.01, min(0.99, raw))
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ── System prompts ─────────────────────────────────────────────────────────
|
| 121 |
+
PHASE1_SYS = textwrap.dedent("""
|
| 122 |
+
You are a Threat Analyst for a 4-node enterprise network.
|
| 123 |
+
Analyze the SIEM metrics and alerts. Identify the threat type.
|
| 124 |
+
|
| 125 |
+
Attack strategies: brute_force, lateral_movement, exfiltration, supply_chain, benign
|
| 126 |
+
If SOC tool evidence is provided, use it to update your belief before classifying.
|
| 127 |
+
|
| 128 |
+
Respond ONLY with valid JSON:
|
| 129 |
+
{"threat_type":"...","confidence":0.0,"target_node":"...","recommended_action":"...","reasoning":"..."}
|
| 130 |
+
|
| 131 |
+
Nodes: auth_service, payment_service, database, api_gateway
|
| 132 |
+
Actions: rate_limit, isolate, honeypot, patch, monitor
|
| 133 |
+
""").strip()
|
| 134 |
+
|
| 135 |
+
PHASE2_SYS = textwrap.dedent("""
|
| 136 |
+
You are a Tactical Executor. Act on the threat assessment provided.
|
| 137 |
+
You cannot see raw network data. Use the analyst assessment plus any SOC tool trace.
|
| 138 |
+
|
| 139 |
+
rate_limit=throttle traffic, isolate=take offline, honeypot=redirect attacker,
|
| 140 |
+
patch=fix vulnerability, monitor=observe only
|
| 141 |
+
|
| 142 |
+
Respond ONLY with valid JSON:
|
| 143 |
+
{"action":"...","target_node":"...","reasoning":"..."}
|
| 144 |
+
|
| 145 |
+
Nodes: auth_service, payment_service, database, api_gateway
|
| 146 |
+
""").strip()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_action(client: OpenAI, obs: Dict) -> Dict[str, Any]:
|
| 150 |
+
"""Call LLM for current phase. Falls back gracefully on parse error."""
|
| 151 |
+
phase = obs.get("phase", 1)
|
| 152 |
+
|
| 153 |
+
if phase == 1:
|
| 154 |
+
sys_msg = PHASE1_SYS
|
| 155 |
+
user_msg = "\n".join([
|
| 156 |
+
"Network nodes:",
|
| 157 |
+
json.dumps(obs.get("network_nodes", {}), indent=2),
|
| 158 |
+
"\nAlerts:",
|
| 159 |
+
"\n".join(obs.get("active_alerts", [])),
|
| 160 |
+
"\nSOC tool evidence:",
|
| 161 |
+
summarize_tool_results(obs.get("tool_results", [])),
|
| 162 |
+
"\nHistory:",
|
| 163 |
+
json.dumps(obs.get("history", []), indent=2),
|
| 164 |
+
"\nClassify the threat:",
|
| 165 |
+
])
|
| 166 |
+
fallback = {
|
| 167 |
+
"threat_type": "brute_force", "confidence": 0.5,
|
| 168 |
+
"target_node": "auth_service", "recommended_action": "monitor",
|
| 169 |
+
"reasoning": "fallback",
|
| 170 |
+
}
|
| 171 |
+
else:
|
| 172 |
+
sys_msg = PHASE2_SYS
|
| 173 |
+
metadata = obs.get("metadata", {}) if isinstance(obs.get("metadata", {}), dict) else {}
|
| 174 |
+
current_turn = int(obs.get("turn", 0) or 0)
|
| 175 |
+
tool_trace = [
|
| 176 |
+
row for row in metadata.get("tool_trace", [])
|
| 177 |
+
if int(row.get("turn", -1)) == current_turn
|
| 178 |
+
]
|
| 179 |
+
user_msg = "\n".join([
|
| 180 |
+
"Threat assessment from analyst:",
|
| 181 |
+
json.dumps(obs.get("phase1_assessment", {}), indent=2),
|
| 182 |
+
"\nSOC tool trace for this turn:",
|
| 183 |
+
json.dumps(tool_trace, indent=2),
|
| 184 |
+
"\nChoose your defensive action:",
|
| 185 |
+
])
|
| 186 |
+
fallback = {
|
| 187 |
+
"action": "monitor",
|
| 188 |
+
"target_node": "auth_service",
|
| 189 |
+
"reasoning": "fallback",
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
resp = client.chat.completions.create(
|
| 194 |
+
model=MODEL_NAME,
|
| 195 |
+
messages=[
|
| 196 |
+
{"role": "system", "content": sys_msg},
|
| 197 |
+
{"role": "user", "content": user_msg},
|
| 198 |
+
],
|
| 199 |
+
temperature=0.1,
|
| 200 |
+
max_tokens=300,
|
| 201 |
+
stream=False,
|
| 202 |
+
)
|
| 203 |
+
text = (resp.choices[0].message.content or "").strip()
|
| 204 |
+
|
| 205 |
+
# Strip markdown fences
|
| 206 |
+
if "```" in text:
|
| 207 |
+
for part in text.split("```"):
|
| 208 |
+
if "{" in part:
|
| 209 |
+
text = part.strip().lstrip("json").strip()
|
| 210 |
+
break
|
| 211 |
+
|
| 212 |
+
return json.loads(text)
|
| 213 |
+
|
| 214 |
+
except Exception as exc:
|
| 215 |
+
print(f"[DEBUG] phase={phase} parse error: {exc}", flush=True)
|
| 216 |
+
return fallback
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def should_use_tools(task_name: str) -> bool:
|
| 220 |
+
if USE_TOOLS_SETTING in ("1", "true", "yes", "on"):
|
| 221 |
+
return True
|
| 222 |
+
if USE_TOOLS_SETTING in ("0", "false", "no", "off"):
|
| 223 |
+
return False
|
| 224 |
+
return task_name == "polymorphic-zero-day"
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def run_soc_episode(client: OpenAI, use_tools: bool) -> tuple[List[float], int, Dict[str, Any]]:
|
| 228 |
+
rewards: List[float] = []
|
| 229 |
+
steps_taken = 0
|
| 230 |
+
|
| 231 |
+
reset = env_post("/soc/reset", {"task": TASK_NAME})
|
| 232 |
+
session_id = str(reset.get("session_id", ""))
|
| 233 |
+
obs = dict(reset.get("observation", {}))
|
| 234 |
+
done = bool(obs.get("done", False))
|
| 235 |
+
|
| 236 |
+
for step in range(1, MAX_STEPS + 1):
|
| 237 |
+
if done:
|
| 238 |
+
break
|
| 239 |
+
|
| 240 |
+
tool_results = investigate_http(
|
| 241 |
+
env_base_url=ENV_BASE_URL,
|
| 242 |
+
session_id=session_id,
|
| 243 |
+
obs=obs,
|
| 244 |
+
use_tools=use_tools,
|
| 245 |
+
thorough=True,
|
| 246 |
+
)
|
| 247 |
+
obs_for_model = attach_tool_results(obs, tool_results)
|
| 248 |
+
parsed = get_action(client, obs_for_model)
|
| 249 |
+
action_str = json.dumps(parsed, separators=(",", ":"))
|
| 250 |
+
if len(action_str) > 100:
|
| 251 |
+
action_str = action_str[:97] + "..."
|
| 252 |
+
|
| 253 |
+
try:
|
| 254 |
+
action = build_env_action(parsed, phase=int(obs.get("phase", 1)))
|
| 255 |
+
action_payload = action.model_dump(
|
| 256 |
+
mode="json",
|
| 257 |
+
exclude_none=True,
|
| 258 |
+
exclude_defaults=True,
|
| 259 |
+
)
|
| 260 |
+
result = env_post("/soc/step", {"session_id": session_id, "action": action_payload})
|
| 261 |
+
obs = dict(result.get("observation", {}))
|
| 262 |
+
reward = float(result.get("reward", obs.get("reward", 0.0)))
|
| 263 |
+
done = bool(result.get("done", obs.get("done", False)))
|
| 264 |
+
error = None
|
| 265 |
+
except Exception as exc:
|
| 266 |
+
reward = 0.0
|
| 267 |
+
done = True
|
| 268 |
+
error = str(exc)[:80]
|
| 269 |
+
|
| 270 |
+
rewards.append(reward)
|
| 271 |
+
steps_taken = step
|
| 272 |
+
log_step(step=step, action=action_str, reward=reward, done=done, error=error)
|
| 273 |
+
|
| 274 |
+
if done:
|
| 275 |
+
break
|
| 276 |
+
|
| 277 |
+
return rewards, steps_taken, obs
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def run_openenv_episode(client: OpenAI) -> tuple[List[float], int, Dict[str, Any]]:
|
| 281 |
+
rewards: List[float] = []
|
| 282 |
+
steps_taken = 0
|
| 283 |
+
obs: Dict[str, Any] = {}
|
| 284 |
+
|
| 285 |
+
env = AdaptshieldEnv(base_url=ENV_BASE_URL).sync()
|
| 286 |
+
with env:
|
| 287 |
+
result = env.reset(task_name=TASK_NAME)
|
| 288 |
+
obs = obs_to_dict(result.observation)
|
| 289 |
+
done = bool(result.done or obs.get("done", False))
|
| 290 |
+
|
| 291 |
+
for step in range(1, MAX_STEPS + 1):
|
| 292 |
+
if done:
|
| 293 |
+
break
|
| 294 |
+
|
| 295 |
+
parsed = get_action(client, obs)
|
| 296 |
+
action_str = json.dumps(parsed, separators=(",", ":"))
|
| 297 |
+
if len(action_str) > 100:
|
| 298 |
+
action_str = action_str[:97] + "..."
|
| 299 |
+
|
| 300 |
+
try:
|
| 301 |
+
action = build_env_action(parsed, phase=int(obs.get("phase", 1)))
|
| 302 |
+
sr = env.step(action)
|
| 303 |
+
obs = obs_to_dict(sr.observation)
|
| 304 |
+
reward = float(sr.reward if sr.reward is not None else obs.get("reward", 0.0))
|
| 305 |
+
done = bool(sr.done or obs.get("done", False))
|
| 306 |
+
error = None
|
| 307 |
+
except Exception as exc:
|
| 308 |
+
reward = 0.0
|
| 309 |
+
done = True
|
| 310 |
+
error = str(exc)[:80]
|
| 311 |
+
|
| 312 |
+
rewards.append(reward)
|
| 313 |
+
steps_taken = step
|
| 314 |
+
log_step(step=step, action=action_str, reward=reward,
|
| 315 |
+
done=done, error=error)
|
| 316 |
+
|
| 317 |
+
if done:
|
| 318 |
+
break
|
| 319 |
+
|
| 320 |
+
return rewards, steps_taken, obs
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def main() -> None:
|
| 324 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 325 |
+
|
| 326 |
+
rewards: List[float] = []
|
| 327 |
+
steps_taken: int = 0
|
| 328 |
+
score: float = 0.50
|
| 329 |
+
success: bool = False
|
| 330 |
+
obs: Dict = {}
|
| 331 |
+
|
| 332 |
+
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 333 |
+
|
| 334 |
+
try:
|
| 335 |
+
if should_use_tools(TASK_NAME):
|
| 336 |
+
rewards, steps_taken, obs = run_soc_episode(client, use_tools=True)
|
| 337 |
+
else:
|
| 338 |
+
rewards, steps_taken, obs = run_openenv_episode(client)
|
| 339 |
+
|
| 340 |
+
score = safe_score(rewards, obs.get("metadata", {}))
|
| 341 |
+
success = score >= SUCCESS_THRESHOLD
|
| 342 |
+
|
| 343 |
+
except Exception as exc:
|
| 344 |
+
print(f"[DEBUG] episode error: {exc}", flush=True)
|
| 345 |
+
score = 0.10
|
| 346 |
+
|
| 347 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
main()
|
launch_hf_grpo_job.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Launch AdaptShield GRPO refinement on Hugging Face Jobs."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import shlex
|
| 8 |
+
import subprocess
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from huggingface_hub import HfApi, get_token, run_job
|
| 13 |
+
from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError
|
| 14 |
+
|
| 15 |
+
from train import MODEL_CHOICES, TASKS
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 19 |
+
DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _should_retry_hf(exc: Exception) -> bool:
|
| 23 |
+
response = getattr(exc, "response", None)
|
| 24 |
+
status_code = getattr(response, "status_code", None)
|
| 25 |
+
return status_code == 429 or (isinstance(status_code, int) and 500 <= status_code < 600)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _retry_hf_call(fn, *args, retries: int = 4, delay_s: float = 2.0, **kwargs):
|
| 29 |
+
last_exc = None
|
| 30 |
+
for attempt in range(retries):
|
| 31 |
+
try:
|
| 32 |
+
return fn(*args, **kwargs)
|
| 33 |
+
except Exception as exc:
|
| 34 |
+
last_exc = exc
|
| 35 |
+
if not _should_retry_hf(exc) or attempt == retries - 1:
|
| 36 |
+
raise
|
| 37 |
+
sleep_for = delay_s * (2 ** attempt)
|
| 38 |
+
print(f"Retrying HF API call after transient error ({exc}); sleeping {sleep_for:.1f}s")
|
| 39 |
+
time.sleep(sleep_for)
|
| 40 |
+
raise last_exc # pragma: no cover
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def infer_repo_url() -> str:
|
| 44 |
+
result = subprocess.run(
|
| 45 |
+
["git", "config", "--get", "remote.origin.url"],
|
| 46 |
+
cwd=REPO_ROOT,
|
| 47 |
+
check=True,
|
| 48 |
+
capture_output=True,
|
| 49 |
+
text=True,
|
| 50 |
+
)
|
| 51 |
+
repo_url = result.stdout.strip()
|
| 52 |
+
if not repo_url:
|
| 53 |
+
raise RuntimeError("Could not infer git remote.origin.url")
|
| 54 |
+
return repo_url
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def repo_namespace(repo_id: str) -> str:
|
| 58 |
+
if "/" not in repo_id:
|
| 59 |
+
raise RuntimeError(f"Invalid repo id: {repo_id}. Expected namespace/name.")
|
| 60 |
+
return repo_id.split("/", 1)[0]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def authenticated_username(api: HfApi) -> str | None:
|
| 64 |
+
try:
|
| 65 |
+
info = api.whoami(cache=True)
|
| 66 |
+
except Exception:
|
| 67 |
+
return None
|
| 68 |
+
if isinstance(info, dict):
|
| 69 |
+
for key in ("name", "fullname", "user"):
|
| 70 |
+
value = info.get(key)
|
| 71 |
+
if isinstance(value, str) and value:
|
| 72 |
+
return value
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def validate_repo_access(
|
| 77 |
+
api: HfApi,
|
| 78 |
+
repo_id: str,
|
| 79 |
+
repo_type: str,
|
| 80 |
+
skip_create: bool,
|
| 81 |
+
allow_cross_namespace: bool,
|
| 82 |
+
) -> None:
|
| 83 |
+
owner = repo_namespace(repo_id)
|
| 84 |
+
username = authenticated_username(api)
|
| 85 |
+
if username and owner != username:
|
| 86 |
+
message = (
|
| 87 |
+
f"Authenticated HF account appears to be '{username}', but target repo is under '{owner}'. "
|
| 88 |
+
"Use a repo under the same namespace or pass --allow-cross-namespace only if you are certain "
|
| 89 |
+
"this token has write access there."
|
| 90 |
+
)
|
| 91 |
+
if not allow_cross_namespace:
|
| 92 |
+
raise RuntimeError(message)
|
| 93 |
+
print(f"Warning: {message}")
|
| 94 |
+
|
| 95 |
+
if skip_create or repo_type == "model":
|
| 96 |
+
try:
|
| 97 |
+
_retry_hf_call(api.repo_info, repo_id=repo_id, repo_type=repo_type)
|
| 98 |
+
except RepositoryNotFoundError as exc:
|
| 99 |
+
raise RuntimeError(
|
| 100 |
+
f"Repo '{repo_id}' ({repo_type}) was not found or is not accessible with the current token."
|
| 101 |
+
) from exc
|
| 102 |
+
except HfHubHTTPError as exc:
|
| 103 |
+
raise RuntimeError(f"Could not verify repo '{repo_id}' ({repo_type}): {exc}") from exc
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def validate_source_artifacts(
|
| 107 |
+
api: HfApi,
|
| 108 |
+
repo_id: str,
|
| 109 |
+
repo_type: str,
|
| 110 |
+
subdir: str,
|
| 111 |
+
) -> None:
|
| 112 |
+
try:
|
| 113 |
+
files = set(_retry_hf_call(api.list_repo_files, repo_id=repo_id, repo_type=repo_type))
|
| 114 |
+
except Exception as exc:
|
| 115 |
+
raise RuntimeError(f"Could not list files for source repo '{repo_id}' ({repo_type}): {exc}") from exc
|
| 116 |
+
|
| 117 |
+
required = {
|
| 118 |
+
f"{subdir}/final/adapter_config.json",
|
| 119 |
+
f"{subdir}/sft_metrics.json",
|
| 120 |
+
}
|
| 121 |
+
missing = sorted(path for path in required if path not in files)
|
| 122 |
+
if missing:
|
| 123 |
+
raise RuntimeError(
|
| 124 |
+
"Source repo is missing required SFT artifacts: " + ", ".join(missing)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def build_command(args: argparse.Namespace, repo_url: str, output_subdir: str) -> str:
|
| 129 |
+
output_path = f"/workspace/adaptshield/checkpoints/{output_subdir}"
|
| 130 |
+
|
| 131 |
+
return f"""
|
| 132 |
+
set -euo pipefail
|
| 133 |
+
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
| 134 |
+
export PYTHONWARNINGS="ignore::FutureWarning"
|
| 135 |
+
export HF_HUB_ENABLE_HF_TRANSFER=1
|
| 136 |
+
export PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 137 |
+
|
| 138 |
+
python - <<'PY'
|
| 139 |
+
import torch
|
| 140 |
+
print(f"baseline torch={{torch.__version__}}, cuda={{torch.version.cuda}}")
|
| 141 |
+
PY
|
| 142 |
+
|
| 143 |
+
apt-get update -qq
|
| 144 |
+
apt-get install -y -qq git
|
| 145 |
+
if [ ! -d /workspace/adaptshield/.git ]; then
|
| 146 |
+
rm -rf /workspace/adaptshield
|
| 147 |
+
git clone --depth 1 {shlex.quote(repo_url)} /workspace/adaptshield
|
| 148 |
+
fi
|
| 149 |
+
cd /workspace/adaptshield
|
| 150 |
+
python -m pip install --upgrade pip wheel setuptools
|
| 151 |
+
# ninja+packaging let any source-built dep that DOES sneak in compile cleanly.
|
| 152 |
+
pip install --upgrade ninja packaging
|
| 153 |
+
pip install -e .
|
| 154 |
+
pip uninstall -y torchaudio || true
|
| 155 |
+
|
| 156 |
+
# Unsloth ships CUDA/torch-pinned extras (cu124 + torch 2.6.0 + xformers+triton wheels).
|
| 157 |
+
# We deliberately use `cu124-torch260` (NOT the `ampere` variant) because:
|
| 158 |
+
# * cu124-torch260 pins torch 2.6 + xformers + triton via prebuilt wheels (no source builds).
|
| 159 |
+
# * cu124-ampere-torch260 ALSO tries to install flash-attn; if its prebuilt wheel URL doesn't
|
| 160 |
+
# match the image's python/cxx11abi exactly, pip falls through to source-building flash-attn
|
| 161 |
+
# (10-30 min, often fails with "ModuleNotFoundError: No module named 'torch'" because PEP 517
|
| 162 |
+
# build isolation hides torch). Unsloth's xformers/triton attention is plenty fast on L4.
|
| 163 |
+
# `unsloth[cu124-torch260]` transitively installs `unsloth[huggingface]` which pins ALL of
|
| 164 |
+
# transformers / trl / peft / accelerate / datasets / bitsandbytes / tokenizers / safetensors
|
| 165 |
+
# to versions Unsloth has tested together. Do NOT add a `--no-deps` override on top of this —
|
| 166 |
+
# previous attempts to do so downgraded peft/trl below what Unsloth requires.
|
| 167 |
+
# --no-build-isolation lets any incidental source build (e.g. a stray dep) see system torch.
|
| 168 |
+
pip install --upgrade --no-build-isolation "unsloth[cu124-torch260]"
|
| 169 |
+
|
| 170 |
+
# Pin transformers to a single known-good version. Why this is necessary:
|
| 171 |
+
# Unsloth's pyproject allows transformers >=4.51.3 ... <=5.5.0. Pip prefers the latest, so it
|
| 172 |
+
# picks 5.5.0 by default. But transformers 4.x requires huggingface-hub<1.0 while 5.x requires
|
| 173 |
+
# hub>=1.5,<2.0 — and unsloth's pyproject does NOT bound hub. So a separate `pip install hub<1.0`
|
| 174 |
+
# silently breaks transformers 5.x (and a separate `pip install hub>=1.5` silently breaks 4.x).
|
| 175 |
+
# The only robust fix is to pin transformers and let pip select the matching hub in the SAME
|
| 176 |
+
# resolution step. We pick 4.57.6 because:
|
| 177 |
+
# * latest 4.x release on PyPI (so qwen3, etc. are supported);
|
| 178 |
+
# * not on Unsloth's blocklist (4.57.0/.4/.5 are; 4.57.6 is fine);
|
| 179 |
+
# * pulls huggingface-hub<1.0 automatically (no separate hub pin needed).
|
| 180 |
+
pip install "transformers==4.57.6"
|
| 181 |
+
|
| 182 |
+
# torchao comes preinstalled in the base image at a version that requires torch 2.7+
|
| 183 |
+
# (it calls torch.utils._pytree.register_constant which doesn't exist in torch 2.6, so
|
| 184 |
+
# `import torchao` crashes with AttributeError). transformers' quantizer registry imports
|
| 185 |
+
# torchao unconditionally if it's installed (`is_torchao_available()` only checks package
|
| 186 |
+
# metadata, not import-ability). With torchao GONE, that check returns False and transformers
|
| 187 |
+
# skips torchao cleanly. We don't use torchao quantization anyway — we use bitsandbytes 4-bit.
|
| 188 |
+
pip uninstall -y torchao || true
|
| 189 |
+
|
| 190 |
+
# Optional helpers we use directly (matplotlib for plots, hf_transfer for fast download/upload).
|
| 191 |
+
pip install --upgrade matplotlib hf_transfer
|
| 192 |
+
|
| 193 |
+
# Hard guard: if torch was upgraded, bitsandbytes will fail at import; fail FAST with a clear log.
|
| 194 |
+
python - <<'PY'
|
| 195 |
+
import sys, torch
|
| 196 |
+
if not torch.__version__.startswith("2.6."):
|
| 197 |
+
print(f"FATAL: torch was upgraded to {{torch.__version__}}; aborting before training.")
|
| 198 |
+
sys.exit(2)
|
| 199 |
+
print(f"torch ok: {{torch.__version__}} cuda={{torch.version.cuda}}")
|
| 200 |
+
PY
|
| 201 |
+
|
| 202 |
+
# Smoke-test the actual modules we use. unsloth MUST import before transformers/trl
|
| 203 |
+
# per its own warning. Importing transformers also triggers its OWN runtime version check on
|
| 204 |
+
# huggingface_hub and tokenizers, AND eagerly imports any installed quantizer backend
|
| 205 |
+
# (torchao, bnb, etc.) — so if anything is mis-pinned this line raises a clear error before
|
| 206 |
+
# training starts.
|
| 207 |
+
python - <<'PY'
|
| 208 |
+
import sys, importlib.util
|
| 209 |
+
|
| 210 |
+
# Pre-flight: torchao must be GONE (preinstalled version requires torch>=2.7 and crashes
|
| 211 |
+
# `import torchao` on torch 2.6). If it leaked back in, fail with a precise message.
|
| 212 |
+
if importlib.util.find_spec("torchao") is not None:
|
| 213 |
+
print("FATAL: torchao is installed; on torch 2.6 it crashes transformers at import. "
|
| 214 |
+
"Run `pip uninstall -y torchao` and rebuild.")
|
| 215 |
+
sys.exit(2)
|
| 216 |
+
|
| 217 |
+
import unsloth # noqa: F401 (must be first)
|
| 218 |
+
import torch, transformers, trl, peft, datasets, bitsandbytes, huggingface_hub
|
| 219 |
+
print(
|
| 220 |
+
f"unsloth={{unsloth.__version__}} transformers={{transformers.__version__}} "
|
| 221 |
+
f"trl={{trl.__version__}} peft={{peft.__version__}} bnb={{bitsandbytes.__version__}} "
|
| 222 |
+
f"hub={{huggingface_hub.__version__}} datasets={{datasets.__version__}}"
|
| 223 |
+
)
|
| 224 |
+
expected_transformers = "4.57.6"
|
| 225 |
+
if transformers.__version__ != expected_transformers:
|
| 226 |
+
print(
|
| 227 |
+
f"FATAL: transformers={{transformers.__version__}} but pinned to {{expected_transformers}}. "
|
| 228 |
+
f"Pip resolution drifted; aborting before training."
|
| 229 |
+
)
|
| 230 |
+
sys.exit(2)
|
| 231 |
+
import train, build_benchmark_table # noqa: F401
|
| 232 |
+
print("Dependency smoke check passed.")
|
| 233 |
+
PY
|
| 234 |
+
|
| 235 |
+
python - <<'PY'
|
| 236 |
+
from huggingface_hub import snapshot_download
|
| 237 |
+
from pathlib import Path
|
| 238 |
+
|
| 239 |
+
repo_id = {args.source_repo!r}
|
| 240 |
+
repo_type = {args.source_repo_type!r}
|
| 241 |
+
subdir = {args.source_subdir!r}
|
| 242 |
+
local_dir = snapshot_download(repo_id=repo_id, repo_type=repo_type)
|
| 243 |
+
adapter_path = Path(local_dir) / subdir / "final"
|
| 244 |
+
sft_metrics_path = Path(local_dir) / subdir / "sft_metrics.json"
|
| 245 |
+
if not adapter_path.exists():
|
| 246 |
+
raise RuntimeError(f"SFT adapter path not found: {{adapter_path}}")
|
| 247 |
+
if not sft_metrics_path.exists():
|
| 248 |
+
raise RuntimeError(f"SFT metrics path not found: {{sft_metrics_path}}")
|
| 249 |
+
print(adapter_path)
|
| 250 |
+
Path("/workspace/adaptshield/.grpo_adapter_path.txt").write_text(str(adapter_path), encoding="utf-8")
|
| 251 |
+
Path("/workspace/adaptshield/.grpo_sft_metrics_path.txt").write_text(str(sft_metrics_path), encoding="utf-8")
|
| 252 |
+
PY
|
| 253 |
+
|
| 254 |
+
ADAPTER_PATH=$(cat /workspace/adaptshield/.grpo_adapter_path.txt)
|
| 255 |
+
SFT_METRICS_PATH=$(cat /workspace/adaptshield/.grpo_sft_metrics_path.txt)
|
| 256 |
+
|
| 257 |
+
python train.py \\
|
| 258 |
+
--trainer grpo \\
|
| 259 |
+
--task {args.task} \\
|
| 260 |
+
--curriculum \\
|
| 261 |
+
--use-tools \\
|
| 262 |
+
--model {args.model} \\
|
| 263 |
+
--model-path "$ADAPTER_PATH" \\
|
| 264 |
+
--lr {args.lr} \\
|
| 265 |
+
--prompt-bank-episodes {args.prompt_bank_episodes} \\
|
| 266 |
+
--max-steps {args.max_steps} \\
|
| 267 |
+
--prompt-bank-hard-multiplier {args.prompt_bank_hard_multiplier} \\
|
| 268 |
+
--prompt-bank-borderline-bonus {args.prompt_bank_borderline_bonus} \\
|
| 269 |
+
--grpo-epochs {args.grpo_epochs} \\
|
| 270 |
+
--num-generations {args.num_generations} \\
|
| 271 |
+
--per-device-batch-size {args.per_device_batch_size} \\
|
| 272 |
+
--gradient-accumulation-steps {args.gradient_accumulation_steps} \\
|
| 273 |
+
--save-every {args.save_every} \\
|
| 274 |
+
--eval-episodes {args.eval_episodes} \\
|
| 275 |
+
--train-world-split train \\
|
| 276 |
+
--heldout-world-split eval \\
|
| 277 |
+
--heldout-seed {args.heldout_seed} \\
|
| 278 |
+
--output {output_path} \\
|
| 279 |
+
--plot
|
| 280 |
+
|
| 281 |
+
if ! python build_benchmark_table.py \\
|
| 282 |
+
--sft-metrics "$SFT_METRICS_PATH" \\
|
| 283 |
+
--grpo-metrics {output_path}/metrics.json \\
|
| 284 |
+
--output {output_path}/benchmark_table.md; then
|
| 285 |
+
echo "Benchmark table generation failed; continuing with core artifacts."
|
| 286 |
+
fi
|
| 287 |
+
|
| 288 |
+
python - <<'PY'
|
| 289 |
+
import os
|
| 290 |
+
import time
|
| 291 |
+
from huggingface_hub import HfApi
|
| 292 |
+
|
| 293 |
+
api = HfApi(token=os.environ["HF_TOKEN"])
|
| 294 |
+
repo_id = os.environ["RUNS_REPO"]
|
| 295 |
+
repo_type = os.environ["RUNS_REPO_TYPE"]
|
| 296 |
+
output_dir = {output_path!r}
|
| 297 |
+
subdir = {output_subdir!r}
|
| 298 |
+
|
| 299 |
+
last_exc = None
|
| 300 |
+
for attempt in range(4):
|
| 301 |
+
try:
|
| 302 |
+
api.upload_folder(
|
| 303 |
+
repo_id=repo_id,
|
| 304 |
+
repo_type=repo_type,
|
| 305 |
+
folder_path=output_dir,
|
| 306 |
+
path_in_repo=subdir,
|
| 307 |
+
)
|
| 308 |
+
last_exc = None
|
| 309 |
+
break
|
| 310 |
+
except Exception as exc:
|
| 311 |
+
last_exc = exc
|
| 312 |
+
response = getattr(exc, "response", None)
|
| 313 |
+
status_code = getattr(response, "status_code", None)
|
| 314 |
+
if status_code == 429 or (isinstance(status_code, int) and 500 <= status_code < 600):
|
| 315 |
+
sleep_for = 2 ** attempt
|
| 316 |
+
print(f"Transient upload error: {{exc}}; retrying in {{sleep_for}}s")
|
| 317 |
+
time.sleep(sleep_for)
|
| 318 |
+
continue
|
| 319 |
+
raise
|
| 320 |
+
if last_exc is not None:
|
| 321 |
+
raise last_exc
|
| 322 |
+
print("Uploaded artifacts to", repo_id)
|
| 323 |
+
PY
|
| 324 |
+
"""
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def default_output_subdir(task: str, model: str) -> str:
|
| 328 |
+
model_slug = model.replace(".", "_")
|
| 329 |
+
if task == "all":
|
| 330 |
+
return f"grpo_worldsplit_{model_slug}"
|
| 331 |
+
task_slug = task.replace("-", "_")
|
| 332 |
+
return f"grpo_{task_slug}_{model_slug}"
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def main() -> int:
|
| 336 |
+
parser = argparse.ArgumentParser(description="Launch AdaptShield GRPO refinement on Hugging Face Jobs")
|
| 337 |
+
parser.add_argument("--runs-repo", required=True)
|
| 338 |
+
parser.add_argument("--runs-repo-type", default="model", choices=["dataset", "model"])
|
| 339 |
+
parser.add_argument("--skip-create", action="store_true")
|
| 340 |
+
parser.add_argument("--allow-cross-namespace", action="store_true")
|
| 341 |
+
parser.add_argument("--repo-url", default=None)
|
| 342 |
+
parser.add_argument("--source-repo", required=True, help="Repo containing SFT artifacts.")
|
| 343 |
+
parser.add_argument("--source-repo-type", default="model", choices=["dataset", "model"])
|
| 344 |
+
parser.add_argument("--source-subdir", default="sft_worldsplit_1_5b", help="Subdirectory containing the SFT output.")
|
| 345 |
+
parser.add_argument("--task", default="all", choices=TASKS + ["all"])
|
| 346 |
+
parser.add_argument("--model", default="1.5b", choices=list(MODEL_CHOICES))
|
| 347 |
+
parser.add_argument("--flavor", default="l4x1")
|
| 348 |
+
parser.add_argument("--timeout", default="6h")
|
| 349 |
+
parser.add_argument("--lr", type=float, default=1e-5)
|
| 350 |
+
parser.add_argument("--prompt-bank-episodes", type=int, default=120)
|
| 351 |
+
parser.add_argument("--max-steps", type=int, default=20)
|
| 352 |
+
parser.add_argument("--prompt-bank-hard-multiplier", type=int, default=3)
|
| 353 |
+
parser.add_argument("--prompt-bank-borderline-bonus", type=int, default=2)
|
| 354 |
+
parser.add_argument("--grpo-epochs", type=int, default=1)
|
| 355 |
+
parser.add_argument("--num-generations", type=int, default=2)
|
| 356 |
+
parser.add_argument("--per-device-batch-size", type=int, default=1)
|
| 357 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 358 |
+
parser.add_argument("--save-every", type=int, default=0)
|
| 359 |
+
parser.add_argument("--eval-episodes", type=int, default=2)
|
| 360 |
+
parser.add_argument("--heldout-seed", type=int, default=314)
|
| 361 |
+
parser.add_argument("--output-subdir", default="")
|
| 362 |
+
args = parser.parse_args()
|
| 363 |
+
|
| 364 |
+
if not args.output_subdir:
|
| 365 |
+
args.output_subdir = default_output_subdir(args.task, args.model)
|
| 366 |
+
|
| 367 |
+
token = get_token()
|
| 368 |
+
if not token:
|
| 369 |
+
raise RuntimeError("No Hugging Face token found. Run `hf auth login` first.")
|
| 370 |
+
|
| 371 |
+
repo_url = args.repo_url or infer_repo_url()
|
| 372 |
+
api = HfApi(token=token)
|
| 373 |
+
validate_repo_access(api, args.runs_repo, args.runs_repo_type, args.skip_create, args.allow_cross_namespace)
|
| 374 |
+
validate_repo_access(api, args.source_repo, args.source_repo_type, True, args.allow_cross_namespace)
|
| 375 |
+
validate_source_artifacts(api, args.source_repo, args.source_repo_type, args.source_subdir)
|
| 376 |
+
if not args.skip_create:
|
| 377 |
+
_retry_hf_call(api.create_repo, repo_id=args.runs_repo, repo_type=args.runs_repo_type, private=True, exist_ok=True)
|
| 378 |
+
|
| 379 |
+
command = build_command(args=args, repo_url=repo_url, output_subdir=args.output_subdir)
|
| 380 |
+
job = _retry_hf_call(
|
| 381 |
+
run_job,
|
| 382 |
+
image=DEFAULT_IMAGE,
|
| 383 |
+
command=["bash", "-lc", command],
|
| 384 |
+
flavor=args.flavor,
|
| 385 |
+
timeout=args.timeout,
|
| 386 |
+
namespace=repo_namespace(args.runs_repo),
|
| 387 |
+
env={
|
| 388 |
+
"RUNS_REPO": args.runs_repo,
|
| 389 |
+
"RUNS_REPO_TYPE": args.runs_repo_type,
|
| 390 |
+
},
|
| 391 |
+
secrets={"HF_TOKEN": token},
|
| 392 |
+
)
|
| 393 |
+
|
| 394 |
+
print("Job launched successfully.")
|
| 395 |
+
print(f"Job ID: {job.id}")
|
| 396 |
+
print(f"Job URL: {job.url}")
|
| 397 |
+
print(f"Artifacts repo: {args.runs_repo}")
|
| 398 |
+
print(f"Artifacts path: {args.output_subdir}")
|
| 399 |
+
return 0
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
if __name__ == "__main__":
|
| 403 |
+
raise SystemExit(main())
|
launch_hf_sft_job.py
ADDED
|
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Launch an AdaptShield SFT training run on Hugging Face Jobs."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import shlex
|
| 8 |
+
import subprocess
|
| 9 |
+
import time
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
from huggingface_hub import HfApi, get_token, run_job
|
| 13 |
+
from huggingface_hub.errors import HfHubHTTPError, RepositoryNotFoundError
|
| 14 |
+
|
| 15 |
+
from train import MODEL_CHOICES
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 19 |
+
DEFAULT_IMAGE = "pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel"
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _should_retry_hf(exc: Exception) -> bool:
|
| 23 |
+
response = getattr(exc, "response", None)
|
| 24 |
+
status_code = getattr(response, "status_code", None)
|
| 25 |
+
return status_code == 429 or (isinstance(status_code, int) and 500 <= status_code < 600)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _retry_hf_call(fn, *args, retries: int = 4, delay_s: float = 2.0, **kwargs):
|
| 29 |
+
last_exc = None
|
| 30 |
+
for attempt in range(retries):
|
| 31 |
+
try:
|
| 32 |
+
return fn(*args, **kwargs)
|
| 33 |
+
except Exception as exc:
|
| 34 |
+
last_exc = exc
|
| 35 |
+
if not _should_retry_hf(exc) or attempt == retries - 1:
|
| 36 |
+
raise
|
| 37 |
+
sleep_for = delay_s * (2 ** attempt)
|
| 38 |
+
print(f"Retrying HF API call after transient error ({exc}); sleeping {sleep_for:.1f}s")
|
| 39 |
+
time.sleep(sleep_for)
|
| 40 |
+
raise last_exc # pragma: no cover
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def repo_namespace(repo_id: str) -> str:
|
| 44 |
+
if "/" not in repo_id:
|
| 45 |
+
raise RuntimeError(f"Invalid repo id: {repo_id}. Expected namespace/name.")
|
| 46 |
+
return repo_id.split("/", 1)[0]
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def authenticated_username(api: HfApi) -> str | None:
|
| 50 |
+
try:
|
| 51 |
+
info = api.whoami(cache=True)
|
| 52 |
+
except Exception:
|
| 53 |
+
return None
|
| 54 |
+
if isinstance(info, dict):
|
| 55 |
+
for key in ("name", "fullname", "user"):
|
| 56 |
+
value = info.get(key)
|
| 57 |
+
if isinstance(value, str) and value:
|
| 58 |
+
return value
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def validate_artifact_repo(
|
| 63 |
+
api: HfApi,
|
| 64 |
+
repo_id: str,
|
| 65 |
+
repo_type: str,
|
| 66 |
+
skip_create: bool,
|
| 67 |
+
allow_cross_namespace: bool,
|
| 68 |
+
) -> None:
|
| 69 |
+
owner = repo_namespace(repo_id)
|
| 70 |
+
username = authenticated_username(api)
|
| 71 |
+
if username and owner != username:
|
| 72 |
+
message = (
|
| 73 |
+
f"Authenticated HF account appears to be '{username}', but artifacts repo is under '{owner}'. "
|
| 74 |
+
"Use a repo under the same namespace or pass --allow-cross-namespace only if you are certain "
|
| 75 |
+
"this token has write access there."
|
| 76 |
+
)
|
| 77 |
+
if not allow_cross_namespace:
|
| 78 |
+
raise RuntimeError(message)
|
| 79 |
+
print(f"Warning: {message}")
|
| 80 |
+
|
| 81 |
+
if skip_create:
|
| 82 |
+
try:
|
| 83 |
+
_retry_hf_call(api.repo_info, repo_id=repo_id, repo_type=repo_type)
|
| 84 |
+
except RepositoryNotFoundError as exc:
|
| 85 |
+
raise RuntimeError(
|
| 86 |
+
f"Artifacts repo '{repo_id}' ({repo_type}) was not found or is not accessible "
|
| 87 |
+
"with the current token. Create it manually under the correct namespace or use "
|
| 88 |
+
"a repo you definitely own before launching the job."
|
| 89 |
+
) from exc
|
| 90 |
+
except HfHubHTTPError as exc:
|
| 91 |
+
raise RuntimeError(
|
| 92 |
+
f"Could not verify artifacts repo '{repo_id}' ({repo_type}) before launch: {exc}"
|
| 93 |
+
) from exc
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def infer_repo_url() -> str:
|
| 97 |
+
result = subprocess.run(
|
| 98 |
+
["git", "config", "--get", "remote.origin.url"],
|
| 99 |
+
cwd=REPO_ROOT,
|
| 100 |
+
check=True,
|
| 101 |
+
capture_output=True,
|
| 102 |
+
text=True,
|
| 103 |
+
)
|
| 104 |
+
repo_url = result.stdout.strip()
|
| 105 |
+
if not repo_url:
|
| 106 |
+
raise RuntimeError("Could not infer git remote.origin.url")
|
| 107 |
+
return repo_url
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def build_command(args: argparse.Namespace, repo_url: str, output_subdir: str) -> str:
|
| 111 |
+
dataset_path = "/workspace/adaptshield/data/adaptshield_sft_worldsplit.jsonl"
|
| 112 |
+
output_path = f"/workspace/adaptshield/checkpoints/{output_subdir}"
|
| 113 |
+
summary_path = "/workspace/adaptshield/data/adaptshield_sft_worldsplit.summary.json"
|
| 114 |
+
extra_train_flags = "--skip-reward-curve" if args.skip_reward_curve else ""
|
| 115 |
+
|
| 116 |
+
return f"""
|
| 117 |
+
set -euo pipefail
|
| 118 |
+
export TRANSFORMERS_NO_ADVISORY_WARNINGS=1
|
| 119 |
+
export PYTHONWARNINGS="ignore::FutureWarning"
|
| 120 |
+
export HF_HUB_ENABLE_HF_TRANSFER=1
|
| 121 |
+
export PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 122 |
+
|
| 123 |
+
python - <<'PY'
|
| 124 |
+
import torch
|
| 125 |
+
print(f"baseline torch={{torch.__version__}}, cuda={{torch.version.cuda}}")
|
| 126 |
+
PY
|
| 127 |
+
|
| 128 |
+
apt-get update -qq
|
| 129 |
+
apt-get install -y -qq git
|
| 130 |
+
if [ ! -d /workspace/adaptshield/.git ]; then
|
| 131 |
+
rm -rf /workspace/adaptshield
|
| 132 |
+
git clone --depth 1 {shlex.quote(repo_url)} /workspace/adaptshield
|
| 133 |
+
fi
|
| 134 |
+
cd /workspace/adaptshield
|
| 135 |
+
python -m pip install --upgrade pip wheel setuptools
|
| 136 |
+
# ninja+packaging let any source-built dep that DOES sneak in compile cleanly.
|
| 137 |
+
pip install --upgrade ninja packaging
|
| 138 |
+
pip install -e .
|
| 139 |
+
pip uninstall -y torchaudio || true
|
| 140 |
+
|
| 141 |
+
# Unsloth ships CUDA/torch-pinned extras (cu124 + torch 2.6.0 + xformers+triton wheels).
|
| 142 |
+
# We deliberately use `cu124-torch260` (NOT the `ampere` variant) because:
|
| 143 |
+
# * cu124-torch260 pins torch 2.6 + xformers + triton via prebuilt wheels (no source builds).
|
| 144 |
+
# * cu124-ampere-torch260 ALSO tries to install flash-attn; if its prebuilt wheel URL doesn't
|
| 145 |
+
# match the image's python/cxx11abi exactly, pip falls through to source-building flash-attn
|
| 146 |
+
# (10-30 min, often fails with "ModuleNotFoundError: No module named 'torch'" because PEP 517
|
| 147 |
+
# build isolation hides torch). Unsloth's xformers/triton attention is plenty fast on L4.
|
| 148 |
+
# `unsloth[cu124-torch260]` transitively installs `unsloth[huggingface]` which pins ALL of
|
| 149 |
+
# transformers / trl / peft / accelerate / datasets / bitsandbytes / tokenizers / safetensors
|
| 150 |
+
# to versions Unsloth has tested together. Do NOT add a `--no-deps` override on top of this —
|
| 151 |
+
# previous attempts to do so downgraded peft/trl below what Unsloth requires.
|
| 152 |
+
# --no-build-isolation lets any incidental source build (e.g. a stray dep) see system torch.
|
| 153 |
+
pip install --upgrade --no-build-isolation "unsloth[cu124-torch260]"
|
| 154 |
+
|
| 155 |
+
# Pin transformers to a single known-good version. Why this is necessary:
|
| 156 |
+
# Unsloth's pyproject allows transformers >=4.51.3 ... <=5.5.0. Pip prefers the latest, so it
|
| 157 |
+
# picks 5.5.0 by default. But transformers 4.x requires huggingface-hub<1.0 while 5.x requires
|
| 158 |
+
# hub>=1.5,<2.0 — and unsloth's pyproject does NOT bound hub. So a separate `pip install hub<1.0`
|
| 159 |
+
# silently breaks transformers 5.x (and a separate `pip install hub>=1.5` silently breaks 4.x).
|
| 160 |
+
# The only robust fix is to pin transformers and let pip select the matching hub in the SAME
|
| 161 |
+
# resolution step. We pick 4.57.6 because:
|
| 162 |
+
# * latest 4.x release on PyPI (so qwen3, etc. are supported);
|
| 163 |
+
# * not on Unsloth's blocklist (4.57.0/.4/.5 are; 4.57.6 is fine);
|
| 164 |
+
# * pulls huggingface-hub<1.0 automatically (no separate hub pin needed).
|
| 165 |
+
pip install "transformers==4.57.6"
|
| 166 |
+
|
| 167 |
+
# torchao comes preinstalled in the base image at a version that requires torch 2.7+
|
| 168 |
+
# (it calls torch.utils._pytree.register_constant which doesn't exist in torch 2.6, so
|
| 169 |
+
# `import torchao` crashes with AttributeError). transformers' quantizer registry imports
|
| 170 |
+
# torchao unconditionally if it's installed (`is_torchao_available()` only checks package
|
| 171 |
+
# metadata, not import-ability). With torchao GONE, that check returns False and transformers
|
| 172 |
+
# skips torchao cleanly. We don't use torchao quantization anyway — we use bitsandbytes 4-bit.
|
| 173 |
+
pip uninstall -y torchao || true
|
| 174 |
+
|
| 175 |
+
# Optional helpers we use directly (matplotlib for plots, hf_transfer for fast download/upload).
|
| 176 |
+
pip install --upgrade matplotlib hf_transfer
|
| 177 |
+
|
| 178 |
+
# Hard guard: if torch was upgraded, bitsandbytes will fail at import; fail FAST with a clear log.
|
| 179 |
+
python - <<'PY'
|
| 180 |
+
import sys, torch
|
| 181 |
+
if not torch.__version__.startswith("2.6."):
|
| 182 |
+
print(f"FATAL: torch was upgraded to {{torch.__version__}}; aborting before training.")
|
| 183 |
+
sys.exit(2)
|
| 184 |
+
print(f"torch ok: {{torch.__version__}} cuda={{torch.version.cuda}}")
|
| 185 |
+
PY
|
| 186 |
+
|
| 187 |
+
# Smoke-test the actual modules we use. unsloth MUST import before transformers/trl
|
| 188 |
+
# per its own warning. Importing transformers also triggers its OWN runtime version check on
|
| 189 |
+
# huggingface_hub and tokenizers, AND eagerly imports any installed quantizer backend
|
| 190 |
+
# (torchao, bnb, etc.) — so if anything is mis-pinned this line raises a clear error before
|
| 191 |
+
# training starts.
|
| 192 |
+
python - <<'PY'
|
| 193 |
+
import sys, importlib.util
|
| 194 |
+
|
| 195 |
+
# Pre-flight: torchao must be GONE (preinstalled version requires torch>=2.7 and crashes
|
| 196 |
+
# `import torchao` on torch 2.6). If it leaked back in, fail with a precise message.
|
| 197 |
+
if importlib.util.find_spec("torchao") is not None:
|
| 198 |
+
print("FATAL: torchao is installed; on torch 2.6 it crashes transformers at import. "
|
| 199 |
+
"Run `pip uninstall -y torchao` and rebuild.")
|
| 200 |
+
sys.exit(2)
|
| 201 |
+
|
| 202 |
+
import unsloth # noqa: F401 (must be first)
|
| 203 |
+
import torch, transformers, trl, peft, datasets, bitsandbytes, huggingface_hub
|
| 204 |
+
print(
|
| 205 |
+
f"unsloth={{unsloth.__version__}} transformers={{transformers.__version__}} "
|
| 206 |
+
f"trl={{trl.__version__}} peft={{peft.__version__}} bnb={{bitsandbytes.__version__}} "
|
| 207 |
+
f"hub={{huggingface_hub.__version__}} datasets={{datasets.__version__}}"
|
| 208 |
+
)
|
| 209 |
+
expected_transformers = "4.57.6"
|
| 210 |
+
if transformers.__version__ != expected_transformers:
|
| 211 |
+
print(
|
| 212 |
+
f"FATAL: transformers={{transformers.__version__}} but pinned to {{expected_transformers}}. "
|
| 213 |
+
f"Pip resolution drifted; aborting before training."
|
| 214 |
+
)
|
| 215 |
+
sys.exit(2)
|
| 216 |
+
import train, train_sft, generate_sft_data # noqa: F401
|
| 217 |
+
print("Dependency smoke check passed.")
|
| 218 |
+
PY
|
| 219 |
+
|
| 220 |
+
python generate_sft_data.py \\
|
| 221 |
+
--task all \\
|
| 222 |
+
--curriculum \\
|
| 223 |
+
--use-tools \\
|
| 224 |
+
--episodes {args.dataset_episodes} \\
|
| 225 |
+
--max-steps {args.max_steps} \\
|
| 226 |
+
--seed {args.seed} \\
|
| 227 |
+
--world-split train \\
|
| 228 |
+
--output {dataset_path}
|
| 229 |
+
|
| 230 |
+
python train_sft.py \\
|
| 231 |
+
--dataset {dataset_path} \\
|
| 232 |
+
--model {args.model} \\
|
| 233 |
+
--epochs {args.epochs} \\
|
| 234 |
+
--lr {args.lr} \\
|
| 235 |
+
--per-device-batch-size {args.per_device_batch_size} \\
|
| 236 |
+
--gradient-accumulation-steps {args.gradient_accumulation_steps} \\
|
| 237 |
+
--save-steps {args.save_steps} \\
|
| 238 |
+
--heldout-seed {args.heldout_seed} \\
|
| 239 |
+
--train-world-split train \\
|
| 240 |
+
--heldout-world-split eval \\
|
| 241 |
+
--eval-task all \\
|
| 242 |
+
--eval-episodes {args.eval_episodes} \\
|
| 243 |
+
--use-tools \\
|
| 244 |
+
--output {output_path} \\
|
| 245 |
+
{extra_train_flags}
|
| 246 |
+
|
| 247 |
+
python - <<'PY'
|
| 248 |
+
import os
|
| 249 |
+
import time
|
| 250 |
+
from huggingface_hub import HfApi
|
| 251 |
+
|
| 252 |
+
api = HfApi(token=os.environ["HF_TOKEN"])
|
| 253 |
+
repo_id = os.environ["RUNS_REPO"]
|
| 254 |
+
repo_type = os.environ["RUNS_REPO_TYPE"]
|
| 255 |
+
output_dir = {output_path!r}
|
| 256 |
+
summary_path = {summary_path!r}
|
| 257 |
+
subdir = {output_subdir!r}
|
| 258 |
+
|
| 259 |
+
last_exc = None
|
| 260 |
+
for attempt in range(4):
|
| 261 |
+
try:
|
| 262 |
+
api.upload_folder(
|
| 263 |
+
repo_id=repo_id,
|
| 264 |
+
repo_type=repo_type,
|
| 265 |
+
folder_path=output_dir,
|
| 266 |
+
path_in_repo=subdir,
|
| 267 |
+
)
|
| 268 |
+
api.upload_file(
|
| 269 |
+
repo_id=repo_id,
|
| 270 |
+
repo_type=repo_type,
|
| 271 |
+
path_or_fileobj=summary_path,
|
| 272 |
+
path_in_repo=f"{{subdir}}/adaptshield_sft_worldsplit.summary.json",
|
| 273 |
+
)
|
| 274 |
+
last_exc = None
|
| 275 |
+
break
|
| 276 |
+
except Exception as exc:
|
| 277 |
+
last_exc = exc
|
| 278 |
+
response = getattr(exc, "response", None)
|
| 279 |
+
status_code = getattr(response, "status_code", None)
|
| 280 |
+
if status_code == 429 or (isinstance(status_code, int) and 500 <= status_code < 600):
|
| 281 |
+
sleep_for = 2 ** attempt
|
| 282 |
+
print(f"Transient upload error: {{exc}}; retrying in {{sleep_for}}s")
|
| 283 |
+
time.sleep(sleep_for)
|
| 284 |
+
continue
|
| 285 |
+
raise
|
| 286 |
+
if last_exc is not None:
|
| 287 |
+
raise last_exc
|
| 288 |
+
print("Uploaded artifacts to", repo_id)
|
| 289 |
+
PY
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def main() -> int:
|
| 294 |
+
parser = argparse.ArgumentParser(description="Launch AdaptShield SFT training on Hugging Face Jobs")
|
| 295 |
+
parser.add_argument("--runs-repo", required=True, help="Artifact repo to upload outputs to, e.g. username/adaptshield-runs")
|
| 296 |
+
parser.add_argument("--runs-repo-type", default="dataset", choices=["dataset", "model"], help="Repo type used to store training artifacts.")
|
| 297 |
+
parser.add_argument("--skip-create", action="store_true", help="Skip repo creation and assume the artifacts repo already exists.")
|
| 298 |
+
parser.add_argument("--allow-cross-namespace", action="store_true", help="Allow uploads to a repo owned by a different namespace than the authenticated account.")
|
| 299 |
+
parser.add_argument("--repo-url", default=None, help="Git repo URL to clone inside the HF Job. Defaults to remote.origin.url")
|
| 300 |
+
parser.add_argument("--model", default="1.5b", choices=list(MODEL_CHOICES))
|
| 301 |
+
parser.add_argument("--flavor", default="l4x1", help="HF Jobs hardware flavor, e.g. l4x1, a10g-small, a100-large")
|
| 302 |
+
parser.add_argument("--timeout", default="6h", help="HF Jobs timeout, e.g. 6h")
|
| 303 |
+
parser.add_argument("--dataset-episodes", type=int, default=240)
|
| 304 |
+
parser.add_argument("--max-steps", type=int, default=20)
|
| 305 |
+
parser.add_argument("--epochs", type=float, default=1.0)
|
| 306 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 307 |
+
parser.add_argument("--per-device-batch-size", type=int, default=2)
|
| 308 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 309 |
+
parser.add_argument("--save-steps", type=int, default=40)
|
| 310 |
+
parser.add_argument("--eval-episodes", type=int, default=2)
|
| 311 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 312 |
+
parser.add_argument("--heldout-seed", type=int, default=314)
|
| 313 |
+
parser.add_argument(
|
| 314 |
+
"--skip-reward-curve",
|
| 315 |
+
action="store_true",
|
| 316 |
+
help="Skip the per-checkpoint held-out reward sweep inside train_sft.py.",
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument("--output-subdir", default=None, help="Optional output folder name in the runs dataset repo")
|
| 319 |
+
args = parser.parse_args()
|
| 320 |
+
|
| 321 |
+
token = get_token()
|
| 322 |
+
if not token:
|
| 323 |
+
raise RuntimeError("No Hugging Face token found. Run `hf auth login` first.")
|
| 324 |
+
|
| 325 |
+
repo_url = args.repo_url or infer_repo_url()
|
| 326 |
+
output_subdir = args.output_subdir or f"sft_worldsplit_{args.model.replace('.', '_')}"
|
| 327 |
+
|
| 328 |
+
api = HfApi(token=token)
|
| 329 |
+
validate_artifact_repo(
|
| 330 |
+
api,
|
| 331 |
+
args.runs_repo,
|
| 332 |
+
args.runs_repo_type,
|
| 333 |
+
args.skip_create,
|
| 334 |
+
args.allow_cross_namespace,
|
| 335 |
+
)
|
| 336 |
+
if not args.skip_create:
|
| 337 |
+
_retry_hf_call(api.create_repo, repo_id=args.runs_repo, repo_type=args.runs_repo_type, private=True, exist_ok=True)
|
| 338 |
+
|
| 339 |
+
command = build_command(args=args, repo_url=repo_url, output_subdir=output_subdir)
|
| 340 |
+
job = _retry_hf_call(
|
| 341 |
+
run_job,
|
| 342 |
+
image=DEFAULT_IMAGE,
|
| 343 |
+
command=["bash", "-lc", command],
|
| 344 |
+
flavor=args.flavor,
|
| 345 |
+
timeout=args.timeout,
|
| 346 |
+
namespace=repo_namespace(args.runs_repo),
|
| 347 |
+
env={
|
| 348 |
+
"RUNS_REPO": args.runs_repo,
|
| 349 |
+
"RUNS_REPO_TYPE": args.runs_repo_type,
|
| 350 |
+
},
|
| 351 |
+
secrets={"HF_TOKEN": token},
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
print("Job launched successfully.")
|
| 355 |
+
print(f"Job ID: {job.id}")
|
| 356 |
+
print(f"Job URL: {job.url}")
|
| 357 |
+
print(f"Artifacts repo: {args.runs_repo}")
|
| 358 |
+
print(f"Artifacts path: {output_subdir}")
|
| 359 |
+
return 0
|
| 360 |
+
|
| 361 |
+
|
| 362 |
+
if __name__ == "__main__":
|
| 363 |
+
raise SystemExit(main())
|
models.py
ADDED
|
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# AdaptShield — Pydantic Data Models
|
| 5 |
+
#
|
| 6 |
+
# CRITICAL DESIGN DECISION: Phase1Action and Phase2Action are SEPARATE classes.
|
| 7 |
+
# A single combined class with optional fields causes 500 errors when the
|
| 8 |
+
# evaluator sends a Phase 2 payload and Pydantic tries to validate Phase 1 fields.
|
| 9 |
+
|
| 10 |
+
from enum import Enum
|
| 11 |
+
from typing import Any, Dict, List, Optional
|
| 12 |
+
|
| 13 |
+
from openenv.core.env_server.types import Action, Observation
|
| 14 |
+
from pydantic import Field, model_validator
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DefenseAction(str, Enum):
|
| 18 |
+
"""
|
| 19 |
+
Strict action space for the Tactical Executor (Phase 2).
|
| 20 |
+
Using Enum prevents LLM hallucination from reaching the grader.
|
| 21 |
+
"""
|
| 22 |
+
RATE_LIMIT = "rate_limit" # Light — throttles traffic, keeps service online
|
| 23 |
+
ISOLATE = "isolate" # Heavy — takes node offline, stops spread
|
| 24 |
+
HONEYPOT = "honeypot" # Strategic — redirects attacker to decoy
|
| 25 |
+
PATCH = "patch" # Targeted — fixes supply chain vulnerability
|
| 26 |
+
MONITOR = "monitor" # Passive — gather info, risk escalation
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ThreatType(str, Enum):
|
| 30 |
+
"""Known attack strategies the Threat Analyst can classify."""
|
| 31 |
+
BRUTE_FORCE = "brute_force"
|
| 32 |
+
LATERAL_MOVEMENT = "lateral_movement"
|
| 33 |
+
EXFILTRATION = "exfiltration"
|
| 34 |
+
SUPPLY_CHAIN = "supply_chain"
|
| 35 |
+
BENIGN = "benign"
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class Phase1Action(Action):
|
| 39 |
+
"""
|
| 40 |
+
Threat Analyst output — pure reasoning, no defensive action.
|
| 41 |
+
|
| 42 |
+
The agent reads raw network state and produces a structured
|
| 43 |
+
threat assessment. This is graded independently for classification
|
| 44 |
+
accuracy before Phase 2 acts on it.
|
| 45 |
+
"""
|
| 46 |
+
threat_type: str = Field(
|
| 47 |
+
...,
|
| 48 |
+
description="Identified attack strategy: brute_force, lateral_movement, "
|
| 49 |
+
"exfiltration, supply_chain, or benign",
|
| 50 |
+
)
|
| 51 |
+
confidence: float = Field(
|
| 52 |
+
...,
|
| 53 |
+
ge=0.0,
|
| 54 |
+
le=1.0,
|
| 55 |
+
description="Confidence in the threat classification (0.0 to 1.0)",
|
| 56 |
+
)
|
| 57 |
+
target_node: str = Field(
|
| 58 |
+
...,
|
| 59 |
+
description="Primary affected node: auth_service, payment_service, "
|
| 60 |
+
"database, or api_gateway",
|
| 61 |
+
)
|
| 62 |
+
recommended_action: DefenseAction = Field(
|
| 63 |
+
...,
|
| 64 |
+
description="Recommended defense action for Phase 2 to execute",
|
| 65 |
+
)
|
| 66 |
+
reasoning: Optional[str] = Field(
|
| 67 |
+
default=None,
|
| 68 |
+
description="Chain of thought. Not graded. Helps training stability.",
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class Phase2Action(Action):
|
| 73 |
+
"""
|
| 74 |
+
Tactical Executor output — defensive action based ONLY on Phase 1 assessment.
|
| 75 |
+
|
| 76 |
+
Phase 2 agent is deliberately blind to raw network state.
|
| 77 |
+
It receives only the Phase 1 threat assessment and must act on it.
|
| 78 |
+
"""
|
| 79 |
+
action: DefenseAction = Field(
|
| 80 |
+
...,
|
| 81 |
+
description="Defense action to execute",
|
| 82 |
+
)
|
| 83 |
+
target_node: str = Field(
|
| 84 |
+
...,
|
| 85 |
+
description="Node to apply action to: auth_service, payment_service, "
|
| 86 |
+
"database, or api_gateway",
|
| 87 |
+
)
|
| 88 |
+
reasoning: Optional[str] = Field(
|
| 89 |
+
default=None,
|
| 90 |
+
description="Chain of thought. Not graded.",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class AdaptShieldAction(Action):
|
| 95 |
+
"""
|
| 96 |
+
Unified action model accepted by the OpenEnv HTTP server.
|
| 97 |
+
|
| 98 |
+
The environment alternates between two phases, so the transport layer must
|
| 99 |
+
accept either a Threat Analyst payload or a Tactical Executor payload.
|
| 100 |
+
Validation keeps those shapes distinct while still fitting the single
|
| 101 |
+
action model expected by `create_app`.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
threat_type: Optional[str] = Field(
|
| 105 |
+
default=None,
|
| 106 |
+
description="Phase 1 only: identified attack strategy",
|
| 107 |
+
)
|
| 108 |
+
confidence: Optional[float] = Field(
|
| 109 |
+
default=None,
|
| 110 |
+
ge=0.0,
|
| 111 |
+
le=1.0,
|
| 112 |
+
description="Phase 1 only: confidence in the threat classification",
|
| 113 |
+
)
|
| 114 |
+
target_node: Optional[str] = Field(
|
| 115 |
+
default=None,
|
| 116 |
+
description="Target node for either phase",
|
| 117 |
+
)
|
| 118 |
+
recommended_action: Optional[DefenseAction] = Field(
|
| 119 |
+
default=None,
|
| 120 |
+
description="Phase 1 only: recommended follow-up action",
|
| 121 |
+
)
|
| 122 |
+
action: Optional[DefenseAction] = Field(
|
| 123 |
+
default=None,
|
| 124 |
+
description="Phase 2 only: defensive action to execute",
|
| 125 |
+
)
|
| 126 |
+
reasoning: Optional[str] = Field(
|
| 127 |
+
default=None,
|
| 128 |
+
description="Optional one-sentence rationale",
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
@model_validator(mode="after")
|
| 132 |
+
def validate_phase_shape(self) -> "AdaptShieldAction":
|
| 133 |
+
phase1_present = any(
|
| 134 |
+
value is not None
|
| 135 |
+
for value in (self.threat_type, self.confidence, self.recommended_action)
|
| 136 |
+
)
|
| 137 |
+
phase2_present = self.action is not None
|
| 138 |
+
|
| 139 |
+
if phase1_present and phase2_present:
|
| 140 |
+
raise ValueError(
|
| 141 |
+
"Action payload must be either Phase 1 or Phase 2, not both."
|
| 142 |
+
)
|
| 143 |
+
if not phase1_present and not phase2_present:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
"Action payload must contain Phase 1 fields or a Phase 2 action."
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
if phase1_present:
|
| 149 |
+
missing = [
|
| 150 |
+
field_name
|
| 151 |
+
for field_name, value in (
|
| 152 |
+
("threat_type", self.threat_type),
|
| 153 |
+
("confidence", self.confidence),
|
| 154 |
+
("target_node", self.target_node),
|
| 155 |
+
("recommended_action", self.recommended_action),
|
| 156 |
+
)
|
| 157 |
+
if value is None
|
| 158 |
+
]
|
| 159 |
+
else:
|
| 160 |
+
missing = [
|
| 161 |
+
field_name
|
| 162 |
+
for field_name, value in (
|
| 163 |
+
("action", self.action),
|
| 164 |
+
("target_node", self.target_node),
|
| 165 |
+
)
|
| 166 |
+
if value is None
|
| 167 |
+
]
|
| 168 |
+
|
| 169 |
+
if missing:
|
| 170 |
+
raise ValueError(
|
| 171 |
+
f"Missing required fields for this phase: {', '.join(missing)}"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
return self
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class AdaptShieldObservation(Observation):
|
| 178 |
+
"""
|
| 179 |
+
Observation returned after each step.
|
| 180 |
+
|
| 181 |
+
Phase 1 observation: contains full network state (network_nodes, active_alerts).
|
| 182 |
+
Phase 2 observation: network_nodes and active_alerts are EMPTY.
|
| 183 |
+
phase1_assessment contains the Phase 1 output.
|
| 184 |
+
|
| 185 |
+
Episode number is NEVER included — agent must rely on signals only.
|
| 186 |
+
"""
|
| 187 |
+
|
| 188 |
+
# Identity
|
| 189 |
+
scenario_id: str = Field(default="")
|
| 190 |
+
task_name: str = Field(default="")
|
| 191 |
+
phase: int = Field(default=1,
|
| 192 |
+
description="1 = Threat Analyst turn, 2 = Tactical Executor turn")
|
| 193 |
+
turn: int = Field(default=0)
|
| 194 |
+
max_turns: int = Field(default=5)
|
| 195 |
+
|
| 196 |
+
# Network state — populated in Phase 1, EMPTY in Phase 2
|
| 197 |
+
network_nodes: Dict[str, Any] = Field(default_factory=dict)
|
| 198 |
+
active_alerts: List[str] = Field(default_factory=list)
|
| 199 |
+
attack_stage: str = Field(
|
| 200 |
+
default="none",
|
| 201 |
+
description="Current attack progression stage: recon, exploit, exfiltration, none",
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Rolling history of last 3 turns
|
| 205 |
+
history: List[Dict[str, str]] = Field(default_factory=list)
|
| 206 |
+
|
| 207 |
+
# Phase 2 only — Phase 1 output passed to executor
|
| 208 |
+
phase1_assessment: Optional[Dict[str, Any]] = Field(
|
| 209 |
+
default=None,
|
| 210 |
+
description="Populated only in Phase 2. Phase 2 agent sees ONLY this.",
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
# Context
|
| 214 |
+
system_context: str = Field(default="")
|
| 215 |
+
available_actions: List[str] = Field(default_factory=list)
|
| 216 |
+
|
| 217 |
+
# Feedback
|
| 218 |
+
last_action_result: Optional[str] = Field(default=None)
|
| 219 |
+
reward: float = Field(default=0.0)
|
| 220 |
+
done: bool = Field(default=False)
|
| 221 |
+
metadata: Dict[str, Any] = Field(default_factory=dict)
|
| 222 |
+
|
| 223 |
+
def model_dump(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
|
| 224 |
+
"""
|
| 225 |
+
Keep metadata in OpenEnv HTTP observation payloads.
|
| 226 |
+
|
| 227 |
+
OpenEnv's serializer excludes metadata from the nested observation by
|
| 228 |
+
default. AdaptShield exposes normalized_score there, so we remove only
|
| 229 |
+
that exclusion while preserving the serializer's reward/done handling.
|
| 230 |
+
"""
|
| 231 |
+
exclude = kwargs.get("exclude")
|
| 232 |
+
if isinstance(exclude, set) and "metadata" in exclude:
|
| 233 |
+
kwargs["exclude"] = set(exclude) - {"metadata"}
|
| 234 |
+
elif isinstance(exclude, dict) and "metadata" in exclude:
|
| 235 |
+
kwargs["exclude"] = {
|
| 236 |
+
key: value for key, value in exclude.items() if key != "metadata"
|
| 237 |
+
}
|
| 238 |
+
return super().model_dump(*args, **kwargs)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
# Backward-compatible aliases for earlier package names.
|
| 242 |
+
AdaptshieldAction = AdaptShieldAction
|
| 243 |
+
AdaptshieldObservation = AdaptShieldObservation
|
openenv.yaml
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
spec_version: 1
|
| 2 |
+
name: adaptshield
|
| 3 |
+
type: space
|
| 4 |
+
runtime: fastapi
|
| 5 |
+
app: server.app:app
|
| 6 |
+
port: 7860
|
| 7 |
+
description: >
|
| 8 |
+
AdaptShield is a two-phase agentic cybersecurity environment that trains
|
| 9 |
+
LLMs to adapt to polymorphic adversarial strategies. An agent acts as
|
| 10 |
+
Threat Analyst (Phase 1) then Tactical Executor (Phase 2), defending a
|
| 11 |
+
simulated 4-node enterprise network against a scripted attacker that progresses
|
| 12 |
+
through attack stages and shifts strategy mid-episode. Grading is fully
|
| 13 |
+
deterministic via Python strategy matching. No LLM-as-judge components.
|
| 14 |
+
tasks:
|
| 15 |
+
- name: direct-triage
|
| 16 |
+
difficulty: easy
|
| 17 |
+
description: Single fixed strategy. Agent learns baseline threat response.
|
| 18 |
+
max_steps: 5
|
| 19 |
+
- name: dual-pivot
|
| 20 |
+
difficulty: medium
|
| 21 |
+
description: Two strategies alternating every 20 episodes. Detect and adapt.
|
| 22 |
+
max_steps: 6
|
| 23 |
+
- name: polymorphic-zero-day
|
| 24 |
+
difficulty: hard
|
| 25 |
+
description: All four strategies with mid-episode shift and false-positive noise.
|
| 26 |
+
max_steps: 8
|
plot_sft_checkpoint_curve.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Plot an SFT checkpoint curve with an optional honest baseline start point."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def load_json(path: Path) -> dict[str, Any]:
|
| 13 |
+
return json.loads(path.read_text(encoding="utf-8"))
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def mean_baseline(benchmark: dict[str, Any], key: str) -> float:
|
| 17 |
+
values = benchmark.get(key, {})
|
| 18 |
+
numeric = [float(value) for value in values.values() if value is not None]
|
| 19 |
+
if not numeric:
|
| 20 |
+
raise ValueError(f"No numeric values found under benchmark key '{key}'")
|
| 21 |
+
return sum(numeric) / len(numeric)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_args() -> argparse.Namespace:
|
| 25 |
+
parser = argparse.ArgumentParser(description="Plot SFT checkpoint learning curve with optional baseline point.")
|
| 26 |
+
parser.add_argument("--metrics", required=True, help="Path to sft_metrics.json")
|
| 27 |
+
parser.add_argument("--output", required=True, help="Where to write the PNG")
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--baseline-json",
|
| 30 |
+
default="",
|
| 31 |
+
help="Optional benchmark_table.json path used to prepend a real baseline point.",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--baseline-key",
|
| 35 |
+
default="tool_baseline",
|
| 36 |
+
choices=["tool_baseline", "no_tool_baseline"],
|
| 37 |
+
help="Which benchmark JSON field to average for the prepended baseline point.",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--baseline-label",
|
| 41 |
+
default="baseline",
|
| 42 |
+
help="X-axis label for the prepended baseline point.",
|
| 43 |
+
)
|
| 44 |
+
return parser.parse_args()
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def main() -> int:
|
| 48 |
+
args = parse_args()
|
| 49 |
+
|
| 50 |
+
metrics = load_json(Path(args.metrics))
|
| 51 |
+
rows = metrics.get("reward_curve_rows", []) or []
|
| 52 |
+
if not rows:
|
| 53 |
+
raise SystemExit("No reward_curve_rows found in the provided SFT metrics file.")
|
| 54 |
+
|
| 55 |
+
labels = [str(row["checkpoint"]) for row in rows]
|
| 56 |
+
train_scores = [float(row["in_distribution_score"]) for row in rows]
|
| 57 |
+
heldout_scores = [float(row["heldout_score"]) for row in rows]
|
| 58 |
+
|
| 59 |
+
if args.baseline_json:
|
| 60 |
+
benchmark = load_json(Path(args.baseline_json))
|
| 61 |
+
baseline_value = mean_baseline(benchmark, args.baseline_key)
|
| 62 |
+
labels = [args.baseline_label] + labels
|
| 63 |
+
train_scores = [baseline_value] + train_scores
|
| 64 |
+
heldout_scores = [baseline_value] + heldout_scores
|
| 65 |
+
|
| 66 |
+
try:
|
| 67 |
+
import matplotlib
|
| 68 |
+
matplotlib.use("Agg")
|
| 69 |
+
import matplotlib.pyplot as plt
|
| 70 |
+
except ImportError as exc:
|
| 71 |
+
raise SystemExit(f"matplotlib is required to plot this curve: {exc}") from exc
|
| 72 |
+
|
| 73 |
+
output_path = Path(args.output)
|
| 74 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
plt.figure(figsize=(11, 5))
|
| 77 |
+
plt.plot(labels, train_scores, marker="o", linewidth=2, color="#174c7a", label="train family")
|
| 78 |
+
plt.plot(labels, heldout_scores, marker="s", linewidth=2, color="#6d4acb", label="held-out family")
|
| 79 |
+
plt.title("Janus SFT Checkpoint Learning Curve")
|
| 80 |
+
plt.xlabel("Checkpoint")
|
| 81 |
+
plt.ylabel("normalized_score")
|
| 82 |
+
plt.ylim(0.0, 1.0)
|
| 83 |
+
plt.grid(alpha=0.25)
|
| 84 |
+
plt.legend()
|
| 85 |
+
plt.xticks(rotation=30, ha="right")
|
| 86 |
+
plt.tight_layout()
|
| 87 |
+
plt.savefig(output_path, dpi=160)
|
| 88 |
+
print(output_path)
|
| 89 |
+
return 0
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
raise SystemExit(main())
|
plot_training.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Plot AdaptShield training CSV or metrics JSON."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import csv
|
| 8 |
+
import json
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import List, Tuple
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_scores(path: Path) -> Tuple[List[int], List[float], str, List[str]]:
|
| 14 |
+
if path.suffix == ".json":
|
| 15 |
+
data = json.loads(path.read_text())
|
| 16 |
+
rows = data.get("rows", []) or data.get("evaluation_rows", [])
|
| 17 |
+
episodes = [int(row["episode"]) for row in rows]
|
| 18 |
+
scores = [float(row["score"]) for row in rows]
|
| 19 |
+
stages = [str(row.get("stage", row.get("task", ""))) for row in rows]
|
| 20 |
+
return episodes, scores, str(data.get("model", "adaptshield")), stages
|
| 21 |
+
|
| 22 |
+
with path.open() as handle:
|
| 23 |
+
rows = list(csv.DictReader(handle))
|
| 24 |
+
episodes = [int(row["episode"]) for row in rows]
|
| 25 |
+
scores = [float(row["score"]) for row in rows]
|
| 26 |
+
stages = [str(row.get("stage", row.get("task", ""))) for row in rows]
|
| 27 |
+
return episodes, scores, "adaptshield-smoke", stages
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def moving_average(values: List[float], window: int) -> List[float]:
|
| 31 |
+
smoothed = []
|
| 32 |
+
for index in range(len(values)):
|
| 33 |
+
start = max(0, index - window + 1)
|
| 34 |
+
chunk = values[start:index + 1]
|
| 35 |
+
smoothed.append(sum(chunk) / len(chunk))
|
| 36 |
+
return smoothed
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def plot(path: Path, output: Path) -> None:
|
| 40 |
+
episodes, scores, label, stages = load_scores(path)
|
| 41 |
+
if not scores:
|
| 42 |
+
raise SystemExit("No scores found to plot.")
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
import matplotlib
|
| 46 |
+
matplotlib.use("Agg")
|
| 47 |
+
import matplotlib.pyplot as plt
|
| 48 |
+
except ImportError:
|
| 49 |
+
first = sum(scores[:max(1, len(scores) // 5)]) / max(1, len(scores) // 5)
|
| 50 |
+
last = sum(scores[-max(1, len(scores) // 5):]) / max(1, len(scores) // 5)
|
| 51 |
+
print("matplotlib is not installed; skipping PNG generation.")
|
| 52 |
+
print(f"Episodes: {len(scores)}")
|
| 53 |
+
print(f"First-window avg: {first:.3f}")
|
| 54 |
+
print(f"Last-window avg: {last:.3f}")
|
| 55 |
+
print(f"Delta: {last - first:+.3f}")
|
| 56 |
+
return
|
| 57 |
+
|
| 58 |
+
window = max(1, min(10, len(scores) // 5))
|
| 59 |
+
smoothed = moving_average(scores, window)
|
| 60 |
+
|
| 61 |
+
output.parent.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
| 63 |
+
ax.plot(episodes, scores, color="#6b8fbf", alpha=0.35, label="raw score")
|
| 64 |
+
ax.plot(episodes, smoothed, color="#123c69", linewidth=2.5, label=f"{window}-episode avg")
|
| 65 |
+
for episode, stage in stage_boundaries(episodes, stages):
|
| 66 |
+
ax.axvline(episode, color="#c44e52", linestyle="--", alpha=0.45)
|
| 67 |
+
ax.text(episode, 0.04, stage.replace("curriculum:", ""), rotation=90, fontsize=8, color="#7a1f24")
|
| 68 |
+
ax.set_title(f"AdaptShield Training Curve ({label})")
|
| 69 |
+
ax.set_xlabel("Episode")
|
| 70 |
+
ax.set_ylabel("normalized_score")
|
| 71 |
+
ax.set_ylim(0.0, 1.0)
|
| 72 |
+
ax.grid(alpha=0.25)
|
| 73 |
+
ax.legend()
|
| 74 |
+
fig.tight_layout()
|
| 75 |
+
fig.savefig(output, dpi=160)
|
| 76 |
+
print(f"Saved plot: {output}")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def stage_boundaries(episodes: List[int], stages: List[str]) -> List[Tuple[int, str]]:
|
| 80 |
+
if not stages:
|
| 81 |
+
return []
|
| 82 |
+
|
| 83 |
+
boundaries = []
|
| 84 |
+
previous = stages[0]
|
| 85 |
+
for episode, stage in zip(episodes, stages):
|
| 86 |
+
if stage != previous:
|
| 87 |
+
boundaries.append((episode, stage))
|
| 88 |
+
previous = stage
|
| 89 |
+
return boundaries
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def parse_args() -> argparse.Namespace:
|
| 93 |
+
parser = argparse.ArgumentParser(description="Plot AdaptShield training output.")
|
| 94 |
+
parser.add_argument("--input", default="training_runs/train_smoke.csv")
|
| 95 |
+
parser.add_argument("--output", default="training_runs/reward_curve.png")
|
| 96 |
+
return parser.parse_args()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def main() -> int:
|
| 100 |
+
args = parse_args()
|
| 101 |
+
plot(Path(args.input), Path(args.output))
|
| 102 |
+
return 0
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
raise SystemExit(main())
|
pyproject.toml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=45", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv-adaptshield"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "AdaptShield: Two-Phase Adaptive Cybersecurity RL Environment"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core[core]>=0.2.2",
|
| 12 |
+
"fastapi>=0.111.0",
|
| 13 |
+
"openai>=1.0.0",
|
| 14 |
+
"uvicorn>=0.24.0",
|
| 15 |
+
"pydantic>=2.0.0",
|
| 16 |
+
"python-dotenv>=1.0.0",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[project.optional-dependencies]
|
| 20 |
+
dev = [
|
| 21 |
+
"pytest>=8.0.0",
|
| 22 |
+
"pytest-cov>=4.0.0",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.scripts]
|
| 26 |
+
server = "adaptshield.server.app:main"
|
| 27 |
+
|
| 28 |
+
[tool.setuptools]
|
| 29 |
+
include-package-data = true
|
| 30 |
+
packages = ["adaptshield", "adaptshield.server"]
|
| 31 |
+
package-dir = { "adaptshield" = ".", "adaptshield.server" = "server" }
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ARG BASE_IMAGE=ghcr.io/meta-pytorch/openenv-base:latest
|
| 2 |
+
FROM ${BASE_IMAGE} AS builder
|
| 3 |
+
|
| 4 |
+
WORKDIR /app
|
| 5 |
+
|
| 6 |
+
RUN apt-get update && \
|
| 7 |
+
apt-get install -y --no-install-recommends git curl && \
|
| 8 |
+
rm -rf /var/lib/apt/lists/*
|
| 9 |
+
|
| 10 |
+
COPY . /app/env
|
| 11 |
+
WORKDIR /app/env
|
| 12 |
+
|
| 13 |
+
RUN if ! command -v uv >/dev/null 2>&1; then \
|
| 14 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh && \
|
| 15 |
+
mv /root/.local/bin/uv /usr/local/bin/uv && \
|
| 16 |
+
mv /root/.local/bin/uvx /usr/local/bin/uvx; \
|
| 17 |
+
fi
|
| 18 |
+
|
| 19 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 20 |
+
if [ -f uv.lock ]; then \
|
| 21 |
+
uv sync --frozen --no-install-project --no-editable; \
|
| 22 |
+
else \
|
| 23 |
+
uv sync --no-install-project --no-editable; \
|
| 24 |
+
fi
|
| 25 |
+
|
| 26 |
+
RUN --mount=type=cache,target=/root/.cache/uv \
|
| 27 |
+
if [ -f uv.lock ]; then \
|
| 28 |
+
uv sync --frozen --no-editable; \
|
| 29 |
+
else \
|
| 30 |
+
uv sync --no-editable; \
|
| 31 |
+
fi
|
| 32 |
+
|
| 33 |
+
FROM ${BASE_IMAGE}
|
| 34 |
+
WORKDIR /app
|
| 35 |
+
|
| 36 |
+
COPY --from=builder /app/env/.venv /app/.venv
|
| 37 |
+
COPY --from=builder /app/env /app/env
|
| 38 |
+
|
| 39 |
+
ENV PATH="/app/.venv/bin:$PATH"
|
| 40 |
+
ENV PYTHONPATH="/app/env:$PYTHONPATH"
|
| 41 |
+
|
| 42 |
+
EXPOSE 7860
|
| 43 |
+
|
| 44 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=30s --retries=3 \
|
| 45 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 46 |
+
|
| 47 |
+
CMD ["sh", "-c", "cd /app/env && uvicorn server.app:app --host 0.0.0.0 --port 7860"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the BSD-style license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
"""Adaptshield environment server components."""
|
| 8 |
+
|
| 9 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 10 |
+
|
| 11 |
+
__all__ = ["AdaptShieldEnvironment"]
|
server/adaptshield_environment.py
ADDED
|
@@ -0,0 +1,1324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdaptShield Environment
|
| 3 |
+
|
| 4 |
+
Two-phase agentic cybersecurity environment implementing full OpenEnv spec.
|
| 5 |
+
|
| 6 |
+
Phase 1 (Threat Analyst): Agent reads raw SIEM state, outputs threat assessment.
|
| 7 |
+
Phase 2 (Tactical Executor): Agent reads ONLY Phase 1 output, executes defense.
|
| 8 |
+
|
| 9 |
+
The attacker progresses through stages (recon→exploit→exfiltration) if agent
|
| 10 |
+
fails to act. On the hard task, strategy shifts mid-episode after turn 3.
|
| 11 |
+
|
| 12 |
+
OpenEnv compliance:
|
| 13 |
+
- reset() returns initial observation
|
| 14 |
+
- step() returns observation with reward, done, info
|
| 15 |
+
- state property returns current State
|
| 16 |
+
- SUPPORTS_CONCURRENT_SESSIONS = True
|
| 17 |
+
- normalized_score ALWAYS present in metadata
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sys
|
| 22 |
+
from enum import Enum
|
| 23 |
+
from typing import Any, Dict, List, Optional
|
| 24 |
+
from uuid import uuid4
|
| 25 |
+
|
| 26 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 27 |
+
|
| 28 |
+
from openenv.core.env_server.interfaces import Environment
|
| 29 |
+
from openenv.core.env_server.types import State
|
| 30 |
+
|
| 31 |
+
from models import AdaptShieldAction, Phase1Action, Phase2Action, AdaptShieldObservation
|
| 32 |
+
from server.attacker import AttackerEngine
|
| 33 |
+
from server.grader import grade_step, normalize_episode_score, _clamp
|
| 34 |
+
from server.scenarios import (
|
| 35 |
+
TASK_CONFIGS,
|
| 36 |
+
build_phase1_obs,
|
| 37 |
+
build_phase2_obs,
|
| 38 |
+
choose_operational_mode,
|
| 39 |
+
choose_world_family,
|
| 40 |
+
mission_profile_for,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
DEFENSE_TTL = {
|
| 45 |
+
"rate_limit": 2,
|
| 46 |
+
"isolate": 2,
|
| 47 |
+
"honeypot": 3,
|
| 48 |
+
"patch": 4,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
DEFENSE_SIDE_EFFECT = {
|
| 52 |
+
"rate_limit": "login_latency",
|
| 53 |
+
"isolate": "service_downtime",
|
| 54 |
+
"honeypot": "attacker_redirection",
|
| 55 |
+
"patch": "temporary_restart",
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
AVAILABLE_SOC_TOOLS = [
|
| 59 |
+
{
|
| 60 |
+
"name": "log_search",
|
| 61 |
+
"endpoint": "/tools/log_search",
|
| 62 |
+
"description": "Search recent SIEM/application logs for a node and time window.",
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"name": "cmdb_lookup",
|
| 66 |
+
"endpoint": "/tools/cmdb_lookup",
|
| 67 |
+
"description": "Inspect service ownership, criticality, dependencies, and blast radius.",
|
| 68 |
+
},
|
| 69 |
+
{
|
| 70 |
+
"name": "edr_status",
|
| 71 |
+
"endpoint": "/tools/edr_status",
|
| 72 |
+
"description": "Check endpoint containment, persistence, beaconing, and active controls.",
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"name": "vuln_lookup",
|
| 76 |
+
"endpoint": "/tools/vuln_lookup",
|
| 77 |
+
"description": "Query internal package/advisory risk for supply-chain investigations.",
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"name": "identity_lookup",
|
| 81 |
+
"endpoint": "/tools/identity_lookup",
|
| 82 |
+
"description": "Inspect account type, privilege level, normal host affinity, and anomalous identity use.",
|
| 83 |
+
},
|
| 84 |
+
{
|
| 85 |
+
"name": "change_calendar_lookup",
|
| 86 |
+
"endpoint": "/tools/change_calendar_lookup",
|
| 87 |
+
"description": "Check whether maintenance, deploys, or patch windows were scheduled for the target service.",
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"name": "netflow_lookup",
|
| 91 |
+
"endpoint": "/tools/netflow_lookup",
|
| 92 |
+
"description": "Inspect east-west and outbound traffic summaries for enterprise network pivots and data movement.",
|
| 93 |
+
},
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
SERVICE_OWNERS = {
|
| 97 |
+
"auth_service": "identity-platform",
|
| 98 |
+
"payment_service": "checkout-platform",
|
| 99 |
+
"database": "data-platform",
|
| 100 |
+
"api_gateway": "edge-platform",
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
IDENTITY_CONTEXT = {
|
| 104 |
+
"auth_service": {
|
| 105 |
+
"account": "svc_auth_frontend",
|
| 106 |
+
"account_type": "service_account",
|
| 107 |
+
"privilege_level": "medium",
|
| 108 |
+
"normal_hosts": ["auth_service", "api_gateway"],
|
| 109 |
+
},
|
| 110 |
+
"payment_service": {
|
| 111 |
+
"account": "svc_checkout",
|
| 112 |
+
"account_type": "service_account",
|
| 113 |
+
"privilege_level": "high",
|
| 114 |
+
"normal_hosts": ["payment_service"],
|
| 115 |
+
},
|
| 116 |
+
"database": {
|
| 117 |
+
"account": "svc_data_sync",
|
| 118 |
+
"account_type": "service_account",
|
| 119 |
+
"privilege_level": "high",
|
| 120 |
+
"normal_hosts": ["database", "payment_service"],
|
| 121 |
+
},
|
| 122 |
+
"api_gateway": {
|
| 123 |
+
"account": "deploy_bot",
|
| 124 |
+
"account_type": "automation",
|
| 125 |
+
"privilege_level": "medium",
|
| 126 |
+
"normal_hosts": ["api_gateway"],
|
| 127 |
+
},
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
CHANGE_CALENDAR = {
|
| 131 |
+
"auth_service": {
|
| 132 |
+
"window": "03:00-03:20Z",
|
| 133 |
+
"change_type": "auth policy sync",
|
| 134 |
+
"expected_actor": "svc_auth_frontend",
|
| 135 |
+
},
|
| 136 |
+
"payment_service": {
|
| 137 |
+
"window": "02:30-02:45Z",
|
| 138 |
+
"change_type": "checkout rollout",
|
| 139 |
+
"expected_actor": "svc_checkout",
|
| 140 |
+
},
|
| 141 |
+
"database": {
|
| 142 |
+
"window": "04:00-04:30Z",
|
| 143 |
+
"change_type": "backup and index maintenance",
|
| 144 |
+
"expected_actor": "svc_data_sync",
|
| 145 |
+
},
|
| 146 |
+
"api_gateway": {
|
| 147 |
+
"window": "03:10-03:25Z",
|
| 148 |
+
"change_type": "gateway deploy",
|
| 149 |
+
"expected_actor": "deploy_bot",
|
| 150 |
+
},
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
class AdaptShieldEnvironment(Environment):
|
| 155 |
+
"""
|
| 156 |
+
AdaptShield: Two-Phase Adaptive Cybersecurity RL Environment.
|
| 157 |
+
|
| 158 |
+
Example:
|
| 159 |
+
>>> env = AdaptShieldEnvironment(task_name="direct-triage")
|
| 160 |
+
>>> obs = env.reset()
|
| 161 |
+
>>> # Phase 1 — classify the threat
|
| 162 |
+
>>> obs2 = env.step(Phase1Action(
|
| 163 |
+
... threat_type="brute_force", confidence=0.9,
|
| 164 |
+
... target_node="auth_service", recommended_action="rate_limit"
|
| 165 |
+
... ))
|
| 166 |
+
>>> print(obs2.phase) # 2
|
| 167 |
+
>>> # Phase 2 — execute the defense
|
| 168 |
+
>>> obs3 = env.step(Phase2Action(
|
| 169 |
+
... action="rate_limit", target_node="auth_service"
|
| 170 |
+
... ))
|
| 171 |
+
>>> print(obs3.reward) # reward signal
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
SUPPORTS_CONCURRENT_SESSIONS: bool = True
|
| 175 |
+
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
task_name: str = "direct-triage",
|
| 179 |
+
world_split: str | None = None,
|
| 180 |
+
world_family: str | None = None,
|
| 181 |
+
operational_mode: str | None = None,
|
| 182 |
+
):
|
| 183 |
+
if task_name not in TASK_CONFIGS:
|
| 184 |
+
task_name = "direct-triage"
|
| 185 |
+
|
| 186 |
+
self._task_name = task_name
|
| 187 |
+
self._config = TASK_CONFIGS[task_name]
|
| 188 |
+
self._world_split = self._sanitize_world_split(world_split or os.environ.get("ADAPTSHIELD_WORLD_SPLIT", "train"))
|
| 189 |
+
self._requested_world_family = world_family or os.environ.get("ADAPTSHIELD_WORLD_FAMILY")
|
| 190 |
+
self._requested_operational_mode = operational_mode or os.environ.get("ADAPTSHIELD_OPERATIONAL_MODE")
|
| 191 |
+
self._world_family = choose_world_family(self._world_split, self._requested_world_family)
|
| 192 |
+
self._operational_mode = choose_operational_mode(task_name, self._requested_operational_mode)
|
| 193 |
+
self._mission_profile = mission_profile_for(task_name, self._operational_mode, self._world_family)
|
| 194 |
+
self._attacker = AttackerEngine(task_name, world_family=self._world_family)
|
| 195 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 196 |
+
|
| 197 |
+
# Episode state
|
| 198 |
+
self._turn: int = 0
|
| 199 |
+
self._phase: int = 1
|
| 200 |
+
self._rewards: List[float] = []
|
| 201 |
+
self._done: bool = False
|
| 202 |
+
self._last_reward: float = 0.0
|
| 203 |
+
self._history: List[Dict[str, str]] = []
|
| 204 |
+
self._phase1_output: Optional[Dict[str, Any]] = None
|
| 205 |
+
self._phase1_grading_output: Optional[Dict[str, Any]] = None
|
| 206 |
+
self._turn_config: Optional[Dict[str, Any]] = None
|
| 207 |
+
self._consecutive_wrong: int = 0
|
| 208 |
+
self._last_obs: Optional[AdaptShieldObservation] = None
|
| 209 |
+
self._episode_replay: List[Dict[str, Any]] = []
|
| 210 |
+
self._last_replay_strategy: Optional[str] = None
|
| 211 |
+
self._active_defenses: List[Dict[str, Any]] = []
|
| 212 |
+
self._foothold_established: bool = False
|
| 213 |
+
self._tool_trace: List[Dict[str, Any]] = []
|
| 214 |
+
self._turn_tool_evidence: Dict[int, List[Dict[str, Any]]] = {}
|
| 215 |
+
self._turn_tool_results: Dict[int, List[Dict[str, Any]]] = {}
|
| 216 |
+
|
| 217 |
+
# ── OpenEnv interface ──────────────────────────────────────────────────
|
| 218 |
+
|
| 219 |
+
def reset(self, task_name: str = None) -> AdaptShieldObservation:
|
| 220 |
+
"""
|
| 221 |
+
Reset environment. Optionally switch task via task_name.
|
| 222 |
+
Always returns Phase 1 observation (Threat Analyst turn).
|
| 223 |
+
"""
|
| 224 |
+
if task_name and task_name in TASK_CONFIGS:
|
| 225 |
+
self._task_name = task_name
|
| 226 |
+
self._config = TASK_CONFIGS[task_name]
|
| 227 |
+
self._world_family = choose_world_family(self._world_split, self._requested_world_family)
|
| 228 |
+
self._operational_mode = choose_operational_mode(self._task_name, self._requested_operational_mode)
|
| 229 |
+
self._mission_profile = mission_profile_for(self._task_name, self._operational_mode, self._world_family)
|
| 230 |
+
self._attacker = AttackerEngine(self._task_name, world_family=self._world_family)
|
| 231 |
+
|
| 232 |
+
self._state = State(episode_id=str(uuid4()), step_count=0)
|
| 233 |
+
self._turn = 1
|
| 234 |
+
self._phase = 1
|
| 235 |
+
self._rewards = []
|
| 236 |
+
self._done = False
|
| 237 |
+
self._last_reward = 0.0
|
| 238 |
+
self._history = []
|
| 239 |
+
self._phase1_output = None
|
| 240 |
+
self._phase1_grading_output = None
|
| 241 |
+
self._consecutive_wrong = 0
|
| 242 |
+
self._episode_replay = []
|
| 243 |
+
self._last_replay_strategy = None
|
| 244 |
+
self._active_defenses = []
|
| 245 |
+
self._foothold_established = False
|
| 246 |
+
self._tool_trace = []
|
| 247 |
+
self._turn_tool_evidence = {}
|
| 248 |
+
self._turn_tool_results = {}
|
| 249 |
+
|
| 250 |
+
self._attacker.reset_episode()
|
| 251 |
+
self._turn_config = self._prepare_turn_config(self._attacker.build_observation())
|
| 252 |
+
|
| 253 |
+
obs_dict = build_phase1_obs(
|
| 254 |
+
turn_config=self._turn_config,
|
| 255 |
+
history=self._history,
|
| 256 |
+
task_name=self._task_name,
|
| 257 |
+
turn=self._turn,
|
| 258 |
+
max_turns=self._config["max_turns"],
|
| 259 |
+
episode_id=self._state.episode_id,
|
| 260 |
+
mission_profile=self._mission_profile,
|
| 261 |
+
)
|
| 262 |
+
obs = self._to_obs(obs_dict)
|
| 263 |
+
obs.metadata = self._metadata_with_defenses(obs.metadata)
|
| 264 |
+
self._last_obs = obs
|
| 265 |
+
return obs
|
| 266 |
+
|
| 267 |
+
def step(
|
| 268 |
+
self, action: AdaptShieldAction | Phase1Action | Phase2Action
|
| 269 |
+
) -> AdaptShieldObservation: # type: ignore[override]
|
| 270 |
+
"""
|
| 271 |
+
Execute one step.
|
| 272 |
+
|
| 273 |
+
Accepts either Phase1Action or Phase2Action.
|
| 274 |
+
Phase 1 → transitions to Phase 2 (no reward yet).
|
| 275 |
+
Phase 2 → grades action, advances turn, returns to Phase 1.
|
| 276 |
+
"""
|
| 277 |
+
if self._done:
|
| 278 |
+
return self._last_obs or self._error_observation(
|
| 279 |
+
"Episode already completed."
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
self._state.step_count += 1
|
| 284 |
+
|
| 285 |
+
# ── Phase 1 → Phase 2 transition ──────────────────────────────
|
| 286 |
+
if self._phase == 1:
|
| 287 |
+
phase1_output = {
|
| 288 |
+
"threat_type": _action_value(getattr(action, "threat_type", None), "unknown"),
|
| 289 |
+
"confidence": _action_float(getattr(action, "confidence", None), 0.5),
|
| 290 |
+
"target_node": _action_value(getattr(action, "target_node", None), "unknown"),
|
| 291 |
+
"recommended_action": _action_value(getattr(action, "recommended_action", None), "monitor"),
|
| 292 |
+
"reasoning": str(getattr(action, "reasoning", "") or ""),
|
| 293 |
+
}
|
| 294 |
+
self._phase1_grading_output = dict(phase1_output)
|
| 295 |
+
self._phase1_output = _degrade_handoff(
|
| 296 |
+
phase1_output=phase1_output,
|
| 297 |
+
turn_config=self._turn_config or {},
|
| 298 |
+
task_name=self._task_name,
|
| 299 |
+
turn=self._turn,
|
| 300 |
+
)
|
| 301 |
+
self._phase = 2
|
| 302 |
+
current_score = normalize_episode_score(self._rewards)
|
| 303 |
+
|
| 304 |
+
obs_dict = build_phase2_obs(
|
| 305 |
+
phase1_output=self._phase1_output,
|
| 306 |
+
history=self._history,
|
| 307 |
+
task_name=self._task_name,
|
| 308 |
+
turn=self._turn,
|
| 309 |
+
max_turns=self._config["max_turns"],
|
| 310 |
+
episode_id=self._state.episode_id,
|
| 311 |
+
current_score=current_score,
|
| 312 |
+
mission_profile=self._mission_profile,
|
| 313 |
+
)
|
| 314 |
+
obs = self._to_obs(obs_dict)
|
| 315 |
+
obs.reward = _clamp(self._last_reward if self._last_reward > 0 else 0.01)
|
| 316 |
+
obs.metadata = self._metadata_with_defenses({
|
| 317 |
+
"episode_id": self._state.episode_id,
|
| 318 |
+
"normalized_score": float(current_score),
|
| 319 |
+
"mission_profile": self._mission_profile,
|
| 320 |
+
})
|
| 321 |
+
self._last_obs = obs
|
| 322 |
+
return obs
|
| 323 |
+
|
| 324 |
+
# ── Phase 2 — grade and advance turn ──────────────────────────
|
| 325 |
+
p2 = {
|
| 326 |
+
"action": _action_value(getattr(action, "action", None), "monitor"),
|
| 327 |
+
"target_node": _action_value(getattr(action, "target_node", None), "unknown"),
|
| 328 |
+
"reasoning": str(getattr(action, "reasoning", "") or ""),
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
current_stage = self._attacker.current_stage()
|
| 332 |
+
foothold_before = self._foothold_established
|
| 333 |
+
reward, catastrophic, info = grade_step(
|
| 334 |
+
phase1_action=self._phase1_grading_output or self._phase1_output or {},
|
| 335 |
+
phase2_action=p2,
|
| 336 |
+
turn_config=self._turn_config or {},
|
| 337 |
+
stage=current_stage,
|
| 338 |
+
consecutive_wrong=self._consecutive_wrong,
|
| 339 |
+
task_name=self._task_name,
|
| 340 |
+
foothold_established=foothold_before,
|
| 341 |
+
mission_profile=self._mission_profile,
|
| 342 |
+
tool_context=self._tool_context_for_turn(),
|
| 343 |
+
)
|
| 344 |
+
|
| 345 |
+
reward = _clamp(_action_float(reward, 0.01))
|
| 346 |
+
self._register_active_defense(p2)
|
| 347 |
+
foothold_transition = self._update_foothold_state(
|
| 348 |
+
p2=p2,
|
| 349 |
+
info=info,
|
| 350 |
+
stage=current_stage,
|
| 351 |
+
)
|
| 352 |
+
info["foothold_established"] = self._foothold_established
|
| 353 |
+
info["foothold_transition"] = foothold_transition
|
| 354 |
+
|
| 355 |
+
# Track consecutive wrong actions for stage escalation
|
| 356 |
+
if info.get("acted_correctly", False):
|
| 357 |
+
self._consecutive_wrong = 0
|
| 358 |
+
else:
|
| 359 |
+
self._consecutive_wrong += 1
|
| 360 |
+
|
| 361 |
+
self._rewards.append(reward)
|
| 362 |
+
self._last_reward = reward
|
| 363 |
+
|
| 364 |
+
# Update history
|
| 365 |
+
replay_strategy = self._attacker.current_strategy()
|
| 366 |
+
strategy_shift = (
|
| 367 |
+
self._last_replay_strategy is not None and
|
| 368 |
+
replay_strategy != self._last_replay_strategy
|
| 369 |
+
)
|
| 370 |
+
self._last_replay_strategy = replay_strategy
|
| 371 |
+
self._episode_replay.append({
|
| 372 |
+
"turn": self._turn,
|
| 373 |
+
"p1": (self._phase1_output or {}).get("threat_type", "unknown"),
|
| 374 |
+
"p2_action": p2["action"],
|
| 375 |
+
"target": p2["target_node"],
|
| 376 |
+
"result": _replay_result(info),
|
| 377 |
+
"shift": strategy_shift,
|
| 378 |
+
"impact": float(info.get("business_impact", 0.0)),
|
| 379 |
+
"blast_radius": info.get("dependency_blast_radius", []),
|
| 380 |
+
"active_defenses": self._active_defense_snapshot(),
|
| 381 |
+
"foothold_established": self._foothold_established,
|
| 382 |
+
"foothold_transition": foothold_transition,
|
| 383 |
+
"mission_alignment": info.get("mission_alignment", "neutral"),
|
| 384 |
+
"tool_calls": info.get("tool_count", 0),
|
| 385 |
+
"tool_evidence_found": info.get("tool_evidence_found", False),
|
| 386 |
+
})
|
| 387 |
+
|
| 388 |
+
self._history.append({
|
| 389 |
+
"turn": str(self._turn),
|
| 390 |
+
"p1": f"classified:{(self._phase1_output or {}).get('threat_type','?')}",
|
| 391 |
+
"p2": f"{p2['action']}→{p2['target_node']}",
|
| 392 |
+
"result": info.get("score_reason", "")[:80],
|
| 393 |
+
"reward": f"{reward:.2f}",
|
| 394 |
+
})
|
| 395 |
+
|
| 396 |
+
# Advance attacker
|
| 397 |
+
self._attacker.advance_turn(
|
| 398 |
+
agent_acted_correctly=info.get("acted_correctly", False)
|
| 399 |
+
)
|
| 400 |
+
self._decay_active_defenses()
|
| 401 |
+
|
| 402 |
+
# Advance turn
|
| 403 |
+
self._turn += 1
|
| 404 |
+
self._phase = 1
|
| 405 |
+
self._phase1_output = None
|
| 406 |
+
self._phase1_grading_output = None
|
| 407 |
+
|
| 408 |
+
episode_done = catastrophic or (self._turn > self._config["max_turns"])
|
| 409 |
+
self._done = episode_done
|
| 410 |
+
|
| 411 |
+
# Compute normalized score — ALWAYS present
|
| 412 |
+
norm_score = normalize_episode_score(self._rewards)
|
| 413 |
+
|
| 414 |
+
if not episode_done:
|
| 415 |
+
self._turn_config = self._prepare_turn_config(self._attacker.build_observation())
|
| 416 |
+
obs_dict = build_phase1_obs(
|
| 417 |
+
turn_config=self._turn_config,
|
| 418 |
+
history=self._history,
|
| 419 |
+
task_name=self._task_name,
|
| 420 |
+
turn=self._turn,
|
| 421 |
+
max_turns=self._config["max_turns"],
|
| 422 |
+
episode_id=self._state.episode_id,
|
| 423 |
+
mission_profile=self._mission_profile,
|
| 424 |
+
)
|
| 425 |
+
obs = self._to_obs(obs_dict)
|
| 426 |
+
obs.reward = reward
|
| 427 |
+
obs.done = False
|
| 428 |
+
obs.last_action_result = info.get("score_reason", "")
|
| 429 |
+
obs.metadata = self._metadata_with_defenses({
|
| 430 |
+
"episode_id": self._state.episode_id,
|
| 431 |
+
"normalized_score": float(norm_score),
|
| 432 |
+
"score_breakdown": info,
|
| 433 |
+
"turns_completed": self._turn - 1,
|
| 434 |
+
"consecutive_wrong": self._consecutive_wrong,
|
| 435 |
+
"mission_profile": self._mission_profile,
|
| 436 |
+
})
|
| 437 |
+
else:
|
| 438 |
+
self._attacker.advance_episode()
|
| 439 |
+
obs_dict = build_phase1_obs(
|
| 440 |
+
turn_config={"network_nodes": {}, "active_alerts": ["[EPISODE COMPLETE]"],
|
| 441 |
+
"attack_stage": "none", "is_benign": False,
|
| 442 |
+
"strategy": "none", "correct_action": "none", "correct_target": "none"},
|
| 443 |
+
history=self._history,
|
| 444 |
+
task_name=self._task_name,
|
| 445 |
+
turn=self._turn,
|
| 446 |
+
max_turns=self._config["max_turns"],
|
| 447 |
+
episode_id=self._state.episode_id,
|
| 448 |
+
mission_profile=self._mission_profile,
|
| 449 |
+
)
|
| 450 |
+
obs = self._to_obs(obs_dict)
|
| 451 |
+
obs.reward = reward
|
| 452 |
+
obs.done = True
|
| 453 |
+
obs.last_action_result = info.get("score_reason", "")
|
| 454 |
+
obs.metadata = self._metadata_with_defenses({
|
| 455 |
+
"episode_id": self._state.episode_id,
|
| 456 |
+
"normalized_score": float(norm_score),
|
| 457 |
+
"score_breakdown": info,
|
| 458 |
+
"raw_rewards": self._rewards,
|
| 459 |
+
"catastrophic": catastrophic,
|
| 460 |
+
"turns_completed": self._turn - 1,
|
| 461 |
+
"episode_replay": self._episode_replay,
|
| 462 |
+
"mission_profile": self._mission_profile,
|
| 463 |
+
})
|
| 464 |
+
|
| 465 |
+
self._last_obs = obs
|
| 466 |
+
return obs
|
| 467 |
+
except Exception as exc:
|
| 468 |
+
return self._error_observation(f"step_error: {exc}")
|
| 469 |
+
|
| 470 |
+
@property
|
| 471 |
+
def state(self) -> State:
|
| 472 |
+
"""Returns State with episode_id and step_count per OpenEnv spec."""
|
| 473 |
+
return self._state
|
| 474 |
+
|
| 475 |
+
# ── Internal ──────────────────────────────────────────────────────────
|
| 476 |
+
|
| 477 |
+
def _to_obs(self, d: Dict[str, Any]) -> AdaptShieldObservation:
|
| 478 |
+
return AdaptShieldObservation(
|
| 479 |
+
scenario_id = d.get("scenario_id", ""),
|
| 480 |
+
task_name = d.get("task_name", self._task_name),
|
| 481 |
+
phase = d.get("phase", 1),
|
| 482 |
+
turn = d.get("turn", 0),
|
| 483 |
+
max_turns = d.get("max_turns", self._config["max_turns"]),
|
| 484 |
+
network_nodes = d.get("network_nodes", {}),
|
| 485 |
+
active_alerts = d.get("active_alerts", []),
|
| 486 |
+
attack_stage = d.get("attack_stage", "none"),
|
| 487 |
+
history = d.get("history", []),
|
| 488 |
+
phase1_assessment = d.get("phase1_assessment"),
|
| 489 |
+
last_action_result = d.get("last_action_result"),
|
| 490 |
+
system_context = d.get("system_context", ""),
|
| 491 |
+
available_actions = d.get("available_actions", []),
|
| 492 |
+
reward = d.get("reward", 0.0),
|
| 493 |
+
done = d.get("done", False),
|
| 494 |
+
metadata = d.get("metadata", {"normalized_score": 0.50}),
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
@staticmethod
|
| 498 |
+
def _sanitize_world_split(value: str) -> str:
|
| 499 |
+
return value if value in {"train", "eval"} else "train"
|
| 500 |
+
|
| 501 |
+
def _error_observation(self, error_message: str) -> AdaptShieldObservation:
|
| 502 |
+
"""Return a safe observation instead of letting step() raise."""
|
| 503 |
+
norm_score = float(normalize_episode_score(self._rewards))
|
| 504 |
+
reward = _clamp(self._last_reward if self._last_reward > 0 else 0.01)
|
| 505 |
+
|
| 506 |
+
if self._phase == 2:
|
| 507 |
+
obs_dict = build_phase2_obs(
|
| 508 |
+
phase1_output=self._phase1_output or {},
|
| 509 |
+
history=self._history,
|
| 510 |
+
task_name=self._task_name,
|
| 511 |
+
turn=self._turn,
|
| 512 |
+
max_turns=self._config["max_turns"],
|
| 513 |
+
episode_id=self._state.episode_id,
|
| 514 |
+
current_score=norm_score,
|
| 515 |
+
mission_profile=self._mission_profile,
|
| 516 |
+
)
|
| 517 |
+
else:
|
| 518 |
+
turn_config = self._turn_config or {
|
| 519 |
+
"network_nodes": {},
|
| 520 |
+
"active_alerts": [f"[ERROR] {error_message}"],
|
| 521 |
+
"attack_stage": "none",
|
| 522 |
+
"is_benign": False,
|
| 523 |
+
"strategy": "unknown",
|
| 524 |
+
"correct_action": "monitor",
|
| 525 |
+
"correct_target": "unknown",
|
| 526 |
+
}
|
| 527 |
+
obs_dict = build_phase1_obs(
|
| 528 |
+
turn_config=turn_config,
|
| 529 |
+
history=self._history,
|
| 530 |
+
task_name=self._task_name,
|
| 531 |
+
turn=self._turn,
|
| 532 |
+
max_turns=self._config["max_turns"],
|
| 533 |
+
episode_id=self._state.episode_id,
|
| 534 |
+
mission_profile=self._mission_profile,
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
obs = self._to_obs(obs_dict)
|
| 538 |
+
obs.reward = float(reward)
|
| 539 |
+
obs.done = bool(self._done)
|
| 540 |
+
obs.last_action_result = error_message
|
| 541 |
+
obs.metadata = self._metadata_with_defenses({
|
| 542 |
+
"episode_id": self._state.episode_id,
|
| 543 |
+
"normalized_score": norm_score,
|
| 544 |
+
"error": error_message,
|
| 545 |
+
"turns_completed": max(0, self._turn - 1),
|
| 546 |
+
"mission_profile": self._mission_profile,
|
| 547 |
+
})
|
| 548 |
+
self._last_obs = obs
|
| 549 |
+
return obs
|
| 550 |
+
|
| 551 |
+
def call_tool(self, tool_name: str, **params: Any) -> Dict[str, Any]:
|
| 552 |
+
"""
|
| 553 |
+
Query the local SOC tool surface.
|
| 554 |
+
|
| 555 |
+
These tools reveal partial evidence, not ground-truth answers. They are
|
| 556 |
+
stateful because responses depend on the current turn, attacker stage,
|
| 557 |
+
foothold state, active defenses, and previous actions.
|
| 558 |
+
"""
|
| 559 |
+
try:
|
| 560 |
+
tool_name = str(tool_name or "").strip()
|
| 561 |
+
node = str(params.get("node", params.get("target_node", "unknown")) or "unknown")
|
| 562 |
+
|
| 563 |
+
if tool_name == "log_search":
|
| 564 |
+
result = self._tool_log_search(node=node, query=str(params.get("query", "")))
|
| 565 |
+
elif tool_name == "cmdb_lookup":
|
| 566 |
+
result = self._tool_cmdb_lookup(node=node)
|
| 567 |
+
elif tool_name == "edr_status":
|
| 568 |
+
result = self._tool_edr_status(node=node)
|
| 569 |
+
elif tool_name == "vuln_lookup":
|
| 570 |
+
result = self._tool_vuln_lookup(node=node, package=str(params.get("package", "")))
|
| 571 |
+
elif tool_name == "identity_lookup":
|
| 572 |
+
result = self._tool_identity_lookup(node=node)
|
| 573 |
+
elif tool_name == "change_calendar_lookup":
|
| 574 |
+
result = self._tool_change_calendar_lookup(node=node)
|
| 575 |
+
elif tool_name == "netflow_lookup":
|
| 576 |
+
result = self._tool_netflow_lookup(node=node)
|
| 577 |
+
else:
|
| 578 |
+
result = {
|
| 579 |
+
"tool": tool_name,
|
| 580 |
+
"turn": self._turn,
|
| 581 |
+
"node": node,
|
| 582 |
+
"error": f"Unknown tool: {tool_name}",
|
| 583 |
+
"available_tools": [tool["name"] for tool in AVAILABLE_SOC_TOOLS],
|
| 584 |
+
}
|
| 585 |
+
|
| 586 |
+
self._record_tool_result(result)
|
| 587 |
+
return _public_tool_result(result)
|
| 588 |
+
except Exception as exc:
|
| 589 |
+
return {
|
| 590 |
+
"tool": str(tool_name),
|
| 591 |
+
"turn": self._turn,
|
| 592 |
+
"error": f"tool_error: {exc}",
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
def _tool_log_search(self, node: str, query: str = "") -> Dict[str, Any]:
|
| 596 |
+
turn_config = self._turn_config or {}
|
| 597 |
+
hard_task = self._task_name == "polymorphic-zero-day"
|
| 598 |
+
strategy = str(turn_config.get("strategy", "benign"))
|
| 599 |
+
target = str(turn_config.get("correct_target", "none"))
|
| 600 |
+
stage = str(turn_config.get("attack_stage", self._attacker.current_stage()))
|
| 601 |
+
relevant = node == target
|
| 602 |
+
|
| 603 |
+
events = [f"{node}: routine health heartbeat accepted"]
|
| 604 |
+
evidence_type = "benign"
|
| 605 |
+
confidence = 0.32
|
| 606 |
+
|
| 607 |
+
if relevant and strategy == "brute_force":
|
| 608 |
+
events = [
|
| 609 |
+
"auth_service: 143 failed logins grouped by /24 source subnet",
|
| 610 |
+
"auth_service: password spray pattern across active user accounts",
|
| 611 |
+
"payment_service: no matching checkout errors during same window",
|
| 612 |
+
]
|
| 613 |
+
evidence_type = "brute_force"
|
| 614 |
+
confidence = 0.88
|
| 615 |
+
elif relevant and strategy == "lateral_movement":
|
| 616 |
+
if hard_task:
|
| 617 |
+
events = [
|
| 618 |
+
"auth_service: low-volume identity failures persist from earlier abuse window",
|
| 619 |
+
"payment_service: shared workload credential reused from a new east-west origin",
|
| 620 |
+
"payment_service: callback mesh fans out across private ports after short-lived sessions",
|
| 621 |
+
]
|
| 622 |
+
confidence = 0.77
|
| 623 |
+
else:
|
| 624 |
+
events = [
|
| 625 |
+
"payment_service: service account login from previously unseen host",
|
| 626 |
+
"auth_service -> payment_service connections exceed baseline",
|
| 627 |
+
"payment_service: short-lived internal sessions fan out to private ports",
|
| 628 |
+
]
|
| 629 |
+
confidence = 0.84
|
| 630 |
+
evidence_type = "lateral_movement"
|
| 631 |
+
elif relevant and strategy == "exfiltration":
|
| 632 |
+
if hard_task:
|
| 633 |
+
events = [
|
| 634 |
+
"database: bulk row access concentrates on high-value records",
|
| 635 |
+
"database: staging bundle assembled before outbound transfer burst",
|
| 636 |
+
"netflow: outbound transfer volume from database exceeds 7-day baseline",
|
| 637 |
+
]
|
| 638 |
+
confidence = 0.82
|
| 639 |
+
else:
|
| 640 |
+
events = [
|
| 641 |
+
"database: sequential reads concentrated on sensitive tables",
|
| 642 |
+
"database: compressed archive created before outbound transfer spike",
|
| 643 |
+
"netflow: database egress exceeds 7-day baseline for this hour",
|
| 644 |
+
]
|
| 645 |
+
confidence = 0.90
|
| 646 |
+
evidence_type = "exfiltration"
|
| 647 |
+
elif relevant and strategy == "supply_chain":
|
| 648 |
+
if hard_task:
|
| 649 |
+
events = [
|
| 650 |
+
"api_gateway: release artifact requested from untrusted update channel",
|
| 651 |
+
"api_gateway: dependency integrity drift detected in gateway manifest",
|
| 652 |
+
"api_gateway: new outbound connection follows release request",
|
| 653 |
+
]
|
| 654 |
+
confidence = 0.79
|
| 655 |
+
else:
|
| 656 |
+
events = [
|
| 657 |
+
"api_gateway: unsigned package update requested by deploy_bot",
|
| 658 |
+
"api_gateway: dependency hash differs from approved registry hash",
|
| 659 |
+
"api_gateway: new outbound connection follows update attempt",
|
| 660 |
+
]
|
| 661 |
+
confidence = 0.86
|
| 662 |
+
evidence_type = "supply_chain"
|
| 663 |
+
elif turn_config.get("is_benign", False):
|
| 664 |
+
events = [
|
| 665 |
+
f"{node}: scheduled maintenance event matches change calendar",
|
| 666 |
+
f"{node}: no persistence, beaconing, or sensitive data access observed",
|
| 667 |
+
]
|
| 668 |
+
confidence = 0.65
|
| 669 |
+
|
| 670 |
+
return {
|
| 671 |
+
"tool": "log_search",
|
| 672 |
+
"turn": self._turn,
|
| 673 |
+
"phase": self._phase,
|
| 674 |
+
"node": node,
|
| 675 |
+
"query": query,
|
| 676 |
+
"events": [self._surface_text(event) for event in events],
|
| 677 |
+
"evidence_type": evidence_type,
|
| 678 |
+
"confidence": confidence,
|
| 679 |
+
"verified": relevant and evidence_type == strategy,
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
def _tool_cmdb_lookup(self, node: str) -> Dict[str, Any]:
|
| 683 |
+
dependencies = {
|
| 684 |
+
"auth_service": ["payment_service"],
|
| 685 |
+
"payment_service": ["api_gateway"],
|
| 686 |
+
"database": ["payment_service", "api_gateway"],
|
| 687 |
+
"api_gateway": ["auth_service", "payment_service", "database"],
|
| 688 |
+
}.get(node, [])
|
| 689 |
+
criticality = {
|
| 690 |
+
"auth_service": "high",
|
| 691 |
+
"payment_service": "critical",
|
| 692 |
+
"database": "critical",
|
| 693 |
+
"api_gateway": "high",
|
| 694 |
+
}.get(node, "unknown")
|
| 695 |
+
safe_actions = {
|
| 696 |
+
"auth_service": ["rate_limit", "patch", "monitor"],
|
| 697 |
+
"payment_service": ["honeypot", "rate_limit", "monitor"],
|
| 698 |
+
"database": ["honeypot", "monitor"],
|
| 699 |
+
"api_gateway": ["patch", "rate_limit", "monitor"],
|
| 700 |
+
}.get(node, ["monitor"])
|
| 701 |
+
|
| 702 |
+
return {
|
| 703 |
+
"tool": "cmdb_lookup",
|
| 704 |
+
"turn": self._turn,
|
| 705 |
+
"phase": self._phase,
|
| 706 |
+
"node": node,
|
| 707 |
+
"owner": SERVICE_OWNERS.get(node, "unknown"),
|
| 708 |
+
"criticality": criticality,
|
| 709 |
+
"dependencies": dependencies,
|
| 710 |
+
"mission_profile": self._mission_profile,
|
| 711 |
+
"safe_actions": safe_actions,
|
| 712 |
+
"evidence_type": "operational_context",
|
| 713 |
+
"verified": True,
|
| 714 |
+
"result_summary": self._surface_text(f"{node} criticality {criticality}; blast radius {len(dependencies)}"),
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
def _tool_edr_status(self, node: str) -> Dict[str, Any]:
|
| 718 |
+
turn_config = self._turn_config or {}
|
| 719 |
+
hard_task = self._task_name == "polymorphic-zero-day"
|
| 720 |
+
strategy = str(turn_config.get("strategy", "benign"))
|
| 721 |
+
target = str(turn_config.get("correct_target", "none"))
|
| 722 |
+
relevant = node == target
|
| 723 |
+
active_controls = [
|
| 724 |
+
control for control in self._active_defenses
|
| 725 |
+
if control.get("target") == node
|
| 726 |
+
]
|
| 727 |
+
|
| 728 |
+
status = {
|
| 729 |
+
"containment": "not_isolated",
|
| 730 |
+
"persistence": False,
|
| 731 |
+
"beaconing": False,
|
| 732 |
+
"active_controls": active_controls,
|
| 733 |
+
}
|
| 734 |
+
evidence_type = "benign"
|
| 735 |
+
confidence = 0.45
|
| 736 |
+
|
| 737 |
+
if relevant and strategy == "lateral_movement":
|
| 738 |
+
if hard_task:
|
| 739 |
+
status.update({
|
| 740 |
+
"containment": "partial" if self._foothold_established else "unconfirmed",
|
| 741 |
+
"persistence": self._foothold_established,
|
| 742 |
+
"beaconing": self._foothold_established,
|
| 743 |
+
"process_note": (
|
| 744 |
+
"callback process remains active under reused workload credential"
|
| 745 |
+
if self._foothold_established else
|
| 746 |
+
"unexpected child process under shared workload credential; no confirmed beacon yet"
|
| 747 |
+
),
|
| 748 |
+
})
|
| 749 |
+
confidence = 0.74 if not self._foothold_established else 0.87
|
| 750 |
+
else:
|
| 751 |
+
status.update({
|
| 752 |
+
"containment": "partial" if self._foothold_established else "not_isolated",
|
| 753 |
+
"persistence": self._foothold_established,
|
| 754 |
+
"beaconing": True,
|
| 755 |
+
"process_note": "unknown child process under service account context",
|
| 756 |
+
})
|
| 757 |
+
confidence = 0.87
|
| 758 |
+
evidence_type = "lateral_movement"
|
| 759 |
+
elif relevant and strategy == "exfiltration":
|
| 760 |
+
status.update({
|
| 761 |
+
"containment": "not_isolated",
|
| 762 |
+
"persistence": self._foothold_established,
|
| 763 |
+
"beaconing": False if hard_task else True,
|
| 764 |
+
"process_note": (
|
| 765 |
+
"bundle staging followed by intermittent outbound worker activity"
|
| 766 |
+
if hard_task else
|
| 767 |
+
"archive process followed by outbound transfer"
|
| 768 |
+
),
|
| 769 |
+
})
|
| 770 |
+
evidence_type = "exfiltration"
|
| 771 |
+
confidence = 0.73 if hard_task else 0.82
|
| 772 |
+
elif active_controls:
|
| 773 |
+
status["containment"] = "control_active"
|
| 774 |
+
confidence = 0.70
|
| 775 |
+
|
| 776 |
+
return {
|
| 777 |
+
"tool": "edr_status",
|
| 778 |
+
"turn": self._turn,
|
| 779 |
+
"phase": self._phase,
|
| 780 |
+
"node": node,
|
| 781 |
+
**status,
|
| 782 |
+
"evidence_type": evidence_type,
|
| 783 |
+
"confidence": confidence,
|
| 784 |
+
"verified": relevant and evidence_type == strategy,
|
| 785 |
+
"process_note": self._surface_text(str(status.get("process_note", ""))) if status.get("process_note") else "",
|
| 786 |
+
}
|
| 787 |
+
|
| 788 |
+
def _tool_vuln_lookup(self, node: str, package: str = "") -> Dict[str, Any]:
|
| 789 |
+
turn_config = self._turn_config or {}
|
| 790 |
+
hard_task = self._task_name == "polymorphic-zero-day"
|
| 791 |
+
strategy = str(turn_config.get("strategy", "benign"))
|
| 792 |
+
target = str(turn_config.get("correct_target", "none"))
|
| 793 |
+
relevant = node == target and strategy == "supply_chain"
|
| 794 |
+
|
| 795 |
+
if relevant:
|
| 796 |
+
advisory = {
|
| 797 |
+
"package": package or "gateway-router",
|
| 798 |
+
"advisory_id": "ADV-AS-042",
|
| 799 |
+
"risk": "critical",
|
| 800 |
+
"finding": (
|
| 801 |
+
"artifact integrity drift from untrusted release channel"
|
| 802 |
+
if hard_task else
|
| 803 |
+
"registry hash mismatch with unsigned update source"
|
| 804 |
+
),
|
| 805 |
+
"recommended_mitigation": "patch from trusted registry",
|
| 806 |
+
}
|
| 807 |
+
evidence_type = "supply_chain"
|
| 808 |
+
confidence = 0.82 if hard_task else 0.91
|
| 809 |
+
else:
|
| 810 |
+
advisory = {
|
| 811 |
+
"package": package or "unknown",
|
| 812 |
+
"advisory_id": None,
|
| 813 |
+
"risk": "none_known",
|
| 814 |
+
"finding": "no matching active internal advisory for this node/package",
|
| 815 |
+
"recommended_mitigation": "continue investigation",
|
| 816 |
+
}
|
| 817 |
+
evidence_type = "benign"
|
| 818 |
+
confidence = 0.55
|
| 819 |
+
|
| 820 |
+
return {
|
| 821 |
+
"tool": "vuln_lookup",
|
| 822 |
+
"turn": self._turn,
|
| 823 |
+
"phase": self._phase,
|
| 824 |
+
"node": node,
|
| 825 |
+
**{
|
| 826 |
+
**advisory,
|
| 827 |
+
"finding": self._surface_text(str(advisory.get("finding", ""))),
|
| 828 |
+
"recommended_mitigation": self._surface_text(str(advisory.get("recommended_mitigation", ""))),
|
| 829 |
+
},
|
| 830 |
+
"evidence_type": evidence_type,
|
| 831 |
+
"confidence": confidence,
|
| 832 |
+
"verified": relevant,
|
| 833 |
+
}
|
| 834 |
+
|
| 835 |
+
def _tool_identity_lookup(self, node: str) -> Dict[str, Any]:
|
| 836 |
+
turn_config = self._turn_config or {}
|
| 837 |
+
strategy = str(turn_config.get("strategy", "benign"))
|
| 838 |
+
target = str(turn_config.get("correct_target", "none"))
|
| 839 |
+
identity = dict(IDENTITY_CONTEXT.get(node, {
|
| 840 |
+
"account": "unknown",
|
| 841 |
+
"account_type": "unknown",
|
| 842 |
+
"privilege_level": "unknown",
|
| 843 |
+
"normal_hosts": [],
|
| 844 |
+
}))
|
| 845 |
+
|
| 846 |
+
lookup = {
|
| 847 |
+
"tool": "identity_lookup",
|
| 848 |
+
"turn": self._turn,
|
| 849 |
+
"phase": self._phase,
|
| 850 |
+
"node": node,
|
| 851 |
+
**identity,
|
| 852 |
+
"recent_source_host": node,
|
| 853 |
+
"source_status": "expected",
|
| 854 |
+
"identity_anomaly": False,
|
| 855 |
+
"result_summary": "",
|
| 856 |
+
"confidence": 0.58,
|
| 857 |
+
"evidence_type": "benign",
|
| 858 |
+
"verified": False,
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
if node == target and strategy == "lateral_movement":
|
| 862 |
+
lookup.update({
|
| 863 |
+
"recent_source_host": "auth_service",
|
| 864 |
+
"source_status": "unexpected",
|
| 865 |
+
"identity_anomaly": True,
|
| 866 |
+
"confidence": 0.84 if self._task_name != "polymorphic-zero-day" else 0.76,
|
| 867 |
+
"evidence_type": "lateral_movement",
|
| 868 |
+
"verified": True,
|
| 869 |
+
})
|
| 870 |
+
elif node == target and strategy == "supply_chain":
|
| 871 |
+
lookup.update({
|
| 872 |
+
"recent_source_host": "external-release-runner",
|
| 873 |
+
"source_status": "unexpected",
|
| 874 |
+
"identity_anomaly": True,
|
| 875 |
+
"confidence": 0.73,
|
| 876 |
+
"evidence_type": "supply_chain",
|
| 877 |
+
"verified": True,
|
| 878 |
+
})
|
| 879 |
+
elif turn_config.get("is_benign", False):
|
| 880 |
+
lookup.update({
|
| 881 |
+
"recent_source_host": identity.get("normal_hosts", [node])[0] if identity.get("normal_hosts") else node,
|
| 882 |
+
"source_status": "scheduled_change_window",
|
| 883 |
+
"confidence": 0.69,
|
| 884 |
+
})
|
| 885 |
+
|
| 886 |
+
if (
|
| 887 |
+
self._task_name == "dual-pivot" and
|
| 888 |
+
strategy == "lateral_movement" and
|
| 889 |
+
self._operational_mode == "evidence_preservation"
|
| 890 |
+
):
|
| 891 |
+
lookup["source_status"] = "unexpected_but_trackable"
|
| 892 |
+
lookup["result_summary"] = self._surface_text(
|
| 893 |
+
"Identity trail is intact; preserving visibility before hard containment is mission-aligned."
|
| 894 |
+
)
|
| 895 |
+
else:
|
| 896 |
+
lookup["result_summary"] = self._surface_text(
|
| 897 |
+
f"account={lookup['account']} source={lookup['recent_source_host']} anomaly={lookup['identity_anomaly']}"
|
| 898 |
+
)
|
| 899 |
+
return lookup
|
| 900 |
+
|
| 901 |
+
def _tool_change_calendar_lookup(self, node: str) -> Dict[str, Any]:
|
| 902 |
+
turn_config = self._turn_config or {}
|
| 903 |
+
strategy = str(turn_config.get("strategy", "benign"))
|
| 904 |
+
target = str(turn_config.get("correct_target", "none"))
|
| 905 |
+
change = dict(CHANGE_CALENDAR.get(node, {
|
| 906 |
+
"window": "none_scheduled",
|
| 907 |
+
"change_type": "none",
|
| 908 |
+
"expected_actor": "unknown",
|
| 909 |
+
}))
|
| 910 |
+
|
| 911 |
+
scheduled = bool(turn_config.get("is_benign", False))
|
| 912 |
+
confidence = 0.66 if scheduled else 0.74
|
| 913 |
+
if node == target and strategy == "supply_chain":
|
| 914 |
+
scheduled = False
|
| 915 |
+
confidence = 0.87 if self._task_name != "polymorphic-zero-day" else 0.78
|
| 916 |
+
elif node == target and strategy == "lateral_movement":
|
| 917 |
+
scheduled = False
|
| 918 |
+
confidence = 0.72
|
| 919 |
+
|
| 920 |
+
change_status = "scheduled" if scheduled else "no_matching_change"
|
| 921 |
+
if (
|
| 922 |
+
self._task_name == "dual-pivot" and
|
| 923 |
+
strategy == "lateral_movement" and
|
| 924 |
+
self._operational_mode == "evidence_preservation"
|
| 925 |
+
):
|
| 926 |
+
change_status = "forensic_observation_hold"
|
| 927 |
+
return {
|
| 928 |
+
"tool": "change_calendar_lookup",
|
| 929 |
+
"turn": self._turn,
|
| 930 |
+
"phase": self._phase,
|
| 931 |
+
"node": node,
|
| 932 |
+
**change,
|
| 933 |
+
"scheduled": scheduled,
|
| 934 |
+
"change_status": change_status,
|
| 935 |
+
"confidence": confidence,
|
| 936 |
+
"evidence_type": "benign" if scheduled else ("supply_chain" if node == target and strategy == "supply_chain" else "operational_context"),
|
| 937 |
+
"verified": scheduled or (node == target and strategy == "supply_chain"),
|
| 938 |
+
}
|
| 939 |
+
|
| 940 |
+
def _tool_netflow_lookup(self, node: str) -> Dict[str, Any]:
|
| 941 |
+
turn_config = self._turn_config or {}
|
| 942 |
+
strategy = str(turn_config.get("strategy", "benign"))
|
| 943 |
+
target = str(turn_config.get("correct_target", "none"))
|
| 944 |
+
hard_task = self._task_name == "polymorphic-zero-day"
|
| 945 |
+
|
| 946 |
+
summary = {
|
| 947 |
+
"tool": "netflow_lookup",
|
| 948 |
+
"turn": self._turn,
|
| 949 |
+
"phase": self._phase,
|
| 950 |
+
"node": node,
|
| 951 |
+
"east_west_connections": 12,
|
| 952 |
+
"outbound_mb": 4,
|
| 953 |
+
"new_destinations": 0,
|
| 954 |
+
"traffic_pattern": "baseline",
|
| 955 |
+
"confidence": 0.57,
|
| 956 |
+
"evidence_type": "benign",
|
| 957 |
+
"verified": False,
|
| 958 |
+
}
|
| 959 |
+
|
| 960 |
+
if node == target and strategy == "lateral_movement":
|
| 961 |
+
summary.update({
|
| 962 |
+
"east_west_connections": 46 if hard_task else 58,
|
| 963 |
+
"outbound_mb": 9,
|
| 964 |
+
"new_destinations": 5 if hard_task else 7,
|
| 965 |
+
"traffic_pattern": "east_west_fanout",
|
| 966 |
+
"confidence": 0.79 if hard_task else 0.88,
|
| 967 |
+
"evidence_type": "lateral_movement",
|
| 968 |
+
"verified": True,
|
| 969 |
+
})
|
| 970 |
+
elif node == target and strategy == "exfiltration":
|
| 971 |
+
summary.update({
|
| 972 |
+
"east_west_connections": 18,
|
| 973 |
+
"outbound_mb": 74 if hard_task else 96,
|
| 974 |
+
"new_destinations": 2,
|
| 975 |
+
"traffic_pattern": "outbound_transfer_burst",
|
| 976 |
+
"confidence": 0.82 if hard_task else 0.91,
|
| 977 |
+
"evidence_type": "exfiltration",
|
| 978 |
+
"verified": True,
|
| 979 |
+
})
|
| 980 |
+
elif node == target and strategy == "supply_chain":
|
| 981 |
+
summary.update({
|
| 982 |
+
"east_west_connections": 16,
|
| 983 |
+
"outbound_mb": 19,
|
| 984 |
+
"new_destinations": 1,
|
| 985 |
+
"traffic_pattern": "post_deploy_callback",
|
| 986 |
+
"confidence": 0.71,
|
| 987 |
+
"evidence_type": "supply_chain",
|
| 988 |
+
"verified": True,
|
| 989 |
+
})
|
| 990 |
+
|
| 991 |
+
if (
|
| 992 |
+
self._task_name == "polymorphic-zero-day" and
|
| 993 |
+
self._operational_mode in {"forensic_hold", "business_continuity"} and
|
| 994 |
+
node == target and strategy in {"lateral_movement", "exfiltration"}
|
| 995 |
+
):
|
| 996 |
+
summary["traffic_pattern"] = "controlled_observation_window"
|
| 997 |
+
summary["confidence"] = max(float(summary["confidence"]), 0.83)
|
| 998 |
+
return summary
|
| 999 |
+
|
| 1000 |
+
def _record_tool_result(self, result: Dict[str, Any]) -> None:
|
| 1001 |
+
turn = int(result.get("turn", self._turn) or self._turn)
|
| 1002 |
+
internal = {
|
| 1003 |
+
"turn": turn,
|
| 1004 |
+
"phase": result.get("phase", self._phase),
|
| 1005 |
+
"tool": result.get("tool", "unknown"),
|
| 1006 |
+
"node": result.get("node", "unknown"),
|
| 1007 |
+
"evidence_type": result.get("evidence_type", "unknown"),
|
| 1008 |
+
"verified": bool(result.get("verified", False)),
|
| 1009 |
+
"confidence": float(result.get("confidence", 0.0) or 0.0),
|
| 1010 |
+
}
|
| 1011 |
+
self._turn_tool_results.setdefault(turn, []).append(internal)
|
| 1012 |
+
|
| 1013 |
+
trace = {
|
| 1014 |
+
"turn": result.get("turn", self._turn),
|
| 1015 |
+
"phase": result.get("phase", self._phase),
|
| 1016 |
+
"tool": result.get("tool", "unknown"),
|
| 1017 |
+
"node": result.get("node", "unknown"),
|
| 1018 |
+
"confidence": float(result.get("confidence", 0.0) or 0.0),
|
| 1019 |
+
"summary": _tool_summary(result),
|
| 1020 |
+
}
|
| 1021 |
+
self._tool_trace.append(trace)
|
| 1022 |
+
|
| 1023 |
+
if internal["verified"]:
|
| 1024 |
+
self._turn_tool_evidence.setdefault(turn, []).append(internal)
|
| 1025 |
+
|
| 1026 |
+
def _tool_context_for_turn(self) -> Dict[str, Any]:
|
| 1027 |
+
evidence = list(self._turn_tool_evidence.get(self._turn, []))
|
| 1028 |
+
return {
|
| 1029 |
+
"turn": self._turn,
|
| 1030 |
+
"tool_count": len([
|
| 1031 |
+
row for row in self._tool_trace
|
| 1032 |
+
if int(row.get("turn", -1)) == self._turn
|
| 1033 |
+
]),
|
| 1034 |
+
"evidence": evidence,
|
| 1035 |
+
"tool_results": list(self._turn_tool_results.get(self._turn, [])),
|
| 1036 |
+
}
|
| 1037 |
+
|
| 1038 |
+
def _update_foothold_state(
|
| 1039 |
+
self,
|
| 1040 |
+
p2: Dict[str, str],
|
| 1041 |
+
info: Dict[str, Any],
|
| 1042 |
+
stage: str,
|
| 1043 |
+
) -> bool:
|
| 1044 |
+
if (
|
| 1045 |
+
self._task_name != "polymorphic-zero-day" or
|
| 1046 |
+
self._foothold_established or
|
| 1047 |
+
stage not in ("exploit", "exfiltration")
|
| 1048 |
+
):
|
| 1049 |
+
return False
|
| 1050 |
+
|
| 1051 |
+
if p2.get("action") == "monitor" or not info.get("acted_correctly", False):
|
| 1052 |
+
self._foothold_established = True
|
| 1053 |
+
return True
|
| 1054 |
+
|
| 1055 |
+
return False
|
| 1056 |
+
|
| 1057 |
+
def _register_active_defense(self, p2: Dict[str, str]) -> None:
|
| 1058 |
+
action = p2.get("action", "monitor")
|
| 1059 |
+
if action not in DEFENSE_TTL:
|
| 1060 |
+
return
|
| 1061 |
+
|
| 1062 |
+
target = p2.get("target_node", "unknown")
|
| 1063 |
+
self._active_defenses = [
|
| 1064 |
+
control for control in self._active_defenses
|
| 1065 |
+
if not (control["action"] == action and control["target"] == target)
|
| 1066 |
+
]
|
| 1067 |
+
self._active_defenses.append({
|
| 1068 |
+
"action": action,
|
| 1069 |
+
"target": target,
|
| 1070 |
+
"ttl": DEFENSE_TTL[action],
|
| 1071 |
+
"side_effect": DEFENSE_SIDE_EFFECT[action],
|
| 1072 |
+
})
|
| 1073 |
+
|
| 1074 |
+
def _decay_active_defenses(self) -> None:
|
| 1075 |
+
next_controls = []
|
| 1076 |
+
for control in self._active_defenses:
|
| 1077 |
+
updated = dict(control)
|
| 1078 |
+
updated["ttl"] = int(updated.get("ttl", 0)) - 1
|
| 1079 |
+
if updated["ttl"] > 0:
|
| 1080 |
+
next_controls.append(updated)
|
| 1081 |
+
self._active_defenses = next_controls
|
| 1082 |
+
|
| 1083 |
+
def _active_defense_snapshot(self) -> List[Dict[str, Any]]:
|
| 1084 |
+
return [dict(control) for control in self._active_defenses]
|
| 1085 |
+
|
| 1086 |
+
def _metadata_with_defenses(self, metadata: Dict[str, Any]) -> Dict[str, Any]:
|
| 1087 |
+
updated = dict(metadata or {})
|
| 1088 |
+
updated["active_defenses"] = self._active_defense_snapshot()
|
| 1089 |
+
updated["mission_profile"] = self._mission_profile
|
| 1090 |
+
updated["world_split"] = self._world_split
|
| 1091 |
+
updated["world_family"] = self._world_family
|
| 1092 |
+
updated["operational_mode"] = self._operational_mode
|
| 1093 |
+
updated["available_tools"] = [dict(tool) for tool in AVAILABLE_SOC_TOOLS]
|
| 1094 |
+
updated["tool_trace"] = [dict(row) for row in self._tool_trace]
|
| 1095 |
+
updated["tool_calls_this_turn"] = len([
|
| 1096 |
+
row for row in self._tool_trace
|
| 1097 |
+
if int(row.get("turn", -1)) == self._turn
|
| 1098 |
+
])
|
| 1099 |
+
return updated
|
| 1100 |
+
|
| 1101 |
+
def _with_active_defense_alerts(self, turn_config: Dict[str, Any]) -> Dict[str, Any]:
|
| 1102 |
+
if not self._active_defenses:
|
| 1103 |
+
return turn_config
|
| 1104 |
+
|
| 1105 |
+
updated = dict(turn_config)
|
| 1106 |
+
alerts = list(updated.get("active_alerts", []))
|
| 1107 |
+
for control in self._active_defenses:
|
| 1108 |
+
alerts.append(
|
| 1109 |
+
"[CONTROL] "
|
| 1110 |
+
f"{control['action']} active on {control['target']} "
|
| 1111 |
+
f"(ttl={control['ttl']}, side_effect={control['side_effect']})"
|
| 1112 |
+
)
|
| 1113 |
+
updated["active_alerts"] = alerts
|
| 1114 |
+
return updated
|
| 1115 |
+
|
| 1116 |
+
def _with_foothold_context(self, turn_config: Dict[str, Any]) -> Dict[str, Any]:
|
| 1117 |
+
if (
|
| 1118 |
+
self._task_name != "polymorphic-zero-day" or
|
| 1119 |
+
not self._foothold_established or
|
| 1120 |
+
not turn_config.get("is_benign", False)
|
| 1121 |
+
):
|
| 1122 |
+
return turn_config
|
| 1123 |
+
|
| 1124 |
+
updated = dict(turn_config)
|
| 1125 |
+
nodes = {name: dict(metrics) for name, metrics in updated.get("network_nodes", {}).items()}
|
| 1126 |
+
payment = nodes.setdefault("payment_service", {})
|
| 1127 |
+
payment["status"] = "suspicious"
|
| 1128 |
+
payment["cpu"] = max(int(payment.get("cpu", 0)), 58)
|
| 1129 |
+
payment["request_rate"] = max(int(payment.get("request_rate", 0)), 130)
|
| 1130 |
+
|
| 1131 |
+
alerts = list(updated.get("active_alerts", []))
|
| 1132 |
+
alerts.append(
|
| 1133 |
+
"[FOOTHOLD] Low-volume internal callback resembles benign service noise, "
|
| 1134 |
+
"but prior missed containment means deception is required."
|
| 1135 |
+
)
|
| 1136 |
+
|
| 1137 |
+
updated["network_nodes"] = nodes
|
| 1138 |
+
updated["active_alerts"] = alerts
|
| 1139 |
+
updated["attack_stage"] = self._attacker.current_stage()
|
| 1140 |
+
updated["is_benign"] = False
|
| 1141 |
+
updated["strategy"] = "lateral_movement"
|
| 1142 |
+
updated["correct_action"] = "honeypot"
|
| 1143 |
+
updated["correct_target"] = "payment_service"
|
| 1144 |
+
updated["foothold_context"] = True
|
| 1145 |
+
return updated
|
| 1146 |
+
|
| 1147 |
+
def _prepare_turn_config(self, turn_config: Dict[str, Any]) -> Dict[str, Any]:
|
| 1148 |
+
updated = self._with_foothold_context(turn_config)
|
| 1149 |
+
updated = self._apply_operational_mode(updated)
|
| 1150 |
+
updated = self._with_active_defense_alerts(updated)
|
| 1151 |
+
return updated
|
| 1152 |
+
|
| 1153 |
+
def _apply_operational_mode(self, turn_config: Dict[str, Any]) -> Dict[str, Any]:
|
| 1154 |
+
updated = dict(turn_config)
|
| 1155 |
+
alerts = list(updated.get("active_alerts", []))
|
| 1156 |
+
strategy = str(updated.get("strategy", "benign"))
|
| 1157 |
+
stage = str(updated.get("attack_stage", "recon"))
|
| 1158 |
+
|
| 1159 |
+
if self._task_name == "dual-pivot" and strategy == "lateral_movement":
|
| 1160 |
+
if self._operational_mode == "evidence_preservation":
|
| 1161 |
+
updated["correct_action"] = "honeypot"
|
| 1162 |
+
alerts.append(
|
| 1163 |
+
self._surface_text(
|
| 1164 |
+
"[MODE] Checkout incident posture favors preserving attacker visibility until identity misuse is corroborated."
|
| 1165 |
+
)
|
| 1166 |
+
)
|
| 1167 |
+
else:
|
| 1168 |
+
updated["correct_action"] = "isolate"
|
| 1169 |
+
alerts.append(
|
| 1170 |
+
self._surface_text(
|
| 1171 |
+
"[MODE] Checkout incident posture favors decisive containment once the pivot is corroborated."
|
| 1172 |
+
)
|
| 1173 |
+
)
|
| 1174 |
+
|
| 1175 |
+
if self._task_name == "polymorphic-zero-day" and not updated.get("is_benign", False):
|
| 1176 |
+
if self._operational_mode in {"forensic_hold", "business_continuity"} and strategy in {"lateral_movement", "exfiltration"}:
|
| 1177 |
+
updated["correct_action"] = "honeypot"
|
| 1178 |
+
alerts.append(
|
| 1179 |
+
self._surface_text(
|
| 1180 |
+
"[MODE] Enterprise posture favors deception over immediate isolation while mapping the callback path."
|
| 1181 |
+
)
|
| 1182 |
+
)
|
| 1183 |
+
elif self._operational_mode == "containment_first":
|
| 1184 |
+
alerts.append(
|
| 1185 |
+
self._surface_text(
|
| 1186 |
+
"[MODE] Enterprise posture favors immediate containment once compromise is corroborated."
|
| 1187 |
+
)
|
| 1188 |
+
)
|
| 1189 |
+
if self._operational_mode == "business_continuity" and stage == "exploit":
|
| 1190 |
+
nodes = {name: dict(metrics) for name, metrics in updated.get("network_nodes", {}).items()}
|
| 1191 |
+
payment = nodes.setdefault("payment_service", {})
|
| 1192 |
+
payment["status"] = "elevated"
|
| 1193 |
+
payment["request_rate"] = max(int(payment.get("request_rate", 0)), 122)
|
| 1194 |
+
updated["network_nodes"] = nodes
|
| 1195 |
+
alerts.append(
|
| 1196 |
+
self._surface_text(
|
| 1197 |
+
"[MODE] Customer traffic remains sensitive; service continuity pressure is elevated during this window."
|
| 1198 |
+
)
|
| 1199 |
+
)
|
| 1200 |
+
|
| 1201 |
+
updated["active_alerts"] = alerts
|
| 1202 |
+
updated["world_split"] = self._world_split
|
| 1203 |
+
updated["world_family"] = self._world_family
|
| 1204 |
+
updated["operational_mode"] = self._operational_mode
|
| 1205 |
+
return updated
|
| 1206 |
+
|
| 1207 |
+
def _surface_text(self, text: str) -> str:
|
| 1208 |
+
return self._attacker._surface(text)
|
| 1209 |
+
|
| 1210 |
+
|
| 1211 |
+
def _action_value(value: Any, default: str) -> str:
|
| 1212 |
+
"""Serialize action fields without leaking Enum member names."""
|
| 1213 |
+
if value is None:
|
| 1214 |
+
return default
|
| 1215 |
+
if isinstance(value, Enum):
|
| 1216 |
+
return str(value.value)
|
| 1217 |
+
return str(value)
|
| 1218 |
+
|
| 1219 |
+
|
| 1220 |
+
def _action_float(value: Any, default: float) -> float:
|
| 1221 |
+
"""Coerce optional numeric action fields to floats with a safe fallback."""
|
| 1222 |
+
if value is None:
|
| 1223 |
+
return float(default)
|
| 1224 |
+
try:
|
| 1225 |
+
return float(value)
|
| 1226 |
+
except (TypeError, ValueError):
|
| 1227 |
+
return float(default)
|
| 1228 |
+
|
| 1229 |
+
|
| 1230 |
+
def _replay_result(info: Dict[str, Any]) -> str:
|
| 1231 |
+
"""Map grader text into compact replay result labels."""
|
| 1232 |
+
reason = str(info.get("score_reason", "")).lower()
|
| 1233 |
+
if "false positive" in reason:
|
| 1234 |
+
return "false_positive"
|
| 1235 |
+
if reason.startswith("unverified"):
|
| 1236 |
+
return "unverified"
|
| 1237 |
+
if reason.startswith("optimal") or reason.startswith("correct") or reason.startswith("context-aware optimal"):
|
| 1238 |
+
return "optimal"
|
| 1239 |
+
if reason.startswith("heavy-handed"):
|
| 1240 |
+
return "heavy"
|
| 1241 |
+
return "wrong"
|
| 1242 |
+
|
| 1243 |
+
|
| 1244 |
+
def _tool_summary(result: Dict[str, Any]) -> str:
|
| 1245 |
+
if result.get("error"):
|
| 1246 |
+
return str(result["error"])[:120]
|
| 1247 |
+
if result.get("tool") == "log_search":
|
| 1248 |
+
events = result.get("events") or []
|
| 1249 |
+
return str(events[0])[:120] if events else "no matching log events"
|
| 1250 |
+
if result.get("tool") == "cmdb_lookup":
|
| 1251 |
+
deps = result.get("dependencies") or []
|
| 1252 |
+
return f"{result.get('node')} criticality={result.get('criticality')} deps={len(deps)}"
|
| 1253 |
+
if result.get("tool") == "edr_status":
|
| 1254 |
+
return (
|
| 1255 |
+
f"containment={result.get('containment')} "
|
| 1256 |
+
f"beaconing={result.get('beaconing')} "
|
| 1257 |
+
f"persistence={result.get('persistence')}"
|
| 1258 |
+
)
|
| 1259 |
+
if result.get("tool") == "vuln_lookup":
|
| 1260 |
+
return f"risk={result.get('risk')} finding={result.get('finding')}"
|
| 1261 |
+
if result.get("tool") == "identity_lookup":
|
| 1262 |
+
return (
|
| 1263 |
+
f"account={result.get('account')} "
|
| 1264 |
+
f"source={result.get('recent_source_host')} "
|
| 1265 |
+
f"anomaly={result.get('identity_anomaly')}"
|
| 1266 |
+
)
|
| 1267 |
+
if result.get("tool") == "change_calendar_lookup":
|
| 1268 |
+
return (
|
| 1269 |
+
f"scheduled={result.get('scheduled')} "
|
| 1270 |
+
f"window={result.get('window')} "
|
| 1271 |
+
f"change={result.get('change_type')}"
|
| 1272 |
+
)
|
| 1273 |
+
if result.get("tool") == "netflow_lookup":
|
| 1274 |
+
return (
|
| 1275 |
+
f"pattern={result.get('traffic_pattern')} "
|
| 1276 |
+
f"east_west={result.get('east_west_connections')} "
|
| 1277 |
+
f"outbound_mb={result.get('outbound_mb')}"
|
| 1278 |
+
)
|
| 1279 |
+
return str(result)[:120]
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
def _public_tool_result(result: Dict[str, Any]) -> Dict[str, Any]:
|
| 1283 |
+
"""Return observable tool facts without leaking grader labels."""
|
| 1284 |
+
hidden = {"evidence_type", "verified"}
|
| 1285 |
+
public = {key: value for key, value in result.items() if key not in hidden}
|
| 1286 |
+
public["result_summary"] = _tool_summary(result)
|
| 1287 |
+
return public
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
def _degrade_handoff(
|
| 1291 |
+
phase1_output: Dict[str, Any],
|
| 1292 |
+
turn_config: Dict[str, Any],
|
| 1293 |
+
task_name: str,
|
| 1294 |
+
turn: int,
|
| 1295 |
+
) -> Dict[str, Any]:
|
| 1296 |
+
"""Degrade hard-task handoffs after the strategy shift without changing grading."""
|
| 1297 |
+
output = dict(phase1_output)
|
| 1298 |
+
if (
|
| 1299 |
+
task_name == "dual-pivot" and
|
| 1300 |
+
turn >= 4 and
|
| 1301 |
+
turn_config.get("strategy") == "lateral_movement"
|
| 1302 |
+
):
|
| 1303 |
+
output["confidence"] = min(float(output.get("confidence", 0.5)), 0.56)
|
| 1304 |
+
output["recommended_action"] = "monitor"
|
| 1305 |
+
output["handoff_quality"] = "degraded"
|
| 1306 |
+
output["handoff_note"] = (
|
| 1307 |
+
"Analyst handoff is less reliable after the pivot; executor should verify "
|
| 1308 |
+
"internal movement before taking payment_service offline."
|
| 1309 |
+
)
|
| 1310 |
+
elif (
|
| 1311 |
+
task_name == "polymorphic-zero-day" and
|
| 1312 |
+
turn >= 4 and
|
| 1313 |
+
turn_config.get("strategy") in {"lateral_movement", "exfiltration", "supply_chain"}
|
| 1314 |
+
):
|
| 1315 |
+
output["confidence"] = min(float(output.get("confidence", 0.5)), 0.36)
|
| 1316 |
+
output["recommended_action"] = "monitor"
|
| 1317 |
+
output["handoff_quality"] = "degraded"
|
| 1318 |
+
output["handoff_note"] = (
|
| 1319 |
+
"Analyst confidence degraded after attacker pivot; executor must verify "
|
| 1320 |
+
"whether monitor is too passive for the current stage."
|
| 1321 |
+
)
|
| 1322 |
+
else:
|
| 1323 |
+
output["handoff_quality"] = "clean"
|
| 1324 |
+
return output
|
server/app.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdaptShield FastAPI Server
|
| 3 |
+
|
| 4 |
+
CRITICAL: Uses factory pattern (make_env function), NOT singleton.
|
| 5 |
+
Singleton was the Round 1 failure — always served wrong task.
|
| 6 |
+
Factory creates a fresh isolated instance per evaluator session.
|
| 7 |
+
|
| 8 |
+
openenv validate requires:
|
| 9 |
+
- def main() function present
|
| 10 |
+
- called as main() in if __name__ block (literal string check)
|
| 11 |
+
- port 7860 (HF Spaces default)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import sys
|
| 16 |
+
from typing import Any, Dict
|
| 17 |
+
from uuid import uuid4
|
| 18 |
+
|
| 19 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from fastapi import Body, HTTPException
|
| 23 |
+
from openenv.core.env_server.http_server import create_app
|
| 24 |
+
except Exception as e:
|
| 25 |
+
raise ImportError(
|
| 26 |
+
"openenv-core required. Install: pip install openenv-core"
|
| 27 |
+
) from e
|
| 28 |
+
|
| 29 |
+
from models import AdaptShieldAction, AdaptShieldObservation
|
| 30 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 31 |
+
|
| 32 |
+
DEFAULT_TASK = os.getenv("ADAPTSHIELD_TASK", "direct-triage")
|
| 33 |
+
SOC_SESSIONS: Dict[str, AdaptShieldEnvironment] = {}
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def make_env() -> AdaptShieldEnvironment:
|
| 37 |
+
"""
|
| 38 |
+
Factory function — fresh isolated instance per session.
|
| 39 |
+
Never a singleton. Evaluator sessions must be independent.
|
| 40 |
+
"""
|
| 41 |
+
return AdaptShieldEnvironment(task_name=DEFAULT_TASK)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
app = create_app(
|
| 45 |
+
make_env,
|
| 46 |
+
AdaptShieldAction,
|
| 47 |
+
AdaptShieldObservation,
|
| 48 |
+
env_name="adaptshield",
|
| 49 |
+
max_concurrent_envs=10,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@app.post("/soc/reset", tags=["AdaptShield SOC Tools"])
|
| 54 |
+
async def soc_reset(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 55 |
+
"""Start a persistent demo session for SOC tool/API workflows."""
|
| 56 |
+
task = str(payload.get("task", DEFAULT_TASK))
|
| 57 |
+
env = AdaptShieldEnvironment(task_name=task)
|
| 58 |
+
obs = env.reset()
|
| 59 |
+
session_id = str(uuid4())
|
| 60 |
+
SOC_SESSIONS[session_id] = env
|
| 61 |
+
return {
|
| 62 |
+
"session_id": session_id,
|
| 63 |
+
"observation": obs.model_dump(mode="json"),
|
| 64 |
+
"available_tools": obs.metadata.get("available_tools", []),
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@app.post("/soc/step", tags=["AdaptShield SOC Tools"])
|
| 69 |
+
async def soc_step(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 70 |
+
"""Step a persistent SOC tool/API session."""
|
| 71 |
+
env = _soc_session(payload)
|
| 72 |
+
try:
|
| 73 |
+
action = AdaptShieldAction(**dict(payload.get("action", {})))
|
| 74 |
+
except Exception as exc:
|
| 75 |
+
raise HTTPException(status_code=422, detail=str(exc)) from exc
|
| 76 |
+
|
| 77 |
+
obs = env.step(action)
|
| 78 |
+
return {
|
| 79 |
+
"session_id": payload.get("session_id"),
|
| 80 |
+
"observation": obs.model_dump(mode="json"),
|
| 81 |
+
"reward": float(obs.reward),
|
| 82 |
+
"done": bool(obs.done),
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
@app.post("/tools/log_search", tags=["AdaptShield SOC Tools"])
|
| 87 |
+
async def tool_log_search(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 88 |
+
"""Search stateful SIEM/application logs for the active session."""
|
| 89 |
+
return _soc_session(payload).call_tool(
|
| 90 |
+
"log_search",
|
| 91 |
+
node=payload.get("node", payload.get("target_node", "unknown")),
|
| 92 |
+
query=payload.get("query", ""),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
@app.post("/tools/cmdb_lookup", tags=["AdaptShield SOC Tools"])
|
| 97 |
+
async def tool_cmdb_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 98 |
+
"""Look up service ownership, criticality, and dependency blast radius."""
|
| 99 |
+
return _soc_session(payload).call_tool(
|
| 100 |
+
"cmdb_lookup",
|
| 101 |
+
node=payload.get("node", payload.get("target_node", "unknown")),
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@app.post("/tools/edr_status", tags=["AdaptShield SOC Tools"])
|
| 106 |
+
async def tool_edr_status(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 107 |
+
"""Check endpoint containment and persistence indicators."""
|
| 108 |
+
return _soc_session(payload).call_tool(
|
| 109 |
+
"edr_status",
|
| 110 |
+
node=payload.get("node", payload.get("target_node", "unknown")),
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@app.post("/tools/vuln_lookup", tags=["AdaptShield SOC Tools"])
|
| 115 |
+
async def tool_vuln_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 116 |
+
"""Query internal vulnerability/advisory evidence for a service package."""
|
| 117 |
+
return _soc_session(payload).call_tool(
|
| 118 |
+
"vuln_lookup",
|
| 119 |
+
node=payload.get("node", payload.get("target_node", "unknown")),
|
| 120 |
+
package=payload.get("package", ""),
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@app.post("/tools/identity_lookup", tags=["AdaptShield SOC Tools"])
|
| 125 |
+
async def tool_identity_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 126 |
+
"""Inspect account behavior and unusual source-host affinity for a service identity."""
|
| 127 |
+
return _soc_session(payload).call_tool(
|
| 128 |
+
"identity_lookup",
|
| 129 |
+
node=payload.get("node", payload.get("target_node", "unknown")),
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@app.post("/tools/change_calendar_lookup", tags=["AdaptShield SOC Tools"])
|
| 134 |
+
async def tool_change_calendar_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 135 |
+
"""Check whether a deploy or maintenance window was actually scheduled."""
|
| 136 |
+
return _soc_session(payload).call_tool(
|
| 137 |
+
"change_calendar_lookup",
|
| 138 |
+
node=payload.get("node", payload.get("target_node", "unknown")),
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@app.post("/tools/netflow_lookup", tags=["AdaptShield SOC Tools"])
|
| 143 |
+
async def tool_netflow_lookup(payload: Dict[str, Any] = Body(default_factory=dict)) -> Dict[str, Any]:
|
| 144 |
+
"""Inspect east-west and outbound traffic summaries for the active session."""
|
| 145 |
+
return _soc_session(payload).call_tool(
|
| 146 |
+
"netflow_lookup",
|
| 147 |
+
node=payload.get("node", payload.get("target_node", "unknown")),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _soc_session(payload: Dict[str, Any]) -> AdaptShieldEnvironment:
|
| 152 |
+
session_id = str(payload.get("session_id", ""))
|
| 153 |
+
env = SOC_SESSIONS.get(session_id)
|
| 154 |
+
if env is None:
|
| 155 |
+
raise HTTPException(
|
| 156 |
+
status_code=404,
|
| 157 |
+
detail="Unknown SOC session. Call /soc/reset first.",
|
| 158 |
+
)
|
| 159 |
+
return env
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def main(host: str = "0.0.0.0", port: int = 7860) -> None:
|
| 163 |
+
"""Start the uvicorn server. Call main() to run."""
|
| 164 |
+
import uvicorn
|
| 165 |
+
uvicorn.run(app, host=host, port=port)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
import argparse
|
| 170 |
+
parser = argparse.ArgumentParser()
|
| 171 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 172 |
+
args = parser.parse_args()
|
| 173 |
+
main()
|
server/attacker.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdaptShield Attacker Engine
|
| 3 |
+
|
| 4 |
+
Deterministic scripted Python adversary — NEVER an LLM.
|
| 5 |
+
Implements attack stage progression within episodes.
|
| 6 |
+
Rotates strategies between episodes on a fixed schedule.
|
| 7 |
+
|
| 8 |
+
The internal episode counter is NEVER passed to the agent.
|
| 9 |
+
Agent must infer strategy from network signals only.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import random
|
| 14 |
+
from typing import Any, Dict, Tuple
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# ── Strategy definitions ───────────────────────────────────────────────────
|
| 18 |
+
STRATEGY_MAP: Dict[str, Dict[str, str]] = {
|
| 19 |
+
"brute_force": {
|
| 20 |
+
"correct_action": "rate_limit",
|
| 21 |
+
"correct_target": "auth_service",
|
| 22 |
+
"description": "Credential stuffing attack on authentication service",
|
| 23 |
+
},
|
| 24 |
+
"lateral_movement": {
|
| 25 |
+
"correct_action": "isolate",
|
| 26 |
+
"correct_target": "payment_service",
|
| 27 |
+
"description": "Attacker pivoting from compromised host to internal nodes",
|
| 28 |
+
},
|
| 29 |
+
"exfiltration": {
|
| 30 |
+
"correct_action": "honeypot",
|
| 31 |
+
"correct_target": "database",
|
| 32 |
+
"description": "Slow sustained data exfiltration from database",
|
| 33 |
+
},
|
| 34 |
+
"supply_chain": {
|
| 35 |
+
"correct_action": "patch",
|
| 36 |
+
"correct_target": "api_gateway",
|
| 37 |
+
"description": "Malicious package injected into update pipeline",
|
| 38 |
+
},
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
# Strategy rotation order
|
| 42 |
+
STRATEGY_ROTATION = [
|
| 43 |
+
"brute_force",
|
| 44 |
+
"lateral_movement",
|
| 45 |
+
"exfiltration",
|
| 46 |
+
"supply_chain",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
# Attack stages within an episode
|
| 50 |
+
# Each stage escalates if agent fails to act
|
| 51 |
+
ATTACK_STAGES = ["recon", "exploit", "exfiltration"]
|
| 52 |
+
|
| 53 |
+
# Baseline healthy network metrics
|
| 54 |
+
BASELINE_NODES: Dict[str, Dict[str, Any]] = {
|
| 55 |
+
"auth_service": {"status": "healthy", "request_rate": 120, "error_rate": 0.01, "cpu": 22, "outbound_mb": 2},
|
| 56 |
+
"payment_service": {"status": "healthy", "request_rate": 85, "error_rate": 0.01, "cpu": 31, "outbound_mb": 5},
|
| 57 |
+
"database": {"status": "healthy", "request_rate": 34, "error_rate": 0.01, "cpu": 18, "outbound_mb": 8},
|
| 58 |
+
"api_gateway": {"status": "healthy", "request_rate": 203, "error_rate": 0.02, "cpu": 29, "outbound_mb": 3},
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
SOURCE_SUBNETS = ["192.168.1.x", "10.44.8.x", "172.16.22.x", "203.0.113.x"]
|
| 62 |
+
SERVICE_ACCOUNTS = ["svc_internal", "svc_billing", "svc_reporter", "deploy_bot"]
|
| 63 |
+
PACKAGE_NAMES = ["core-auth-lib", "gateway-router", "payment-sdk", "session-cache"]
|
| 64 |
+
DB_TABLES = ["customer_tokens", "invoice_archive", "payment_methods", "audit_events"]
|
| 65 |
+
ALERT_SOURCES = ["SIEM", "EDR", "WAF", "NETFLOW"]
|
| 66 |
+
|
| 67 |
+
FAMILY_REPLACEMENTS = {
|
| 68 |
+
"train-a": {
|
| 69 |
+
"failed login attempts": "failed login attempts",
|
| 70 |
+
"password spray pattern": "password spray pattern",
|
| 71 |
+
"Unusual internal connections": "Unusual internal connections",
|
| 72 |
+
"Cross-node traffic volume": "Cross-node traffic volume",
|
| 73 |
+
"outbound traffic": "outbound traffic",
|
| 74 |
+
"Large compressed archive creation activity": "Large compressed archive creation activity",
|
| 75 |
+
"unsigned package update request": "unsigned package update request",
|
| 76 |
+
"binary hash mismatch": "binary hash mismatch",
|
| 77 |
+
"Unexpected outbound connection": "Unexpected outbound connection",
|
| 78 |
+
"scheduled analytics cache refresh": "scheduled analytics cache refresh",
|
| 79 |
+
},
|
| 80 |
+
"train-b": {
|
| 81 |
+
"failed login attempts": "authentication rejections",
|
| 82 |
+
"password spray pattern": "credential abuse spread",
|
| 83 |
+
"Unusual internal connections": "Unexpected service-mesh hops",
|
| 84 |
+
"Cross-node traffic volume": "Identity-linked east-west volume",
|
| 85 |
+
"outbound traffic": "egress volume",
|
| 86 |
+
"Large compressed archive creation activity": "archive staging activity",
|
| 87 |
+
"unsigned package update request": "unapproved release request",
|
| 88 |
+
"binary hash mismatch": "manifest integrity drift",
|
| 89 |
+
"Unexpected outbound connection": "post-deploy callback session",
|
| 90 |
+
"scheduled analytics cache refresh": "scheduled cache tuning window",
|
| 91 |
+
},
|
| 92 |
+
"eval-x": {
|
| 93 |
+
"failed login attempts": "auth rejection burst",
|
| 94 |
+
"password spray pattern": "credential reuse sweep",
|
| 95 |
+
"Unusual internal connections": "callback-path fan-out",
|
| 96 |
+
"Cross-node traffic volume": "mesh traffic clustering",
|
| 97 |
+
"outbound traffic": "data egress pressure",
|
| 98 |
+
"Large compressed archive creation activity": "bundle staging activity",
|
| 99 |
+
"unsigned package update request": "release provenance anomaly",
|
| 100 |
+
"binary hash mismatch": "artifact provenance drift",
|
| 101 |
+
"Unexpected outbound connection": "release-linked callback session",
|
| 102 |
+
"scheduled analytics cache refresh": "approved observability warmup",
|
| 103 |
+
},
|
| 104 |
+
"eval-y": {
|
| 105 |
+
"failed login attempts": "lockout storm",
|
| 106 |
+
"password spray pattern": "shared-secret sweep",
|
| 107 |
+
"Unusual internal connections": "lateral fan-out path",
|
| 108 |
+
"Cross-node traffic volume": "cross-domain session churn",
|
| 109 |
+
"outbound traffic": "archive egress volume",
|
| 110 |
+
"Large compressed archive creation activity": "sealed archive staging",
|
| 111 |
+
"unsigned package update request": "cross-approval deploy request",
|
| 112 |
+
"binary hash mismatch": "release integrity anomaly",
|
| 113 |
+
"Unexpected outbound connection": "unknown release callback",
|
| 114 |
+
"scheduled analytics cache refresh": "scheduled edge warmup",
|
| 115 |
+
},
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class AttackerEngine:
|
| 120 |
+
"""
|
| 121 |
+
Polymorphic scripted attacker with stage progression.
|
| 122 |
+
|
| 123 |
+
Within an episode: attack progresses through recon → exploit → exfiltration
|
| 124 |
+
if the agent fails to act correctly. Early correct action stops progression.
|
| 125 |
+
|
| 126 |
+
Between episodes: strategy rotates on a fixed schedule per task.
|
| 127 |
+
Hard task additionally shifts strategy mid-episode after turn 3.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, task_name: str, world_family: str = "train-a"):
|
| 131 |
+
random.seed(int(os.environ.get("ADAPTSHIELD_SEED", random.randint(0, 9999))))
|
| 132 |
+
|
| 133 |
+
self.task_name = task_name
|
| 134 |
+
self.world_family = world_family
|
| 135 |
+
self._episode = 0 # internal — NEVER passed to agent
|
| 136 |
+
self._turn = 0 # within-episode turn counter
|
| 137 |
+
self._stage_idx = 0 # current attack stage index
|
| 138 |
+
self._escalated = False # did agent miss a turn?
|
| 139 |
+
|
| 140 |
+
self._shift_every = {
|
| 141 |
+
"direct-triage": 9999,
|
| 142 |
+
"dual-pivot": 20,
|
| 143 |
+
"polymorphic-zero-day": 10,
|
| 144 |
+
}.get(task_name, 9999)
|
| 145 |
+
|
| 146 |
+
self._noise_rate = 0.15 if task_name == "polymorphic-zero-day" else 0.0
|
| 147 |
+
|
| 148 |
+
# ── Public interface ───────────────────────────────────────────────────
|
| 149 |
+
|
| 150 |
+
def reset_episode(self) -> None:
|
| 151 |
+
"""Called at the start of each episode."""
|
| 152 |
+
self._turn = 0
|
| 153 |
+
self._stage_idx = 0
|
| 154 |
+
self._escalated = False
|
| 155 |
+
|
| 156 |
+
def current_strategy(self) -> str:
|
| 157 |
+
"""Return active strategy — accounts for mid-episode shift on hard task."""
|
| 158 |
+
base_idx = (self._episode // self._shift_every) % len(STRATEGY_ROTATION)
|
| 159 |
+
|
| 160 |
+
# Medium task: one clean pivot from brute force to lateral movement.
|
| 161 |
+
if self.task_name == "dual-pivot" and self._turn >= 3:
|
| 162 |
+
base_idx = (base_idx + 1) % 2
|
| 163 |
+
|
| 164 |
+
# Hard task: strategy shifts mid-episode after turn 3
|
| 165 |
+
if self.task_name == "polymorphic-zero-day" and self._turn >= 3:
|
| 166 |
+
base_idx = (base_idx + 1) % len(STRATEGY_ROTATION)
|
| 167 |
+
|
| 168 |
+
return STRATEGY_ROTATION[base_idx]
|
| 169 |
+
|
| 170 |
+
def current_stage(self) -> str:
|
| 171 |
+
"""Return current attack stage."""
|
| 172 |
+
return ATTACK_STAGES[min(self._stage_idx, len(ATTACK_STAGES) - 1)]
|
| 173 |
+
|
| 174 |
+
def build_observation(self) -> Dict[str, Any]:
|
| 175 |
+
"""
|
| 176 |
+
Build the network observation for current turn.
|
| 177 |
+
Returns ambiguous signals — no explicit strategy names or codes.
|
| 178 |
+
Agent must reason about what the metrics mean.
|
| 179 |
+
"""
|
| 180 |
+
is_benign = random.random() < self._noise_rate
|
| 181 |
+
strategy = self.current_strategy()
|
| 182 |
+
stage = self.current_stage()
|
| 183 |
+
|
| 184 |
+
nodes = {k: dict(v) for k, v in BASELINE_NODES.items()}
|
| 185 |
+
alerts = []
|
| 186 |
+
|
| 187 |
+
if is_benign:
|
| 188 |
+
alerts = self._build_benign_alerts()
|
| 189 |
+
nodes["api_gateway"]["cpu"] = 71
|
| 190 |
+
nodes["api_gateway"]["status"] = "elevated"
|
| 191 |
+
return {
|
| 192 |
+
"network_nodes": nodes,
|
| 193 |
+
"active_alerts": alerts,
|
| 194 |
+
"attack_stage": "none",
|
| 195 |
+
"is_benign": True,
|
| 196 |
+
"strategy": "benign",
|
| 197 |
+
"correct_action": "monitor",
|
| 198 |
+
"correct_target": "none",
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
nodes, alerts = self._build_attack_state(strategy, stage, nodes)
|
| 202 |
+
|
| 203 |
+
return {
|
| 204 |
+
"network_nodes": nodes,
|
| 205 |
+
"active_alerts": alerts,
|
| 206 |
+
"attack_stage": stage,
|
| 207 |
+
"is_benign": False,
|
| 208 |
+
"strategy": strategy,
|
| 209 |
+
"correct_action": STRATEGY_MAP[strategy]["correct_action"],
|
| 210 |
+
"correct_target": STRATEGY_MAP[strategy]["correct_target"],
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
def advance_turn(self, agent_acted_correctly: bool) -> None:
|
| 214 |
+
"""
|
| 215 |
+
Advance turn counter. Escalate attack stage if agent failed.
|
| 216 |
+
Called after grading each Phase 2 action.
|
| 217 |
+
"""
|
| 218 |
+
self._turn += 1
|
| 219 |
+
if not agent_acted_correctly:
|
| 220 |
+
self._escalated = True
|
| 221 |
+
self._stage_idx = min(self._stage_idx + 1, len(ATTACK_STAGES) - 1)
|
| 222 |
+
|
| 223 |
+
def advance_episode(self) -> None:
|
| 224 |
+
"""Called at the end of each episode."""
|
| 225 |
+
self._episode += 1
|
| 226 |
+
self.reset_episode()
|
| 227 |
+
|
| 228 |
+
# ── Internal signal builders ───────────────────────────────────────────
|
| 229 |
+
|
| 230 |
+
def _build_benign_alerts(self) -> list:
|
| 231 |
+
"""Build convincing but harmless noise alerts."""
|
| 232 |
+
deploy_user = random.choice(SERVICE_ACCOUNTS)
|
| 233 |
+
alert_prefix = self._alert_prefix()
|
| 234 |
+
options = [
|
| 235 |
+
[
|
| 236 |
+
f"{alert_prefix} Scheduled database backup completed: 2.3GB archived",
|
| 237 |
+
f"{alert_prefix} User '{deploy_user}' executed routine deployment script",
|
| 238 |
+
f"{alert_prefix} api_gateway: elevated CPU due to analytics batch job",
|
| 239 |
+
],
|
| 240 |
+
[
|
| 241 |
+
f"{alert_prefix} Certificate renewal completed for *.internal.corp",
|
| 242 |
+
f"{alert_prefix} Routine log rotation completed on all nodes",
|
| 243 |
+
f"{alert_prefix} payment_service: normal traffic spike during billing window",
|
| 244 |
+
],
|
| 245 |
+
[
|
| 246 |
+
f"{alert_prefix} Health check: all nodes nominal",
|
| 247 |
+
f"{alert_prefix} auth_service: password policy enforcement triggered for 3 users",
|
| 248 |
+
f"{alert_prefix} Scheduled vulnerability scan completed: 0 critical findings",
|
| 249 |
+
],
|
| 250 |
+
]
|
| 251 |
+
return [self._surface(line) for line in random.choice(options)]
|
| 252 |
+
|
| 253 |
+
def _build_attack_state(
|
| 254 |
+
self,
|
| 255 |
+
strategy: str,
|
| 256 |
+
stage: str,
|
| 257 |
+
nodes: Dict,
|
| 258 |
+
) -> Tuple[Dict, list]:
|
| 259 |
+
"""
|
| 260 |
+
Build network state and alerts for an active attack.
|
| 261 |
+
Signals are intentionally ambiguous — no strategy names or MITRE codes.
|
| 262 |
+
Severity escalates with attack stage.
|
| 263 |
+
"""
|
| 264 |
+
stage_multiplier = {"recon": 1.0, "exploit": 1.6, "exfiltration": 2.4}
|
| 265 |
+
mult = stage_multiplier.get(stage, 1.0)
|
| 266 |
+
|
| 267 |
+
alerts = []
|
| 268 |
+
alert_prefix = self._alert_prefix()
|
| 269 |
+
|
| 270 |
+
if strategy == "brute_force":
|
| 271 |
+
subnet = random.choice(SOURCE_SUBNETS)
|
| 272 |
+
nodes["auth_service"]["status"] = "degraded" if stage != "recon" else "elevated"
|
| 273 |
+
nodes["auth_service"]["request_rate"] = int(340 * mult)
|
| 274 |
+
nodes["auth_service"]["error_rate"] = min(0.48, 0.18 * mult)
|
| 275 |
+
nodes["auth_service"]["cpu"] = min(96, int(45 * mult))
|
| 276 |
+
if self.task_name == "direct-triage":
|
| 277 |
+
nodes["api_gateway"]["status"] = "elevated"
|
| 278 |
+
nodes["api_gateway"]["cpu"] = max(nodes["api_gateway"]["cpu"], 54)
|
| 279 |
+
alerts = [
|
| 280 |
+
f"{alert_prefix} auth_service: {int(120 * mult)} failed login attempts in last 60s from subnet {subnet}",
|
| 281 |
+
f"{alert_prefix} auth_service: error rate {nodes['auth_service']['error_rate']:.0%}; {stage} activity detected",
|
| 282 |
+
f"{alert_prefix} payment_service: normal transaction volume",
|
| 283 |
+
]
|
| 284 |
+
if self.task_name == "direct-triage":
|
| 285 |
+
alerts.append(
|
| 286 |
+
f"{alert_prefix} api_gateway: elevated CPU overlaps with scheduled analytics cache refresh"
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
elif strategy == "lateral_movement":
|
| 290 |
+
account = random.choice(SERVICE_ACCOUNTS)
|
| 291 |
+
nodes["payment_service"]["status"] = "suspicious"
|
| 292 |
+
nodes["payment_service"]["request_rate"] = int(95 * mult)
|
| 293 |
+
nodes["payment_service"]["cpu"] = int(38 * mult)
|
| 294 |
+
nodes["auth_service"]["request_rate"] = int(155 * mult)
|
| 295 |
+
add_auth_distractor = (
|
| 296 |
+
(self.task_name == "dual-pivot" and self._turn == 3) or
|
| 297 |
+
(self.task_name == "polymorphic-zero-day" and self._turn >= 3)
|
| 298 |
+
)
|
| 299 |
+
if add_auth_distractor:
|
| 300 |
+
nodes["auth_service"]["status"] = "elevated"
|
| 301 |
+
nodes["auth_service"]["error_rate"] = min(0.36, 0.14 * mult)
|
| 302 |
+
nodes["auth_service"]["cpu"] = min(90, int(42 * mult))
|
| 303 |
+
alerts = [
|
| 304 |
+
f"{alert_prefix} Unusual internal connections: auth_service -> payment_service ({int(40 * mult)}/min)",
|
| 305 |
+
f"{alert_prefix} payment_service: service account '{account}' authenticated from new host",
|
| 306 |
+
f"{alert_prefix} Cross-node traffic volume {mult:.1f}x above 7-day baseline",
|
| 307 |
+
]
|
| 308 |
+
if add_auth_distractor:
|
| 309 |
+
alerts.append(
|
| 310 |
+
f"{alert_prefix} auth_service: residual failed-login noise from earlier credential attack"
|
| 311 |
+
)
|
| 312 |
+
if self.task_name == "dual-pivot":
|
| 313 |
+
alerts.append(
|
| 314 |
+
f"{alert_prefix} payment_service: emergency checkout tuning window is not on the approved change calendar"
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
elif strategy == "exfiltration":
|
| 318 |
+
table = random.choice(DB_TABLES)
|
| 319 |
+
nodes["database"]["status"] = "suspicious" if stage != "recon" else "elevated"
|
| 320 |
+
nodes["database"]["outbound_mb"] = int(8 * mult * 12)
|
| 321 |
+
nodes["database"]["request_rate"] = int(34 * mult * 3)
|
| 322 |
+
nodes["database"]["cpu"] = int(18 * mult * 1.5)
|
| 323 |
+
alerts = [
|
| 324 |
+
f"{alert_prefix} database: outbound traffic {nodes['database']['outbound_mb']}MB/hr; {mult:.1f}x above baseline",
|
| 325 |
+
f"{alert_prefix} database: {int(280 * mult)} sequential reads from table '{table}'",
|
| 326 |
+
f"{alert_prefix} Large compressed archive creation activity on database node",
|
| 327 |
+
]
|
| 328 |
+
if self.task_name == "polymorphic-zero-day":
|
| 329 |
+
alerts.append(
|
| 330 |
+
f"{alert_prefix} payment_service: reused workload credential appears on east-west callback path"
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
elif strategy == "supply_chain":
|
| 334 |
+
package = random.choice(PACKAGE_NAMES)
|
| 335 |
+
nodes["api_gateway"]["status"] = "suspicious"
|
| 336 |
+
nodes["api_gateway"]["cpu"] = int(29 + 20 * mult)
|
| 337 |
+
alerts = [
|
| 338 |
+
f"{alert_prefix} api_gateway: unsigned package update request received from external source",
|
| 339 |
+
f"{alert_prefix} api_gateway: binary hash mismatch on dependency '{package}'",
|
| 340 |
+
f"{alert_prefix} Unexpected outbound connection from api_gateway to unrecognized host",
|
| 341 |
+
]
|
| 342 |
+
if self.task_name == "polymorphic-zero-day":
|
| 343 |
+
alerts.append(
|
| 344 |
+
f"{alert_prefix} change calendar: no approved gateway deploy window covers this request"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
return nodes, [self._surface(alert) for alert in alerts]
|
| 348 |
+
|
| 349 |
+
def _alert_prefix(self) -> str:
|
| 350 |
+
"""Return deterministic-looking SOC alert metadata under ADAPTSHIELD_SEED."""
|
| 351 |
+
source = random.choice(ALERT_SOURCES)
|
| 352 |
+
alert_id = random.randint(1000, 9999)
|
| 353 |
+
minute = random.randint(0, 59)
|
| 354 |
+
second = random.randint(0, 59)
|
| 355 |
+
return f"[{source}-{alert_id} 03:{minute:02d}:{second:02d}Z]"
|
| 356 |
+
|
| 357 |
+
def _surface(self, text: str) -> str:
|
| 358 |
+
surfaced = str(text)
|
| 359 |
+
for source, target in FAMILY_REPLACEMENTS.get(self.world_family, {}).items():
|
| 360 |
+
surfaced = surfaced.replace(source, target)
|
| 361 |
+
return surfaced
|
server/grader.py
ADDED
|
@@ -0,0 +1,501 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdaptShield Grader
|
| 3 |
+
|
| 4 |
+
Fully deterministic reward computation.
|
| 5 |
+
No NLP, no LLM-as-judge, no external API calls.
|
| 6 |
+
Runs in milliseconds.
|
| 7 |
+
|
| 8 |
+
INVARIANT: All return values are strictly in [0.01, 0.99].
|
| 9 |
+
normalized_score is ALWAYS present in metadata — never missing.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from typing import Any, Dict, List, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# ── Reward constants ───────────────────────────────────────────────────────
|
| 16 |
+
BASE_REWARD = 0.50 # survival — agent gets this for staying alive
|
| 17 |
+
P1_TYPE_BONUS = 0.15 # Phase 1 correctly identified threat type
|
| 18 |
+
P1_TARGET_BONUS = 0.10 # Phase 1 correctly identified target node
|
| 19 |
+
P2_OPTIMAL = 0.39 # correct + efficient action (0.50+0.39 = 0.89 base max)
|
| 20 |
+
P2_HEAVY = 0.18 # correct action but heavy-handed
|
| 21 |
+
P2_UNVERIFIED = 0.10 # right hard-task action without investigative evidence
|
| 22 |
+
P2_WRONG = -0.25 # wrong action on real threat
|
| 23 |
+
FALSE_POSITIVE = -0.39 # acted aggressively on benign event
|
| 24 |
+
STAGE_ESCALATION = -0.10 # penalty applied when attack stage escalates
|
| 25 |
+
CATASTROPHIC = -0.49 # database exfiltration completed → done=True
|
| 26 |
+
|
| 27 |
+
# Optimal vs heavy-handed action mapping
|
| 28 |
+
OPTIMAL_ACTION = {
|
| 29 |
+
"brute_force": "rate_limit",
|
| 30 |
+
"lateral_movement": "isolate",
|
| 31 |
+
"exfiltration": "honeypot",
|
| 32 |
+
"supply_chain": "patch",
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
HEAVY_ACTION = {
|
| 36 |
+
"brute_force": "isolate", # works but unnecessary downtime
|
| 37 |
+
"lateral_movement": "honeypot", # works but slower
|
| 38 |
+
"exfiltration": "isolate", # works but database goes offline
|
| 39 |
+
"supply_chain": "isolate", # works but gateway goes offline
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Catastrophic: exfiltration succeeds if agent uses wrong action 3+ times
|
| 43 |
+
EXFIL_CATASTROPHIC_ACTIONS = ["monitor", "rate_limit", "patch"]
|
| 44 |
+
|
| 45 |
+
# Operational model: criticality and dependency blast radius.
|
| 46 |
+
# This is intentionally lightweight so it strengthens realism without making
|
| 47 |
+
# rewards noisy or hard to train against.
|
| 48 |
+
ASSET_CRITICALITY = {
|
| 49 |
+
"auth_service": 0.70,
|
| 50 |
+
"payment_service": 0.90,
|
| 51 |
+
"database": 1.00,
|
| 52 |
+
"api_gateway": 0.80,
|
| 53 |
+
"none": 0.00,
|
| 54 |
+
"unknown": 0.50,
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
SERVICE_DEPENDENCIES = {
|
| 58 |
+
"auth_service": ["payment_service"],
|
| 59 |
+
"payment_service": ["api_gateway"],
|
| 60 |
+
"database": ["payment_service", "api_gateway"],
|
| 61 |
+
"api_gateway": ["auth_service", "payment_service", "database"],
|
| 62 |
+
"none": [],
|
| 63 |
+
"unknown": [],
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
ACTION_DISRUPTION = {
|
| 67 |
+
"monitor": 0.00,
|
| 68 |
+
"patch": 0.06,
|
| 69 |
+
"rate_limit": 0.10,
|
| 70 |
+
"honeypot": 0.12,
|
| 71 |
+
"isolate": 0.35,
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
MAX_OPERATIONAL_PENALTY = 0.05
|
| 75 |
+
MAX_MISSION_ADJUSTMENT = 0.04
|
| 76 |
+
|
| 77 |
+
BASE_REQUIRED_TOOL_FUSION = {
|
| 78 |
+
"brute_force": {"log_search", "cmdb_lookup"},
|
| 79 |
+
"lateral_movement": {"edr_status", "log_search"},
|
| 80 |
+
"exfiltration": {"log_search", "edr_status"},
|
| 81 |
+
"supply_chain": {"vuln_lookup", "log_search"},
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
TASK_REQUIRED_TOOL_FUSION = {
|
| 85 |
+
"direct-triage": {
|
| 86 |
+
"brute_force": {"log_search"},
|
| 87 |
+
},
|
| 88 |
+
"dual-pivot": {
|
| 89 |
+
"lateral_movement": {"edr_status", "log_search", "identity_lookup"},
|
| 90 |
+
},
|
| 91 |
+
"polymorphic-zero-day": {
|
| 92 |
+
"brute_force": {"log_search", "cmdb_lookup", "identity_lookup"},
|
| 93 |
+
"lateral_movement": {"edr_status", "log_search", "identity_lookup", "cmdb_lookup"},
|
| 94 |
+
"exfiltration": {"log_search", "edr_status", "netflow_lookup", "cmdb_lookup"},
|
| 95 |
+
"supply_chain": {"vuln_lookup", "log_search", "change_calendar_lookup", "cmdb_lookup"},
|
| 96 |
+
},
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def grade_step(
|
| 101 |
+
phase1_action: Dict[str, Any],
|
| 102 |
+
phase2_action: Dict[str, Any],
|
| 103 |
+
turn_config: Dict[str, Any],
|
| 104 |
+
stage: str,
|
| 105 |
+
consecutive_wrong: int,
|
| 106 |
+
task_name: str = "",
|
| 107 |
+
foothold_established: bool = False,
|
| 108 |
+
mission_profile: Dict[str, Any] | None = None,
|
| 109 |
+
tool_context: Dict[str, Any] | None = None,
|
| 110 |
+
) -> Tuple[float, bool, Dict[str, Any]]:
|
| 111 |
+
"""
|
| 112 |
+
Grade a complete two-phase step.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
phase1_action: Agent's Phase 1 output (threat assessment)
|
| 116 |
+
phase2_action: Agent's Phase 2 output (defensive action)
|
| 117 |
+
turn_config: Ground truth from AttackerEngine.build_observation()
|
| 118 |
+
stage: Current attack stage (recon/exploit/exfiltration)
|
| 119 |
+
consecutive_wrong: How many consecutive wrong actions agent has taken
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
(reward, catastrophic_failure, info_dict)
|
| 123 |
+
reward is ALWAYS in range [0.01, 0.99]
|
| 124 |
+
"""
|
| 125 |
+
is_benign = turn_config.get("is_benign", False)
|
| 126 |
+
strategy = turn_config.get("strategy", "benign")
|
| 127 |
+
correct_action = turn_config.get("correct_action", "monitor")
|
| 128 |
+
correct_target = turn_config.get("correct_target", "none")
|
| 129 |
+
mission_profile = mission_profile or {}
|
| 130 |
+
tool_context = tool_context or {}
|
| 131 |
+
contextual_countermeasure = (
|
| 132 |
+
task_name == "polymorphic-zero-day" and
|
| 133 |
+
foothold_established and
|
| 134 |
+
strategy == "lateral_movement"
|
| 135 |
+
)
|
| 136 |
+
if contextual_countermeasure:
|
| 137 |
+
correct_action = "honeypot"
|
| 138 |
+
|
| 139 |
+
p1_type = str(phase1_action.get("threat_type", ""))
|
| 140 |
+
p1_target = str(phase1_action.get("target_node", ""))
|
| 141 |
+
p2_action = str(phase2_action.get("action", "monitor"))
|
| 142 |
+
p2_target = str(phase2_action.get("target_node", ""))
|
| 143 |
+
|
| 144 |
+
reward = BASE_REWARD
|
| 145 |
+
catastrophic = False
|
| 146 |
+
|
| 147 |
+
info = {
|
| 148 |
+
"strategy": strategy,
|
| 149 |
+
"stage": stage,
|
| 150 |
+
"correct_action": correct_action,
|
| 151 |
+
"correct_target": correct_target,
|
| 152 |
+
"p1_threat_type": p1_type,
|
| 153 |
+
"p1_target": p1_target,
|
| 154 |
+
"p2_action": p2_action,
|
| 155 |
+
"p2_target": p2_target,
|
| 156 |
+
"is_benign": is_benign,
|
| 157 |
+
"p1_type_correct": False,
|
| 158 |
+
"p1_target_correct": False,
|
| 159 |
+
"score_reason": "",
|
| 160 |
+
"business_impact": 0.0,
|
| 161 |
+
"availability_impact": 0.0,
|
| 162 |
+
"security_risk": 0.0,
|
| 163 |
+
"dependency_blast_radius": [],
|
| 164 |
+
"operational_penalty": 0.0,
|
| 165 |
+
"foothold_established": foothold_established,
|
| 166 |
+
"contextual_countermeasure": contextual_countermeasure,
|
| 167 |
+
"mission_profile": mission_profile,
|
| 168 |
+
"mission_alignment": "neutral",
|
| 169 |
+
"mission_adjustment": 0.0,
|
| 170 |
+
"tool_verification_required": False,
|
| 171 |
+
"tool_evidence_found": False,
|
| 172 |
+
"tool_fusion_required": [],
|
| 173 |
+
"tool_fusion_found": [],
|
| 174 |
+
"tool_count": int(tool_context.get("tool_count", 0) or 0),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
# ── False positive path ────────────────────────────────────────────────
|
| 178 |
+
if is_benign:
|
| 179 |
+
if p2_action in ("rate_limit", "isolate", "honeypot", "patch"):
|
| 180 |
+
reward += FALSE_POSITIVE
|
| 181 |
+
info["acted_correctly"] = False
|
| 182 |
+
info["score_reason"] = (
|
| 183 |
+
f"False positive: used {p2_action} on benign event"
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
reward += 0.15
|
| 187 |
+
info["acted_correctly"] = True
|
| 188 |
+
info["score_reason"] = "Correct: monitored benign event without overreacting"
|
| 189 |
+
_apply_operational_impact(
|
| 190 |
+
info=info,
|
| 191 |
+
action=p2_action,
|
| 192 |
+
target=p2_target,
|
| 193 |
+
result_kind="false_positive" if p2_action != "monitor" else "optimal",
|
| 194 |
+
strategy=strategy,
|
| 195 |
+
stage=stage,
|
| 196 |
+
)
|
| 197 |
+
_apply_mission_alignment(
|
| 198 |
+
info=info,
|
| 199 |
+
action=p2_action,
|
| 200 |
+
target=p2_target,
|
| 201 |
+
result_kind="false_positive" if p2_action != "monitor" else "optimal",
|
| 202 |
+
mission_profile=mission_profile,
|
| 203 |
+
)
|
| 204 |
+
reward += info["mission_adjustment"]
|
| 205 |
+
reward -= info["operational_penalty"]
|
| 206 |
+
return _clamp(reward), False, info
|
| 207 |
+
|
| 208 |
+
# ── Phase 1 accuracy bonuses ───────────────────────────────────────────
|
| 209 |
+
if p1_type == strategy:
|
| 210 |
+
reward += P1_TYPE_BONUS
|
| 211 |
+
info["p1_type_correct"] = True
|
| 212 |
+
|
| 213 |
+
if p1_target == correct_target:
|
| 214 |
+
reward += P1_TARGET_BONUS
|
| 215 |
+
info["p1_target_correct"] = True
|
| 216 |
+
|
| 217 |
+
# ── Catastrophic failure: exfiltration completed ───────────────────────
|
| 218 |
+
if (strategy == "exfiltration" and
|
| 219 |
+
stage == "exfiltration" and
|
| 220 |
+
p2_action in EXFIL_CATASTROPHIC_ACTIONS):
|
| 221 |
+
reward = BASE_REWARD + CATASTROPHIC
|
| 222 |
+
catastrophic = True
|
| 223 |
+
info["score_reason"] = (
|
| 224 |
+
"CATASTROPHIC: Database exfiltration completed — "
|
| 225 |
+
f"agent used {p2_action} instead of honeypot"
|
| 226 |
+
)
|
| 227 |
+
_apply_operational_impact(
|
| 228 |
+
info=info,
|
| 229 |
+
action=p2_action,
|
| 230 |
+
target=p2_target,
|
| 231 |
+
result_kind="wrong",
|
| 232 |
+
strategy=strategy,
|
| 233 |
+
stage=stage,
|
| 234 |
+
)
|
| 235 |
+
_apply_mission_alignment(
|
| 236 |
+
info=info,
|
| 237 |
+
action=p2_action,
|
| 238 |
+
target=p2_target,
|
| 239 |
+
result_kind="wrong",
|
| 240 |
+
mission_profile=mission_profile,
|
| 241 |
+
)
|
| 242 |
+
reward += info["mission_adjustment"]
|
| 243 |
+
return _clamp(reward), catastrophic, info
|
| 244 |
+
|
| 245 |
+
# ── Stage escalation penalty ───────────────────────────────────────────
|
| 246 |
+
if stage == "exploit" and consecutive_wrong >= 1:
|
| 247 |
+
reward += STAGE_ESCALATION
|
| 248 |
+
elif stage == "exfiltration" and consecutive_wrong >= 2:
|
| 249 |
+
reward += STAGE_ESCALATION * 2
|
| 250 |
+
|
| 251 |
+
# ── Phase 2 action grading ─────────────────────────────────────────────
|
| 252 |
+
optimal = correct_action
|
| 253 |
+
heavy = "" if contextual_countermeasure else HEAVY_ACTION.get(strategy, "")
|
| 254 |
+
if heavy == optimal:
|
| 255 |
+
heavy = ""
|
| 256 |
+
requires_tool_verification = (
|
| 257 |
+
not is_benign and
|
| 258 |
+
strategy in OPTIMAL_ACTION and
|
| 259 |
+
(
|
| 260 |
+
task_name == "polymorphic-zero-day" or
|
| 261 |
+
(task_name == "dual-pivot" and strategy == "lateral_movement") or
|
| 262 |
+
(task_name == "direct-triage" and strategy == "brute_force")
|
| 263 |
+
)
|
| 264 |
+
)
|
| 265 |
+
required_tools = _required_tool_fusion(task_name=task_name, strategy=strategy)
|
| 266 |
+
tool_evidence_found, fusion_found = _has_relevant_tool_evidence(
|
| 267 |
+
tool_context=tool_context,
|
| 268 |
+
strategy=strategy,
|
| 269 |
+
target=correct_target,
|
| 270 |
+
required_tools=required_tools,
|
| 271 |
+
)
|
| 272 |
+
info["tool_verification_required"] = requires_tool_verification
|
| 273 |
+
info["tool_evidence_found"] = tool_evidence_found
|
| 274 |
+
info["tool_fusion_required"] = sorted(required_tools)
|
| 275 |
+
info["tool_fusion_found"] = sorted(fusion_found)
|
| 276 |
+
|
| 277 |
+
if (
|
| 278 |
+
p2_action == optimal and
|
| 279 |
+
p2_target == correct_target and
|
| 280 |
+
requires_tool_verification and
|
| 281 |
+
not tool_evidence_found
|
| 282 |
+
):
|
| 283 |
+
reward += P2_UNVERIFIED
|
| 284 |
+
result_kind = "unverified"
|
| 285 |
+
info["score_reason"] = (
|
| 286 |
+
f"Unverified correct action: {p2_action} on {p2_target} would help, "
|
| 287 |
+
f"but {task_name or 'this task'} requires stronger SOC evidence before full credit"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
elif p2_action == optimal and p2_target == correct_target:
|
| 291 |
+
reward += P2_OPTIMAL
|
| 292 |
+
result_kind = "optimal"
|
| 293 |
+
if contextual_countermeasure:
|
| 294 |
+
info["score_reason"] = (
|
| 295 |
+
f"Context-aware optimal: {p2_action} on {p2_target} — "
|
| 296 |
+
"foothold already established, so deception beats isolation"
|
| 297 |
+
)
|
| 298 |
+
else:
|
| 299 |
+
info["score_reason"] = (
|
| 300 |
+
f"Optimal: {p2_action} on {p2_target} — attack stopped efficiently"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
elif p2_action == optimal and p2_target != correct_target:
|
| 304 |
+
reward += P2_HEAVY * 0.5
|
| 305 |
+
result_kind = "wrong_target"
|
| 306 |
+
info["score_reason"] = (
|
| 307 |
+
f"Right action ({p2_action}) but wrong target "
|
| 308 |
+
f"(got {p2_target}, needed {correct_target})"
|
| 309 |
+
)
|
| 310 |
+
|
| 311 |
+
elif p2_action == heavy and p2_target == correct_target:
|
| 312 |
+
reward += P2_HEAVY
|
| 313 |
+
result_kind = "heavy"
|
| 314 |
+
info["score_reason"] = (
|
| 315 |
+
f"Heavy-handed: {p2_action} stopped attack on {p2_target} "
|
| 316 |
+
f"but caused unnecessary service disruption"
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
else:
|
| 320 |
+
reward += P2_WRONG
|
| 321 |
+
result_kind = "wrong"
|
| 322 |
+
info["score_reason"] = (
|
| 323 |
+
f"Wrong: {p2_action} on {p2_target} — "
|
| 324 |
+
f"needed {correct_action} on {correct_target}"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
acted_correctly = p2_action in (optimal, heavy) and p2_target == correct_target
|
| 328 |
+
info["acted_correctly"] = acted_correctly
|
| 329 |
+
_apply_operational_impact(
|
| 330 |
+
info=info,
|
| 331 |
+
action=p2_action,
|
| 332 |
+
target=p2_target,
|
| 333 |
+
result_kind=result_kind,
|
| 334 |
+
strategy=strategy,
|
| 335 |
+
stage=stage,
|
| 336 |
+
)
|
| 337 |
+
_apply_mission_alignment(
|
| 338 |
+
info=info,
|
| 339 |
+
action=p2_action,
|
| 340 |
+
target=p2_target,
|
| 341 |
+
result_kind=result_kind,
|
| 342 |
+
mission_profile=mission_profile,
|
| 343 |
+
)
|
| 344 |
+
reward += info["mission_adjustment"]
|
| 345 |
+
reward -= info["operational_penalty"]
|
| 346 |
+
|
| 347 |
+
return _clamp(reward), catastrophic, info
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def _apply_mission_alignment(
|
| 351 |
+
info: Dict[str, Any],
|
| 352 |
+
action: str,
|
| 353 |
+
target: str,
|
| 354 |
+
result_kind: str,
|
| 355 |
+
mission_profile: Dict[str, Any],
|
| 356 |
+
) -> None:
|
| 357 |
+
sla_priority = str(mission_profile.get("sla_priority", "balanced"))
|
| 358 |
+
primary_asset = str(mission_profile.get("primary_asset", "unknown"))
|
| 359 |
+
risk_tolerance = str(mission_profile.get("risk_tolerance", "medium"))
|
| 360 |
+
|
| 361 |
+
adjustment = 0.0
|
| 362 |
+
alignment = "neutral"
|
| 363 |
+
|
| 364 |
+
if sla_priority == "availability" and action == "isolate" and target == primary_asset:
|
| 365 |
+
adjustment -= MAX_MISSION_ADJUSTMENT
|
| 366 |
+
alignment = "sla_violation"
|
| 367 |
+
elif sla_priority == "availability" and result_kind == "optimal" and action in ("rate_limit", "patch", "monitor"):
|
| 368 |
+
adjustment += MAX_MISSION_ADJUSTMENT / 2
|
| 369 |
+
alignment = "sla_aligned"
|
| 370 |
+
elif sla_priority == "containment" and result_kind == "optimal" and action in ("honeypot", "isolate", "patch"):
|
| 371 |
+
adjustment += MAX_MISSION_ADJUSTMENT / 2
|
| 372 |
+
alignment = "containment_aligned"
|
| 373 |
+
elif risk_tolerance == "low" and result_kind in ("wrong", "wrong_target"):
|
| 374 |
+
adjustment -= MAX_MISSION_ADJUSTMENT / 2
|
| 375 |
+
alignment = "risk_misaligned"
|
| 376 |
+
|
| 377 |
+
info["mission_alignment"] = alignment
|
| 378 |
+
info["mission_adjustment"] = round(adjustment, 2)
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def _apply_operational_impact(
|
| 382 |
+
info: Dict[str, Any],
|
| 383 |
+
action: str,
|
| 384 |
+
target: str,
|
| 385 |
+
result_kind: str,
|
| 386 |
+
strategy: str,
|
| 387 |
+
stage: str,
|
| 388 |
+
) -> None:
|
| 389 |
+
"""
|
| 390 |
+
Add deterministic business-impact telemetry and a small bounded penalty.
|
| 391 |
+
|
| 392 |
+
The penalty is intentionally capped at 0.05 so existing learning curves keep
|
| 393 |
+
their shape while demos can explain service criticality and blast radius.
|
| 394 |
+
"""
|
| 395 |
+
criticality = ASSET_CRITICALITY.get(target, ASSET_CRITICALITY["unknown"])
|
| 396 |
+
disruption = ACTION_DISRUPTION.get(action, 0.10)
|
| 397 |
+
dependents = SERVICE_DEPENDENCIES.get(target, [])
|
| 398 |
+
dependency_factor = min(1.0, 0.15 * len(dependents))
|
| 399 |
+
|
| 400 |
+
availability = round(min(1.0, disruption * (criticality + dependency_factor)), 2)
|
| 401 |
+
security = _security_risk(result_kind=result_kind, strategy=strategy, stage=stage)
|
| 402 |
+
impact = round(min(1.0, availability + security), 2)
|
| 403 |
+
|
| 404 |
+
if result_kind == "optimal":
|
| 405 |
+
penalty = 0.0
|
| 406 |
+
elif result_kind == "unverified":
|
| 407 |
+
penalty = round(min(MAX_OPERATIONAL_PENALTY, impact * MAX_OPERATIONAL_PENALTY / 2), 2)
|
| 408 |
+
else:
|
| 409 |
+
penalty = round(min(MAX_OPERATIONAL_PENALTY, impact * MAX_OPERATIONAL_PENALTY), 2)
|
| 410 |
+
|
| 411 |
+
info["business_impact"] = impact
|
| 412 |
+
info["availability_impact"] = availability
|
| 413 |
+
info["security_risk"] = security
|
| 414 |
+
info["dependency_blast_radius"] = dependents if disruption > 0 else []
|
| 415 |
+
info["operational_penalty"] = penalty
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def _security_risk(result_kind: str, strategy: str, stage: str) -> float:
|
| 419 |
+
if result_kind in ("optimal", "heavy"):
|
| 420 |
+
return 0.0
|
| 421 |
+
if result_kind == "unverified":
|
| 422 |
+
return 0.08
|
| 423 |
+
if result_kind == "false_positive":
|
| 424 |
+
return 0.0
|
| 425 |
+
|
| 426 |
+
stage_risk = {
|
| 427 |
+
"recon": 0.18,
|
| 428 |
+
"exploit": 0.32,
|
| 429 |
+
"exfiltration": 0.50,
|
| 430 |
+
}.get(stage, 0.20)
|
| 431 |
+
|
| 432 |
+
if strategy == "exfiltration":
|
| 433 |
+
stage_risk += 0.15
|
| 434 |
+
elif strategy == "lateral_movement":
|
| 435 |
+
stage_risk += 0.08
|
| 436 |
+
|
| 437 |
+
return round(min(1.0, stage_risk), 2)
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _has_relevant_tool_evidence(
|
| 441 |
+
tool_context: Dict[str, Any],
|
| 442 |
+
strategy: str,
|
| 443 |
+
target: str,
|
| 444 |
+
required_tools: set[str],
|
| 445 |
+
) -> Tuple[bool, set[str]]:
|
| 446 |
+
fusion_found = {
|
| 447 |
+
str(result.get("tool", ""))
|
| 448 |
+
for result in tool_context.get("tool_results", []) or []
|
| 449 |
+
if str(result.get("node", "")) == target
|
| 450 |
+
}
|
| 451 |
+
has_attack_evidence = False
|
| 452 |
+
for evidence in tool_context.get("evidence", []) or []:
|
| 453 |
+
if (
|
| 454 |
+
str(evidence.get("evidence_type", "")) == strategy and
|
| 455 |
+
str(evidence.get("node", "")) == target and
|
| 456 |
+
bool(evidence.get("verified", False))
|
| 457 |
+
):
|
| 458 |
+
has_attack_evidence = True
|
| 459 |
+
break
|
| 460 |
+
|
| 461 |
+
return has_attack_evidence and required_tools.issubset(fusion_found), fusion_found
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _required_tool_fusion(task_name: str, strategy: str) -> set[str]:
|
| 465 |
+
task_rules = TASK_REQUIRED_TOOL_FUSION.get(task_name, {})
|
| 466 |
+
if strategy in task_rules:
|
| 467 |
+
return set(task_rules[strategy])
|
| 468 |
+
return set(BASE_REQUIRED_TOOL_FUSION.get(strategy, set()))
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def _clamp(value: float) -> float:
|
| 472 |
+
"""Strict bounds: never exactly 0.0 or 1.0."""
|
| 473 |
+
return max(0.01, min(0.99, round(value, 2)))
|
| 474 |
+
|
| 475 |
+
|
| 476 |
+
def normalize_episode_score(rewards: List[float]) -> float:
|
| 477 |
+
"""
|
| 478 |
+
Normalize episode rewards to a single score strictly in (0.01, 0.99).
|
| 479 |
+
ALWAYS returns a value — never raises, never returns exactly 0 or 1.
|
| 480 |
+
"""
|
| 481 |
+
if not rewards:
|
| 482 |
+
return 0.50
|
| 483 |
+
|
| 484 |
+
total = sum(rewards)
|
| 485 |
+
n = len(rewards)
|
| 486 |
+
|
| 487 |
+
# Per-step rewards are clamped before they enter the episode reward list,
|
| 488 |
+
# so normalization must use the reachable ceiling instead of the raw
|
| 489 |
+
# unclamped sum of bonuses. Otherwise perfect episodes top out around 0.87.
|
| 490 |
+
max_step_reward = _clamp(
|
| 491 |
+
BASE_REWARD + P2_OPTIMAL + P1_TYPE_BONUS + P1_TARGET_BONUS + MAX_MISSION_ADJUSTMENT
|
| 492 |
+
)
|
| 493 |
+
min_step_reward = _clamp(BASE_REWARD + CATASTROPHIC)
|
| 494 |
+
max_poss = n * max_step_reward
|
| 495 |
+
min_poss = n * min_step_reward
|
| 496 |
+
|
| 497 |
+
if max_poss == min_poss:
|
| 498 |
+
return 0.50
|
| 499 |
+
|
| 500 |
+
raw = (total - min_poss) / (max_poss - min_poss)
|
| 501 |
+
return _clamp(raw)
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv[core]>=0.2.0
|
| 2 |
+
fastapi>=0.115.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
|
server/scenarios.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
AdaptShield Scenarios
|
| 3 |
+
|
| 4 |
+
Observation builder and system prompts.
|
| 5 |
+
IMPORTANT: No MITRE codes in alerts. No explicit strategy names.
|
| 6 |
+
Agent must reason from raw metrics — not pattern match on codes.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import random
|
| 10 |
+
from typing import Any, Dict, List
|
| 11 |
+
|
| 12 |
+
VALID_NODES = ["auth_service", "payment_service", "database", "api_gateway"]
|
| 13 |
+
VALID_ACTIONS = ["rate_limit", "isolate", "honeypot", "patch", "monitor"]
|
| 14 |
+
|
| 15 |
+
TASK_CONFIGS = {
|
| 16 |
+
"direct-triage": {
|
| 17 |
+
"max_turns": 5,
|
| 18 |
+
"description": "Single fixed attacker strategy. Learn baseline threat response.",
|
| 19 |
+
"mission_profile": {
|
| 20 |
+
"mission_id": "login_stability",
|
| 21 |
+
"primary_asset": "auth_service",
|
| 22 |
+
"sla_priority": "availability",
|
| 23 |
+
"risk_tolerance": "medium",
|
| 24 |
+
"objective": "Preserve user login availability while stopping credential abuse.",
|
| 25 |
+
},
|
| 26 |
+
},
|
| 27 |
+
"dual-pivot": {
|
| 28 |
+
"max_turns": 6,
|
| 29 |
+
"description": "Two strategies alternating every 20 episodes.",
|
| 30 |
+
"mission_profile": {
|
| 31 |
+
"mission_id": "checkout_continuity",
|
| 32 |
+
"primary_asset": "payment_service",
|
| 33 |
+
"sla_priority": "availability",
|
| 34 |
+
"risk_tolerance": "medium",
|
| 35 |
+
"objective": "Keep checkout online while containing internal movement.",
|
| 36 |
+
},
|
| 37 |
+
},
|
| 38 |
+
"polymorphic-zero-day": {
|
| 39 |
+
"max_turns": 8,
|
| 40 |
+
"description": "All four strategies with mid-episode shift and noise injection.",
|
| 41 |
+
"mission_profile": {
|
| 42 |
+
"mission_id": "breach_containment",
|
| 43 |
+
"primary_asset": "database",
|
| 44 |
+
"sla_priority": "containment",
|
| 45 |
+
"risk_tolerance": "low",
|
| 46 |
+
"objective": "Protect sensitive data and adapt response if attacker persistence is established.",
|
| 47 |
+
},
|
| 48 |
+
},
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
WORLD_FAMILY_SPLITS = {
|
| 52 |
+
"train": ["train-a", "train-b"],
|
| 53 |
+
"eval": ["eval-x", "eval-y"],
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
TASK_OPERATIONAL_MODES = {
|
| 57 |
+
"direct-triage": ["availability_guarded"],
|
| 58 |
+
"dual-pivot": ["containment_first", "evidence_preservation"],
|
| 59 |
+
"polymorphic-zero-day": [
|
| 60 |
+
"containment_first",
|
| 61 |
+
"forensic_hold",
|
| 62 |
+
"business_continuity",
|
| 63 |
+
],
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
FAMILY_MISSION_NOTES = {
|
| 67 |
+
"train-a": "Primary incident feed emphasizes login telemetry and CMDB-linked service relationships.",
|
| 68 |
+
"train-b": "Primary incident feed emphasizes identity anomalies, service-account movement, and deploy context.",
|
| 69 |
+
"eval-x": "Primary incident feed emphasizes east-west callback patterns, release provenance drift, and egress clustering.",
|
| 70 |
+
"eval-y": "Primary incident feed emphasizes rejection bursts, archive staging, and cross-domain approval mismatches.",
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
MODE_HINTS = {
|
| 74 |
+
"availability_guarded": {
|
| 75 |
+
"sla_priority": "availability",
|
| 76 |
+
"risk_tolerance": "medium",
|
| 77 |
+
"objective_suffix": "Use one confirming signal before disruptive action when login noise overlaps with maintenance chatter.",
|
| 78 |
+
},
|
| 79 |
+
"containment_first": {
|
| 80 |
+
"sla_priority": "containment",
|
| 81 |
+
"risk_tolerance": "low",
|
| 82 |
+
"objective_suffix": "Prefer decisive containment once corroborating evidence suggests compromise is active.",
|
| 83 |
+
},
|
| 84 |
+
"evidence_preservation": {
|
| 85 |
+
"sla_priority": "balanced",
|
| 86 |
+
"risk_tolerance": "medium",
|
| 87 |
+
"objective_suffix": "Preserve attacker visibility during checkout pivots until identity misuse is corroborated across tools.",
|
| 88 |
+
},
|
| 89 |
+
"forensic_hold": {
|
| 90 |
+
"sla_priority": "containment",
|
| 91 |
+
"risk_tolerance": "low",
|
| 92 |
+
"objective_suffix": "Preserve attacker tradecraft long enough to map the callback path before cutting access entirely.",
|
| 93 |
+
},
|
| 94 |
+
"business_continuity": {
|
| 95 |
+
"sla_priority": "availability",
|
| 96 |
+
"risk_tolerance": "medium",
|
| 97 |
+
"objective_suffix": "Reduce blast radius while protecting customer-facing continuity whenever deception can safely buy time.",
|
| 98 |
+
},
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
PHASE1_SYSTEM = """You are a Threat Analyst defending a 4-node enterprise network.
|
| 102 |
+
Analyze the network metrics and security alerts to classify the threat.
|
| 103 |
+
|
| 104 |
+
Network nodes: auth_service, payment_service, database, api_gateway
|
| 105 |
+
|
| 106 |
+
Attack strategies to recognize:
|
| 107 |
+
- brute_force: High failed login rate on auth_service, error spikes
|
| 108 |
+
- lateral_movement: Unusual cross-node connections, service account anomalies
|
| 109 |
+
- exfiltration: Anomalous outbound data volume from database node
|
| 110 |
+
- supply_chain: Suspicious update requests or hash mismatches on api_gateway
|
| 111 |
+
- benign: Normal operations — scheduled jobs, routine activity
|
| 112 |
+
|
| 113 |
+
Respond ONLY with valid JSON, no other text:
|
| 114 |
+
{
|
| 115 |
+
"threat_type": "<brute_force|lateral_movement|exfiltration|supply_chain|benign>",
|
| 116 |
+
"confidence": <0.0 to 1.0>,
|
| 117 |
+
"target_node": "<auth_service|payment_service|database|api_gateway>",
|
| 118 |
+
"recommended_action": "<rate_limit|isolate|honeypot|patch|monitor>",
|
| 119 |
+
"reasoning": "<one sentence chain of thought>"
|
| 120 |
+
}"""
|
| 121 |
+
|
| 122 |
+
PHASE2_SYSTEM = """You are a Tactical Executor defending a 4-node enterprise network.
|
| 123 |
+
You receive a threat assessment from the Threat Analyst and must execute the defense.
|
| 124 |
+
You CANNOT see raw network logs — act only on the assessment provided.
|
| 125 |
+
|
| 126 |
+
Available actions:
|
| 127 |
+
- rate_limit: Throttle traffic to node. Light touch, keeps service online. Best for DoS/brute force.
|
| 128 |
+
- isolate: Take node completely offline. Stops spread but causes downtime. Use for lateral movement.
|
| 129 |
+
- honeypot: Redirect attacker to decoy system. Best for data exfiltration attempts.
|
| 130 |
+
- patch: Apply security update. Targeted fix for supply chain attacks.
|
| 131 |
+
- monitor: Observe without acting. Use only when genuinely uncertain or event is benign.
|
| 132 |
+
|
| 133 |
+
Respond ONLY with valid JSON, no other text:
|
| 134 |
+
{
|
| 135 |
+
"action": "<rate_limit|isolate|honeypot|patch|monitor>",
|
| 136 |
+
"target_node": "<auth_service|payment_service|database|api_gateway>",
|
| 137 |
+
"reasoning": "<one sentence chain of thought>"
|
| 138 |
+
}"""
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def build_phase1_obs(
|
| 142 |
+
turn_config: Dict[str, Any],
|
| 143 |
+
history: List[Dict[str, str]],
|
| 144 |
+
task_name: str,
|
| 145 |
+
turn: int,
|
| 146 |
+
max_turns: int,
|
| 147 |
+
episode_id: str,
|
| 148 |
+
mission_profile: Dict[str, Any] | None = None,
|
| 149 |
+
) -> Dict[str, Any]:
|
| 150 |
+
"""Build Phase 1 observation — full network state visible."""
|
| 151 |
+
mission_profile = mission_profile or {}
|
| 152 |
+
return {
|
| 153 |
+
"scenario_id": episode_id,
|
| 154 |
+
"task_name": task_name,
|
| 155 |
+
"phase": 1,
|
| 156 |
+
"turn": turn,
|
| 157 |
+
"max_turns": max_turns,
|
| 158 |
+
"network_nodes": turn_config["network_nodes"],
|
| 159 |
+
"active_alerts": turn_config["active_alerts"],
|
| 160 |
+
"attack_stage": turn_config.get("attack_stage", "none"),
|
| 161 |
+
"history": history[-3:],
|
| 162 |
+
"phase1_assessment": None,
|
| 163 |
+
"last_action_result": None,
|
| 164 |
+
"system_context": _with_mission_context(PHASE1_SYSTEM, mission_profile),
|
| 165 |
+
"available_actions": VALID_ACTIONS,
|
| 166 |
+
"reward": 0.0,
|
| 167 |
+
"done": False,
|
| 168 |
+
"metadata": {
|
| 169 |
+
"episode_id": episode_id,
|
| 170 |
+
"normalized_score": 0.50, # always present from step 1
|
| 171 |
+
"mission_profile": mission_profile,
|
| 172 |
+
},
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def build_phase2_obs(
|
| 177 |
+
phase1_output: Dict[str, Any],
|
| 178 |
+
history: List[Dict[str, str]],
|
| 179 |
+
task_name: str,
|
| 180 |
+
turn: int,
|
| 181 |
+
max_turns: int,
|
| 182 |
+
episode_id: str,
|
| 183 |
+
current_score: float,
|
| 184 |
+
mission_profile: Dict[str, Any] | None = None,
|
| 185 |
+
) -> Dict[str, Any]:
|
| 186 |
+
"""
|
| 187 |
+
Build Phase 2 observation.
|
| 188 |
+
CRITICAL: network_nodes and active_alerts are EMPTY.
|
| 189 |
+
Phase 2 agent is blind to raw state — sees only Phase 1 assessment.
|
| 190 |
+
"""
|
| 191 |
+
mission_profile = mission_profile or {}
|
| 192 |
+
return {
|
| 193 |
+
"scenario_id": episode_id,
|
| 194 |
+
"task_name": task_name,
|
| 195 |
+
"phase": 2,
|
| 196 |
+
"turn": turn,
|
| 197 |
+
"max_turns": max_turns,
|
| 198 |
+
"network_nodes": {}, # deliberately empty
|
| 199 |
+
"active_alerts": [], # deliberately empty
|
| 200 |
+
"attack_stage": "hidden",
|
| 201 |
+
"history": history[-3:],
|
| 202 |
+
"phase1_assessment": phase1_output,
|
| 203 |
+
"last_action_result": None,
|
| 204 |
+
"system_context": _with_mission_context(PHASE2_SYSTEM, mission_profile),
|
| 205 |
+
"available_actions": VALID_ACTIONS,
|
| 206 |
+
"reward": 0.0,
|
| 207 |
+
"done": False,
|
| 208 |
+
"metadata": {
|
| 209 |
+
"episode_id": episode_id,
|
| 210 |
+
"normalized_score": current_score, # always present
|
| 211 |
+
"mission_profile": mission_profile,
|
| 212 |
+
},
|
| 213 |
+
}
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def _with_mission_context(system_prompt: str, mission_profile: Dict[str, Any]) -> str:
|
| 217 |
+
if not mission_profile:
|
| 218 |
+
return system_prompt
|
| 219 |
+
|
| 220 |
+
mission = "\n".join([
|
| 221 |
+
"",
|
| 222 |
+
"Mission context:",
|
| 223 |
+
f"- mission_id: {mission_profile.get('mission_id', 'unknown')}",
|
| 224 |
+
f"- primary_asset: {mission_profile.get('primary_asset', 'unknown')}",
|
| 225 |
+
f"- sla_priority: {mission_profile.get('sla_priority', 'balanced')}",
|
| 226 |
+
f"- risk_tolerance: {mission_profile.get('risk_tolerance', 'medium')}",
|
| 227 |
+
f"- objective: {mission_profile.get('objective', 'Balance security and availability.')}",
|
| 228 |
+
])
|
| 229 |
+
return f"{system_prompt}{mission}"
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def choose_world_family(world_split: str, requested_family: str | None = None) -> str:
|
| 233 |
+
if requested_family:
|
| 234 |
+
return requested_family
|
| 235 |
+
families = WORLD_FAMILY_SPLITS.get(world_split, WORLD_FAMILY_SPLITS["train"])
|
| 236 |
+
return random.choice(families)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def choose_operational_mode(task_name: str, requested_mode: str | None = None) -> str:
|
| 240 |
+
if requested_mode:
|
| 241 |
+
return requested_mode
|
| 242 |
+
modes = TASK_OPERATIONAL_MODES.get(task_name, ["availability_guarded"])
|
| 243 |
+
return random.choice(modes)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
def mission_profile_for(task_name: str, operational_mode: str, world_family: str) -> Dict[str, Any]:
|
| 247 |
+
base = dict(TASK_CONFIGS[task_name].get("mission_profile", {}))
|
| 248 |
+
mode = MODE_HINTS.get(operational_mode, {})
|
| 249 |
+
base["world_family"] = world_family
|
| 250 |
+
base["operational_mode_hint"] = operational_mode.replace("_", " ")
|
| 251 |
+
base["scenario_style"] = FAMILY_MISSION_NOTES.get(world_family, "")
|
| 252 |
+
if mode.get("sla_priority"):
|
| 253 |
+
base["sla_priority"] = mode["sla_priority"]
|
| 254 |
+
if mode.get("risk_tolerance"):
|
| 255 |
+
base["risk_tolerance"] = mode["risk_tolerance"]
|
| 256 |
+
objective = str(base.get("objective", "")).rstrip()
|
| 257 |
+
suffix = str(mode.get("objective_suffix", "")).strip()
|
| 258 |
+
family_note = str(FAMILY_MISSION_NOTES.get(world_family, "")).strip()
|
| 259 |
+
if suffix:
|
| 260 |
+
objective = f"{objective} {suffix}".strip()
|
| 261 |
+
if family_note:
|
| 262 |
+
objective = f"{objective} {family_note}".strip()
|
| 263 |
+
base["objective"] = objective
|
| 264 |
+
return base
|
smoke_test.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick repo-root smoke test for AdaptShield.
|
| 4 |
+
|
| 5 |
+
Run from the repo root:
|
| 6 |
+
python smoke_test.py
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 16 |
+
|
| 17 |
+
if str(REPO_ROOT) not in sys.path:
|
| 18 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 19 |
+
|
| 20 |
+
import __init__ as adaptshield
|
| 21 |
+
import server.app as server_app
|
| 22 |
+
from models import AdaptShieldAction
|
| 23 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main() -> int:
|
| 27 |
+
print("AdaptShield smoke test")
|
| 28 |
+
print(f"- package exports: {adaptshield.__all__}")
|
| 29 |
+
print(f"- server app type: {server_app.app.__class__.__name__}")
|
| 30 |
+
|
| 31 |
+
env = AdaptShieldEnvironment("direct-triage")
|
| 32 |
+
obs = env.reset()
|
| 33 |
+
print(
|
| 34 |
+
f"- reset: phase={obs.phase} turn={obs.turn} "
|
| 35 |
+
f"score={obs.metadata.get('normalized_score')}"
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
obs = env.step(
|
| 39 |
+
AdaptShieldAction(
|
| 40 |
+
threat_type="brute_force",
|
| 41 |
+
confidence=0.9,
|
| 42 |
+
target_node="auth_service",
|
| 43 |
+
recommended_action="rate_limit",
|
| 44 |
+
)
|
| 45 |
+
)
|
| 46 |
+
print(f"- phase 1 -> phase 2: assessment={obs.phase1_assessment}")
|
| 47 |
+
|
| 48 |
+
obs = env.step(AdaptShieldAction(action="rate_limit", target_node="auth_service"))
|
| 49 |
+
print(
|
| 50 |
+
f"- phase 2 -> next turn: reward={obs.reward} done={obs.done} "
|
| 51 |
+
f"result={obs.last_action_result}"
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
print("Smoke test passed.")
|
| 55 |
+
return 0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__":
|
| 59 |
+
raise SystemExit(main())
|
soc_tools.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Shared SOC investigation helpers for AdaptShield agents."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import urllib.request
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
THREAT_TOOL_PLAN = {
|
| 12 |
+
"brute_force": [("log_search", "auth_service"), ("cmdb_lookup", "auth_service"), ("identity_lookup", "auth_service")],
|
| 13 |
+
"lateral_movement": [("edr_status", "payment_service"), ("log_search", "payment_service"), ("identity_lookup", "payment_service"), ("cmdb_lookup", "payment_service")],
|
| 14 |
+
"exfiltration": [("log_search", "database"), ("edr_status", "database"), ("netflow_lookup", "database"), ("cmdb_lookup", "database")],
|
| 15 |
+
"supply_chain": [("vuln_lookup", "api_gateway"), ("log_search", "api_gateway"), ("change_calendar_lookup", "api_gateway"), ("cmdb_lookup", "api_gateway")],
|
| 16 |
+
"benign": [("cmdb_lookup", "api_gateway")],
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
FALLBACK_SWEEP = [
|
| 20 |
+
("edr_status", "payment_service"),
|
| 21 |
+
("log_search", "database"),
|
| 22 |
+
("vuln_lookup", "api_gateway"),
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def classify_from_metrics(network_nodes: Dict[str, Dict[str, Any]]) -> str:
|
| 27 |
+
auth = network_nodes.get("auth_service", {})
|
| 28 |
+
payment = network_nodes.get("payment_service", {})
|
| 29 |
+
database = network_nodes.get("database", {})
|
| 30 |
+
gateway = network_nodes.get("api_gateway", {})
|
| 31 |
+
|
| 32 |
+
if float(auth.get("error_rate", 0.0)) >= 0.10:
|
| 33 |
+
return "brute_force"
|
| 34 |
+
if payment.get("status") == "suspicious" or float(payment.get("cpu", 0)) >= 55:
|
| 35 |
+
return "lateral_movement"
|
| 36 |
+
if float(database.get("outbound_mb", 0)) >= 50:
|
| 37 |
+
return "exfiltration"
|
| 38 |
+
if gateway.get("status") == "suspicious":
|
| 39 |
+
return "supply_chain"
|
| 40 |
+
return "benign"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def investigate_local(env: Any, obs: Any, use_tools: bool) -> List[Dict[str, Any]]:
|
| 44 |
+
"""Query local environment tool methods before Phase 1 action."""
|
| 45 |
+
return investigate_local_with_depth(env, obs, use_tools=use_tools, thorough=False)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def investigate_local_with_depth(
|
| 49 |
+
env: Any,
|
| 50 |
+
obs: Any,
|
| 51 |
+
use_tools: bool,
|
| 52 |
+
thorough: bool,
|
| 53 |
+
) -> List[Dict[str, Any]]:
|
| 54 |
+
"""Query local tools; thorough mode adds evidence-fusion follow-ups."""
|
| 55 |
+
if not use_tools or getattr(obs, "phase", 1) != 1:
|
| 56 |
+
return []
|
| 57 |
+
task_name = getattr(obs, "task_name", "")
|
| 58 |
+
threat = classify_from_metrics(getattr(obs, "network_nodes", {}))
|
| 59 |
+
if task_name == "direct-triage":
|
| 60 |
+
if threat == "brute_force":
|
| 61 |
+
return [env.call_tool("log_search", node="auth_service")]
|
| 62 |
+
return []
|
| 63 |
+
if task_name == "dual-pivot":
|
| 64 |
+
if threat == "lateral_movement":
|
| 65 |
+
return [
|
| 66 |
+
env.call_tool("edr_status", node="payment_service"),
|
| 67 |
+
env.call_tool("log_search", node="payment_service"),
|
| 68 |
+
env.call_tool("identity_lookup", node="payment_service"),
|
| 69 |
+
]
|
| 70 |
+
tool_name, node = THREAT_TOOL_PLAN.get(threat, THREAT_TOOL_PLAN["benign"])[0]
|
| 71 |
+
return [env.call_tool(tool_name, node=node)]
|
| 72 |
+
|
| 73 |
+
if task_name != "polymorphic-zero-day":
|
| 74 |
+
return []
|
| 75 |
+
|
| 76 |
+
results = []
|
| 77 |
+
for tool_name, node in THREAT_TOOL_PLAN.get(threat, THREAT_TOOL_PLAN["benign"]):
|
| 78 |
+
results.append(env.call_tool(tool_name, node=node))
|
| 79 |
+
|
| 80 |
+
if not has_attack_indicators(results):
|
| 81 |
+
for tool_name, node in FALLBACK_SWEEP:
|
| 82 |
+
if (tool_name, node) not in THREAT_TOOL_PLAN.get(threat, []):
|
| 83 |
+
results.append(env.call_tool(tool_name, node=node))
|
| 84 |
+
if thorough:
|
| 85 |
+
_complete_evidence_fusion(
|
| 86 |
+
call_tool=lambda tool_name, node: env.call_tool(tool_name, node=node),
|
| 87 |
+
results=results,
|
| 88 |
+
)
|
| 89 |
+
return results
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def investigate_http(
|
| 93 |
+
env_base_url: str,
|
| 94 |
+
session_id: Optional[str],
|
| 95 |
+
obs: Dict[str, Any],
|
| 96 |
+
use_tools: bool,
|
| 97 |
+
thorough: bool = False,
|
| 98 |
+
) -> List[Dict[str, Any]]:
|
| 99 |
+
"""Query SOC HTTP tool endpoints for a persistent /soc session."""
|
| 100 |
+
if not use_tools or not session_id or int(obs.get("phase", 1)) != 1:
|
| 101 |
+
return []
|
| 102 |
+
task_name = obs.get("task_name")
|
| 103 |
+
threat = classify_from_metrics(obs.get("network_nodes", {}))
|
| 104 |
+
|
| 105 |
+
def call(tool_name: str, node: str) -> Dict[str, Any]:
|
| 106 |
+
path = f"/tools/{tool_name}"
|
| 107 |
+
payload = {"session_id": session_id, "node": node}
|
| 108 |
+
return http_post(env_base_url, path, payload)
|
| 109 |
+
|
| 110 |
+
if task_name == "direct-triage":
|
| 111 |
+
if threat == "brute_force":
|
| 112 |
+
return [call("log_search", "auth_service")]
|
| 113 |
+
return []
|
| 114 |
+
|
| 115 |
+
results: List[Dict[str, Any]] = []
|
| 116 |
+
|
| 117 |
+
if task_name == "dual-pivot":
|
| 118 |
+
if threat == "lateral_movement":
|
| 119 |
+
return [
|
| 120 |
+
call("edr_status", "payment_service"),
|
| 121 |
+
call("log_search", "payment_service"),
|
| 122 |
+
call("identity_lookup", "payment_service"),
|
| 123 |
+
]
|
| 124 |
+
tool_name, node = THREAT_TOOL_PLAN.get(threat, THREAT_TOOL_PLAN["benign"])[0]
|
| 125 |
+
return [call(tool_name, node)]
|
| 126 |
+
|
| 127 |
+
if task_name != "polymorphic-zero-day":
|
| 128 |
+
return []
|
| 129 |
+
|
| 130 |
+
for tool_name, node in THREAT_TOOL_PLAN.get(threat, THREAT_TOOL_PLAN["benign"]):
|
| 131 |
+
results.append(call(tool_name, node))
|
| 132 |
+
|
| 133 |
+
if not has_attack_indicators(results):
|
| 134 |
+
for tool_name, node in FALLBACK_SWEEP:
|
| 135 |
+
if (tool_name, node) not in THREAT_TOOL_PLAN.get(threat, []):
|
| 136 |
+
results.append(call(tool_name, node))
|
| 137 |
+
if thorough:
|
| 138 |
+
_complete_evidence_fusion(call_tool=call, results=results)
|
| 139 |
+
return results
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def has_attack_indicators(results: List[Dict[str, Any]]) -> bool:
|
| 143 |
+
return infer_threat_from_tool_results(results)["threat_type"] != "benign"
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def infer_threat_from_tool_results(results: List[Dict[str, Any]]) -> Dict[str, str]:
|
| 147 |
+
"""Infer threat from observable tool facts only, not grader labels."""
|
| 148 |
+
text_blob = " ".join(
|
| 149 |
+
[
|
| 150 |
+
str(result.get("result_summary", "")) + " " +
|
| 151 |
+
" ".join(str(event) for event in result.get("events", [])) + " " +
|
| 152 |
+
str(result.get("finding", "")) + " " +
|
| 153 |
+
str(result.get("process_note", ""))
|
| 154 |
+
for result in results
|
| 155 |
+
]
|
| 156 |
+
).lower()
|
| 157 |
+
|
| 158 |
+
if any(result.get("risk") == "critical" for result in results) or "hash mismatch" in text_blob:
|
| 159 |
+
return {"threat_type": "supply_chain", "target_node": "api_gateway", "action": "patch"}
|
| 160 |
+
if "sequential reads" in text_blob or "compressed archive" in text_blob or "egress exceeds" in text_blob or "outbound_transfer_burst" in text_blob:
|
| 161 |
+
return {"threat_type": "exfiltration", "target_node": "database", "action": "honeypot"}
|
| 162 |
+
if (
|
| 163 |
+
any(result.get("beaconing") for result in results) or
|
| 164 |
+
"service account" in text_blob or
|
| 165 |
+
"internal sessions" in text_blob or
|
| 166 |
+
"identity_anomaly" in text_blob or
|
| 167 |
+
"source=auth_service" in text_blob or
|
| 168 |
+
"east_west_fanout" in text_blob
|
| 169 |
+
):
|
| 170 |
+
return {"threat_type": "lateral_movement", "target_node": "payment_service", "action": "isolate"}
|
| 171 |
+
if "failed logins" in text_blob or "password spray" in text_blob:
|
| 172 |
+
return {"threat_type": "brute_force", "target_node": "auth_service", "action": "rate_limit"}
|
| 173 |
+
return {"threat_type": "benign", "target_node": "api_gateway", "action": "monitor"}
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _complete_evidence_fusion(
|
| 177 |
+
call_tool: Callable[[str, str], Dict[str, Any]],
|
| 178 |
+
results: List[Dict[str, Any]],
|
| 179 |
+
) -> None:
|
| 180 |
+
belief = infer_threat_from_tool_results(results)
|
| 181 |
+
threat = belief["threat_type"]
|
| 182 |
+
if threat == "benign":
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
called = {
|
| 186 |
+
(str(result.get("tool", "")), str(result.get("node", "")))
|
| 187 |
+
for result in results
|
| 188 |
+
}
|
| 189 |
+
for tool_name, node in THREAT_TOOL_PLAN.get(threat, []):
|
| 190 |
+
if (tool_name, node) not in called:
|
| 191 |
+
results.append(call_tool(tool_name, node))
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def attach_tool_results(obs: Dict[str, Any], tool_results: List[Dict[str, Any]]) -> Dict[str, Any]:
|
| 195 |
+
updated = dict(obs)
|
| 196 |
+
updated["tool_results"] = tool_results
|
| 197 |
+
return updated
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def summarize_tool_results(tool_results: List[Dict[str, Any]]) -> str:
|
| 201 |
+
if not tool_results:
|
| 202 |
+
return "No SOC tools queried for this turn."
|
| 203 |
+
|
| 204 |
+
lines = []
|
| 205 |
+
for result in tool_results:
|
| 206 |
+
lines.append(json.dumps(_compact_result(result), separators=(",", ":")))
|
| 207 |
+
return "\n".join(lines)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def http_post(env_base_url: str, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
|
| 211 |
+
url = f"{env_base_url.rstrip('/')}{path}"
|
| 212 |
+
req = urllib.request.Request(
|
| 213 |
+
url,
|
| 214 |
+
data=json.dumps(payload).encode(),
|
| 215 |
+
headers={"Content-Type": "application/json"},
|
| 216 |
+
)
|
| 217 |
+
with urllib.request.urlopen(req, timeout=60) as response:
|
| 218 |
+
return json.loads(response.read())
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def _compact_result(result: Dict[str, Any]) -> Dict[str, Any]:
|
| 222 |
+
keep = [
|
| 223 |
+
"tool",
|
| 224 |
+
"node",
|
| 225 |
+
"evidence_type",
|
| 226 |
+
"verified",
|
| 227 |
+
"confidence",
|
| 228 |
+
"events",
|
| 229 |
+
"containment",
|
| 230 |
+
"persistence",
|
| 231 |
+
"beaconing",
|
| 232 |
+
"criticality",
|
| 233 |
+
"dependencies",
|
| 234 |
+
"risk",
|
| 235 |
+
"finding",
|
| 236 |
+
"recommended_mitigation",
|
| 237 |
+
"safe_actions",
|
| 238 |
+
]
|
| 239 |
+
return {key: result[key] for key in keep if key in result}
|
tests/test_regression.py
ADDED
|
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import unittest
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
REPO_ROOT = Path(__file__).resolve().parents[1]
|
| 7 |
+
PACKAGE_ROOT = REPO_ROOT / "adaptshield"
|
| 8 |
+
|
| 9 |
+
if str(REPO_ROOT) not in sys.path:
|
| 10 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 11 |
+
if str(PACKAGE_ROOT) not in sys.path:
|
| 12 |
+
sys.path.insert(0, str(PACKAGE_ROOT))
|
| 13 |
+
|
| 14 |
+
import __init__ as adaptshield
|
| 15 |
+
import server.app as server_app
|
| 16 |
+
import train as train_module
|
| 17 |
+
from client import AdaptshieldEnv
|
| 18 |
+
from models import AdaptShieldAction
|
| 19 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 20 |
+
from server.grader import normalize_episode_score, _required_tool_fusion
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PackageRegressionTests(unittest.TestCase):
|
| 24 |
+
def test_package_import_exports_expected_symbols(self) -> None:
|
| 25 |
+
self.assertIn("AdaptShieldAction", adaptshield.__all__)
|
| 26 |
+
self.assertIn("AdaptShieldObservation", adaptshield.__all__)
|
| 27 |
+
self.assertIn("AdaptshieldEnv", adaptshield.__all__)
|
| 28 |
+
|
| 29 |
+
def test_server_app_imports_fastapi_instance(self) -> None:
|
| 30 |
+
self.assertEqual(server_app.app.__class__.__name__, "FastAPI")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class EnvironmentRegressionTests(unittest.TestCase):
|
| 34 |
+
def test_phase_flow_accepts_both_action_shapes(self) -> None:
|
| 35 |
+
env = AdaptShieldEnvironment("direct-triage")
|
| 36 |
+
|
| 37 |
+
phase1_obs = env.reset()
|
| 38 |
+
self.assertEqual(phase1_obs.phase, 1)
|
| 39 |
+
self.assertEqual(phase1_obs.turn, 1)
|
| 40 |
+
self.assertEqual(phase1_obs.metadata["normalized_score"], 0.50)
|
| 41 |
+
self.assertIn("mission_profile", phase1_obs.metadata)
|
| 42 |
+
self.assertEqual(phase1_obs.metadata["world_split"], "train")
|
| 43 |
+
self.assertIn(phase1_obs.metadata["world_family"], {"train-a", "train-b"})
|
| 44 |
+
|
| 45 |
+
phase2_obs = env.step(
|
| 46 |
+
AdaptShieldAction(
|
| 47 |
+
threat_type="brute_force",
|
| 48 |
+
confidence=0.9,
|
| 49 |
+
target_node="auth_service",
|
| 50 |
+
recommended_action="rate_limit",
|
| 51 |
+
)
|
| 52 |
+
)
|
| 53 |
+
self.assertEqual(phase2_obs.phase, 2)
|
| 54 |
+
self.assertEqual(phase2_obs.phase1_assessment["recommended_action"], "rate_limit")
|
| 55 |
+
|
| 56 |
+
next_turn_obs = env.step(
|
| 57 |
+
AdaptShieldAction(action="rate_limit", target_node="auth_service")
|
| 58 |
+
)
|
| 59 |
+
self.assertEqual(next_turn_obs.phase, 1)
|
| 60 |
+
self.assertGreaterEqual(next_turn_obs.reward, 0.65)
|
| 61 |
+
self.assertIn("requires stronger SOC evidence", next_turn_obs.last_action_result)
|
| 62 |
+
self.assertIn("business_impact", next_turn_obs.metadata["score_breakdown"])
|
| 63 |
+
self.assertIn("dependency_blast_radius", next_turn_obs.metadata["score_breakdown"])
|
| 64 |
+
self.assertIn("mission_alignment", next_turn_obs.metadata["score_breakdown"])
|
| 65 |
+
self.assertIn("active_defenses", next_turn_obs.metadata)
|
| 66 |
+
self.assertIn("available_tools", next_turn_obs.metadata)
|
| 67 |
+
tool_names = {tool["name"] for tool in next_turn_obs.metadata["available_tools"]}
|
| 68 |
+
self.assertTrue({
|
| 69 |
+
"identity_lookup",
|
| 70 |
+
"change_calendar_lookup",
|
| 71 |
+
"netflow_lookup",
|
| 72 |
+
}.issubset(tool_names))
|
| 73 |
+
|
| 74 |
+
env = AdaptShieldEnvironment("direct-triage")
|
| 75 |
+
env.reset()
|
| 76 |
+
env.call_tool("log_search", node="auth_service")
|
| 77 |
+
env.step(
|
| 78 |
+
AdaptShieldAction(
|
| 79 |
+
threat_type="brute_force",
|
| 80 |
+
confidence=0.9,
|
| 81 |
+
target_node="auth_service",
|
| 82 |
+
recommended_action="rate_limit",
|
| 83 |
+
)
|
| 84 |
+
)
|
| 85 |
+
verified_obs = env.step(
|
| 86 |
+
AdaptShieldAction(action="rate_limit", target_node="auth_service")
|
| 87 |
+
)
|
| 88 |
+
self.assertGreaterEqual(verified_obs.reward, 0.9)
|
| 89 |
+
self.assertIn("Optimal: rate_limit", verified_obs.last_action_result)
|
| 90 |
+
|
| 91 |
+
def test_client_payload_omits_empty_metadata_and_serializes_enums(self) -> None:
|
| 92 |
+
client = AdaptshieldEnv(base_url="http://localhost:7860")
|
| 93 |
+
|
| 94 |
+
phase1_payload = client._step_payload(
|
| 95 |
+
AdaptShieldAction(
|
| 96 |
+
threat_type="benign",
|
| 97 |
+
confidence=0.8,
|
| 98 |
+
target_node="auth_service",
|
| 99 |
+
recommended_action="monitor",
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
self.assertEqual(
|
| 103 |
+
phase1_payload,
|
| 104 |
+
{
|
| 105 |
+
"threat_type": "benign",
|
| 106 |
+
"confidence": 0.8,
|
| 107 |
+
"target_node": "auth_service",
|
| 108 |
+
"recommended_action": "monitor",
|
| 109 |
+
},
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
phase2_payload = client._step_payload(
|
| 113 |
+
AdaptShieldAction(action="rate_limit", target_node="auth_service")
|
| 114 |
+
)
|
| 115 |
+
self.assertEqual(
|
| 116 |
+
phase2_payload,
|
| 117 |
+
{"action": "rate_limit", "target_node": "auth_service"},
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
def test_hard_task_records_verified_tool_evidence(self) -> None:
|
| 121 |
+
env = AdaptShieldEnvironment("polymorphic-zero-day")
|
| 122 |
+
for _ in range(8):
|
| 123 |
+
obs = env.reset()
|
| 124 |
+
turn_config = dict(getattr(env, "_turn_config", {}) or {})
|
| 125 |
+
if not turn_config.get("is_benign", False):
|
| 126 |
+
break
|
| 127 |
+
else:
|
| 128 |
+
self.fail("Expected a non-benign hard-task reset within 8 attempts")
|
| 129 |
+
|
| 130 |
+
self.assertIn("available_tools", obs.metadata)
|
| 131 |
+
self.assertNotIn("foothold_established", obs.metadata)
|
| 132 |
+
|
| 133 |
+
target = str(turn_config.get("correct_target", "auth_service"))
|
| 134 |
+
for tool_name in sorted(_required_tool_fusion("polymorphic-zero-day", str(turn_config.get("strategy", "benign")))):
|
| 135 |
+
tool_result = env.call_tool(tool_name, node=target)
|
| 136 |
+
self.assertNotIn("verified", tool_result)
|
| 137 |
+
self.assertNotIn("evidence_type", tool_result)
|
| 138 |
+
self.assertTrue(tool_result.get("result_summary"))
|
| 139 |
+
|
| 140 |
+
env.step(
|
| 141 |
+
AdaptShieldAction(
|
| 142 |
+
threat_type=turn_config.get("strategy", "brute_force"),
|
| 143 |
+
confidence=0.9,
|
| 144 |
+
target_node=target,
|
| 145 |
+
recommended_action=turn_config.get("correct_action", "monitor"),
|
| 146 |
+
)
|
| 147 |
+
)
|
| 148 |
+
obs = env.step(
|
| 149 |
+
AdaptShieldAction(
|
| 150 |
+
action=turn_config.get("correct_action", "monitor"),
|
| 151 |
+
target_node=target,
|
| 152 |
+
)
|
| 153 |
+
)
|
| 154 |
+
breakdown = obs.metadata["score_breakdown"]
|
| 155 |
+
self.assertTrue(breakdown["tool_verification_required"])
|
| 156 |
+
self.assertTrue(breakdown["tool_evidence_found"])
|
| 157 |
+
self.assertGreaterEqual(obs.reward, 0.65)
|
| 158 |
+
|
| 159 |
+
def test_enterprise_context_tools_return_public_fields_only(self) -> None:
|
| 160 |
+
env = AdaptShieldEnvironment("polymorphic-zero-day")
|
| 161 |
+
env.reset()
|
| 162 |
+
|
| 163 |
+
identity = env.call_tool("identity_lookup", node="payment_service")
|
| 164 |
+
self.assertIn("account", identity)
|
| 165 |
+
self.assertIn("recent_source_host", identity)
|
| 166 |
+
self.assertNotIn("verified", identity)
|
| 167 |
+
self.assertNotIn("evidence_type", identity)
|
| 168 |
+
|
| 169 |
+
change = env.call_tool("change_calendar_lookup", node="api_gateway")
|
| 170 |
+
self.assertIn("scheduled", change)
|
| 171 |
+
self.assertIn("change_status", change)
|
| 172 |
+
self.assertNotIn("verified", change)
|
| 173 |
+
self.assertNotIn("evidence_type", change)
|
| 174 |
+
|
| 175 |
+
netflow = env.call_tool("netflow_lookup", node="database")
|
| 176 |
+
self.assertIn("traffic_pattern", netflow)
|
| 177 |
+
self.assertIn("east_west_connections", netflow)
|
| 178 |
+
self.assertNotIn("verified", netflow)
|
| 179 |
+
self.assertNotIn("evidence_type", netflow)
|
| 180 |
+
|
| 181 |
+
def test_dual_pivot_requires_tool_confirmation_after_pivot(self) -> None:
|
| 182 |
+
env = AdaptShieldEnvironment("dual-pivot", operational_mode="containment_first")
|
| 183 |
+
env.reset()
|
| 184 |
+
|
| 185 |
+
for _ in range(3):
|
| 186 |
+
turn_config = dict(getattr(env, "_turn_config", {}) or {})
|
| 187 |
+
env.step(
|
| 188 |
+
AdaptShieldAction(
|
| 189 |
+
threat_type=str(turn_config.get("strategy", "brute_force")),
|
| 190 |
+
confidence=0.9,
|
| 191 |
+
target_node=str(turn_config.get("correct_target", "auth_service")),
|
| 192 |
+
recommended_action=str(turn_config.get("correct_action", "monitor")),
|
| 193 |
+
)
|
| 194 |
+
)
|
| 195 |
+
obs = env.step(
|
| 196 |
+
AdaptShieldAction(
|
| 197 |
+
action=str(turn_config.get("correct_action", "monitor")),
|
| 198 |
+
target_node=str(turn_config.get("correct_target", "auth_service")),
|
| 199 |
+
)
|
| 200 |
+
)
|
| 201 |
+
self.assertFalse(obs.done)
|
| 202 |
+
|
| 203 |
+
turn_config = dict(getattr(env, "_turn_config", {}) or {})
|
| 204 |
+
self.assertEqual(turn_config.get("strategy"), "lateral_movement")
|
| 205 |
+
target = str(turn_config.get("correct_target", "payment_service"))
|
| 206 |
+
|
| 207 |
+
env.step(
|
| 208 |
+
AdaptShieldAction(
|
| 209 |
+
threat_type="lateral_movement",
|
| 210 |
+
confidence=0.9,
|
| 211 |
+
target_node=target,
|
| 212 |
+
recommended_action=str(turn_config.get("correct_action", "isolate")),
|
| 213 |
+
)
|
| 214 |
+
)
|
| 215 |
+
obs = env.step(
|
| 216 |
+
AdaptShieldAction(
|
| 217 |
+
action=str(turn_config.get("correct_action", "isolate")),
|
| 218 |
+
target_node=target,
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
self.assertTrue(obs.metadata["score_breakdown"]["tool_verification_required"])
|
| 222 |
+
self.assertFalse(obs.metadata["score_breakdown"]["tool_evidence_found"])
|
| 223 |
+
self.assertIn("requires stronger SOC evidence", obs.last_action_result)
|
| 224 |
+
|
| 225 |
+
env = AdaptShieldEnvironment("dual-pivot", operational_mode="containment_first")
|
| 226 |
+
env.reset()
|
| 227 |
+
for _ in range(3):
|
| 228 |
+
turn_config = dict(getattr(env, "_turn_config", {}) or {})
|
| 229 |
+
env.step(
|
| 230 |
+
AdaptShieldAction(
|
| 231 |
+
threat_type=str(turn_config.get("strategy", "brute_force")),
|
| 232 |
+
confidence=0.9,
|
| 233 |
+
target_node=str(turn_config.get("correct_target", "auth_service")),
|
| 234 |
+
recommended_action=str(turn_config.get("correct_action", "monitor")),
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
env.step(
|
| 238 |
+
AdaptShieldAction(
|
| 239 |
+
action=str(turn_config.get("correct_action", "monitor")),
|
| 240 |
+
target_node=str(turn_config.get("correct_target", "auth_service")),
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
turn_config = dict(getattr(env, "_turn_config", {}) or {})
|
| 245 |
+
target = str(turn_config.get("correct_target", "payment_service"))
|
| 246 |
+
env.call_tool("edr_status", node=target)
|
| 247 |
+
env.call_tool("log_search", node=target)
|
| 248 |
+
env.call_tool("identity_lookup", node=target)
|
| 249 |
+
env.step(
|
| 250 |
+
AdaptShieldAction(
|
| 251 |
+
threat_type="lateral_movement",
|
| 252 |
+
confidence=0.9,
|
| 253 |
+
target_node=target,
|
| 254 |
+
recommended_action=str(turn_config.get("correct_action", "isolate")),
|
| 255 |
+
)
|
| 256 |
+
)
|
| 257 |
+
obs = env.step(
|
| 258 |
+
AdaptShieldAction(
|
| 259 |
+
action=str(turn_config.get("correct_action", "isolate")),
|
| 260 |
+
target_node=target,
|
| 261 |
+
)
|
| 262 |
+
)
|
| 263 |
+
self.assertTrue(obs.metadata["score_breakdown"]["tool_verification_required"])
|
| 264 |
+
self.assertTrue(obs.metadata["score_breakdown"]["tool_evidence_found"])
|
| 265 |
+
self.assertIn("Optimal: isolate", obs.last_action_result)
|
| 266 |
+
|
| 267 |
+
def test_world_family_metadata_and_surfaces_are_selectable(self) -> None:
|
| 268 |
+
env = AdaptShieldEnvironment(
|
| 269 |
+
"direct-triage",
|
| 270 |
+
world_split="eval",
|
| 271 |
+
world_family="eval-x",
|
| 272 |
+
)
|
| 273 |
+
obs = env.reset()
|
| 274 |
+
self.assertEqual(obs.metadata["world_split"], "eval")
|
| 275 |
+
self.assertEqual(obs.metadata["world_family"], "eval-x")
|
| 276 |
+
alerts_blob = " ".join(obs.active_alerts).lower()
|
| 277 |
+
self.assertTrue(
|
| 278 |
+
"auth rejection burst" in alerts_blob or
|
| 279 |
+
"credential reuse sweep" in alerts_blob
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
def test_operational_modes_change_medium_and_hard_optimal_actions(self) -> None:
|
| 283 |
+
medium_env = AdaptShieldEnvironment(
|
| 284 |
+
"dual-pivot",
|
| 285 |
+
operational_mode="evidence_preservation",
|
| 286 |
+
world_family="train-b",
|
| 287 |
+
)
|
| 288 |
+
medium_env.reset()
|
| 289 |
+
for _ in range(3):
|
| 290 |
+
turn_config = dict(getattr(medium_env, "_turn_config", {}) or {})
|
| 291 |
+
medium_env.step(
|
| 292 |
+
AdaptShieldAction(
|
| 293 |
+
threat_type=str(turn_config.get("strategy", "brute_force")),
|
| 294 |
+
confidence=0.9,
|
| 295 |
+
target_node=str(turn_config.get("correct_target", "auth_service")),
|
| 296 |
+
recommended_action=str(turn_config.get("correct_action", "monitor")),
|
| 297 |
+
)
|
| 298 |
+
)
|
| 299 |
+
medium_env.step(
|
| 300 |
+
AdaptShieldAction(
|
| 301 |
+
action=str(turn_config.get("correct_action", "monitor")),
|
| 302 |
+
target_node=str(turn_config.get("correct_target", "auth_service")),
|
| 303 |
+
)
|
| 304 |
+
)
|
| 305 |
+
self.assertEqual(getattr(medium_env, "_turn_config", {}).get("strategy"), "lateral_movement")
|
| 306 |
+
self.assertEqual(getattr(medium_env, "_turn_config", {}).get("correct_action"), "honeypot")
|
| 307 |
+
|
| 308 |
+
hard_env = AdaptShieldEnvironment(
|
| 309 |
+
"polymorphic-zero-day",
|
| 310 |
+
operational_mode="forensic_hold",
|
| 311 |
+
world_family="eval-y",
|
| 312 |
+
)
|
| 313 |
+
hard_obs = hard_env.reset()
|
| 314 |
+
adjusted = hard_env._apply_operational_mode({
|
| 315 |
+
"strategy": "exfiltration",
|
| 316 |
+
"attack_stage": "exploit",
|
| 317 |
+
"is_benign": False,
|
| 318 |
+
"correct_action": "isolate",
|
| 319 |
+
"correct_target": "database",
|
| 320 |
+
"network_nodes": {"payment_service": {"status": "healthy", "request_rate": 85}},
|
| 321 |
+
"active_alerts": [],
|
| 322 |
+
})
|
| 323 |
+
self.assertEqual(hard_obs.metadata["operational_mode"], "forensic_hold")
|
| 324 |
+
self.assertEqual(adjusted.get("correct_action"), "honeypot")
|
| 325 |
+
|
| 326 |
+
def test_prompt_bank_builds_phase_rows_without_gpu_deps(self) -> None:
|
| 327 |
+
rows = train_module.build_prompt_bank(
|
| 328 |
+
tokenizer=None,
|
| 329 |
+
selected_task="all",
|
| 330 |
+
curriculum=True,
|
| 331 |
+
rollout_episodes=3,
|
| 332 |
+
max_steps=6,
|
| 333 |
+
use_tools=True,
|
| 334 |
+
seed=42,
|
| 335 |
+
)
|
| 336 |
+
self.assertTrue(rows)
|
| 337 |
+
phases = {int(row["phase"]) for row in rows}
|
| 338 |
+
tasks = {str(row["task"]) for row in rows}
|
| 339 |
+
self.assertIn(1, phases)
|
| 340 |
+
self.assertIn(2, phases)
|
| 341 |
+
self.assertTrue(tasks.intersection({"direct-triage", "dual-pivot", "polymorphic-zero-day"}))
|
| 342 |
+
hard_rows = [row for row in rows if row["task"] == "polymorphic-zero-day"]
|
| 343 |
+
self.assertTrue(hard_rows)
|
| 344 |
+
self.assertTrue(any(int(row["tool_calls"]) >= 2 for row in hard_rows))
|
| 345 |
+
|
| 346 |
+
def test_normalized_score_uses_reachable_reward_ceiling(self) -> None:
|
| 347 |
+
rewards = [0.99] * 10
|
| 348 |
+
self.assertEqual(normalize_episode_score(rewards), 0.99)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
if __name__ == "__main__":
|
| 352 |
+
unittest.main()
|
tool_baseline.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Tool-aware AdaptShield baseline for world-modeling demos."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
from typing import Any, Dict, List
|
| 9 |
+
|
| 10 |
+
from baseline import (
|
| 11 |
+
BENCHMARK,
|
| 12 |
+
MAX_STEPS,
|
| 13 |
+
POLICY,
|
| 14 |
+
TASKS,
|
| 15 |
+
action_from_payload,
|
| 16 |
+
log_end,
|
| 17 |
+
log_step,
|
| 18 |
+
phase1_payload as no_tool_phase1_payload,
|
| 19 |
+
phase2_payload as no_tool_phase2_payload,
|
| 20 |
+
print_replay,
|
| 21 |
+
)
|
| 22 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 23 |
+
from soc_tools import infer_threat_from_tool_results, investigate_local
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
MODEL_NAME = "tool-aware-baseline"
|
| 27 |
+
|
| 28 |
+
def log_start(task: str) -> None:
|
| 29 |
+
print(f"[START] task={task} env={BENCHMARK} model={MODEL_NAME}", flush=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def phase2_payload(obs: Any, belief_by_turn: Dict[int, Dict[str, str]]) -> Dict[str, Any]:
|
| 33 |
+
"""Use belief inferred from observable SOC tool evidence when Phase 2 is ambiguous."""
|
| 34 |
+
belief = belief_by_turn.get(int(obs.turn), {})
|
| 35 |
+
if obs.task_name == "polymorphic-zero-day" and belief:
|
| 36 |
+
return {
|
| 37 |
+
"action": belief["action"],
|
| 38 |
+
"target_node": belief["target_node"],
|
| 39 |
+
"reasoning": "inferred from observable SOC tool fields",
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
return no_tool_phase2_payload(obs)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def phase1_payload(obs: Any, belief_by_turn: Dict[int, Dict[str, str]]) -> Dict[str, Any]:
|
| 46 |
+
"""Use tool-derived belief in Phase 1 so the baseline is tool-aware end to end."""
|
| 47 |
+
belief = belief_by_turn.get(int(obs.turn), {})
|
| 48 |
+
if obs.task_name == "polymorphic-zero-day" and belief:
|
| 49 |
+
return {
|
| 50 |
+
"threat_type": belief["threat_type"],
|
| 51 |
+
"confidence": 0.86,
|
| 52 |
+
"target_node": belief["target_node"],
|
| 53 |
+
"recommended_action": belief["action"],
|
| 54 |
+
"reasoning": "classified from observable SOC tool fields",
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
return no_tool_phase1_payload(obs)
|
| 58 |
+
|
| 59 |
+
def run_task(task: str, emit_logs: bool = True) -> Dict[str, Any]:
|
| 60 |
+
env = AdaptShieldEnvironment(task_name=task)
|
| 61 |
+
obs = env.reset()
|
| 62 |
+
rewards: List[float] = []
|
| 63 |
+
steps = 0
|
| 64 |
+
belief_by_turn: Dict[int, Dict[str, str]] = {}
|
| 65 |
+
|
| 66 |
+
if emit_logs:
|
| 67 |
+
log_start(task)
|
| 68 |
+
|
| 69 |
+
while not obs.done and steps < MAX_STEPS:
|
| 70 |
+
if obs.phase == 1:
|
| 71 |
+
tool_results = investigate_local(env, obs, use_tools=True)
|
| 72 |
+
belief_by_turn[int(obs.turn)] = infer_threat_from_tool_results(tool_results)
|
| 73 |
+
payload = phase1_payload(obs, belief_by_turn)
|
| 74 |
+
else:
|
| 75 |
+
payload = phase2_payload(obs, belief_by_turn)
|
| 76 |
+
|
| 77 |
+
obs = env.step(action_from_payload(payload))
|
| 78 |
+
reward = float(obs.reward)
|
| 79 |
+
rewards.append(reward)
|
| 80 |
+
steps += 1
|
| 81 |
+
|
| 82 |
+
if emit_logs:
|
| 83 |
+
log_step(steps, payload, reward, obs.done)
|
| 84 |
+
|
| 85 |
+
metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
|
| 86 |
+
score = float(metadata.get("normalized_score", 0.01))
|
| 87 |
+
success = obs.done and 0.01 <= score <= 0.99
|
| 88 |
+
|
| 89 |
+
if emit_logs:
|
| 90 |
+
log_end(success, steps, score, rewards)
|
| 91 |
+
tool_trace = metadata.get("tool_trace") or []
|
| 92 |
+
print(f"[TOOLS] calls={len(tool_trace)} trace={json.dumps(tool_trace[-4:], separators=(',', ':'))}")
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"task": task,
|
| 96 |
+
"score": score,
|
| 97 |
+
"steps": steps,
|
| 98 |
+
"done": bool(obs.done),
|
| 99 |
+
"rewards": rewards,
|
| 100 |
+
"metadata": metadata,
|
| 101 |
+
"normalized_score_present": "normalized_score" in metadata,
|
| 102 |
+
"success": success,
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def parse_args() -> argparse.Namespace:
|
| 107 |
+
parser = argparse.ArgumentParser(description="Run AdaptShield tool-aware baseline.")
|
| 108 |
+
parser.add_argument("--task", default="polymorphic-zero-day", choices=TASKS + ["all"])
|
| 109 |
+
parser.add_argument("--replay", action="store_true")
|
| 110 |
+
return parser.parse_args()
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def main() -> int:
|
| 114 |
+
args = parse_args()
|
| 115 |
+
tasks = TASKS if args.task == "all" else [args.task]
|
| 116 |
+
|
| 117 |
+
for index, task in enumerate(tasks):
|
| 118 |
+
if index:
|
| 119 |
+
print()
|
| 120 |
+
result = run_task(task, emit_logs=True)
|
| 121 |
+
if args.replay:
|
| 122 |
+
print_replay(task, result["metadata"])
|
| 123 |
+
|
| 124 |
+
return 0
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
raise SystemExit(main())
|
train.py
ADDED
|
@@ -0,0 +1,1332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""AdaptShield onsite GPU training harness with safe local fallback."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import inspect
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import random
|
| 11 |
+
import sys
|
| 12 |
+
import time
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Any, Dict, List, Tuple
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 18 |
+
|
| 19 |
+
if str(REPO_ROOT) not in sys.path:
|
| 20 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 21 |
+
|
| 22 |
+
from models import AdaptShieldAction
|
| 23 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 24 |
+
from train_smoke import TASKS, run_smoke_training
|
| 25 |
+
from plot_training import plot as plot_training_output
|
| 26 |
+
from soc_tools import attach_tool_results, investigate_local_with_depth, summarize_tool_results
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
MODEL_CHOICES = {
|
| 30 |
+
"0.5b": "unsloth/Qwen2.5-0.5B-Instruct",
|
| 31 |
+
"1.5b": "unsloth/Qwen2.5-1.5B-Instruct",
|
| 32 |
+
"3b": "unsloth/Qwen2.5-3B-Instruct",
|
| 33 |
+
"7b": "unsloth/Qwen2.5-7B-Instruct",
|
| 34 |
+
}
|
| 35 |
+
DEFAULT_MODEL = "1.5b"
|
| 36 |
+
MAX_SEQ_LEN = 2048
|
| 37 |
+
LORA_RANK = 16
|
| 38 |
+
|
| 39 |
+
CURRICULUM_STAGES = [
|
| 40 |
+
("direct-triage", 0.30),
|
| 41 |
+
("dual-pivot", 0.40),
|
| 42 |
+
("polymorphic-zero-day", 0.30),
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
PHASE1_SYS = """You are a Threat Analyst for a 4-node enterprise network.
|
| 47 |
+
Analyze SIEM metrics and alerts. Identify the threat.
|
| 48 |
+
|
| 49 |
+
Attack strategies: brute_force, lateral_movement, exfiltration, supply_chain, benign
|
| 50 |
+
Nodes: auth_service, payment_service, database, api_gateway
|
| 51 |
+
Actions: rate_limit, isolate, honeypot, patch, monitor
|
| 52 |
+
If SOC tool evidence is provided, use it to update your belief before classifying.
|
| 53 |
+
|
| 54 |
+
Respond ONLY with valid JSON:
|
| 55 |
+
{"threat_type":"...","confidence":0.0,"target_node":"...","recommended_action":"...","reasoning":"..."}"""
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
PHASE2_SYS = """You are a Tactical Executor. Act only on the analyst handoff.
|
| 59 |
+
You cannot see raw network data in Phase 2.
|
| 60 |
+
Use the analyst handoff plus any SOC tool trace from this turn.
|
| 61 |
+
|
| 62 |
+
Actions: rate_limit, isolate, honeypot, patch, monitor
|
| 63 |
+
Nodes: auth_service, payment_service, database, api_gateway
|
| 64 |
+
|
| 65 |
+
Respond ONLY with valid JSON:
|
| 66 |
+
{"action":"...","target_node":"...","reasoning":"..."}"""
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def obs_to_dict(obs: Any) -> Dict[str, Any]:
|
| 70 |
+
if hasattr(obs, "model_dump"):
|
| 71 |
+
return obs.model_dump(mode="json")
|
| 72 |
+
return dict(obs)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def make_phase1_prompt(obs: Dict[str, Any]) -> str:
|
| 76 |
+
return "\n".join([
|
| 77 |
+
"Network nodes:",
|
| 78 |
+
json.dumps(obs.get("network_nodes", {}), indent=2),
|
| 79 |
+
"",
|
| 80 |
+
"Active alerts:",
|
| 81 |
+
"\n".join(obs.get("active_alerts", [])),
|
| 82 |
+
"",
|
| 83 |
+
"SOC tool evidence:",
|
| 84 |
+
summarize_tool_results(obs.get("tool_results", [])),
|
| 85 |
+
"",
|
| 86 |
+
"Recent history:",
|
| 87 |
+
json.dumps(obs.get("history", [])[-3:], indent=2),
|
| 88 |
+
"",
|
| 89 |
+
"Classify the threat:",
|
| 90 |
+
])
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def make_phase2_prompt(obs: Dict[str, Any]) -> str:
|
| 94 |
+
metadata = obs.get("metadata", {}) if isinstance(obs.get("metadata", {}), dict) else {}
|
| 95 |
+
current_turn = int(obs.get("turn", 0) or 0)
|
| 96 |
+
tool_trace = [
|
| 97 |
+
row for row in metadata.get("tool_trace", [])
|
| 98 |
+
if int(row.get("turn", -1)) == current_turn
|
| 99 |
+
]
|
| 100 |
+
return "\n".join([
|
| 101 |
+
"Threat assessment from analyst:",
|
| 102 |
+
json.dumps(obs.get("phase1_assessment", {}), indent=2),
|
| 103 |
+
"",
|
| 104 |
+
"SOC tool trace for this turn:",
|
| 105 |
+
json.dumps(tool_trace, indent=2),
|
| 106 |
+
"",
|
| 107 |
+
"Choose the defensive action:",
|
| 108 |
+
])
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def build_messages(obs: Dict[str, Any]) -> List[Dict[str, str]]:
|
| 112 |
+
if int(obs.get("phase", 1)) == 1:
|
| 113 |
+
return [
|
| 114 |
+
{"role": "system", "content": PHASE1_SYS},
|
| 115 |
+
{"role": "user", "content": make_phase1_prompt(obs)},
|
| 116 |
+
]
|
| 117 |
+
return [
|
| 118 |
+
{"role": "system", "content": PHASE2_SYS},
|
| 119 |
+
{"role": "user", "content": make_phase2_prompt(obs)},
|
| 120 |
+
]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def task_for_episode(
|
| 124 |
+
episode: int,
|
| 125 |
+
total_episodes: int,
|
| 126 |
+
selected_task: str,
|
| 127 |
+
curriculum: bool,
|
| 128 |
+
) -> Tuple[str, str]:
|
| 129 |
+
if not curriculum:
|
| 130 |
+
if selected_task == "all":
|
| 131 |
+
task = TASKS[(episode - 1) % len(TASKS)]
|
| 132 |
+
return task, "round_robin"
|
| 133 |
+
return selected_task, "fixed"
|
| 134 |
+
|
| 135 |
+
progress = episode / max(1, total_episodes)
|
| 136 |
+
cumulative = 0.0
|
| 137 |
+
for task, fraction in CURRICULUM_STAGES:
|
| 138 |
+
cumulative += fraction
|
| 139 |
+
if progress <= cumulative:
|
| 140 |
+
return task, f"curriculum:{task}"
|
| 141 |
+
return CURRICULUM_STAGES[-1][0], f"curriculum:{CURRICULUM_STAGES[-1][0]}"
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def save_metrics(
|
| 145 |
+
output_dir: Path,
|
| 146 |
+
rows: List[Dict[str, Any]],
|
| 147 |
+
model_name: str,
|
| 148 |
+
episodes: int,
|
| 149 |
+
curriculum: bool,
|
| 150 |
+
use_tools: bool,
|
| 151 |
+
trainer: str = "pg",
|
| 152 |
+
evaluation_rows: List[Dict[str, Any]] | None = None,
|
| 153 |
+
heldout_evaluation_rows: List[Dict[str, Any]] | None = None,
|
| 154 |
+
prompt_bank_size: int = 0,
|
| 155 |
+
extra: Dict[str, Any] | None = None,
|
| 156 |
+
) -> Path:
|
| 157 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 158 |
+
best_score = max((float(row["score"]) for row in rows), default=0.0)
|
| 159 |
+
metrics_path = output_dir / "metrics.json"
|
| 160 |
+
payload = {
|
| 161 |
+
"model": model_name,
|
| 162 |
+
"episodes": episodes,
|
| 163 |
+
"curriculum": curriculum,
|
| 164 |
+
"curriculum_stages": CURRICULUM_STAGES,
|
| 165 |
+
"use_tools": use_tools,
|
| 166 |
+
"trainer": trainer,
|
| 167 |
+
"rows": rows,
|
| 168 |
+
"best_score": best_score,
|
| 169 |
+
}
|
| 170 |
+
if evaluation_rows is not None:
|
| 171 |
+
payload["evaluation_rows"] = evaluation_rows
|
| 172 |
+
if heldout_evaluation_rows is not None:
|
| 173 |
+
payload["heldout_evaluation_rows"] = heldout_evaluation_rows
|
| 174 |
+
if prompt_bank_size:
|
| 175 |
+
payload["prompt_bank_size"] = prompt_bank_size
|
| 176 |
+
if extra:
|
| 177 |
+
payload.update(extra)
|
| 178 |
+
metrics_path.write_text(json.dumps(payload, indent=2))
|
| 179 |
+
return metrics_path
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def maybe_plot(metrics_path: Path, output_dir: Path) -> None:
|
| 183 |
+
try:
|
| 184 |
+
plot_training_output(metrics_path, output_dir / "reward_curve.png")
|
| 185 |
+
except Exception as exc:
|
| 186 |
+
print(f"Plot generation skipped: {exc}")
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def parse_response(text: str, phase: int) -> Dict[str, Any]:
|
| 190 |
+
"""Parse model JSON. Invalid output becomes a safe phase-correct action."""
|
| 191 |
+
if "```" in text:
|
| 192 |
+
for part in text.split("```"):
|
| 193 |
+
if "{" in part:
|
| 194 |
+
text = part.strip().removeprefix("json").strip()
|
| 195 |
+
break
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
parsed = json.loads(text)
|
| 199 |
+
if phase == 1:
|
| 200 |
+
return {
|
| 201 |
+
"threat_type": str(parsed.get("threat_type", "brute_force")),
|
| 202 |
+
"confidence": float(parsed.get("confidence", 0.5)),
|
| 203 |
+
"target_node": str(parsed.get("target_node", "auth_service")),
|
| 204 |
+
"recommended_action": str(parsed.get("recommended_action", "monitor")),
|
| 205 |
+
"reasoning": str(parsed.get("reasoning", "")),
|
| 206 |
+
}
|
| 207 |
+
return {
|
| 208 |
+
"action": str(parsed.get("action", "monitor")),
|
| 209 |
+
"target_node": str(parsed.get("target_node", "auth_service")),
|
| 210 |
+
"reasoning": str(parsed.get("reasoning", "")),
|
| 211 |
+
}
|
| 212 |
+
except Exception:
|
| 213 |
+
if phase == 1:
|
| 214 |
+
return {
|
| 215 |
+
"threat_type": "brute_force",
|
| 216 |
+
"confidence": 0.5,
|
| 217 |
+
"target_node": "auth_service",
|
| 218 |
+
"recommended_action": "monitor",
|
| 219 |
+
"reasoning": "parse_error",
|
| 220 |
+
}
|
| 221 |
+
return {
|
| 222 |
+
"action": "monitor",
|
| 223 |
+
"target_node": "auth_service",
|
| 224 |
+
"reasoning": "parse_error",
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def render_messages(messages: List[Dict[str, str]], tokenizer: Any | None = None) -> str:
|
| 229 |
+
if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
|
| 230 |
+
return tokenizer.apply_chat_template(
|
| 231 |
+
messages,
|
| 232 |
+
tokenize=False,
|
| 233 |
+
add_generation_prompt=True,
|
| 234 |
+
)
|
| 235 |
+
return "\n\n".join(
|
| 236 |
+
f"{message.get('role', 'user').upper()}:\n{message.get('content', '')}"
|
| 237 |
+
for message in messages
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def generate_response(model: Any, tokenizer: Any, messages: List[Dict[str, str]]) -> Tuple[str, str]:
|
| 242 |
+
import torch
|
| 243 |
+
|
| 244 |
+
prompt = render_messages(messages, tokenizer=tokenizer)
|
| 245 |
+
device = getattr(model, "device", None)
|
| 246 |
+
if device is None:
|
| 247 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 248 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 249 |
+
|
| 250 |
+
pad_token_id = (
|
| 251 |
+
tokenizer.pad_token_id
|
| 252 |
+
if getattr(tokenizer, "pad_token_id", None) is not None
|
| 253 |
+
else tokenizer.eos_token_id
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
with torch.no_grad():
|
| 257 |
+
_normalize_generation_config(model)
|
| 258 |
+
output_ids = model.generate(
|
| 259 |
+
**inputs,
|
| 260 |
+
max_new_tokens=220,
|
| 261 |
+
temperature=0.7,
|
| 262 |
+
do_sample=True,
|
| 263 |
+
pad_token_id=pad_token_id,
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
|
| 267 |
+
response = tokenizer.decode(new_ids, skip_special_tokens=True).strip()
|
| 268 |
+
return prompt, response
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _current_reference(env: AdaptShieldEnvironment) -> Dict[str, Any]:
|
| 272 |
+
turn_config = dict(getattr(env, "_turn_config", {}) or {})
|
| 273 |
+
is_benign = bool(turn_config.get("is_benign", False))
|
| 274 |
+
threat_type = "benign" if is_benign else str(turn_config.get("strategy", "benign"))
|
| 275 |
+
target_node = str(turn_config.get("correct_target", "auth_service"))
|
| 276 |
+
expected_action = str(turn_config.get("correct_action", "monitor"))
|
| 277 |
+
return {
|
| 278 |
+
"threat_type": threat_type,
|
| 279 |
+
"target_node": target_node,
|
| 280 |
+
"expected_action": expected_action,
|
| 281 |
+
"stage": str(turn_config.get("attack_stage", getattr(env._attacker, "current_stage", lambda: "recon")())),
|
| 282 |
+
"is_benign": is_benign,
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def _align_trainable_dtypes(model: Any, target_dtype: Any | None = None) -> str:
|
| 287 |
+
"""Keep LoRA/trainable params on the same compute dtype as the main model.
|
| 288 |
+
|
| 289 |
+
Some adapter checkpoints reload trainable LoRA weights as float32, while
|
| 290 |
+
Unsloth GRPO kernels run activations in float16/bfloat16. That mismatch
|
| 291 |
+
trips fast_lora matmuls at runtime. We fix only trainable floating params.
|
| 292 |
+
"""
|
| 293 |
+
import torch
|
| 294 |
+
|
| 295 |
+
if target_dtype is None:
|
| 296 |
+
for param in model.parameters():
|
| 297 |
+
if param.is_floating_point() and not param.requires_grad:
|
| 298 |
+
target_dtype = param.dtype
|
| 299 |
+
break
|
| 300 |
+
if target_dtype is None:
|
| 301 |
+
for param in model.parameters():
|
| 302 |
+
if param.is_floating_point():
|
| 303 |
+
target_dtype = param.dtype
|
| 304 |
+
break
|
| 305 |
+
if target_dtype is None:
|
| 306 |
+
return "no-floating-params"
|
| 307 |
+
|
| 308 |
+
converted = 0
|
| 309 |
+
for param in model.parameters():
|
| 310 |
+
if param.requires_grad and param.is_floating_point() and param.dtype != target_dtype:
|
| 311 |
+
param.data = param.data.to(target_dtype)
|
| 312 |
+
converted += 1
|
| 313 |
+
|
| 314 |
+
for buffer_name, buffer in model.named_buffers():
|
| 315 |
+
if "lora_" in buffer_name and buffer.is_floating_point() and buffer.dtype != target_dtype:
|
| 316 |
+
buffer.data = buffer.data.to(target_dtype)
|
| 317 |
+
|
| 318 |
+
if getattr(model, "generation_config", None) is not None:
|
| 319 |
+
_normalize_generation_config(model)
|
| 320 |
+
|
| 321 |
+
return f"{target_dtype} ({converted} trainable params aligned)"
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
def _normalize_generation_config(model: Any) -> None:
|
| 325 |
+
generation_config = getattr(model, "generation_config", None)
|
| 326 |
+
if generation_config is None:
|
| 327 |
+
return
|
| 328 |
+
for field in ("max_length",):
|
| 329 |
+
try:
|
| 330 |
+
setattr(generation_config, field, None)
|
| 331 |
+
except Exception:
|
| 332 |
+
continue
|
| 333 |
+
|
| 334 |
+
|
| 335 |
+
def _load_training_model_and_tokenizer(
|
| 336 |
+
model_name: str,
|
| 337 |
+
model_key: str,
|
| 338 |
+
max_seq_length: int,
|
| 339 |
+
compute_dtype: Any,
|
| 340 |
+
seed: int,
|
| 341 |
+
):
|
| 342 |
+
from unsloth import FastLanguageModel
|
| 343 |
+
|
| 344 |
+
adapter_path = model_name if _looks_like_adapter_path(model_name) else ""
|
| 345 |
+
base_model_name = MODEL_CHOICES[model_key] if adapter_path else model_name
|
| 346 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 347 |
+
model_name=base_model_name,
|
| 348 |
+
max_seq_length=max_seq_length,
|
| 349 |
+
load_in_4bit=True,
|
| 350 |
+
dtype=compute_dtype,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if adapter_path:
|
| 354 |
+
from peft import PeftModel
|
| 355 |
+
|
| 356 |
+
model = PeftModel.from_pretrained(
|
| 357 |
+
model,
|
| 358 |
+
adapter_path,
|
| 359 |
+
is_trainable=True,
|
| 360 |
+
autocast_adapter_dtype=False,
|
| 361 |
+
)
|
| 362 |
+
try:
|
| 363 |
+
from transformers import AutoTokenizer
|
| 364 |
+
|
| 365 |
+
tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
|
| 366 |
+
except Exception:
|
| 367 |
+
pass
|
| 368 |
+
else:
|
| 369 |
+
model = FastLanguageModel.get_peft_model(
|
| 370 |
+
model,
|
| 371 |
+
r=LORA_RANK,
|
| 372 |
+
target_modules=[
|
| 373 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 374 |
+
"gate_proj", "up_proj", "down_proj",
|
| 375 |
+
],
|
| 376 |
+
lora_alpha=LORA_RANK * 2,
|
| 377 |
+
lora_dropout=0.0,
|
| 378 |
+
bias="none",
|
| 379 |
+
use_gradient_checkpointing="unsloth",
|
| 380 |
+
random_state=seed,
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
return model, tokenizer
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _teacher_payload(phase: int, reference: Dict[str, Any]) -> Dict[str, Any]:
|
| 387 |
+
if phase == 1:
|
| 388 |
+
return {
|
| 389 |
+
"threat_type": reference["threat_type"],
|
| 390 |
+
"confidence": 0.92 if reference["threat_type"] != "benign" else 0.78,
|
| 391 |
+
"target_node": reference["target_node"],
|
| 392 |
+
"recommended_action": reference["expected_action"],
|
| 393 |
+
"reasoning": "reference policy",
|
| 394 |
+
}
|
| 395 |
+
return {
|
| 396 |
+
"action": reference["expected_action"],
|
| 397 |
+
"target_node": reference["target_node"],
|
| 398 |
+
"reasoning": "reference policy",
|
| 399 |
+
}
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def build_prompt_bank(
|
| 403 |
+
tokenizer: Any | None,
|
| 404 |
+
selected_task: str,
|
| 405 |
+
curriculum: bool,
|
| 406 |
+
rollout_episodes: int,
|
| 407 |
+
max_steps: int,
|
| 408 |
+
use_tools: bool,
|
| 409 |
+
seed: int,
|
| 410 |
+
world_split: str = "train",
|
| 411 |
+
world_family: str | None = None,
|
| 412 |
+
hard_multiplier: int = 2,
|
| 413 |
+
borderline_bonus: int = 1,
|
| 414 |
+
) -> List[Dict[str, Any]]:
|
| 415 |
+
random.seed(seed)
|
| 416 |
+
rows: List[Dict[str, Any]] = []
|
| 417 |
+
for episode in range(1, rollout_episodes + 1):
|
| 418 |
+
task, stage = task_for_episode(
|
| 419 |
+
episode=episode,
|
| 420 |
+
total_episodes=rollout_episodes,
|
| 421 |
+
selected_task=selected_task,
|
| 422 |
+
curriculum=curriculum,
|
| 423 |
+
)
|
| 424 |
+
env = AdaptShieldEnvironment(
|
| 425 |
+
task_name=task,
|
| 426 |
+
world_split=world_split,
|
| 427 |
+
world_family=world_family,
|
| 428 |
+
)
|
| 429 |
+
obs = env.reset()
|
| 430 |
+
step_count = 0
|
| 431 |
+
while not obs.done and step_count < max_steps:
|
| 432 |
+
phase = int(getattr(obs, "phase", 1))
|
| 433 |
+
tool_results = investigate_local_with_depth(
|
| 434 |
+
env,
|
| 435 |
+
obs,
|
| 436 |
+
use_tools=use_tools,
|
| 437 |
+
thorough=True,
|
| 438 |
+
)
|
| 439 |
+
obs_dict = attach_tool_results(obs_to_dict(obs), tool_results)
|
| 440 |
+
messages = build_messages(obs_dict)
|
| 441 |
+
reference = _current_reference(env)
|
| 442 |
+
rows.append({
|
| 443 |
+
"prompt": render_messages(messages, tokenizer=tokenizer),
|
| 444 |
+
"task": task,
|
| 445 |
+
"stage": stage,
|
| 446 |
+
"phase": phase,
|
| 447 |
+
"turn": int(getattr(obs, "turn", 0) or 0),
|
| 448 |
+
"attack_stage": reference["stage"],
|
| 449 |
+
"world_split": getattr(env, "_world_split", world_split),
|
| 450 |
+
"world_family": getattr(env, "_world_family", world_family or ""),
|
| 451 |
+
"operational_mode": getattr(env, "_operational_mode", ""),
|
| 452 |
+
"expected_threat_type": reference["threat_type"],
|
| 453 |
+
"expected_target_node": reference["target_node"],
|
| 454 |
+
"expected_recommended_action": reference["expected_action"] if phase == 1 else "",
|
| 455 |
+
"expected_action": reference["expected_action"] if phase == 2 else "",
|
| 456 |
+
"tool_calls": len(tool_results),
|
| 457 |
+
"history_length": len(obs_dict.get("history", [])),
|
| 458 |
+
"difficulty_tags": _difficulty_tags(
|
| 459 |
+
task=task,
|
| 460 |
+
phase=phase,
|
| 461 |
+
attack_stage=reference["stage"],
|
| 462 |
+
tool_calls=len(tool_results),
|
| 463 |
+
handoff_quality=str((obs_dict.get("phase1_assessment") or {}).get("handoff_quality", "")),
|
| 464 |
+
),
|
| 465 |
+
})
|
| 466 |
+
base_row = rows[-1]
|
| 467 |
+
for _ in range(_prompt_bank_extra_copies(
|
| 468 |
+
row=base_row,
|
| 469 |
+
hard_multiplier=hard_multiplier,
|
| 470 |
+
borderline_bonus=borderline_bonus,
|
| 471 |
+
)):
|
| 472 |
+
rows.append(dict(base_row))
|
| 473 |
+
obs = env.step(AdaptShieldAction(**_teacher_payload(phase, reference)))
|
| 474 |
+
step_count += 1
|
| 475 |
+
return rows
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def _difficulty_tags(
|
| 479 |
+
task: str,
|
| 480 |
+
phase: int,
|
| 481 |
+
attack_stage: str,
|
| 482 |
+
tool_calls: int,
|
| 483 |
+
handoff_quality: str,
|
| 484 |
+
) -> List[str]:
|
| 485 |
+
tags: List[str] = []
|
| 486 |
+
if task == "polymorphic-zero-day":
|
| 487 |
+
tags.append("hard")
|
| 488 |
+
elif task == "dual-pivot":
|
| 489 |
+
tags.append("medium")
|
| 490 |
+
if phase == 2:
|
| 491 |
+
tags.append("phase2")
|
| 492 |
+
if attack_stage in {"exploit", "exfiltration"}:
|
| 493 |
+
tags.append("late_stage")
|
| 494 |
+
if tool_calls >= 3:
|
| 495 |
+
tags.append("tool_fusion")
|
| 496 |
+
if handoff_quality == "degraded":
|
| 497 |
+
tags.append("borderline")
|
| 498 |
+
return tags
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def _prompt_bank_extra_copies(
|
| 502 |
+
row: Dict[str, Any],
|
| 503 |
+
hard_multiplier: int,
|
| 504 |
+
borderline_bonus: int,
|
| 505 |
+
) -> int:
|
| 506 |
+
tags = set(row.get("difficulty_tags", []) or [])
|
| 507 |
+
extra = 0
|
| 508 |
+
if row.get("task") == "polymorphic-zero-day":
|
| 509 |
+
extra += max(0, hard_multiplier - 1)
|
| 510 |
+
elif row.get("task") == "dual-pivot" and "late_stage" in tags:
|
| 511 |
+
extra += 1
|
| 512 |
+
if "borderline" in tags or ("phase2" in tags and "tool_fusion" in tags and "late_stage" in tags):
|
| 513 |
+
extra += max(0, borderline_bonus)
|
| 514 |
+
return extra
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
def _completion_to_text(completion: Any) -> str:
|
| 518 |
+
if isinstance(completion, str):
|
| 519 |
+
return completion
|
| 520 |
+
if isinstance(completion, dict):
|
| 521 |
+
if "content" in completion:
|
| 522 |
+
return str(completion.get("content", ""))
|
| 523 |
+
if "text" in completion:
|
| 524 |
+
return str(completion.get("text", ""))
|
| 525 |
+
if isinstance(completion, list):
|
| 526 |
+
parts = []
|
| 527 |
+
for item in completion:
|
| 528 |
+
if isinstance(item, dict):
|
| 529 |
+
parts.append(str(item.get("content", item.get("text", ""))))
|
| 530 |
+
else:
|
| 531 |
+
parts.append(str(item))
|
| 532 |
+
return "".join(parts)
|
| 533 |
+
return str(completion)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def _phase1_reward(
|
| 537 |
+
parsed: Dict[str, Any],
|
| 538 |
+
expected_threat_type: str,
|
| 539 |
+
expected_target_node: str,
|
| 540 |
+
expected_recommended_action: str,
|
| 541 |
+
) -> float:
|
| 542 |
+
reward = 0.08
|
| 543 |
+
if parsed.get("threat_type") == expected_threat_type:
|
| 544 |
+
reward += 0.36
|
| 545 |
+
if parsed.get("target_node") == expected_target_node:
|
| 546 |
+
reward += 0.20
|
| 547 |
+
if parsed.get("recommended_action") == expected_recommended_action:
|
| 548 |
+
reward += 0.18
|
| 549 |
+
try:
|
| 550 |
+
confidence = float(parsed.get("confidence", 0.5))
|
| 551 |
+
except Exception:
|
| 552 |
+
confidence = 0.5
|
| 553 |
+
if 0.0 <= confidence <= 1.0:
|
| 554 |
+
reward += 0.05
|
| 555 |
+
if parsed.get("threat_type") == expected_threat_type and confidence >= 0.65:
|
| 556 |
+
reward += 0.06
|
| 557 |
+
elif parsed.get("threat_type") != expected_threat_type and confidence >= 0.80:
|
| 558 |
+
reward -= 0.05
|
| 559 |
+
if parsed.get("recommended_action") == "monitor" and expected_threat_type != "benign":
|
| 560 |
+
reward -= 0.05
|
| 561 |
+
return max(0.01, min(0.99, round(reward, 2)))
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
def _phase2_reward(
|
| 565 |
+
parsed: Dict[str, Any],
|
| 566 |
+
expected_action: str,
|
| 567 |
+
expected_target_node: str,
|
| 568 |
+
tool_calls: int,
|
| 569 |
+
) -> float:
|
| 570 |
+
reward = 0.08
|
| 571 |
+
if parsed.get("action") == expected_action:
|
| 572 |
+
reward += 0.62
|
| 573 |
+
if parsed.get("target_node") == expected_target_node:
|
| 574 |
+
reward += 0.18
|
| 575 |
+
if parsed.get("action") == expected_action and tool_calls >= 2:
|
| 576 |
+
reward += 0.07
|
| 577 |
+
if parsed.get("action") == "monitor" and expected_action != "monitor":
|
| 578 |
+
reward -= 0.08
|
| 579 |
+
return max(0.01, min(0.99, round(reward, 2)))
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def build_grpo_reward_fn():
|
| 583 |
+
def reward_fn(completions: List[Any], **kwargs: Any) -> List[float]:
|
| 584 |
+
phases = kwargs.get("phase", [])
|
| 585 |
+
expected_threat_types = kwargs.get("expected_threat_type", [])
|
| 586 |
+
expected_targets = kwargs.get("expected_target_node", [])
|
| 587 |
+
expected_recommended_actions = kwargs.get("expected_recommended_action", [])
|
| 588 |
+
expected_actions = kwargs.get("expected_action", [])
|
| 589 |
+
tool_calls = kwargs.get("tool_calls", [])
|
| 590 |
+
rewards: List[float] = []
|
| 591 |
+
for index, completion in enumerate(completions):
|
| 592 |
+
phase = int(phases[index]) if phases else 1
|
| 593 |
+
text = _completion_to_text(completion)
|
| 594 |
+
parsed = parse_response(text, phase)
|
| 595 |
+
if phase == 1:
|
| 596 |
+
reward = _phase1_reward(
|
| 597 |
+
parsed=parsed,
|
| 598 |
+
expected_threat_type=str(expected_threat_types[index]),
|
| 599 |
+
expected_target_node=str(expected_targets[index]),
|
| 600 |
+
expected_recommended_action=str(expected_recommended_actions[index]),
|
| 601 |
+
)
|
| 602 |
+
else:
|
| 603 |
+
reward = _phase2_reward(
|
| 604 |
+
parsed=parsed,
|
| 605 |
+
expected_action=str(expected_actions[index]),
|
| 606 |
+
expected_target_node=str(expected_targets[index]),
|
| 607 |
+
tool_calls=int(tool_calls[index]) if tool_calls else 0,
|
| 608 |
+
)
|
| 609 |
+
rewards.append(reward)
|
| 610 |
+
return rewards
|
| 611 |
+
|
| 612 |
+
return reward_fn
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def _filter_supported_kwargs(callable_obj: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
| 616 |
+
try:
|
| 617 |
+
signature = inspect.signature(callable_obj)
|
| 618 |
+
except (TypeError, ValueError):
|
| 619 |
+
return kwargs
|
| 620 |
+
valid = {}
|
| 621 |
+
for key, value in kwargs.items():
|
| 622 |
+
if key in signature.parameters:
|
| 623 |
+
valid[key] = value
|
| 624 |
+
return valid
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def _trainer_log_rows(log_history: List[Dict[str, Any]], selected_task: str) -> List[Dict[str, Any]]:
|
| 628 |
+
rows: List[Dict[str, Any]] = []
|
| 629 |
+
for entry in log_history:
|
| 630 |
+
step = entry.get("step")
|
| 631 |
+
if step is None:
|
| 632 |
+
continue
|
| 633 |
+
reward_keys = [
|
| 634 |
+
"reward",
|
| 635 |
+
"mean_reward",
|
| 636 |
+
"rewards/mean",
|
| 637 |
+
"objective",
|
| 638 |
+
"objective/rlhf_reward",
|
| 639 |
+
]
|
| 640 |
+
score = None
|
| 641 |
+
for key in reward_keys:
|
| 642 |
+
if key in entry:
|
| 643 |
+
try:
|
| 644 |
+
score = float(entry[key])
|
| 645 |
+
break
|
| 646 |
+
except Exception:
|
| 647 |
+
continue
|
| 648 |
+
if score is None:
|
| 649 |
+
score = 0.50
|
| 650 |
+
row = {
|
| 651 |
+
"episode": int(step),
|
| 652 |
+
"task": "mixed" if selected_task == "all" else selected_task,
|
| 653 |
+
"stage": "grpo",
|
| 654 |
+
"score": max(0.01, min(0.99, score)),
|
| 655 |
+
"loss": float(entry.get("loss", 0.0) or 0.0),
|
| 656 |
+
"learning_rate": float(entry.get("learning_rate", 0.0) or 0.0),
|
| 657 |
+
}
|
| 658 |
+
rows.append(row)
|
| 659 |
+
return rows
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
def evaluate_model_suite(
|
| 663 |
+
model: Any,
|
| 664 |
+
tokenizer: Any,
|
| 665 |
+
selected_task: str,
|
| 666 |
+
eval_episodes: int,
|
| 667 |
+
max_steps: int,
|
| 668 |
+
use_tools: bool,
|
| 669 |
+
world_split: str = "train",
|
| 670 |
+
world_family: str | None = None,
|
| 671 |
+
seed_start: int | None = None,
|
| 672 |
+
) -> List[Dict[str, Any]]:
|
| 673 |
+
tasks = TASKS if selected_task == "all" else [selected_task]
|
| 674 |
+
rows: List[Dict[str, Any]] = []
|
| 675 |
+
for task in tasks:
|
| 676 |
+
scores: List[float] = []
|
| 677 |
+
steps: List[int] = []
|
| 678 |
+
tool_calls: List[int] = []
|
| 679 |
+
original_seed = os.environ.get("ADAPTSHIELD_SEED")
|
| 680 |
+
for episode_index in range(eval_episodes):
|
| 681 |
+
if seed_start is not None:
|
| 682 |
+
os.environ["ADAPTSHIELD_SEED"] = str(seed_start + len(rows) * 100 + episode_index)
|
| 683 |
+
_, metrics = run_model_episode(
|
| 684 |
+
model=model,
|
| 685 |
+
tokenizer=tokenizer,
|
| 686 |
+
task=task,
|
| 687 |
+
max_steps=max_steps,
|
| 688 |
+
use_tools=use_tools,
|
| 689 |
+
world_split=world_split,
|
| 690 |
+
world_family=world_family,
|
| 691 |
+
)
|
| 692 |
+
scores.append(float(metrics["score"]))
|
| 693 |
+
steps.append(int(metrics["steps"]))
|
| 694 |
+
tool_calls.append(int(metrics["tool_calls"]))
|
| 695 |
+
if original_seed is None:
|
| 696 |
+
os.environ.pop("ADAPTSHIELD_SEED", None)
|
| 697 |
+
else:
|
| 698 |
+
os.environ["ADAPTSHIELD_SEED"] = original_seed
|
| 699 |
+
rows.append({
|
| 700 |
+
"episode": len(rows) + 1,
|
| 701 |
+
"task": task,
|
| 702 |
+
"stage": "evaluation",
|
| 703 |
+
"score": round(sum(scores) / len(scores), 3) if scores else 0.50,
|
| 704 |
+
"steps": round(sum(steps) / len(steps), 2) if steps else 0.0,
|
| 705 |
+
"tool_calls": round(sum(tool_calls) / len(tool_calls), 2) if tool_calls else 0.0,
|
| 706 |
+
"eval_episodes": eval_episodes,
|
| 707 |
+
"world_split": world_split,
|
| 708 |
+
"world_family": world_family or "auto",
|
| 709 |
+
})
|
| 710 |
+
return rows
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def run_model_episode(
|
| 714 |
+
model: Any,
|
| 715 |
+
tokenizer: Any,
|
| 716 |
+
task: str,
|
| 717 |
+
max_steps: int,
|
| 718 |
+
use_tools: bool,
|
| 719 |
+
world_split: str = "train",
|
| 720 |
+
world_family: str | None = None,
|
| 721 |
+
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
| 722 |
+
env = AdaptShieldEnvironment(
|
| 723 |
+
task_name=task,
|
| 724 |
+
world_split=world_split,
|
| 725 |
+
world_family=world_family,
|
| 726 |
+
)
|
| 727 |
+
obs = env.reset()
|
| 728 |
+
samples: List[Dict[str, Any]] = []
|
| 729 |
+
rewards: List[float] = []
|
| 730 |
+
tool_calls = 0
|
| 731 |
+
|
| 732 |
+
while not obs.done and len(samples) < max_steps:
|
| 733 |
+
phase = int(getattr(obs, "phase", 1))
|
| 734 |
+
tool_results = investigate_local_with_depth(
|
| 735 |
+
env,
|
| 736 |
+
obs,
|
| 737 |
+
use_tools=use_tools,
|
| 738 |
+
thorough=True,
|
| 739 |
+
)
|
| 740 |
+
tool_calls += len(tool_results)
|
| 741 |
+
obs_dict = obs_to_dict(obs)
|
| 742 |
+
obs_dict = attach_tool_results(obs_dict, tool_results)
|
| 743 |
+
messages = build_messages(obs_dict)
|
| 744 |
+
prompt, response = generate_response(model, tokenizer, messages)
|
| 745 |
+
payload = parse_response(response, phase)
|
| 746 |
+
|
| 747 |
+
try:
|
| 748 |
+
obs = env.step(AdaptShieldAction(**payload))
|
| 749 |
+
reward = float(obs.reward)
|
| 750 |
+
except Exception as exc:
|
| 751 |
+
reward = 0.01
|
| 752 |
+
samples.append({
|
| 753 |
+
"prompt": prompt,
|
| 754 |
+
"response": response,
|
| 755 |
+
"reward": reward,
|
| 756 |
+
"phase": phase,
|
| 757 |
+
"tool_calls": len(tool_results),
|
| 758 |
+
"error": str(exc),
|
| 759 |
+
})
|
| 760 |
+
break
|
| 761 |
+
|
| 762 |
+
rewards.append(reward)
|
| 763 |
+
samples.append({
|
| 764 |
+
"prompt": prompt,
|
| 765 |
+
"response": response,
|
| 766 |
+
"reward": reward,
|
| 767 |
+
"phase": phase,
|
| 768 |
+
"tool_calls": len(tool_results),
|
| 769 |
+
"error": None,
|
| 770 |
+
})
|
| 771 |
+
|
| 772 |
+
metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
|
| 773 |
+
if "normalized_score" not in metadata:
|
| 774 |
+
raise RuntimeError("normalized_score missing after training episode")
|
| 775 |
+
|
| 776 |
+
return samples, {
|
| 777 |
+
"score": float(metadata["normalized_score"]),
|
| 778 |
+
"steps": len(samples),
|
| 779 |
+
"reward_sum": sum(rewards),
|
| 780 |
+
"mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
|
| 781 |
+
"tool_calls": tool_calls,
|
| 782 |
+
"world_split": world_split,
|
| 783 |
+
"world_family": metadata.get("world_family", world_family or "auto"),
|
| 784 |
+
"operational_mode": metadata.get("operational_mode", "unknown"),
|
| 785 |
+
}
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def train_policy_gradient(args: argparse.Namespace) -> None:
|
| 789 |
+
import torch
|
| 790 |
+
from torch.optim import AdamW
|
| 791 |
+
|
| 792 |
+
random.seed(args.seed)
|
| 793 |
+
torch.manual_seed(args.seed)
|
| 794 |
+
|
| 795 |
+
model_name = args.model_path or MODEL_CHOICES[args.model]
|
| 796 |
+
output_dir = Path(args.output)
|
| 797 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 798 |
+
|
| 799 |
+
print("AdaptShield policy-gradient GPU training")
|
| 800 |
+
print(f"Task: {args.task}")
|
| 801 |
+
print(f"Curriculum: {args.curriculum}")
|
| 802 |
+
print(f"Use tools: {args.use_tools}")
|
| 803 |
+
print(f"Model: {model_name}")
|
| 804 |
+
print(f"Episodes: {args.episodes}")
|
| 805 |
+
print(f"Output: {output_dir}")
|
| 806 |
+
print()
|
| 807 |
+
|
| 808 |
+
model, tokenizer = _load_training_model_and_tokenizer(
|
| 809 |
+
model_name=model_name,
|
| 810 |
+
model_key=args.model,
|
| 811 |
+
max_seq_length=MAX_SEQ_LEN,
|
| 812 |
+
compute_dtype=None,
|
| 813 |
+
seed=args.seed,
|
| 814 |
+
)
|
| 815 |
+
from unsloth import FastLanguageModel
|
| 816 |
+
FastLanguageModel.for_training(model)
|
| 817 |
+
dtype_summary = _align_trainable_dtypes(model)
|
| 818 |
+
print(f"Aligned trainable parameter dtypes: {dtype_summary}")
|
| 819 |
+
optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
|
| 820 |
+
|
| 821 |
+
rows: List[Dict[str, Any]] = []
|
| 822 |
+
best_score = -1.0
|
| 823 |
+
for episode in range(1, args.episodes + 1):
|
| 824 |
+
started = time.time()
|
| 825 |
+
task, stage = task_for_episode(
|
| 826 |
+
episode=episode,
|
| 827 |
+
total_episodes=args.episodes,
|
| 828 |
+
selected_task=args.task,
|
| 829 |
+
curriculum=args.curriculum,
|
| 830 |
+
)
|
| 831 |
+
samples, metrics = run_model_episode(
|
| 832 |
+
model=model,
|
| 833 |
+
tokenizer=tokenizer,
|
| 834 |
+
task=task,
|
| 835 |
+
max_steps=args.max_steps,
|
| 836 |
+
use_tools=args.use_tools,
|
| 837 |
+
world_split=args.train_world_split,
|
| 838 |
+
)
|
| 839 |
+
rewards = [float(sample["reward"]) for sample in samples]
|
| 840 |
+
baseline = sum(rewards) / len(rewards) if rewards else 0.0
|
| 841 |
+
total_loss = 0.0
|
| 842 |
+
|
| 843 |
+
for sample in samples:
|
| 844 |
+
advantage = float(sample["reward"]) - baseline
|
| 845 |
+
full_text = sample["prompt"] + sample["response"] + tokenizer.eos_token
|
| 846 |
+
inputs = tokenizer(
|
| 847 |
+
full_text,
|
| 848 |
+
return_tensors="pt",
|
| 849 |
+
truncation=True,
|
| 850 |
+
max_length=MAX_SEQ_LEN,
|
| 851 |
+
).to("cuda")
|
| 852 |
+
outputs = model(**inputs, labels=inputs["input_ids"])
|
| 853 |
+
loss = outputs.loss * (-advantage)
|
| 854 |
+
|
| 855 |
+
optimizer.zero_grad()
|
| 856 |
+
loss.backward()
|
| 857 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 858 |
+
optimizer.step()
|
| 859 |
+
total_loss += float(loss.item())
|
| 860 |
+
|
| 861 |
+
row = {
|
| 862 |
+
"episode": episode,
|
| 863 |
+
"task": task,
|
| 864 |
+
"stage": stage,
|
| 865 |
+
"score": metrics["score"],
|
| 866 |
+
"steps": metrics["steps"],
|
| 867 |
+
"reward_sum": metrics["reward_sum"],
|
| 868 |
+
"mean_reward": metrics["mean_reward"],
|
| 869 |
+
"tool_calls": metrics["tool_calls"],
|
| 870 |
+
"loss": total_loss,
|
| 871 |
+
"seconds": round(time.time() - started, 2),
|
| 872 |
+
}
|
| 873 |
+
rows.append(row)
|
| 874 |
+
|
| 875 |
+
print(
|
| 876 |
+
f"episode={episode:03d} task={task:<20} "
|
| 877 |
+
f"stage={stage:<32} "
|
| 878 |
+
f"score={row['score']:.3f} mean_reward={row['mean_reward']:.3f} "
|
| 879 |
+
f"loss={row['loss']:.4f} steps={row['steps']:02d} tools={row['tool_calls']:02d}"
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
if row["score"] > best_score:
|
| 883 |
+
best_score = row["score"]
|
| 884 |
+
model.save_pretrained(output_dir / "best")
|
| 885 |
+
tokenizer.save_pretrained(output_dir / "best")
|
| 886 |
+
|
| 887 |
+
if args.save_every and episode % args.save_every == 0:
|
| 888 |
+
model.save_pretrained(output_dir / f"checkpoint-{episode}")
|
| 889 |
+
tokenizer.save_pretrained(output_dir / f"checkpoint-{episode}")
|
| 890 |
+
|
| 891 |
+
model.save_pretrained(output_dir / "final")
|
| 892 |
+
tokenizer.save_pretrained(output_dir / "final")
|
| 893 |
+
|
| 894 |
+
evaluation_rows = evaluate_model_suite(
|
| 895 |
+
model=model,
|
| 896 |
+
tokenizer=tokenizer,
|
| 897 |
+
selected_task=args.task,
|
| 898 |
+
eval_episodes=args.eval_episodes,
|
| 899 |
+
max_steps=args.max_steps,
|
| 900 |
+
use_tools=args.use_tools,
|
| 901 |
+
world_split=args.train_world_split,
|
| 902 |
+
seed_start=args.heldout_seed,
|
| 903 |
+
)
|
| 904 |
+
heldout_evaluation_rows = evaluate_model_suite(
|
| 905 |
+
model=model,
|
| 906 |
+
tokenizer=tokenizer,
|
| 907 |
+
selected_task=args.task,
|
| 908 |
+
eval_episodes=args.eval_episodes,
|
| 909 |
+
max_steps=args.max_steps,
|
| 910 |
+
use_tools=args.use_tools,
|
| 911 |
+
world_split=args.heldout_world_split,
|
| 912 |
+
seed_start=args.heldout_seed,
|
| 913 |
+
)
|
| 914 |
+
|
| 915 |
+
metrics_path = save_metrics(
|
| 916 |
+
output_dir=output_dir,
|
| 917 |
+
rows=rows,
|
| 918 |
+
model_name=model_name,
|
| 919 |
+
episodes=args.episodes,
|
| 920 |
+
curriculum=args.curriculum,
|
| 921 |
+
use_tools=args.use_tools,
|
| 922 |
+
trainer="pg",
|
| 923 |
+
evaluation_rows=evaluation_rows,
|
| 924 |
+
heldout_evaluation_rows=heldout_evaluation_rows,
|
| 925 |
+
extra={
|
| 926 |
+
"train_world_split": args.train_world_split,
|
| 927 |
+
"heldout_world_split": args.heldout_world_split,
|
| 928 |
+
"heldout_seed": args.heldout_seed,
|
| 929 |
+
},
|
| 930 |
+
)
|
| 931 |
+
if args.plot:
|
| 932 |
+
maybe_plot(metrics_path, output_dir)
|
| 933 |
+
print()
|
| 934 |
+
print(f"Training complete. Best score: {best_score:.3f}")
|
| 935 |
+
print("Post-train online evaluation:")
|
| 936 |
+
for row in evaluation_rows:
|
| 937 |
+
print(
|
| 938 |
+
f" task={row['task']:<20} score={row['score']:.3f} "
|
| 939 |
+
f"steps={row['steps']} tools={row['tool_calls']}"
|
| 940 |
+
)
|
| 941 |
+
print("Held-out family evaluation:")
|
| 942 |
+
for row in heldout_evaluation_rows:
|
| 943 |
+
print(
|
| 944 |
+
f" task={row['task']:<20} score={row['score']:.3f} "
|
| 945 |
+
f"steps={row['steps']} tools={row['tool_calls']}"
|
| 946 |
+
)
|
| 947 |
+
print(f"Metrics saved to: {metrics_path}")
|
| 948 |
+
|
| 949 |
+
|
| 950 |
+
def train_grpo(args: argparse.Namespace) -> None:
|
| 951 |
+
from datasets import Dataset
|
| 952 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 953 |
+
import torch
|
| 954 |
+
|
| 955 |
+
random.seed(args.seed)
|
| 956 |
+
torch.manual_seed(args.seed)
|
| 957 |
+
|
| 958 |
+
model_name = args.model_path or MODEL_CHOICES[args.model]
|
| 959 |
+
output_dir = Path(args.output)
|
| 960 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 961 |
+
|
| 962 |
+
print("AdaptShield GRPO training")
|
| 963 |
+
print(f"Task: {args.task}")
|
| 964 |
+
print(f"Curriculum: {args.curriculum}")
|
| 965 |
+
print(f"Use tools: {args.use_tools}")
|
| 966 |
+
print(f"Model: {model_name}")
|
| 967 |
+
print(f"Prompt-bank episodes: {args.prompt_bank_episodes}")
|
| 968 |
+
print(f"GRPO epochs: {args.grpo_epochs}")
|
| 969 |
+
print(f"Eval episodes: {args.eval_episodes}")
|
| 970 |
+
print(f"Output: {output_dir}")
|
| 971 |
+
print()
|
| 972 |
+
|
| 973 |
+
bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
|
| 974 |
+
compute_dtype = torch.bfloat16 if bf16_supported else torch.float16
|
| 975 |
+
model, tokenizer = _load_training_model_and_tokenizer(
|
| 976 |
+
model_name=model_name,
|
| 977 |
+
model_key=args.model,
|
| 978 |
+
max_seq_length=MAX_SEQ_LEN,
|
| 979 |
+
compute_dtype=compute_dtype,
|
| 980 |
+
seed=args.seed,
|
| 981 |
+
)
|
| 982 |
+
from unsloth import FastLanguageModel
|
| 983 |
+
if getattr(tokenizer, "pad_token", None) is None:
|
| 984 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 985 |
+
if getattr(model, "config", None) is not None:
|
| 986 |
+
try:
|
| 987 |
+
model.config.return_dict = True
|
| 988 |
+
except Exception:
|
| 989 |
+
pass
|
| 990 |
+
try:
|
| 991 |
+
model.config.use_cache = False
|
| 992 |
+
except Exception:
|
| 993 |
+
pass
|
| 994 |
+
if getattr(model, "generation_config", None) is not None:
|
| 995 |
+
try:
|
| 996 |
+
model.generation_config.pad_token_id = tokenizer.pad_token_id
|
| 997 |
+
except Exception:
|
| 998 |
+
pass
|
| 999 |
+
FastLanguageModel.for_training(model)
|
| 1000 |
+
dtype_summary = _align_trainable_dtypes(model, target_dtype=compute_dtype)
|
| 1001 |
+
print(f"Using GRPO compute dtype: {compute_dtype}")
|
| 1002 |
+
print(f"Aligned trainable parameter dtypes: {dtype_summary}")
|
| 1003 |
+
|
| 1004 |
+
prompt_bank = build_prompt_bank(
|
| 1005 |
+
tokenizer=tokenizer,
|
| 1006 |
+
selected_task=args.task,
|
| 1007 |
+
curriculum=args.curriculum,
|
| 1008 |
+
rollout_episodes=args.prompt_bank_episodes,
|
| 1009 |
+
max_steps=args.max_steps,
|
| 1010 |
+
use_tools=args.use_tools,
|
| 1011 |
+
seed=args.seed,
|
| 1012 |
+
world_split=args.train_world_split,
|
| 1013 |
+
hard_multiplier=args.prompt_bank_hard_multiplier,
|
| 1014 |
+
borderline_bonus=args.prompt_bank_borderline_bonus,
|
| 1015 |
+
)
|
| 1016 |
+
if not prompt_bank:
|
| 1017 |
+
raise RuntimeError("Prompt bank is empty; cannot start GRPO training.")
|
| 1018 |
+
|
| 1019 |
+
dataset = Dataset.from_list(prompt_bank)
|
| 1020 |
+
reward_fn = build_grpo_reward_fn()
|
| 1021 |
+
|
| 1022 |
+
config_kwargs = {
|
| 1023 |
+
"output_dir": str(output_dir),
|
| 1024 |
+
"learning_rate": args.lr,
|
| 1025 |
+
"per_device_train_batch_size": args.per_device_batch_size,
|
| 1026 |
+
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| 1027 |
+
"num_train_epochs": args.grpo_epochs,
|
| 1028 |
+
"max_prompt_length": MAX_SEQ_LEN - 256,
|
| 1029 |
+
"max_completion_length": 256,
|
| 1030 |
+
"num_generations": args.num_generations,
|
| 1031 |
+
"logging_steps": 1,
|
| 1032 |
+
"save_strategy": "no" if args.save_every <= 0 else "steps",
|
| 1033 |
+
"report_to": "none",
|
| 1034 |
+
"remove_unused_columns": False,
|
| 1035 |
+
"bf16": bf16_supported,
|
| 1036 |
+
"fp16": not bf16_supported,
|
| 1037 |
+
"max_grad_norm": 1.0,
|
| 1038 |
+
"seed": args.seed,
|
| 1039 |
+
}
|
| 1040 |
+
if args.save_every > 0:
|
| 1041 |
+
config_kwargs["save_steps"] = args.save_every
|
| 1042 |
+
grpo_config = GRPOConfig(**_filter_supported_kwargs(GRPOConfig, config_kwargs))
|
| 1043 |
+
|
| 1044 |
+
trainer_kwargs = {
|
| 1045 |
+
"model": model,
|
| 1046 |
+
"reward_funcs": [reward_fn],
|
| 1047 |
+
"args": grpo_config,
|
| 1048 |
+
"train_dataset": dataset,
|
| 1049 |
+
"processing_class": tokenizer,
|
| 1050 |
+
"tokenizer": tokenizer,
|
| 1051 |
+
}
|
| 1052 |
+
trainer = GRPOTrainer(**_filter_supported_kwargs(GRPOTrainer, trainer_kwargs))
|
| 1053 |
+
trainer.train()
|
| 1054 |
+
|
| 1055 |
+
model.save_pretrained(output_dir / "final")
|
| 1056 |
+
tokenizer.save_pretrained(output_dir / "final")
|
| 1057 |
+
|
| 1058 |
+
log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
|
| 1059 |
+
train_rows = _trainer_log_rows(log_history, selected_task=args.task)
|
| 1060 |
+
if not train_rows:
|
| 1061 |
+
train_rows = [{
|
| 1062 |
+
"episode": index + 1,
|
| 1063 |
+
"task": "mixed" if args.task == "all" else args.task,
|
| 1064 |
+
"stage": "grpo",
|
| 1065 |
+
"score": 0.50,
|
| 1066 |
+
} for index in range(max(1, args.grpo_epochs))]
|
| 1067 |
+
|
| 1068 |
+
try:
|
| 1069 |
+
evaluation_rows = evaluate_model_suite(
|
| 1070 |
+
model=model,
|
| 1071 |
+
tokenizer=tokenizer,
|
| 1072 |
+
selected_task=args.task,
|
| 1073 |
+
eval_episodes=args.eval_episodes,
|
| 1074 |
+
max_steps=args.max_steps,
|
| 1075 |
+
use_tools=args.use_tools,
|
| 1076 |
+
world_split=args.train_world_split,
|
| 1077 |
+
seed_start=args.heldout_seed,
|
| 1078 |
+
)
|
| 1079 |
+
except Exception as exc:
|
| 1080 |
+
print(f"GRPO in-distribution evaluation failed: {exc}")
|
| 1081 |
+
evaluation_rows = []
|
| 1082 |
+
try:
|
| 1083 |
+
heldout_evaluation_rows = evaluate_model_suite(
|
| 1084 |
+
model=model,
|
| 1085 |
+
tokenizer=tokenizer,
|
| 1086 |
+
selected_task=args.task,
|
| 1087 |
+
eval_episodes=args.eval_episodes,
|
| 1088 |
+
max_steps=args.max_steps,
|
| 1089 |
+
use_tools=args.use_tools,
|
| 1090 |
+
world_split=args.heldout_world_split,
|
| 1091 |
+
seed_start=args.heldout_seed,
|
| 1092 |
+
)
|
| 1093 |
+
except Exception as exc:
|
| 1094 |
+
print(f"GRPO held-out evaluation failed: {exc}")
|
| 1095 |
+
heldout_evaluation_rows = []
|
| 1096 |
+
|
| 1097 |
+
metrics_path = save_metrics(
|
| 1098 |
+
output_dir=output_dir,
|
| 1099 |
+
rows=train_rows,
|
| 1100 |
+
model_name=model_name,
|
| 1101 |
+
episodes=max(1, len(train_rows)),
|
| 1102 |
+
curriculum=args.curriculum,
|
| 1103 |
+
use_tools=args.use_tools,
|
| 1104 |
+
trainer="grpo",
|
| 1105 |
+
evaluation_rows=evaluation_rows,
|
| 1106 |
+
heldout_evaluation_rows=heldout_evaluation_rows,
|
| 1107 |
+
prompt_bank_size=len(prompt_bank),
|
| 1108 |
+
extra={
|
| 1109 |
+
"train_world_split": args.train_world_split,
|
| 1110 |
+
"heldout_world_split": args.heldout_world_split,
|
| 1111 |
+
"heldout_seed": args.heldout_seed,
|
| 1112 |
+
"base_model": model_name,
|
| 1113 |
+
},
|
| 1114 |
+
)
|
| 1115 |
+
if args.plot:
|
| 1116 |
+
maybe_plot(metrics_path, output_dir)
|
| 1117 |
+
print("GRPO training complete.")
|
| 1118 |
+
print(f"Prompt bank size: {len(prompt_bank)}")
|
| 1119 |
+
print("Post-train online evaluation:")
|
| 1120 |
+
for row in evaluation_rows:
|
| 1121 |
+
print(
|
| 1122 |
+
f" task={row['task']:<20} score={row['score']:.3f} "
|
| 1123 |
+
f"steps={row['steps']} tools={row['tool_calls']}"
|
| 1124 |
+
)
|
| 1125 |
+
print("Held-out family evaluation:")
|
| 1126 |
+
for row in heldout_evaluation_rows:
|
| 1127 |
+
print(
|
| 1128 |
+
f" task={row['task']:<20} score={row['score']:.3f} "
|
| 1129 |
+
f"steps={row['steps']} tools={row['tool_calls']}"
|
| 1130 |
+
)
|
| 1131 |
+
if log_history:
|
| 1132 |
+
final_keys = sorted(log_history[-1].keys())
|
| 1133 |
+
print(f"Trainer log keys: {final_keys}")
|
| 1134 |
+
print(f"Metrics saved to: {metrics_path}")
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
def _looks_like_adapter_path(model_name: str) -> bool:
|
| 1138 |
+
path = Path(str(model_name))
|
| 1139 |
+
return path.exists() and (path / "adapter_config.json").exists()
|
| 1140 |
+
|
| 1141 |
+
|
| 1142 |
+
def run_fallback_smoke(args: argparse.Namespace) -> None:
|
| 1143 |
+
if args.use_tools:
|
| 1144 |
+
run_tool_fallback_smoke(args)
|
| 1145 |
+
return
|
| 1146 |
+
|
| 1147 |
+
if args.curriculum:
|
| 1148 |
+
tasks = [
|
| 1149 |
+
task_for_episode(
|
| 1150 |
+
episode=episode,
|
| 1151 |
+
total_episodes=min(args.episodes, args.smoke_episodes),
|
| 1152 |
+
selected_task=args.task,
|
| 1153 |
+
curriculum=True,
|
| 1154 |
+
)[0]
|
| 1155 |
+
for episode in range(1, min(args.episodes, args.smoke_episodes) + 1)
|
| 1156 |
+
]
|
| 1157 |
+
else:
|
| 1158 |
+
tasks = TASKS if args.task == "all" else [args.task]
|
| 1159 |
+
|
| 1160 |
+
rows = run_smoke_training(
|
| 1161 |
+
tasks=tasks,
|
| 1162 |
+
episodes=min(args.episodes, args.smoke_episodes),
|
| 1163 |
+
output=Path(args.output) / "train_smoke.csv",
|
| 1164 |
+
seed=args.seed,
|
| 1165 |
+
epsilon=0.85,
|
| 1166 |
+
epsilon_decay=0.94,
|
| 1167 |
+
epsilon_floor=0.08,
|
| 1168 |
+
lr=0.35,
|
| 1169 |
+
max_steps=args.max_steps,
|
| 1170 |
+
)
|
| 1171 |
+
output_dir = Path(args.output)
|
| 1172 |
+
metrics_rows = []
|
| 1173 |
+
for row in rows:
|
| 1174 |
+
row = dict(row)
|
| 1175 |
+
episode = int(row["episode"])
|
| 1176 |
+
_, stage = task_for_episode(
|
| 1177 |
+
episode=episode,
|
| 1178 |
+
total_episodes=min(args.episodes, args.smoke_episodes),
|
| 1179 |
+
selected_task=args.task,
|
| 1180 |
+
curriculum=args.curriculum,
|
| 1181 |
+
)
|
| 1182 |
+
row["stage"] = stage
|
| 1183 |
+
metrics_rows.append(row)
|
| 1184 |
+
|
| 1185 |
+
metrics_path = save_metrics(
|
| 1186 |
+
output_dir=output_dir,
|
| 1187 |
+
rows=metrics_rows,
|
| 1188 |
+
model_name="smoke-tabular-policy",
|
| 1189 |
+
episodes=min(args.episodes, args.smoke_episodes),
|
| 1190 |
+
curriculum=args.curriculum,
|
| 1191 |
+
use_tools=False,
|
| 1192 |
+
)
|
| 1193 |
+
print(f"Metrics saved to: {metrics_path}")
|
| 1194 |
+
if args.plot:
|
| 1195 |
+
maybe_plot(metrics_path, output_dir)
|
| 1196 |
+
|
| 1197 |
+
|
| 1198 |
+
def run_tool_fallback_smoke(args: argparse.Namespace) -> None:
|
| 1199 |
+
"""No-GPU tool-aware rehearsal. This validates flow, not model learning."""
|
| 1200 |
+
from tool_baseline import run_task as run_tool_task
|
| 1201 |
+
|
| 1202 |
+
total = min(args.episodes, args.smoke_episodes)
|
| 1203 |
+
if args.curriculum:
|
| 1204 |
+
tasks = [
|
| 1205 |
+
task_for_episode(
|
| 1206 |
+
episode=episode,
|
| 1207 |
+
total_episodes=total,
|
| 1208 |
+
selected_task=args.task,
|
| 1209 |
+
curriculum=True,
|
| 1210 |
+
)[0]
|
| 1211 |
+
for episode in range(1, total + 1)
|
| 1212 |
+
]
|
| 1213 |
+
else:
|
| 1214 |
+
tasks = TASKS if args.task == "all" else [args.task]
|
| 1215 |
+
|
| 1216 |
+
print("AdaptShield tool-aware smoke evaluation")
|
| 1217 |
+
print("Mode: no-GPU flow validation, not model learning")
|
| 1218 |
+
print(f"Tasks: {', '.join(tasks)}")
|
| 1219 |
+
print(f"Episodes: {total}")
|
| 1220 |
+
print()
|
| 1221 |
+
|
| 1222 |
+
rows: List[Dict[str, Any]] = []
|
| 1223 |
+
for episode in range(1, total + 1):
|
| 1224 |
+
task = tasks[(episode - 1) % len(tasks)]
|
| 1225 |
+
result = run_tool_task(task, emit_logs=False)
|
| 1226 |
+
metadata = result.get("metadata", {})
|
| 1227 |
+
tool_calls = len(metadata.get("tool_trace", [])) if isinstance(metadata, dict) else 0
|
| 1228 |
+
_, stage = task_for_episode(
|
| 1229 |
+
episode=episode,
|
| 1230 |
+
total_episodes=total,
|
| 1231 |
+
selected_task=args.task,
|
| 1232 |
+
curriculum=args.curriculum,
|
| 1233 |
+
)
|
| 1234 |
+
row = {
|
| 1235 |
+
"episode": episode,
|
| 1236 |
+
"task": task,
|
| 1237 |
+
"stage": stage,
|
| 1238 |
+
"score": result["score"],
|
| 1239 |
+
"steps": result["steps"],
|
| 1240 |
+
"reward_sum": sum(result["rewards"]),
|
| 1241 |
+
"mean_reward": sum(result["rewards"]) / len(result["rewards"]) if result["rewards"] else 0.0,
|
| 1242 |
+
"tool_calls": tool_calls,
|
| 1243 |
+
"status": "PASS" if result["success"] else "FAIL",
|
| 1244 |
+
}
|
| 1245 |
+
rows.append(row)
|
| 1246 |
+
print(
|
| 1247 |
+
f"episode={episode:03d} task={task:<20} "
|
| 1248 |
+
f"score={row['score']:.3f} steps={row['steps']:02d} "
|
| 1249 |
+
f"tools={tool_calls:02d} {row['status']}"
|
| 1250 |
+
)
|
| 1251 |
+
|
| 1252 |
+
output_dir = Path(args.output)
|
| 1253 |
+
metrics_path = save_metrics(
|
| 1254 |
+
output_dir=output_dir,
|
| 1255 |
+
rows=rows,
|
| 1256 |
+
model_name="tool-aware-smoke-policy",
|
| 1257 |
+
episodes=total,
|
| 1258 |
+
curriculum=args.curriculum,
|
| 1259 |
+
use_tools=True,
|
| 1260 |
+
)
|
| 1261 |
+
print(f"Metrics saved to: {metrics_path}")
|
| 1262 |
+
if args.plot:
|
| 1263 |
+
maybe_plot(metrics_path, output_dir)
|
| 1264 |
+
|
| 1265 |
+
|
| 1266 |
+
def parse_args() -> argparse.Namespace:
|
| 1267 |
+
parser = argparse.ArgumentParser(description="AdaptShield training harness.")
|
| 1268 |
+
parser.add_argument("--task", default="direct-triage", choices=TASKS + ["all"])
|
| 1269 |
+
parser.add_argument("--model", default=DEFAULT_MODEL, choices=list(MODEL_CHOICES))
|
| 1270 |
+
parser.add_argument("--model-path", default="", help="Optional local/HF adapter path to continue training from.")
|
| 1271 |
+
parser.add_argument("--episodes", type=int, default=60)
|
| 1272 |
+
parser.add_argument("--max-steps", type=int, default=30)
|
| 1273 |
+
parser.add_argument("--output", default="checkpoints/adaptshield")
|
| 1274 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 1275 |
+
parser.add_argument("--lr", type=float, default=1e-5)
|
| 1276 |
+
parser.add_argument("--save-every", type=int, default=20)
|
| 1277 |
+
parser.add_argument("--smoke", action="store_true", help="Force dependency-free smoke mode.")
|
| 1278 |
+
parser.add_argument("--smoke-episodes", type=int, default=30)
|
| 1279 |
+
parser.add_argument("--curriculum", action="store_true", help="Train direct -> dual -> hard instead of fixed/round-robin tasks.")
|
| 1280 |
+
parser.add_argument("--use-tools", action="store_true", help="Let GPU training query SOC tools before hard-task actions.")
|
| 1281 |
+
parser.add_argument("--plot", action="store_true", help="Generate reward_curve.png from metrics.json after training.")
|
| 1282 |
+
parser.add_argument("--trainer", default="auto", choices=["auto", "pg", "grpo"], help="Training backend: safe policy-gradient fallback or TRL GRPO.")
|
| 1283 |
+
parser.add_argument("--prompt-bank-episodes", type=int, default=24, help="Reference rollout episodes used to build the GRPO prompt bank.")
|
| 1284 |
+
parser.add_argument("--prompt-bank-hard-multiplier", type=int, default=2, help="Duplicate hard-task GRPO prompts this many times to emphasize difficult slices.")
|
| 1285 |
+
parser.add_argument("--prompt-bank-borderline-bonus", type=int, default=1, help="Extra copies for degraded-handoff / borderline GRPO prompts.")
|
| 1286 |
+
parser.add_argument("--grpo-epochs", type=int, default=1, help="Number of epochs over the prompt bank for GRPO runs.")
|
| 1287 |
+
parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt when TRL path is active.")
|
| 1288 |
+
parser.add_argument("--per-device-batch-size", type=int, default=1, help="Per-device batch size for GRPO training.")
|
| 1289 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=4, help="Gradient accumulation for GRPO training.")
|
| 1290 |
+
parser.add_argument("--eval-episodes", type=int, default=2, help="Online environment episodes per task after GPU training.")
|
| 1291 |
+
parser.add_argument("--train-world-split", default="train", choices=["train", "eval"], help="World split used for training/prompt-bank generation.")
|
| 1292 |
+
parser.add_argument("--heldout-world-split", default="eval", choices=["train", "eval"], help="World split used for held-out evaluation.")
|
| 1293 |
+
parser.add_argument("--heldout-seed", type=int, default=314, help="Seed offset used for held-out evaluation episodes.")
|
| 1294 |
+
return parser.parse_args()
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
def main() -> int:
|
| 1298 |
+
args = parse_args()
|
| 1299 |
+
if args.smoke:
|
| 1300 |
+
run_fallback_smoke(args)
|
| 1301 |
+
return 0
|
| 1302 |
+
|
| 1303 |
+
trainer_choice = args.trainer
|
| 1304 |
+
if trainer_choice == "auto":
|
| 1305 |
+
try:
|
| 1306 |
+
import datasets # noqa: F401
|
| 1307 |
+
import trl # noqa: F401
|
| 1308 |
+
trainer_choice = "grpo"
|
| 1309 |
+
except ImportError:
|
| 1310 |
+
trainer_choice = "pg"
|
| 1311 |
+
|
| 1312 |
+
try:
|
| 1313 |
+
if trainer_choice == "grpo":
|
| 1314 |
+
train_grpo(args)
|
| 1315 |
+
else:
|
| 1316 |
+
train_policy_gradient(args)
|
| 1317 |
+
except ImportError as exc:
|
| 1318 |
+
print(f"GPU training dependency missing for trainer={trainer_choice}: {exc}")
|
| 1319 |
+
if trainer_choice == "grpo":
|
| 1320 |
+
print("Falling back to policy-gradient GPU trainer.")
|
| 1321 |
+
try:
|
| 1322 |
+
train_policy_gradient(args)
|
| 1323 |
+
return 0
|
| 1324 |
+
except ImportError as nested_exc:
|
| 1325 |
+
print(f"Policy-gradient fallback also unavailable: {nested_exc}")
|
| 1326 |
+
print("Falling back to dependency-free smoke training.")
|
| 1327 |
+
run_fallback_smoke(args)
|
| 1328 |
+
return 0
|
| 1329 |
+
|
| 1330 |
+
|
| 1331 |
+
if __name__ == "__main__":
|
| 1332 |
+
raise SystemExit(main())
|
train_sft.py
ADDED
|
@@ -0,0 +1,611 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Supervised fine-tuning for AdaptShield chat-style demonstrations."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import random
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, List
|
| 12 |
+
|
| 13 |
+
from train import (
|
| 14 |
+
DEFAULT_MODEL,
|
| 15 |
+
LORA_RANK,
|
| 16 |
+
MAX_SEQ_LEN,
|
| 17 |
+
MODEL_CHOICES,
|
| 18 |
+
_align_trainable_dtypes,
|
| 19 |
+
_filter_supported_kwargs,
|
| 20 |
+
_normalize_generation_config,
|
| 21 |
+
evaluate_model_suite,
|
| 22 |
+
run_model_episode,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def load_jsonl(path: Path) -> List[Dict[str, Any]]:
|
| 27 |
+
rows: List[Dict[str, Any]] = []
|
| 28 |
+
with path.open("r", encoding="utf-8") as handle:
|
| 29 |
+
for line in handle:
|
| 30 |
+
line = line.strip()
|
| 31 |
+
if not line:
|
| 32 |
+
continue
|
| 33 |
+
rows.append(json.loads(line))
|
| 34 |
+
if not rows:
|
| 35 |
+
raise RuntimeError(f"No training rows found in {path}")
|
| 36 |
+
return rows
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def build_loss_plot(log_history: List[Dict[str, Any]], output_path: Path) -> None:
|
| 40 |
+
try:
|
| 41 |
+
import matplotlib.pyplot as plt
|
| 42 |
+
except ImportError:
|
| 43 |
+
print("matplotlib not installed; skipping loss plot")
|
| 44 |
+
return
|
| 45 |
+
|
| 46 |
+
xs: List[int] = []
|
| 47 |
+
ys: List[float] = []
|
| 48 |
+
for index, entry in enumerate(log_history, start=1):
|
| 49 |
+
if "loss" not in entry:
|
| 50 |
+
continue
|
| 51 |
+
step = int(entry.get("step", index) or index)
|
| 52 |
+
try:
|
| 53 |
+
loss = float(entry["loss"])
|
| 54 |
+
except Exception:
|
| 55 |
+
continue
|
| 56 |
+
xs.append(step)
|
| 57 |
+
ys.append(loss)
|
| 58 |
+
|
| 59 |
+
if not xs:
|
| 60 |
+
print("No loss entries found; skipping loss plot")
|
| 61 |
+
return
|
| 62 |
+
|
| 63 |
+
plt.figure(figsize=(10, 5))
|
| 64 |
+
plt.plot(xs, ys, color="#0f4c81", linewidth=2, label="training loss")
|
| 65 |
+
plt.xlabel("Training step")
|
| 66 |
+
plt.ylabel("Loss")
|
| 67 |
+
plt.title("AdaptShield SFT Loss Curve")
|
| 68 |
+
plt.grid(alpha=0.3)
|
| 69 |
+
plt.legend()
|
| 70 |
+
plt.tight_layout()
|
| 71 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 72 |
+
plt.close()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def build_reward_plot(rows: List[Dict[str, Any]], output_path: Path) -> None:
|
| 76 |
+
try:
|
| 77 |
+
import matplotlib.pyplot as plt
|
| 78 |
+
except ImportError:
|
| 79 |
+
print("matplotlib not installed; skipping reward plot")
|
| 80 |
+
return
|
| 81 |
+
|
| 82 |
+
if not rows:
|
| 83 |
+
print("No held-out reward rows found; skipping reward plot")
|
| 84 |
+
return
|
| 85 |
+
|
| 86 |
+
checkpoint_labels = [str(row["checkpoint"]) for row in rows]
|
| 87 |
+
in_distribution_scores = [float(row["in_distribution_score"]) for row in rows]
|
| 88 |
+
heldout_scores = [float(row["heldout_score"]) for row in rows]
|
| 89 |
+
|
| 90 |
+
plt.figure(figsize=(10, 5))
|
| 91 |
+
plt.plot(
|
| 92 |
+
range(len(rows)),
|
| 93 |
+
in_distribution_scores,
|
| 94 |
+
color="#136f63",
|
| 95 |
+
linewidth=2.5,
|
| 96 |
+
marker="o",
|
| 97 |
+
label="in-distribution mean reward",
|
| 98 |
+
)
|
| 99 |
+
plt.plot(
|
| 100 |
+
range(len(rows)),
|
| 101 |
+
heldout_scores,
|
| 102 |
+
color="#8a3ffc",
|
| 103 |
+
linewidth=2.5,
|
| 104 |
+
marker="s",
|
| 105 |
+
label="held-out family mean reward",
|
| 106 |
+
)
|
| 107 |
+
plt.xticks(range(len(rows)), checkpoint_labels, rotation=35, ha="right")
|
| 108 |
+
plt.xlabel("Checkpoint")
|
| 109 |
+
plt.ylabel("normalized_score")
|
| 110 |
+
plt.title("AdaptShield In-Distribution vs Held-out Reward Curve")
|
| 111 |
+
plt.ylim(0.0, 1.0)
|
| 112 |
+
plt.grid(alpha=0.3)
|
| 113 |
+
plt.legend()
|
| 114 |
+
plt.tight_layout()
|
| 115 |
+
plt.savefig(output_path, dpi=150, bbox_inches="tight")
|
| 116 |
+
plt.close()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def render_example(example: Dict[str, Any], tokenizer: Any) -> str:
|
| 120 |
+
if "messages" in example:
|
| 121 |
+
return tokenizer.apply_chat_template(
|
| 122 |
+
example["messages"],
|
| 123 |
+
tokenize=False,
|
| 124 |
+
add_generation_prompt=False,
|
| 125 |
+
)
|
| 126 |
+
return str(example["text"])
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _checkpoint_sort_key(path: Path) -> tuple[int, str]:
|
| 130 |
+
if path.name == "final":
|
| 131 |
+
return (10**9, path.name)
|
| 132 |
+
if path.name.startswith("checkpoint-"):
|
| 133 |
+
try:
|
| 134 |
+
return (int(path.name.split("-", 1)[1]), path.name)
|
| 135 |
+
except Exception:
|
| 136 |
+
return (10**8, path.name)
|
| 137 |
+
return (10**7, path.name)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def checkpoint_dirs(output_dir: Path) -> List[Path]:
|
| 141 |
+
checkpoints = [
|
| 142 |
+
path for path in output_dir.iterdir()
|
| 143 |
+
if path.is_dir() and (path.name.startswith("checkpoint-") or path.name == "final")
|
| 144 |
+
]
|
| 145 |
+
return sorted(checkpoints, key=_checkpoint_sort_key)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def evaluate_suite_with_seed(
|
| 149 |
+
model: Any,
|
| 150 |
+
tokenizer: Any,
|
| 151 |
+
selected_task: str,
|
| 152 |
+
eval_episodes: int,
|
| 153 |
+
max_steps: int,
|
| 154 |
+
use_tools: bool,
|
| 155 |
+
seed_start: int,
|
| 156 |
+
world_split: str,
|
| 157 |
+
world_family: str | None = None,
|
| 158 |
+
) -> List[Dict[str, Any]]:
|
| 159 |
+
tasks = ["direct-triage", "dual-pivot", "polymorphic-zero-day"] if selected_task == "all" else [selected_task]
|
| 160 |
+
rows: List[Dict[str, Any]] = []
|
| 161 |
+
original_seed = os.environ.get("ADAPTSHIELD_SEED")
|
| 162 |
+
try:
|
| 163 |
+
for task_index, task in enumerate(tasks):
|
| 164 |
+
scores: List[float] = []
|
| 165 |
+
steps: List[int] = []
|
| 166 |
+
tool_calls: List[int] = []
|
| 167 |
+
for episode_index in range(eval_episodes):
|
| 168 |
+
os.environ["ADAPTSHIELD_SEED"] = str(seed_start + task_index * 100 + episode_index)
|
| 169 |
+
try:
|
| 170 |
+
_, metrics = run_model_episode(
|
| 171 |
+
model=model,
|
| 172 |
+
tokenizer=tokenizer,
|
| 173 |
+
task=task,
|
| 174 |
+
max_steps=max_steps,
|
| 175 |
+
use_tools=use_tools,
|
| 176 |
+
world_split=world_split,
|
| 177 |
+
world_family=world_family,
|
| 178 |
+
)
|
| 179 |
+
except Exception as exc:
|
| 180 |
+
print(f" eval episode failed (task={task}, ep={episode_index}): {exc}")
|
| 181 |
+
continue
|
| 182 |
+
scores.append(float(metrics["score"]))
|
| 183 |
+
steps.append(int(metrics["steps"]))
|
| 184 |
+
tool_calls.append(int(metrics["tool_calls"]))
|
| 185 |
+
rows.append({
|
| 186 |
+
"task": task,
|
| 187 |
+
"score": round(sum(scores) / len(scores), 3) if scores else 0.50,
|
| 188 |
+
"steps": round(sum(steps) / len(steps), 2) if steps else 0.0,
|
| 189 |
+
"tool_calls": round(sum(tool_calls) / len(tool_calls), 2) if tool_calls else 0.0,
|
| 190 |
+
"eval_episodes": eval_episodes,
|
| 191 |
+
"successful_episodes": len(scores),
|
| 192 |
+
"seed_start": seed_start,
|
| 193 |
+
"world_split": world_split,
|
| 194 |
+
"world_family": world_family or "auto",
|
| 195 |
+
})
|
| 196 |
+
finally:
|
| 197 |
+
if original_seed is None:
|
| 198 |
+
os.environ.pop("ADAPTSHIELD_SEED", None)
|
| 199 |
+
else:
|
| 200 |
+
os.environ["ADAPTSHIELD_SEED"] = original_seed
|
| 201 |
+
return rows
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _free_gpu(*objects: Any) -> None:
|
| 205 |
+
"""Best-effort release of GPU memory between checkpoint evaluations."""
|
| 206 |
+
import gc
|
| 207 |
+
|
| 208 |
+
for obj in objects:
|
| 209 |
+
try:
|
| 210 |
+
del obj
|
| 211 |
+
except Exception:
|
| 212 |
+
pass
|
| 213 |
+
gc.collect()
|
| 214 |
+
try:
|
| 215 |
+
import torch
|
| 216 |
+
|
| 217 |
+
if torch.cuda.is_available():
|
| 218 |
+
torch.cuda.empty_cache()
|
| 219 |
+
torch.cuda.synchronize()
|
| 220 |
+
except Exception:
|
| 221 |
+
pass
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def _load_checkpoint_for_eval(
|
| 225 |
+
checkpoint_dir: Path,
|
| 226 |
+
base_model_name: str,
|
| 227 |
+
max_seq_length: int,
|
| 228 |
+
) -> tuple[Any, Any]:
|
| 229 |
+
"""Load an adapter checkpoint robustly, falling back to PEFT if needed."""
|
| 230 |
+
from unsloth import FastLanguageModel
|
| 231 |
+
|
| 232 |
+
is_adapter_only = (checkpoint_dir / "adapter_config.json").exists() and not (
|
| 233 |
+
checkpoint_dir / "config.json"
|
| 234 |
+
).exists()
|
| 235 |
+
|
| 236 |
+
if not is_adapter_only:
|
| 237 |
+
try:
|
| 238 |
+
return FastLanguageModel.from_pretrained(
|
| 239 |
+
model_name=str(checkpoint_dir),
|
| 240 |
+
max_seq_length=max_seq_length,
|
| 241 |
+
load_in_4bit=True,
|
| 242 |
+
dtype=None,
|
| 243 |
+
)
|
| 244 |
+
except Exception as exc:
|
| 245 |
+
print(f" direct load failed for {checkpoint_dir.name}: {exc}; "
|
| 246 |
+
"falling back to base+adapter loader.")
|
| 247 |
+
|
| 248 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 249 |
+
model_name=base_model_name,
|
| 250 |
+
max_seq_length=max_seq_length,
|
| 251 |
+
load_in_4bit=True,
|
| 252 |
+
dtype=None,
|
| 253 |
+
)
|
| 254 |
+
from peft import PeftModel
|
| 255 |
+
|
| 256 |
+
model = PeftModel.from_pretrained(
|
| 257 |
+
model,
|
| 258 |
+
str(checkpoint_dir),
|
| 259 |
+
is_trainable=False,
|
| 260 |
+
autocast_adapter_dtype=False,
|
| 261 |
+
)
|
| 262 |
+
try:
|
| 263 |
+
from transformers import AutoTokenizer
|
| 264 |
+
|
| 265 |
+
tokenizer = AutoTokenizer.from_pretrained(str(checkpoint_dir), trust_remote_code=True)
|
| 266 |
+
except Exception:
|
| 267 |
+
pass
|
| 268 |
+
return model, tokenizer
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def evaluate_saved_checkpoints(
|
| 272 |
+
output_dir: Path,
|
| 273 |
+
model_key: str,
|
| 274 |
+
max_seq_length: int,
|
| 275 |
+
selected_task: str,
|
| 276 |
+
eval_episodes: int,
|
| 277 |
+
max_steps: int,
|
| 278 |
+
use_tools: bool,
|
| 279 |
+
heldout_seed: int,
|
| 280 |
+
train_world_split: str,
|
| 281 |
+
heldout_world_split: str,
|
| 282 |
+
) -> List[Dict[str, Any]]:
|
| 283 |
+
base_model_name = MODEL_CHOICES[model_key]
|
| 284 |
+
rows: List[Dict[str, Any]] = []
|
| 285 |
+
for index, checkpoint_dir in enumerate(checkpoint_dirs(output_dir)):
|
| 286 |
+
print(f"Held-out evaluating checkpoint: {checkpoint_dir.name}")
|
| 287 |
+
model = None
|
| 288 |
+
tokenizer = None
|
| 289 |
+
try:
|
| 290 |
+
model, tokenizer = _load_checkpoint_for_eval(
|
| 291 |
+
checkpoint_dir=checkpoint_dir,
|
| 292 |
+
base_model_name=base_model_name,
|
| 293 |
+
max_seq_length=max_seq_length,
|
| 294 |
+
)
|
| 295 |
+
_normalize_generation_config(model)
|
| 296 |
+
_align_trainable_dtypes(model)
|
| 297 |
+
in_distribution_rows = evaluate_suite_with_seed(
|
| 298 |
+
model=model,
|
| 299 |
+
tokenizer=tokenizer,
|
| 300 |
+
selected_task=selected_task,
|
| 301 |
+
eval_episodes=eval_episodes,
|
| 302 |
+
max_steps=max_steps,
|
| 303 |
+
use_tools=use_tools,
|
| 304 |
+
seed_start=heldout_seed + index * 1000,
|
| 305 |
+
world_split=train_world_split,
|
| 306 |
+
)
|
| 307 |
+
heldout_rows = evaluate_suite_with_seed(
|
| 308 |
+
model=model,
|
| 309 |
+
tokenizer=tokenizer,
|
| 310 |
+
selected_task=selected_task,
|
| 311 |
+
eval_episodes=eval_episodes,
|
| 312 |
+
max_steps=max_steps,
|
| 313 |
+
use_tools=use_tools,
|
| 314 |
+
seed_start=heldout_seed + index * 1000,
|
| 315 |
+
world_split=heldout_world_split,
|
| 316 |
+
)
|
| 317 |
+
in_distribution_score = round(
|
| 318 |
+
sum(float(row["score"]) for row in in_distribution_rows) / max(1, len(in_distribution_rows)),
|
| 319 |
+
3,
|
| 320 |
+
)
|
| 321 |
+
heldout_score = round(
|
| 322 |
+
sum(float(row["score"]) for row in heldout_rows) / max(1, len(heldout_rows)),
|
| 323 |
+
3,
|
| 324 |
+
)
|
| 325 |
+
rows.append({
|
| 326 |
+
"checkpoint": checkpoint_dir.name,
|
| 327 |
+
"in_distribution_score": in_distribution_score,
|
| 328 |
+
"heldout_score": heldout_score,
|
| 329 |
+
"in_distribution_rows": in_distribution_rows,
|
| 330 |
+
"heldout_rows": heldout_rows,
|
| 331 |
+
})
|
| 332 |
+
except Exception as exc:
|
| 333 |
+
print(f" checkpoint eval failed for {checkpoint_dir.name}: {exc}")
|
| 334 |
+
rows.append({
|
| 335 |
+
"checkpoint": checkpoint_dir.name,
|
| 336 |
+
"in_distribution_score": 0.0,
|
| 337 |
+
"heldout_score": 0.0,
|
| 338 |
+
"error": str(exc),
|
| 339 |
+
})
|
| 340 |
+
finally:
|
| 341 |
+
_free_gpu(model, tokenizer)
|
| 342 |
+
model = None
|
| 343 |
+
tokenizer = None
|
| 344 |
+
return rows
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def train_sft(args: argparse.Namespace) -> None:
|
| 348 |
+
from unsloth import FastLanguageModel
|
| 349 |
+
from datasets import Dataset
|
| 350 |
+
from trl import SFTTrainer
|
| 351 |
+
import torch
|
| 352 |
+
|
| 353 |
+
random.seed(args.seed)
|
| 354 |
+
torch.manual_seed(args.seed)
|
| 355 |
+
|
| 356 |
+
dataset_path = Path(args.dataset)
|
| 357 |
+
rows = load_jsonl(dataset_path)
|
| 358 |
+
if args.max_rows and args.max_rows > 0:
|
| 359 |
+
rows = rows[: args.max_rows]
|
| 360 |
+
|
| 361 |
+
model_name = MODEL_CHOICES[args.model]
|
| 362 |
+
output_dir = Path(args.output)
|
| 363 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 364 |
+
|
| 365 |
+
print("AdaptShield SFT training")
|
| 366 |
+
print(f"Dataset: {dataset_path}")
|
| 367 |
+
print(f"Rows: {len(rows)}")
|
| 368 |
+
print(f"Model: {model_name}")
|
| 369 |
+
print(f"Epochs: {args.epochs}")
|
| 370 |
+
print(f"Batch size: {args.per_device_batch_size}")
|
| 371 |
+
print(f"Grad accumulation: {args.gradient_accumulation_steps}")
|
| 372 |
+
print(f"Learning rate: {args.lr}")
|
| 373 |
+
print(f"Output: {output_dir}")
|
| 374 |
+
print()
|
| 375 |
+
|
| 376 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 377 |
+
model_name=model_name,
|
| 378 |
+
max_seq_length=args.max_seq_length,
|
| 379 |
+
load_in_4bit=True,
|
| 380 |
+
dtype=None,
|
| 381 |
+
)
|
| 382 |
+
model = FastLanguageModel.get_peft_model(
|
| 383 |
+
model,
|
| 384 |
+
r=LORA_RANK,
|
| 385 |
+
target_modules=[
|
| 386 |
+
"q_proj", "k_proj", "v_proj", "o_proj",
|
| 387 |
+
"gate_proj", "up_proj", "down_proj",
|
| 388 |
+
],
|
| 389 |
+
lora_alpha=LORA_RANK * 2,
|
| 390 |
+
lora_dropout=0.0,
|
| 391 |
+
bias="none",
|
| 392 |
+
use_gradient_checkpointing="unsloth",
|
| 393 |
+
random_state=args.seed,
|
| 394 |
+
)
|
| 395 |
+
if getattr(tokenizer, "pad_token", None) is None:
|
| 396 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 397 |
+
_normalize_generation_config(model)
|
| 398 |
+
_align_trainable_dtypes(model)
|
| 399 |
+
|
| 400 |
+
prepared_rows = [{"text": render_example(row, tokenizer), **row} for row in rows]
|
| 401 |
+
dataset = Dataset.from_list(prepared_rows)
|
| 402 |
+
|
| 403 |
+
bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
|
| 404 |
+
|
| 405 |
+
try:
|
| 406 |
+
from trl import SFTConfig
|
| 407 |
+
train_config_cls = SFTConfig
|
| 408 |
+
except ImportError:
|
| 409 |
+
from transformers import TrainingArguments
|
| 410 |
+
train_config_cls = TrainingArguments
|
| 411 |
+
|
| 412 |
+
config_kwargs = {
|
| 413 |
+
"output_dir": str(output_dir),
|
| 414 |
+
"learning_rate": args.lr,
|
| 415 |
+
"per_device_train_batch_size": args.per_device_batch_size,
|
| 416 |
+
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
| 417 |
+
"num_train_epochs": args.epochs,
|
| 418 |
+
"logging_steps": 1,
|
| 419 |
+
"save_strategy": "steps",
|
| 420 |
+
"save_steps": args.save_steps,
|
| 421 |
+
"report_to": "none",
|
| 422 |
+
"seed": args.seed,
|
| 423 |
+
"bf16": bf16_supported,
|
| 424 |
+
"fp16": not bf16_supported,
|
| 425 |
+
"max_seq_length": args.max_seq_length,
|
| 426 |
+
"dataset_text_field": "text",
|
| 427 |
+
"dataset_num_proc": 1,
|
| 428 |
+
"packing": False,
|
| 429 |
+
}
|
| 430 |
+
train_args = train_config_cls(
|
| 431 |
+
**_filter_supported_kwargs(train_config_cls, config_kwargs)
|
| 432 |
+
)
|
| 433 |
+
|
| 434 |
+
trainer_kwargs = {
|
| 435 |
+
"model": model,
|
| 436 |
+
"train_dataset": dataset,
|
| 437 |
+
"args": train_args,
|
| 438 |
+
"processing_class": tokenizer,
|
| 439 |
+
"tokenizer": tokenizer,
|
| 440 |
+
"dataset_text_field": "text",
|
| 441 |
+
"dataset_num_proc": 1,
|
| 442 |
+
"max_seq_length": args.max_seq_length,
|
| 443 |
+
"packing": False,
|
| 444 |
+
}
|
| 445 |
+
trainer = SFTTrainer(**_filter_supported_kwargs(SFTTrainer, trainer_kwargs))
|
| 446 |
+
trainer.train()
|
| 447 |
+
|
| 448 |
+
final_dir = output_dir / "final"
|
| 449 |
+
model.save_pretrained(final_dir)
|
| 450 |
+
tokenizer.save_pretrained(final_dir)
|
| 451 |
+
|
| 452 |
+
log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
|
| 453 |
+
loss_plot_path = output_dir / "loss_curve.png"
|
| 454 |
+
try:
|
| 455 |
+
build_loss_plot(log_history, loss_plot_path)
|
| 456 |
+
except Exception as exc:
|
| 457 |
+
print(f"Loss plot generation skipped: {exc}")
|
| 458 |
+
|
| 459 |
+
metrics: Dict[str, Any] = {
|
| 460 |
+
"trainer": "sft",
|
| 461 |
+
"model": model_name,
|
| 462 |
+
"dataset": str(dataset_path),
|
| 463 |
+
"rows": len(rows),
|
| 464 |
+
"epochs": args.epochs,
|
| 465 |
+
"learning_rate": args.lr,
|
| 466 |
+
"evaluation_rows": [],
|
| 467 |
+
"heldout_evaluation_rows": [],
|
| 468 |
+
"heldout_seed": args.heldout_seed,
|
| 469 |
+
"train_world_split": args.train_world_split,
|
| 470 |
+
"heldout_world_split": args.heldout_world_split,
|
| 471 |
+
"reward_curve_rows": [],
|
| 472 |
+
"log_history": log_history,
|
| 473 |
+
}
|
| 474 |
+
metrics_path = output_dir / "sft_metrics.json"
|
| 475 |
+
|
| 476 |
+
def _flush_metrics() -> None:
|
| 477 |
+
metrics_path.write_text(json.dumps(metrics, indent=2), encoding="utf-8")
|
| 478 |
+
|
| 479 |
+
_flush_metrics()
|
| 480 |
+
|
| 481 |
+
try:
|
| 482 |
+
metrics["evaluation_rows"] = evaluate_suite_with_seed(
|
| 483 |
+
model=model,
|
| 484 |
+
tokenizer=tokenizer,
|
| 485 |
+
selected_task=args.eval_task,
|
| 486 |
+
eval_episodes=args.eval_episodes,
|
| 487 |
+
max_steps=args.eval_max_steps,
|
| 488 |
+
use_tools=args.use_tools,
|
| 489 |
+
seed_start=args.heldout_seed,
|
| 490 |
+
world_split=args.train_world_split,
|
| 491 |
+
)
|
| 492 |
+
except Exception as exc:
|
| 493 |
+
print(f"In-distribution evaluation failed: {exc}")
|
| 494 |
+
_flush_metrics()
|
| 495 |
+
|
| 496 |
+
try:
|
| 497 |
+
metrics["heldout_evaluation_rows"] = evaluate_suite_with_seed(
|
| 498 |
+
model=model,
|
| 499 |
+
tokenizer=tokenizer,
|
| 500 |
+
selected_task=args.eval_task,
|
| 501 |
+
eval_episodes=args.eval_episodes,
|
| 502 |
+
max_steps=args.eval_max_steps,
|
| 503 |
+
use_tools=args.use_tools,
|
| 504 |
+
seed_start=args.heldout_seed,
|
| 505 |
+
world_split=args.heldout_world_split,
|
| 506 |
+
)
|
| 507 |
+
except Exception as exc:
|
| 508 |
+
print(f"Held-out evaluation failed: {exc}")
|
| 509 |
+
_flush_metrics()
|
| 510 |
+
|
| 511 |
+
reward_curve_rows: List[Dict[str, Any]] = []
|
| 512 |
+
if args.skip_reward_curve:
|
| 513 |
+
print("Skipping per-checkpoint reward curve (--skip-reward-curve).")
|
| 514 |
+
else:
|
| 515 |
+
# Free training-time model before reloading checkpoints to avoid OOM.
|
| 516 |
+
_free_gpu(model, trainer)
|
| 517 |
+
try:
|
| 518 |
+
reward_curve_rows = evaluate_saved_checkpoints(
|
| 519 |
+
output_dir=output_dir,
|
| 520 |
+
model_key=args.model,
|
| 521 |
+
max_seq_length=args.max_seq_length,
|
| 522 |
+
selected_task=args.eval_task,
|
| 523 |
+
eval_episodes=args.eval_episodes,
|
| 524 |
+
max_steps=args.eval_max_steps,
|
| 525 |
+
use_tools=args.use_tools,
|
| 526 |
+
heldout_seed=args.heldout_seed,
|
| 527 |
+
train_world_split=args.train_world_split,
|
| 528 |
+
heldout_world_split=args.heldout_world_split,
|
| 529 |
+
)
|
| 530 |
+
except Exception as exc:
|
| 531 |
+
print(f"Per-checkpoint reward curve failed: {exc}")
|
| 532 |
+
metrics["reward_curve_rows"] = reward_curve_rows
|
| 533 |
+
_flush_metrics()
|
| 534 |
+
|
| 535 |
+
reward_plot_path = output_dir / "reward_curve.png"
|
| 536 |
+
if reward_curve_rows:
|
| 537 |
+
try:
|
| 538 |
+
build_reward_plot(reward_curve_rows, reward_plot_path)
|
| 539 |
+
except Exception as exc:
|
| 540 |
+
print(f"Reward plot generation skipped: {exc}")
|
| 541 |
+
|
| 542 |
+
evaluation_rows = metrics["evaluation_rows"]
|
| 543 |
+
heldout_evaluation_rows = metrics["heldout_evaluation_rows"]
|
| 544 |
+
|
| 545 |
+
print("SFT complete.")
|
| 546 |
+
print(f"Saved adapter to: {final_dir}")
|
| 547 |
+
print(f"Loss curve: {loss_plot_path}")
|
| 548 |
+
print(f"Reward curve: {reward_plot_path}")
|
| 549 |
+
print(f"Metrics: {metrics_path}")
|
| 550 |
+
print("Post-train evaluation:")
|
| 551 |
+
for row in evaluation_rows:
|
| 552 |
+
print(
|
| 553 |
+
f" task={row['task']:<20} score={row['score']:.3f} "
|
| 554 |
+
f"steps={row['steps']} tools={row['tool_calls']}"
|
| 555 |
+
)
|
| 556 |
+
print("Held-out checkpoint reward curve:")
|
| 557 |
+
for row in reward_curve_rows:
|
| 558 |
+
print(
|
| 559 |
+
f" checkpoint={row['checkpoint']:<16} "
|
| 560 |
+
f"in_dist={row['in_distribution_score']:.3f} "
|
| 561 |
+
f"heldout={row['heldout_score']:.3f}"
|
| 562 |
+
)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
def main() -> None:
|
| 566 |
+
parser = argparse.ArgumentParser(description="AdaptShield supervised fine-tuning")
|
| 567 |
+
parser.add_argument(
|
| 568 |
+
"--dataset",
|
| 569 |
+
default="data/adaptshield_sft.jsonl",
|
| 570 |
+
help="Path to JSONL dataset from generate_sft_data.py",
|
| 571 |
+
)
|
| 572 |
+
parser.add_argument(
|
| 573 |
+
"--model",
|
| 574 |
+
default=DEFAULT_MODEL,
|
| 575 |
+
choices=list(MODEL_CHOICES.keys()),
|
| 576 |
+
)
|
| 577 |
+
parser.add_argument("--output", default="checkpoints/sft-run")
|
| 578 |
+
parser.add_argument("--epochs", type=float, default=1.0)
|
| 579 |
+
parser.add_argument("--lr", type=float, default=2e-4)
|
| 580 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 581 |
+
parser.add_argument("--heldout-seed", type=int, default=314)
|
| 582 |
+
parser.add_argument("--train-world-split", default="train", choices=["train", "eval"])
|
| 583 |
+
parser.add_argument("--heldout-world-split", default="eval", choices=["train", "eval"])
|
| 584 |
+
parser.add_argument("--max-rows", type=int, default=0)
|
| 585 |
+
parser.add_argument("--max-seq-length", type=int, default=MAX_SEQ_LEN)
|
| 586 |
+
parser.add_argument("--per-device-batch-size", type=int, default=2)
|
| 587 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=4)
|
| 588 |
+
parser.add_argument("--save-steps", type=int, default=40)
|
| 589 |
+
parser.add_argument(
|
| 590 |
+
"--eval-task",
|
| 591 |
+
default="all",
|
| 592 |
+
choices=["all", "direct-triage", "dual-pivot", "polymorphic-zero-day"],
|
| 593 |
+
)
|
| 594 |
+
parser.add_argument("--eval-episodes", type=int, default=2)
|
| 595 |
+
parser.add_argument("--eval-max-steps", type=int, default=20)
|
| 596 |
+
parser.add_argument(
|
| 597 |
+
"--use-tools",
|
| 598 |
+
action="store_true",
|
| 599 |
+
help="Use SOC tools during post-train evaluation.",
|
| 600 |
+
)
|
| 601 |
+
parser.add_argument(
|
| 602 |
+
"--skip-reward-curve",
|
| 603 |
+
action="store_true",
|
| 604 |
+
help="Skip the per-checkpoint reward curve sweep (faster, avoids OOM).",
|
| 605 |
+
)
|
| 606 |
+
args = parser.parse_args()
|
| 607 |
+
train_sft(args)
|
| 608 |
+
|
| 609 |
+
|
| 610 |
+
if __name__ == "__main__":
|
| 611 |
+
main()
|
train_smoke.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Dependency-free training-readiness smoke test for AdaptShield."""
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import csv
|
| 8 |
+
import random
|
| 9 |
+
import sys
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from typing import Any, Dict, Iterable, List, Tuple
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
REPO_ROOT = Path(__file__).resolve().parent
|
| 15 |
+
|
| 16 |
+
if str(REPO_ROOT) not in sys.path:
|
| 17 |
+
sys.path.insert(0, str(REPO_ROOT))
|
| 18 |
+
|
| 19 |
+
from models import AdaptShieldAction
|
| 20 |
+
from server.adaptshield_environment import AdaptShieldEnvironment
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
TASKS = ["direct-triage", "dual-pivot", "polymorphic-zero-day"]
|
| 24 |
+
POLICY = {
|
| 25 |
+
"brute_force": ("auth_service", "rate_limit"),
|
| 26 |
+
"lateral_movement": ("payment_service", "isolate"),
|
| 27 |
+
"exfiltration": ("database", "honeypot"),
|
| 28 |
+
"supply_chain": ("api_gateway", "patch"),
|
| 29 |
+
"benign": ("api_gateway", "monitor"),
|
| 30 |
+
}
|
| 31 |
+
ACTION_SPACE = [
|
| 32 |
+
("auth_service", "rate_limit"),
|
| 33 |
+
("payment_service", "isolate"),
|
| 34 |
+
("database", "honeypot"),
|
| 35 |
+
("api_gateway", "patch"),
|
| 36 |
+
("api_gateway", "monitor"),
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def classify_from_metrics(network_nodes: Dict[str, Dict[str, Any]]) -> str:
|
| 41 |
+
auth = network_nodes.get("auth_service", {})
|
| 42 |
+
payment = network_nodes.get("payment_service", {})
|
| 43 |
+
database = network_nodes.get("database", {})
|
| 44 |
+
gateway = network_nodes.get("api_gateway", {})
|
| 45 |
+
|
| 46 |
+
if float(auth.get("error_rate", 0.0)) >= 0.10:
|
| 47 |
+
return "brute_force"
|
| 48 |
+
if payment.get("status") == "suspicious" or float(payment.get("cpu", 0)) >= 55:
|
| 49 |
+
return "lateral_movement"
|
| 50 |
+
if float(database.get("outbound_mb", 0)) >= 50:
|
| 51 |
+
return "exfiltration"
|
| 52 |
+
if gateway.get("status") == "suspicious":
|
| 53 |
+
return "supply_chain"
|
| 54 |
+
return "benign"
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class TabularDefensePolicy:
|
| 58 |
+
"""Tiny epsilon-greedy policy used only to verify trainability."""
|
| 59 |
+
|
| 60 |
+
def __init__(self, epsilon: float, lr: float) -> None:
|
| 61 |
+
self.epsilon = epsilon
|
| 62 |
+
self.lr = lr
|
| 63 |
+
self.q: Dict[str, Dict[Tuple[str, str], float]] = {
|
| 64 |
+
threat: {action: 0.50 for action in ACTION_SPACE}
|
| 65 |
+
for threat in POLICY
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
def choose_phase1(self, obs: Any) -> Dict[str, Any]:
|
| 69 |
+
threat = classify_from_metrics(obs.network_nodes)
|
| 70 |
+
target, action = POLICY[threat]
|
| 71 |
+
return {
|
| 72 |
+
"threat_type": threat,
|
| 73 |
+
"confidence": 0.90,
|
| 74 |
+
"target_node": target,
|
| 75 |
+
"recommended_action": action,
|
| 76 |
+
"reasoning": "smoke-train metric policy",
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
def choose_phase2(self, obs: Any) -> Tuple[Dict[str, Any], str, Tuple[str, str]]:
|
| 80 |
+
assessment = obs.phase1_assessment or {}
|
| 81 |
+
threat = str(assessment.get("threat_type", "benign"))
|
| 82 |
+
choices = self.q.get(threat, self.q["benign"])
|
| 83 |
+
|
| 84 |
+
if random.random() < self.epsilon:
|
| 85 |
+
target, action = random.choice(ACTION_SPACE)
|
| 86 |
+
else:
|
| 87 |
+
best_value = max(choices.values())
|
| 88 |
+
best_actions = [
|
| 89 |
+
action for action, value in choices.items()
|
| 90 |
+
if value == best_value
|
| 91 |
+
]
|
| 92 |
+
target, action = random.choice(best_actions)
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"action": action,
|
| 96 |
+
"target_node": target,
|
| 97 |
+
"reasoning": "epsilon-greedy smoke policy",
|
| 98 |
+
}, threat, (target, action)
|
| 99 |
+
|
| 100 |
+
def update(self, threat: str, selected: Tuple[str, str], reward: float) -> None:
|
| 101 |
+
choices = self.q.setdefault(
|
| 102 |
+
threat,
|
| 103 |
+
{action: 0.50 for action in ACTION_SPACE},
|
| 104 |
+
)
|
| 105 |
+
old_value = choices.get(selected, 0.50)
|
| 106 |
+
choices[selected] = old_value + self.lr * (reward - old_value)
|
| 107 |
+
|
| 108 |
+
def decay(self, rate: float, floor: float) -> None:
|
| 109 |
+
self.epsilon = max(floor, self.epsilon * rate)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def run_episode(task: str, policy: TabularDefensePolicy, max_steps: int) -> Dict[str, Any]:
|
| 113 |
+
env = AdaptShieldEnvironment(task_name=task)
|
| 114 |
+
obs = env.reset()
|
| 115 |
+
rewards: List[float] = []
|
| 116 |
+
steps = 0
|
| 117 |
+
|
| 118 |
+
while not obs.done and steps < max_steps:
|
| 119 |
+
if obs.phase == 1:
|
| 120 |
+
payload = policy.choose_phase1(obs)
|
| 121 |
+
obs = env.step(AdaptShieldAction(**payload))
|
| 122 |
+
else:
|
| 123 |
+
payload, threat, selected = policy.choose_phase2(obs)
|
| 124 |
+
obs = env.step(AdaptShieldAction(**payload))
|
| 125 |
+
policy.update(threat, selected, float(obs.reward))
|
| 126 |
+
|
| 127 |
+
rewards.append(float(obs.reward))
|
| 128 |
+
steps += 1
|
| 129 |
+
|
| 130 |
+
metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
|
| 131 |
+
if "normalized_score" not in metadata:
|
| 132 |
+
raise RuntimeError("normalized_score missing during smoke training")
|
| 133 |
+
|
| 134 |
+
metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
|
| 135 |
+
return {
|
| 136 |
+
"task": task,
|
| 137 |
+
"score": float(metadata.get("normalized_score", 0.01)),
|
| 138 |
+
"reward_sum": sum(rewards),
|
| 139 |
+
"mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
|
| 140 |
+
"steps": steps,
|
| 141 |
+
"done": bool(obs.done),
|
| 142 |
+
"normalized_score_present": "normalized_score" in metadata,
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def write_rows(path: Path, rows: Iterable[Dict[str, Any]]) -> None:
|
| 147 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 148 |
+
rows = list(rows)
|
| 149 |
+
if not rows:
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
with path.open("w", newline="") as handle:
|
| 153 |
+
writer = csv.DictWriter(handle, fieldnames=list(rows[0].keys()))
|
| 154 |
+
writer.writeheader()
|
| 155 |
+
writer.writerows(rows)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def trend(values: List[float]) -> Tuple[float, float]:
|
| 159 |
+
if not values:
|
| 160 |
+
return 0.0, 0.0
|
| 161 |
+
window = max(1, len(values) // 5)
|
| 162 |
+
first = sum(values[:window]) / window
|
| 163 |
+
last = sum(values[-window:]) / window
|
| 164 |
+
return first, last
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def run_smoke_training(
|
| 168 |
+
tasks: List[str],
|
| 169 |
+
episodes: int,
|
| 170 |
+
output: Path,
|
| 171 |
+
seed: int,
|
| 172 |
+
epsilon: float,
|
| 173 |
+
epsilon_decay: float,
|
| 174 |
+
epsilon_floor: float,
|
| 175 |
+
lr: float,
|
| 176 |
+
max_steps: int,
|
| 177 |
+
) -> List[Dict[str, Any]]:
|
| 178 |
+
random.seed(seed)
|
| 179 |
+
policy = TabularDefensePolicy(epsilon=epsilon, lr=lr)
|
| 180 |
+
rows: List[Dict[str, Any]] = []
|
| 181 |
+
|
| 182 |
+
print("AdaptShield smoke training")
|
| 183 |
+
print(f"Tasks: {', '.join(tasks)}")
|
| 184 |
+
print(f"Episodes: {episodes}")
|
| 185 |
+
print(f"Output: {output}")
|
| 186 |
+
print()
|
| 187 |
+
|
| 188 |
+
for episode in range(1, episodes + 1):
|
| 189 |
+
task = tasks[(episode - 1) % len(tasks)]
|
| 190 |
+
result = run_episode(task=task, policy=policy, max_steps=max_steps)
|
| 191 |
+
result.update({
|
| 192 |
+
"episode": episode,
|
| 193 |
+
"epsilon": round(policy.epsilon, 4),
|
| 194 |
+
"status": "PASS" if result["done"] and result["normalized_score_present"] else "FAIL",
|
| 195 |
+
})
|
| 196 |
+
rows.append(result)
|
| 197 |
+
policy.decay(epsilon_decay, epsilon_floor)
|
| 198 |
+
|
| 199 |
+
print(
|
| 200 |
+
f"episode={episode:03d} task={task:<20} "
|
| 201 |
+
f"score={result['score']:.3f} steps={result['steps']:02d} "
|
| 202 |
+
f"epsilon={result['epsilon']:.3f} {result['status']}"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
write_rows(output, rows)
|
| 206 |
+
|
| 207 |
+
scores = [float(row["score"]) for row in rows]
|
| 208 |
+
first, last = trend(scores)
|
| 209 |
+
print()
|
| 210 |
+
print(f"First-window avg score: {first:.3f}")
|
| 211 |
+
print(f"Last-window avg score: {last:.3f}")
|
| 212 |
+
print(f"Score delta: {last - first:+.3f}")
|
| 213 |
+
print(f"Saved CSV: {output}")
|
| 214 |
+
print("Smoke training verdict: PASS")
|
| 215 |
+
return rows
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def parse_args() -> argparse.Namespace:
|
| 219 |
+
parser = argparse.ArgumentParser(description="Run cheap AdaptShield training smoke test.")
|
| 220 |
+
parser.add_argument("--task", default="direct-triage", choices=TASKS + ["all"])
|
| 221 |
+
parser.add_argument("--episodes", type=int, default=30)
|
| 222 |
+
parser.add_argument("--output", default="training_runs/train_smoke.csv")
|
| 223 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 224 |
+
parser.add_argument("--epsilon", type=float, default=0.85)
|
| 225 |
+
parser.add_argument("--epsilon-decay", type=float, default=0.94)
|
| 226 |
+
parser.add_argument("--epsilon-floor", type=float, default=0.08)
|
| 227 |
+
parser.add_argument("--lr", type=float, default=0.35)
|
| 228 |
+
parser.add_argument("--max-steps", type=int, default=30)
|
| 229 |
+
return parser.parse_args()
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def main() -> int:
|
| 233 |
+
args = parse_args()
|
| 234 |
+
tasks = TASKS if args.task == "all" else [args.task]
|
| 235 |
+
run_smoke_training(
|
| 236 |
+
tasks=tasks,
|
| 237 |
+
episodes=args.episodes,
|
| 238 |
+
output=Path(args.output),
|
| 239 |
+
seed=args.seed,
|
| 240 |
+
epsilon=args.epsilon,
|
| 241 |
+
epsilon_decay=args.epsilon_decay,
|
| 242 |
+
epsilon_floor=args.epsilon_floor,
|
| 243 |
+
lr=args.lr,
|
| 244 |
+
max_steps=args.max_steps,
|
| 245 |
+
)
|
| 246 |
+
return 0
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
raise SystemExit(main())
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|