opensoc-env / eval /eval.py
shivam2k3's picture
OpenSOC v1
bb6a031
"""Evaluate baseline and trained defender on the frozen hold-out set.
Two models are compared by default:
* **Baseline**: vanilla Qwen2.5-3B-Instruct, no SFT, no GRPO.
* **Trained**: Qwen2.5-3B-Instruct + SFT warm-start + GRPO curriculum.
Both are scored on `data/holdout.jsonl` using the verifier's ground-truth
labels. Reported metrics (printed and saved to `--out-dir`):
* Macro F1 + per-class precision/recall
* 5x5 confusion matrix
* Dismiss-on-malicious rate (the cardinal SOC failure mode)
* Over-react rate (containment on benign)
Inference path
--------------
We use Unsloth's `FastLanguageModel.from_pretrained(... load_in_4bit=True)`
with `model.fast_generate` to keep eval under 10 minutes on a T4. When
GPU deps aren't available (e.g. the Hugging Face Space build log), the
script falls back to a verifier-only sanity check by re-grading the
held-out file against itself, which serves as a smoke test.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
from typing import List, Tuple
_HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(_HERE))
from eval.metrics import ( # noqa: E402
accuracy,
confusion_matrix,
dismiss_on_malicious_rate,
over_react_rate,
per_class_f1,
)
from schema import Alert, Event, IncidentCategory, TriageAction # noqa: E402
from train.prompt_format import ( # noqa: E402
SYSTEM_PROMPT,
parse_defender_response,
render_defender_prompt,
)
def _load_holdout(path: str):
items = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
items.append(json.loads(line))
return items
def _to_alert_events(rec: dict) -> Tuple[Alert, List[Event]]:
a = rec["alert"]
alert = Alert(
alert_id=a["alert_id"],
category=IncidentCategory(a["category"]),
severity=a["severity"],
summary=a["summary"],
host=a.get("host", ""),
user=a.get("user", ""),
)
events = [Event(**e) for e in rec["events"]]
return alert, events
def _print_metrics(label: str, preds: List[str], truths: List[str]) -> dict:
cm = confusion_matrix(preds, truths)
macro_f1, per_class = per_class_f1(cm)
acc = accuracy(preds, truths)
miss = dismiss_on_malicious_rate(preds, truths)
over = over_react_rate(preds, truths)
print(f"\n=== {label} ===")
print(f" accuracy: {acc:.3f}")
print(f" macro F1: {macro_f1:.3f}")
print(f" dismiss-on-malicious: {miss:.3f}")
print(f" over-react on benign: {over:.3f}")
print(" per-class:")
for cls, m in per_class.items():
print(f" {cls:<18} P={m['precision']:.2f} R={m['recall']:.2f} F1={m['f1']:.2f} (n={int(m['support'])})")
return {
"label": label,
"accuracy": acc,
"macro_f1": macro_f1,
"dismiss_on_malicious": miss,
"over_react_rate": over,
"per_class": per_class,
"confusion_matrix": cm,
}
# ---------------------------------------------------------------------------
# Inference adapters
# ---------------------------------------------------------------------------
class _VerifierOracle:
"""A 'model' that always returns the verifier's correct answer.
Used as a smoke test when GPU deps aren't installed; it should achieve
100% accuracy / 0% dismiss-on-malicious by construction.
"""
name = "verifier_oracle"
def predict(self, alert: Alert, events: List[Event], gold: dict) -> str:
return f"Action: {gold['ground_truth']}\nCitedLog: {gold['triggering_log_id']}\nRationale: oracle"
class _AlwaysDismissBaseline:
"""A trivial baseline that always says 'dismiss'."""
name = "always_dismiss"
def predict(self, alert: Alert, events: List[Event], gold: dict) -> str:
return "Action: dismiss\nCitedLog: L1-0\nRationale: trivial baseline"
def _try_load_unsloth_model(model_name: str, adapter_path: str | None):
"""Load a model via Unsloth. Returns None if GPU deps aren't installed."""
try:
from unsloth import FastLanguageModel
except ImportError:
return None
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=model_name,
max_seq_length=2048,
dtype=None,
load_in_4bit=True,
)
if adapter_path and os.path.exists(adapter_path):
model.load_adapter(adapter_path, adapter_name="default", is_trainable=False)
FastLanguageModel.for_inference(model)
return model, tokenizer
def _generate(model_pair, alert, events) -> str:
model, tokenizer = model_pair
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": render_defender_prompt(alert, events)},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
out = model.generate(**inputs, max_new_tokens=128, do_sample=False, temperature=0.0)
text = tokenizer.decode(out[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return text
# ---------------------------------------------------------------------------
# Main eval
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--baseline", default="unsloth/Qwen2.5-3B-Instruct")
parser.add_argument("--trained-adapter", default="checkpoints/defender_grpo/stage4_adversarial/adapter")
parser.add_argument("--holdout", default="data/holdout.jsonl")
parser.add_argument("--out-dir", default="eval/results")
parser.add_argument("--smoke-only", action="store_true",
help="Skip GPU model loading; run oracle + always_dismiss only.")
args = parser.parse_args()
holdout_path = os.path.join(os.path.dirname(_HERE), args.holdout)
out_dir = os.path.join(os.path.dirname(_HERE), args.out_dir)
os.makedirs(out_dir, exist_ok=True)
holdout = _load_holdout(holdout_path)
truths = [r["ground_truth"] for r in holdout]
print(f"Loaded {len(holdout)} hold-out incidents from {holdout_path}")
summaries = []
# --- Always-dismiss baseline (sanity) ---
preds_dismiss = []
for rec in holdout:
alert, events = _to_alert_events(rec)
text = _AlwaysDismissBaseline().predict(alert, events, rec)
parsed = parse_defender_response(text)
preds_dismiss.append(parsed.action.value if parsed.action else "dismiss")
summaries.append(_print_metrics("always_dismiss", preds_dismiss, truths))
# --- Verifier oracle (sanity) ---
preds_oracle = []
for rec in holdout:
alert, events = _to_alert_events(rec)
text = _VerifierOracle().predict(alert, events, rec)
parsed = parse_defender_response(text)
preds_oracle.append(parsed.action.value if parsed.action else "dismiss")
summaries.append(_print_metrics("verifier_oracle", preds_oracle, truths))
# --- Real models ---
if not args.smoke_only:
baseline_pair = _try_load_unsloth_model(args.baseline, adapter_path=None)
if baseline_pair is not None:
preds_baseline = []
for rec in holdout:
alert, events = _to_alert_events(rec)
text = _generate(baseline_pair, alert, events)
parsed = parse_defender_response(text)
preds_baseline.append(parsed.action.value if parsed.action else "dismiss")
summaries.append(_print_metrics("baseline_zero_shot", preds_baseline, truths))
adapter_full = os.path.join(os.path.dirname(_HERE), args.trained_adapter)
if os.path.exists(adapter_full):
trained_pair = _try_load_unsloth_model(args.baseline, adapter_path=adapter_full)
if trained_pair is not None:
preds_trained = []
for rec in holdout:
alert, events = _to_alert_events(rec)
text = _generate(trained_pair, alert, events)
parsed = parse_defender_response(text)
preds_trained.append(parsed.action.value if parsed.action else "dismiss")
summaries.append(_print_metrics("opensoc_grpo", preds_trained, truths))
else:
print(f"\n(skip) trained adapter not found at {adapter_full}")
else:
print("\n(skip) GPU deps not installed; skipping baseline_zero_shot and opensoc_grpo.")
out_json = os.path.join(out_dir, "summary.json")
with open(out_json, "w") as f:
json.dump(summaries, f, indent=2)
print(f"\nSaved summary to {out_json}")
if __name__ == "__main__":
main()