Code PRM Critic — Qwen3.5-4B LoRA
A Process Reward Model (PRM) critic that estimates the action-value function Q(s_t, a_t) at each step of a code agent's trajectory. Given a multi-turn agent conversation, the model assigns a scalar Q-value to each assistant turn, indicating how likely that action is to lead to a successfully resolved software engineering task.
Based on the critic training approach from arXiv:2505.13652.
Model Details
- Base model:
Qwen/Qwen3.5-4B(vision encoder removed, LM head replaced) - Fine-tuning: LoRA (r=8, α=16) on
q_proj,v_proj,out_proj - Value head:
Linear(hidden_size → 1, bias=False), bfloat16 - Training signal: MSE loss on
<value>token positions against TD(λ) targets - Context length: 32,000 tokens
- Hardware: 2× NVIDIA A40 (44 GB each), ~19.6 hours
Training Data
Trained on SWE-agent trajectories from MiniMax 2.7 runs on SWE-Smith tasks — 3,923 trajectories across three cohorts, split 80/20 into train/val. Each trajectory is labeled with a binary outcome (resolved / unresolved). Trajectories with outcome error, incomplete, or completed_only are excluded.
| Cohort | Total | Resolved | Unresolved |
|---|---|---|---|
swesmith-minimax_2_7_0 |
2058 | 1218 | 827 |
swesmith-minimax_2_7_1 |
1555 | 924 | 620 |
swesmith-minimax_2_7_2 |
310 | 175 | 130 |
TD(λ) targets: target_t = λ^(T-1-t) × reward, λ=0.7, reward ∈ {0.0, 1.0}. A special <value> token is appended after each assistant turn; the scalar predicted at that position is the Q-value for that turn.
Training Results
| Epoch | Train Loss | Eval Loss |
|---|---|---|
| 1 | 6.92 | 4.58 |
| 2 | 0.86 | 0.85 |
| 3 | 0.63 | 0.68 |
| 4 | 0.67 | 0.67 |
Cosine LR schedule, lr=2e-6, warmup 7 steps, weight decay=0.1, gradient checkpointing enabled.
How to Use
This adapter requires a small custom wrapper to attach the value head. Install dependencies first:
pip install transformers peft torch
Load and score a trajectory:
import torch
from transformers import AutoTokenizer, Qwen3_5ForConditionalGeneration
from peft import PeftModel
CHECKPOINT = "zkjzou99/code-prm-critic-lora-32k"
BASE_MODEL = "Qwen/Qwen3.5-4B"
VALUE_TOKEN = "<value>"
# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)
# --- Load backbone ---
backbone = Qwen3_5ForConditionalGeneration.from_pretrained(
BASE_MODEL,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
)
del backbone.model.visual # vision encoder unused
backbone.model.language_model.resize_token_embeddings(len(tokenizer))
backbone.model.language_model = PeftModel.from_pretrained(
backbone.model.language_model, CHECKPOINT
)
# --- Load value head ---
import torch.nn as nn
value_head = nn.Linear(backbone.config.text_config.hidden_size, 1, bias=False)
value_head.load_state_dict(
torch.hub.load_state_dict_from_url(
f"https://huggingface.co/{CHECKPOINT}/resolve/main/value_head.pt",
map_location="cpu",
)
)
value_head = value_head.to(torch.bfloat16).cuda()
backbone = backbone.cuda().eval()
# --- Score a trajectory ---
# messages: list of {"role": "system"/"user"/"assistant"/"tool", "content": "..."}
# (standard format from the training data)
IM_START, IM_END = "<|im_start|>", "<|im_end|>"
value_id = tokenizer.convert_tokens_to_ids(VALUE_TOKEN)
def build_input(messages):
tokens = []
value_positions = []
for msg in messages:
role, content = msg["role"], msg.get("content", "")
if role == "tool":
text = f"<|im_start|>user\n<tool_response>\n{content}\n</tool_response><|im_end|>\n"
else:
text = f"{IM_START}{role}\n{content}{IM_END}\n"
ids = tokenizer.encode(text, add_special_tokens=False)
tokens.extend(ids)
if role == "assistant":
value_positions.append(len(tokens))
tokens.append(value_id)
return tokens, value_positions
tokens, value_positions = build_input(messages)
input_ids = torch.tensor([tokens]).cuda()
attention_mask = torch.ones_like(input_ids)
with torch.inference_mode():
hidden = backbone.model.language_model(
input_ids=input_ids, attention_mask=attention_mask
).last_hidden_state # [1, L, H]
q_values = value_head(hidden).squeeze(-1)[0] # [L]
scores = [(i, q_values[pos].item()) for i, pos in enumerate(value_positions)]
print(scores)
# → [(0, 0.543), (1, 0.237), ..., (27, 0.439)]
# Higher Q-value = action more likely to lead to task resolution
Single-example output (resolved trajectory, 28 assistant turns):
trajectory_id : bottlepy__bottle.a8dfef30.lm_rewrite__6wkup0zn
resolved : True
n_turns : 28
Q-values : [0.543, 0.237, 1.258, 2.906, 1.305, 1.781, 0.551, 0.355, 0.244,
0.135, 0.531, 0.742, 0.034, 1.18, 0.223, 1.078, 0.299, -0.264,
-0.008, 0.26, -0.855, -0.438, 0.938, 0.867, 0.551, -0.645, -0.046, 0.439]
final Q-value : 0.4395
Q-values are raw scalars (not sigmoid-squashed). Higher values indicate greater estimated probability of leading to task resolution. The final turn's Q-value is a natural summary score for the full trajectory.
Citation
@article{code-prm-2025,
title = {Training LLM-based Critics for Code Agent Trajectories},
author = {},
year = {2025},
url = {https://arxiv.org/abs/2505.13652}
}
Framework Versions
- PEFT 0.18.1
- Transformers ≥ 4.52
- PyTorch 2.11.0+cu128
- Downloads last month
- 17