mmtf's picture
Upload Gemma 4 31B token-level span-max probe (layer 15/60, ex_AUC=0.692, tok_AUC=0.730)
636e0ae verified
---
library_name: numpy
base_model: google/gemma-4-31B-it
tags:
- gemma
- gemma-4-31b
- linear-probe
- token-level
- span-max
- vulnerability-detection
---
# GemmaForge Token-Level Probe β€” Gemma 4 31B (span-max)
Frozen `google/gemma-4-31B-it` (60 decoder layers, hidden_size 5376) with a
**linear span-max probe** on per-token hidden states of decoder layer
**15**. Trained with the Obeso, Arditi et al. 2025
(arXiv 2509.03531 Β§3) span-max loss on `data/dataset.jsonl` (N=1374,
687 pos / 687 neg, token-level char-range labels propagated from SVEN diffs).
Companion models:
- [`peaktwilight/gemmaforge-gemma4-probe`](https://huggingface.co/peaktwilight/gemmaforge-gemma4-probe) β€” sample-level (last-token) probe on Gemma 4 E2B
- [`peaktwilight/gemmaforge-31b-probe`](https://huggingface.co/peaktwilight/gemmaforge-31b-probe) β€” sample-level (last-token) probe on Gemma 4 31B
## Files
- `probe_spanmax_31b.npz` β€” `(w, b, layer)`; `sigmoid(w @ hidden[layer + 1][0, t, :] + b)` is the per-token risk
- `probe_spanmax_31b_card.json` β€” per-layer token/example AUC, training config
- `token_probs_31b.npz` β€” per-row token probabilities (`probs_row_NNNN`) on `data/dataset.jsonl`
- `token_offsets_31b.npz` β€” per-row `(T, 2)` char offsets (`offsets_row_NNNN`)
- `spans.json` β€” positive `(example_id, tok_start, tok_end)` triples for the trainer's span-max pool
- `token_report_31b.md` / `token_report_31b.json` β€” full three-level eval report
## Training setup
| | |
|---|---|
| Base model | `google/gemma-4-31B-it` |
| Decoder layers / hidden | 60 / 5376 |
| Layers probed | [15, 30, 45, 59] (25/50/75/100% depth) |
| Winning layer | 15 |
| Loss | span-max (alpha=10, omega 0β†’1 linear) |
| Optimizer / epochs / batch | AdamW (lr=1e-3) / 30 / 8 examples |
| Activation dtype on disk | float16 |
### Layer sweep (winner = best example-level AUC)
| Layer | tok_AUC | ex_AUC |
|---:|---:|---:|
| 15 | 0.730 | 0.692 **<-- winner** |
| 30 | 0.721 | 0.657 |
| 45 | 0.752 | 0.647 |
| 59 | 0.683 | 0.676 |
## Eval headline (`data/dataset.jsonl`, N=1374, pos=687)
| Split | `all` AUC | `proximal_all` AUC | `span_max` AUC | `dilated_span_max` AUC |
|---|---:|---:|---:|---:|
| `random_stratified` | 0.879 | 0.729 | 0.669 | 0.565 |
| `group_repo` | 0.812 | 0.714 | 0.623 | 0.495 |
| `heldout_cwe::CWE-089` | 0.960 | 0.713 | 0.944 | 0.941 |
| `heldout_cwe::CWE-125` | 0.813 | 0.714 | 0.512 | 0.332 |
| `heldout_cwe::CWE-078` | 0.881 | 0.748 | 0.724 | 0.686 |
| `heldout_cwe::CWE-476` | 0.794 | 0.706 | 0.497 | 0.285 |
| `heldout_cwe::CWE-079` | 0.796 | 0.677 | 0.492 | 0.299 |
| `heldout_lang::test=c` | 0.791 | 0.697 | 0.497 | 0.292 |
| `heldout_lang::test=cpp` | 0.778 | 0.696 | 0.494 | 0.307 |
| `heldout_lang::test=python` | 0.896 | 0.711 | 0.824 | 0.781 |
Notes:
- `all` measures the probe on every token (streaming-UI view); best for per-token highlighting.
- `span_max` collapses to one decision per example β€” directly comparable to sample-level probes.
- `span` is NaN on this corpus: SVEN has no sanitizer-annotated negatives, so the span level is single-class. The protocol synthesises a whole-file negative span for label=0 examples inside `span_max` only.
## Reproduce inference
```python
from huggingface_hub import hf_hub_download
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
probe = np.load(hf_hub_download("mmtf/gemmaforge-31b-token-probe", "probe_spanmax_31b.npz"))
w, b, layer = probe["w"], float(probe["b"]), int(probe["layer"])
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-4-31B-it", torch_dtype=torch.bfloat16,
device_map="auto", attn_implementation="eager",
)
tok = AutoTokenizer.from_pretrained("google/gemma-4-31B-it")
ids = tok("def vuln(x): return os.system(x)", return_tensors="pt").input_ids.to(model.device)
with torch.inference_mode():
out = model(ids, output_hidden_states=True, use_cache=False)
h = out.hidden_states[layer + 1][0].float().cpu().numpy()
per_token_risk = 1.0 / (1.0 + np.exp(-(h @ w + b)))
```
## Pipeline
Full training + eval pipeline: <https://github.com/peaktwilight/gemmaforge>