Code PRM Critic — Qwen3.5-4B LoRA

A Process Reward Model (PRM) critic that estimates the action-value function Q(s_t, a_t) at each step of a code agent's trajectory. Given a multi-turn agent conversation, the model assigns a scalar Q-value to each assistant turn, indicating how likely that action is to lead to a successfully resolved software engineering task.

Based on the critic training approach from arXiv:2505.13652.

Model Details

  • Base model: Qwen/Qwen3.5-4B (vision encoder removed, LM head replaced)
  • Fine-tuning: LoRA (r=8, α=16) on q_proj, v_proj, out_proj
  • Value head: Linear(hidden_size → 1, bias=False), bfloat16
  • Training signal: MSE loss on <value> token positions against TD(λ) targets
  • Context length: 32,000 tokens
  • Hardware: 2× NVIDIA A40 (44 GB each), ~19.6 hours

Training Data

Trained on SWE-agent trajectories from MiniMax 2.7 runs on SWE-Smith tasks — 3,923 trajectories across three cohorts, split 80/20 into train/val. Each trajectory is labeled with a binary outcome (resolved / unresolved). Trajectories with outcome error, incomplete, or completed_only are excluded.

Cohort Total Resolved Unresolved
swesmith-minimax_2_7_0 2058 1218 827
swesmith-minimax_2_7_1 1555 924 620
swesmith-minimax_2_7_2 310 175 130

TD(λ) targets: target_t = λ^(T-1-t) × reward, λ=0.7, reward ∈ {0.0, 1.0}. A special <value> token is appended after each assistant turn; the scalar predicted at that position is the Q-value for that turn.

Training Results

Epoch Train Loss Eval Loss
1 6.92 4.58
2 0.86 0.85
3 0.63 0.68
4 0.67 0.67

Cosine LR schedule, lr=2e-6, warmup 7 steps, weight decay=0.1, gradient checkpointing enabled.

How to Use

This adapter requires a small custom wrapper to attach the value head. Install dependencies first:

pip install transformers peft torch

Load and score a trajectory:

import torch
from transformers import AutoTokenizer, Qwen3_5ForConditionalGeneration
from peft import PeftModel

CHECKPOINT = "zkjzou99/code-prm-critic-lora-32k"
BASE_MODEL  = "Qwen/Qwen3.5-4B"
VALUE_TOKEN = "<value>"

# --- Load tokenizer ---
tokenizer = AutoTokenizer.from_pretrained(CHECKPOINT)

# --- Load backbone ---
backbone = Qwen3_5ForConditionalGeneration.from_pretrained(
    BASE_MODEL,
    torch_dtype=torch.bfloat16,
    attn_implementation="sdpa",
)
del backbone.model.visual          # vision encoder unused
backbone.model.language_model.resize_token_embeddings(len(tokenizer))
backbone.model.language_model = PeftModel.from_pretrained(
    backbone.model.language_model, CHECKPOINT
)

# --- Load value head ---
import torch.nn as nn
value_head = nn.Linear(backbone.config.text_config.hidden_size, 1, bias=False)
value_head.load_state_dict(
    torch.hub.load_state_dict_from_url(
        f"https://huggingface.co/{CHECKPOINT}/resolve/main/value_head.pt",
        map_location="cpu",
    )
)
value_head = value_head.to(torch.bfloat16).cuda()
backbone = backbone.cuda().eval()


# --- Score a trajectory ---
# messages: list of {"role": "system"/"user"/"assistant"/"tool", "content": "..."}
# (standard format from the training data)

IM_START, IM_END = "<|im_start|>", "<|im_end|>"
value_id = tokenizer.convert_tokens_to_ids(VALUE_TOKEN)

def build_input(messages):
    tokens = []
    value_positions = []
    for msg in messages:
        role, content = msg["role"], msg.get("content", "")
        if role == "tool":
            text = f"<|im_start|>user\n<tool_response>\n{content}\n</tool_response><|im_end|>\n"
        else:
            text = f"{IM_START}{role}\n{content}{IM_END}\n"
        ids = tokenizer.encode(text, add_special_tokens=False)
        tokens.extend(ids)
        if role == "assistant":
            value_positions.append(len(tokens))
            tokens.append(value_id)
    return tokens, value_positions

tokens, value_positions = build_input(messages)
input_ids = torch.tensor([tokens]).cuda()
attention_mask = torch.ones_like(input_ids)

with torch.inference_mode():
    hidden = backbone.model.language_model(
        input_ids=input_ids, attention_mask=attention_mask
    ).last_hidden_state  # [1, L, H]
    q_values = value_head(hidden).squeeze(-1)[0]  # [L]

scores = [(i, q_values[pos].item()) for i, pos in enumerate(value_positions)]
print(scores)
# → [(0, 0.543), (1, 0.237), ..., (27, 0.439)]
# Higher Q-value = action more likely to lead to task resolution

Single-example output (resolved trajectory, 28 assistant turns):

trajectory_id : bottlepy__bottle.a8dfef30.lm_rewrite__6wkup0zn
resolved      : True
n_turns       : 28
Q-values      : [0.543, 0.237, 1.258, 2.906, 1.305, 1.781, 0.551, 0.355, 0.244,
                 0.135, 0.531, 0.742, 0.034, 1.18, 0.223, 1.078, 0.299, -0.264,
                 -0.008, 0.26, -0.855, -0.438, 0.938, 0.867, 0.551, -0.645, -0.046, 0.439]
final Q-value : 0.4395

Q-values are raw scalars (not sigmoid-squashed). Higher values indicate greater estimated probability of leading to task resolution. The final turn's Q-value is a natural summary score for the full trajectory.

Citation

@article{code-prm-2025,
  title  = {Training LLM-based Critics for Code Agent Trajectories},
  author = {},
  year   = {2025},
  url    = {https://arxiv.org/abs/2505.13652}
}

Framework Versions

  • PEFT 0.18.1
  • Transformers ≥ 4.52
  • PyTorch 2.11.0+cu128
Downloads last month
17
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for zkjzou99/code-prm-critic-lora-32k

Finetuned
Qwen/Qwen3.5-4B
Adapter
(160)
this model

Paper for zkjzou99/code-prm-critic-lora-32k