| --- |
| library_name: numpy |
| base_model: google/gemma-4-31B-it |
| tags: |
| - gemma |
| - gemma-4-31b |
| - linear-probe |
| - token-level |
| - span-max |
| - vulnerability-detection |
| --- |
| |
| # GemmaForge Token-Level Probe β Gemma 4 31B (span-max) |
|
|
| Frozen `google/gemma-4-31B-it` (60 decoder layers, hidden_size 5376) with a |
| **linear span-max probe** on per-token hidden states of decoder layer |
| **15**. Trained with the Obeso, Arditi et al. 2025 |
| (arXiv 2509.03531 Β§3) span-max loss on `data/dataset.jsonl` (N=1374, |
| 687 pos / 687 neg, token-level char-range labels propagated from SVEN diffs). |
| |
| Companion models: |
| - [`peaktwilight/gemmaforge-gemma4-probe`](https://huggingface.co/peaktwilight/gemmaforge-gemma4-probe) β sample-level (last-token) probe on Gemma 4 E2B |
| - [`peaktwilight/gemmaforge-31b-probe`](https://huggingface.co/peaktwilight/gemmaforge-31b-probe) β sample-level (last-token) probe on Gemma 4 31B |
| |
| ## Files |
| |
| - `probe_spanmax_31b.npz` β `(w, b, layer)`; `sigmoid(w @ hidden[layer + 1][0, t, :] + b)` is the per-token risk |
| - `probe_spanmax_31b_card.json` β per-layer token/example AUC, training config |
| - `token_probs_31b.npz` β per-row token probabilities (`probs_row_NNNN`) on `data/dataset.jsonl` |
| - `token_offsets_31b.npz` β per-row `(T, 2)` char offsets (`offsets_row_NNNN`) |
| - `spans.json` β positive `(example_id, tok_start, tok_end)` triples for the trainer's span-max pool |
| - `token_report_31b.md` / `token_report_31b.json` β full three-level eval report |
|
|
| ## Training setup |
|
|
| | | | |
| |---|---| |
| | Base model | `google/gemma-4-31B-it` | |
| | Decoder layers / hidden | 60 / 5376 | |
| | Layers probed | [15, 30, 45, 59] (25/50/75/100% depth) | |
| | Winning layer | 15 | |
| | Loss | span-max (alpha=10, omega 0β1 linear) | |
| | Optimizer / epochs / batch | AdamW (lr=1e-3) / 30 / 8 examples | |
| | Activation dtype on disk | float16 | |
|
|
| ### Layer sweep (winner = best example-level AUC) |
|
|
| | Layer | tok_AUC | ex_AUC | |
| |---:|---:|---:| |
| | 15 | 0.730 | 0.692 **<-- winner** | |
| | 30 | 0.721 | 0.657 | |
| | 45 | 0.752 | 0.647 | |
| | 59 | 0.683 | 0.676 | |
|
|
| ## Eval headline (`data/dataset.jsonl`, N=1374, pos=687) |
|
|
| | Split | `all` AUC | `proximal_all` AUC | `span_max` AUC | `dilated_span_max` AUC | |
| |---|---:|---:|---:|---:| |
| | `random_stratified` | 0.879 | 0.729 | 0.669 | 0.565 | |
| | `group_repo` | 0.812 | 0.714 | 0.623 | 0.495 | |
| | `heldout_cwe::CWE-089` | 0.960 | 0.713 | 0.944 | 0.941 | |
| | `heldout_cwe::CWE-125` | 0.813 | 0.714 | 0.512 | 0.332 | |
| | `heldout_cwe::CWE-078` | 0.881 | 0.748 | 0.724 | 0.686 | |
| | `heldout_cwe::CWE-476` | 0.794 | 0.706 | 0.497 | 0.285 | |
| | `heldout_cwe::CWE-079` | 0.796 | 0.677 | 0.492 | 0.299 | |
| | `heldout_lang::test=c` | 0.791 | 0.697 | 0.497 | 0.292 | |
| | `heldout_lang::test=cpp` | 0.778 | 0.696 | 0.494 | 0.307 | |
| | `heldout_lang::test=python` | 0.896 | 0.711 | 0.824 | 0.781 | |
|
|
| Notes: |
| - `all` measures the probe on every token (streaming-UI view); best for per-token highlighting. |
| - `span_max` collapses to one decision per example β directly comparable to sample-level probes. |
| - `span` is NaN on this corpus: SVEN has no sanitizer-annotated negatives, so the span level is single-class. The protocol synthesises a whole-file negative span for label=0 examples inside `span_max` only. |
|
|
| ## Reproduce inference |
|
|
| ```python |
| from huggingface_hub import hf_hub_download |
| import numpy as np |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| |
| probe = np.load(hf_hub_download("mmtf/gemmaforge-31b-token-probe", "probe_spanmax_31b.npz")) |
| w, b, layer = probe["w"], float(probe["b"]), int(probe["layer"]) |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| "google/gemma-4-31B-it", torch_dtype=torch.bfloat16, |
| device_map="auto", attn_implementation="eager", |
| ) |
| tok = AutoTokenizer.from_pretrained("google/gemma-4-31B-it") |
| |
| ids = tok("def vuln(x): return os.system(x)", return_tensors="pt").input_ids.to(model.device) |
| with torch.inference_mode(): |
| out = model(ids, output_hidden_states=True, use_cache=False) |
| h = out.hidden_states[layer + 1][0].float().cpu().numpy() |
| per_token_risk = 1.0 / (1.0 + np.exp(-(h @ w + b))) |
| ``` |
|
|
| ## Pipeline |
|
|
| Full training + eval pipeline: <https://github.com/peaktwilight/gemmaforge> |
|
|