multiclass-folded / predict.py
simonmorley's picture
Multi-class softmax folded detector — initial release (V8-V14 + V16, 2026-05-13)
4a9a4d9 verified
#!/usr/bin/env python3
# SPDX-License-Identifier: Apache-2.0
"""Multi-class softmax folded detector — scoreability-gated inference.
The recommended consumption surface for the 9-class V8-V14 + V16 multi-class
folded detector. Wraps the raw `CalibratedClassifierCV` estimator with two
production-side gates:
1. **Scoreability gate**: refuses to score bundles where neither
``responses.parquet`` nor ``packets.pcap`` has content. Bundles with
no observed RPC traffic AND no captured network packets cannot be
classified usefully; the gate returns an explicit "unscoreable"
verdict instead of producing a spurious argmax.
2. **Feature-coverage flag**: emits a ``feature_coverage`` string
describing which bundle modalities contributed features
(``"resp_only"``, ``"pcap_only"``, ``"full"``, ``"partial"``). V16
gossip-abuse predictions are load-bearing on ``pcap.*`` features;
V8-V14 are load-bearing on ``responses.*``. Callers should
downweight predictions where the modality coverage doesn't match
the predicted class.
Callers who want raw model output without these gates should load
``model.joblib`` directly — see the "Bypassing the gate" section of the
model card.
Usage::
from predict import load_model, score_bundle
payload = load_model("/path/to/model.joblib")
record = score_bundle("/path/to/bundle_dir", payload)
print(record["argmax_class"], record["class_probs"])
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
import joblib
import numpy as np
import pyarrow.parquet as pq
# nr-bundle-spec — the reference parser. Pip-install via
# pip install git+https://github.com/NullRabbitLabs/nr-bundle-spec.git
from bundle_spec import BundleManifest
def load_model(model_path: str | Path) -> dict[str, Any]:
"""Load the multi-class folded lineage-dict payload from a joblib file."""
return joblib.load(model_path)
def _modality_state(bundle_dir: Path) -> tuple[bool, int, bool]:
"""Inspect bundle modality presence.
Returns (has_responses_with_rows, n_responses_rows, has_packets_pcap).
"""
responses_path = bundle_dir / "responses.parquet"
n_resp = 0
has_resp = False
if responses_path.is_file():
table = pq.read_table(responses_path)
n_resp = table.num_rows
has_resp = n_resp > 0
has_pcap = (bundle_dir / "packets.pcap").is_file()
return has_resp, n_resp, has_pcap
def _feature_coverage(has_resp: bool, has_pcap: bool) -> str:
"""Bundle-level feature-coverage flag for downstream gating."""
if has_resp and has_pcap:
return "full"
if has_resp:
return "resp_only"
if has_pcap:
return "pcap_only"
return "none"
def _extract_features(bundle_dir: Path, feature_names: list[str]) -> np.ndarray:
"""Extract the model's 107-feature vector from a bundle directory.
Uses the nr_training feature extractor under the hood — same pipeline
the model was trained against. Falls back to a minimal pyarrow-based
extractor for the response-aggregate features if nr_training isn't
on the import path (deployment-time graceful degradation).
"""
# Try the canonical extractor first; fall back to manual extraction.
try:
import sys
sys.path.insert(0, str(Path(__file__).resolve().parent))
# nr_training is the substrate-side feature extractor; absent in
# most deployment envs. Caller should install it from the
# nr-substrate working repo if they want exact-equivalence
# extraction matching training. This block is best-effort.
from nr_training.contracts import BundleManifest as _BM
from nr_training.datasets.loader import Bundle, _sha256
from nr_training.features import batch_extract
mfp = bundle_dir / "manifest.json"
m = _BM.model_validate_json(mfp.read_text())
b = Bundle(corpus_id=m.corpus_id, bundle_dir=bundle_dir, manifest=m,
manifest_sha256=_sha256(mfp), pcap_sha256=None)
fvs = batch_extract([b])
return np.array([[fvs[0].features.get(n, 0.0) for n in feature_names]], dtype=float)
except ImportError:
# Minimal fallback: only resp.* features from responses.parquet.
features = {name: 0.0 for name in feature_names}
rp = bundle_dir / "responses.parquet"
if rp.is_file():
table = pq.read_table(rp)
if table.num_rows > 0:
req = table.column("request_size_bytes").to_numpy()
resp = table.column("response_size_bytes").to_numpy()
if "resp.req_bytes_max" in features:
features["resp.req_bytes_max"] = float(req.max())
if "resp.resp_bytes_max" in features:
features["resp.resp_bytes_max"] = float(resp.max())
with np.errstate(divide="ignore", invalid="ignore"):
ratios = np.where(req > 0, resp / req, 0.0)
if "resp.amp_ratio_max" in features:
features["resp.amp_ratio_max"] = float(ratios.max())
if "resp.amp_ratio_mean" in features:
features["resp.amp_ratio_mean"] = float(ratios.mean())
if "resp.amp_ratio_median" in features:
features["resp.amp_ratio_median"] = float(np.median(ratios))
return np.array([[features[n] for n in feature_names]], dtype=float)
def score_bundle(bundle_dir: str | Path, payload: dict[str, Any]) -> dict[str, Any]:
"""Score one bundle through the multi-class folded model.
Returns a record with:
- ``verdict``: ``"<class_name>"`` or ``"unscoreable"``.
- ``argmax_class``: argmax class name (None if unscoreable).
- ``argmax_p``: probability of the argmax class (None if unscoreable).
- ``class_probs``: dict of P(class) for every class in class_order.
- ``reason``: human-readable explanation when unscoreable.
- ``feature_coverage``: ``"full"`` / ``"resp_only"`` / ``"pcap_only"`` / ``"none"``.
- ``corpus_id``, ``primitive_id``, ``ground_truth``: from manifest.
- ``n_responses_rows``: number of rows in responses.parquet.
"""
bundle_dir = Path(bundle_dir)
manifest_path = bundle_dir / "manifest.json"
if not manifest_path.is_file():
return {
"verdict": "unscoreable",
"reason": f"manifest.json not found at {manifest_path}",
"argmax_class": None,
"argmax_p": None,
"class_probs": None,
}
manifest = BundleManifest.model_validate_json(manifest_path.read_text())
has_resp, n_resp, has_pcap = _modality_state(bundle_dir)
# Scoreability gate: at least one of {responses.parquet with rows,
# packets.pcap on disk} must be present.
if not (has_resp or has_pcap):
return {
"verdict": "unscoreable",
"reason": (
"Neither responses.parquet (with rows) nor packets.pcap is "
"present in the bundle. The multi-class folded detector "
"cannot classify bundles with no observed RPC traffic AND "
"no captured network packets. Bundles in this state are "
"typically passive-workload captures (e.g. validator running "
"idle with no clients) — use a non-bundle telemetry path "
"for those workloads."
),
"argmax_class": None,
"argmax_p": None,
"class_probs": None,
"corpus_id": manifest.corpus_id,
"primitive_id": manifest.primitive_id,
"feature_coverage": "none",
"n_responses_rows": n_resp,
}
feature_names = payload["feature_names"]
class_order = payload["class_order"]
X = _extract_features(bundle_dir, feature_names)
proba = payload["model"].predict_proba(X)[0]
argmax = int(np.argmax(proba))
class_probs = {cls: float(proba[i]) for i, cls in enumerate(class_order)}
coverage = _feature_coverage(has_resp, has_pcap)
argmax_class = class_order[argmax]
# Modality-mismatch warning: V8-V14 classes are load-bearing on pcap.*
# features for some discriminations (especially rate-cardinality
# features that V11 / benign-vs-attack boundaries depend on). If the
# bundle is resp_only and the model picks a non-V16 class with low
# confidence, the prediction may be OOD-by-construction (the model
# was trained on full-modality bundles; resp_only inputs aren't part
# of its training distribution). Surface the warning.
coverage_warning = None
if coverage == "resp_only" and argmax_class != "V16" and proba[argmax] < 0.8:
coverage_warning = (
f"argmax={argmax_class} with P={proba[argmax]:.3f} on resp_only "
"coverage; multi-class folded was trained on full-modality "
"bundles, so predictions on pcap-absent inputs are out-of-"
"distribution. For reliable V8-V14 inference, provide bundles "
"with raw packets.pcap present."
)
elif coverage == "resp_only" and argmax_class == "V16":
coverage_warning = (
"argmax=V16 with resp_only coverage. V16 is load-bearing on "
"pcap.* features; an argmax of V16 on a pcap-absent bundle "
"is likely a misclassification driven by the missing-modality "
"signal, not a true gossip-abuse detection. Provide bundles "
"with raw packets.pcap for V16 inference."
)
return {
"verdict": argmax_class,
"argmax_class": argmax_class,
"argmax_p": float(proba[argmax]),
"class_probs": class_probs,
"reason": None,
"corpus_id": manifest.corpus_id,
"primitive_id": manifest.primitive_id,
"ground_truth": (
manifest.ground_truth_label.value
if hasattr(manifest.ground_truth_label, "value")
else str(manifest.ground_truth_label)
),
"feature_coverage": coverage,
"coverage_warning": coverage_warning,
"n_responses_rows": n_resp,
}