GemmaForge Token-Level Probe β€” Gemma 4 31B (span-max)

Frozen google/gemma-4-31B-it (60 decoder layers, hidden_size 5376) with a linear span-max probe on per-token hidden states of decoder layer 15. Trained with the Obeso, Arditi et al. 2025 (arXiv 2509.03531 Β§3) span-max loss on data/dataset.jsonl (N=1374, 687 pos / 687 neg, token-level char-range labels propagated from SVEN diffs).

Companion models:

Files

  • probe_spanmax_31b.npz β€” (w, b, layer); sigmoid(w @ hidden[layer + 1][0, t, :] + b) is the per-token risk
  • probe_spanmax_31b_card.json β€” per-layer token/example AUC, training config
  • token_probs_31b.npz β€” per-row token probabilities (probs_row_NNNN) on data/dataset.jsonl
  • token_offsets_31b.npz β€” per-row (T, 2) char offsets (offsets_row_NNNN)
  • spans.json β€” positive (example_id, tok_start, tok_end) triples for the trainer's span-max pool
  • token_report_31b.md / token_report_31b.json β€” full three-level eval report

Training setup

Base model google/gemma-4-31B-it
Decoder layers / hidden 60 / 5376
Layers probed [15, 30, 45, 59] (25/50/75/100% depth)
Winning layer 15
Loss span-max (alpha=10, omega 0β†’1 linear)
Optimizer / epochs / batch AdamW (lr=1e-3) / 30 / 8 examples
Activation dtype on disk float16

Layer sweep (winner = best example-level AUC)

Layer tok_AUC ex_AUC
15 0.730 0.692 <-- winner
30 0.721 0.657
45 0.752 0.647
59 0.683 0.676

Eval headline (data/dataset.jsonl, N=1374, pos=687)

Split all AUC proximal_all AUC span_max AUC dilated_span_max AUC
random_stratified 0.879 0.729 0.669 0.565
group_repo 0.812 0.714 0.623 0.495
heldout_cwe::CWE-089 0.960 0.713 0.944 0.941
heldout_cwe::CWE-125 0.813 0.714 0.512 0.332
heldout_cwe::CWE-078 0.881 0.748 0.724 0.686
heldout_cwe::CWE-476 0.794 0.706 0.497 0.285
heldout_cwe::CWE-079 0.796 0.677 0.492 0.299
heldout_lang::test=c 0.791 0.697 0.497 0.292
heldout_lang::test=cpp 0.778 0.696 0.494 0.307
heldout_lang::test=python 0.896 0.711 0.824 0.781

Notes:

  • all measures the probe on every token (streaming-UI view); best for per-token highlighting.
  • span_max collapses to one decision per example β€” directly comparable to sample-level probes.
  • span is NaN on this corpus: SVEN has no sanitizer-annotated negatives, so the span level is single-class. The protocol synthesises a whole-file negative span for label=0 examples inside span_max only.

Reproduce inference

from huggingface_hub import hf_hub_download
import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

probe = np.load(hf_hub_download("mmtf/gemmaforge-31b-token-probe", "probe_spanmax_31b.npz"))
w, b, layer = probe["w"], float(probe["b"]), int(probe["layer"])

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-4-31B-it", torch_dtype=torch.bfloat16,
    device_map="auto", attn_implementation="eager",
)
tok = AutoTokenizer.from_pretrained("google/gemma-4-31B-it")

ids = tok("def vuln(x): return os.system(x)", return_tensors="pt").input_ids.to(model.device)
with torch.inference_mode():
    out = model(ids, output_hidden_states=True, use_cache=False)
h = out.hidden_states[layer + 1][0].float().cpu().numpy()
per_token_risk = 1.0 / (1.0 + np.exp(-(h @ w + b)))

Pipeline

Full training + eval pipeline: https://github.com/peaktwilight/gemmaforge

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for mmtf/gemmaforge-31b-token-probe

Finetuned
(163)
this model