Instructions to use zhiqix/PUM with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use zhiqix/PUM with PEFT:
Task type is invalid.
- Notebooks
- Google Colab
- Kaggle
PUM: Prefix Utility Model
This repository contains the released checkpoint of PUM, a Prefix Utility Model for gain-based evaluation of LLM reasoning prefixes.
PUM scores partial reasoning prefixes according to their estimated utility for improving future problem-solving success. Given a math problem and one or more candidate reasoning prefixes, the model assigns a scalar score to each prefix. Higher scores indicate higher estimated prefix utility under the current scorer.
This model is associated with the paper:
From Correctness to Utility: Gain-Based Prefix Evaluation for LLM Reasoning
Model Structure
This repository is not a standard text-generation model. It is a custom scorer composed of:
- Base model:
Qwen/Qwen3-4B-Instruct-2507 - LoRA adapter:
backbone/ - Scalar MLP value head:
value_head.pt
Expected repository structure:
.
├── backbone/
│ ├── adapter_config.json
│ └── adapter_model.safetensors
├── value_head.pt
├── model_args.json
├── train_args.json
├── tokenizer_config.json
├── special_tokens_map.json
├── added_tokens.json
├── chat_template.jinja
├── vocab.json
├── merges.txt
└── README.md
Installation
pip install torch transformers peft huggingface_hub accelerate safetensors
Minimal Usage
Save the following code as run_pum.py and run it directly.
import os
import json
import math
from typing import Dict, Any, List
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
REPO_ID = "zhiqix/PUM"
SYSTEM_PROMPT = "Please reason step by step, and put your final answer within \\boxed{}."
def get_hidden_size(config):
for k in ["hidden_size", "n_embd", "d_model", "dim"]:
if hasattr(config, k):
return getattr(config, k)
raise ValueError("Cannot infer hidden size from model config.")
def build_text(problem: str, prefix: str) -> str:
problem = "" if problem is None else str(problem)
prefix = "" if prefix is None else str(prefix)
return (
f"<system>\n{SYSTEM_PROMPT}\n\n"
f"<problem>\n{problem}\n\n"
f"<prefix>\n{prefix}\n\n"
f"Score:"
)
class PUMScorer(nn.Module):
def __init__(self, backbone: nn.Module, dropout_p: float = 0.1, use_mlp_head: bool = True):
super().__init__()
self.backbone = backbone
if hasattr(self.backbone.config, "use_cache"):
self.backbone.config.use_cache = False
hidden_size = get_hidden_size(self.backbone.config)
self.dropout = nn.Dropout(dropout_p)
if use_mlp_head:
self.value_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, 1),
)
else:
self.value_head = nn.Linear(hidden_size, 1)
def _last_non_pad_index(self, attention_mask: torch.Tensor) -> torch.Tensor:
flipped = attention_mask.flip(dims=[1])
last_from_end = flipped.long().argmax(dim=1)
return attention_mask.size(1) - 1 - last_from_end
@torch.no_grad()
def score(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.backbone(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True,
use_cache=False,
)
hidden = outputs.hidden_states[-1]
idx = self._last_non_pad_index(attention_mask)
pooled = hidden[torch.arange(hidden.size(0), device=hidden.device), idx]
pooled = self.dropout(pooled)
return self.value_head(pooled).squeeze(-1)
def load_json(path: str) -> Dict[str, Any]:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def build_value_head(backbone: nn.Module, dropout_p: float, value_head_state: Dict[str, torch.Tensor]):
state_keys = set(value_head_state.keys())
mlp_keys = {"0.weight", "0.bias", "2.weight", "2.bias"}
linear_keys = {"weight", "bias"}
if state_keys == mlp_keys:
scorer = PUMScorer(backbone, dropout_p=dropout_p, use_mlp_head=True)
elif state_keys == linear_keys:
scorer = PUMScorer(backbone, dropout_p=dropout_p, use_mlp_head=False)
else:
raise ValueError(f"Unrecognized value head state keys: {state_keys}")
scorer.value_head.load_state_dict(value_head_state)
return scorer
def load_pum(repo_id: str = REPO_ID):
ckpt_dir = snapshot_download(repo_id)
model_args = load_json(os.path.join(ckpt_dir, "model_args.json"))
model_name_or_path = model_args["model_name_or_path"]
trust_remote_code = model_args.get("trust_remote_code", True)
dropout_p = model_args.get("dropout", 0.1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer_source = (
ckpt_dir
if os.path.exists(os.path.join(ckpt_dir, "tokenizer_config.json"))
else model_name_or_path
)
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_source,
trust_remote_code=trust_remote_code,
use_fast=False,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
base_model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
trust_remote_code=trust_remote_code,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
backbone = PeftModel.from_pretrained(
base_model,
os.path.join(ckpt_dir, "backbone"),
)
state = torch.load(os.path.join(ckpt_dir, "value_head.pt"), map_location="cpu")
value_head_state = state["value_head"] if isinstance(state, dict) and "value_head" in state else state
scorer = build_value_head(
backbone=backbone,
dropout_p=dropout_p,
value_head_state=value_head_state,
)
scorer.to(device=device, dtype=dtype)
scorer.eval()
return scorer, tokenizer, device, ckpt_dir
@torch.no_grad()
def score_prefixes(
scorer: PUMScorer,
tokenizer,
device: torch.device,
problem: str,
prefixes: List[str],
max_length: int = 4096,
):
texts = [build_text(problem, prefix) for prefix in prefixes]
batch = tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
batch = {k: v.to(device) for k, v in batch.items()}
scores = scorer.score(batch["input_ids"], batch["attention_mask"])
return scores.detach().float().cpu().tolist()
def sigmoid(x: float) -> float:
return 1.0 / (1.0 + math.exp(-x))
def main():
scorer, tokenizer, device, ckpt_dir = load_pum(REPO_ID)
problem = "If $\\det \\mathbf{A} = 4$ and $\\det \\mathbf{B} = -5,$ then find $\\det(\\mathbf{A}\\mathbf{B}).$"
prefixes = [
"Recall that the determinant of a product satisfies $\\det(AB)=\\det(A)\\det(B)$. Substituting the values gives $4\\cdot(-5)$.",
"Recall that the determinant of a product satisfies $\\det(AB)=\\det(A)+\\det(B)$. Substituting the values gives $4+(-5)$."
]
scores = score_prefixes(
scorer=scorer,
tokenizer=tokenizer,
device=device,
problem=problem,
prefixes=prefixes,
max_length=4096,
)
print("Downloaded checkpoint:", ckpt_dir)
print("\nPrefix scores:")
for i, score in enumerate(scores):
print(f"prefix[{i}] score = {score:.6f}")
logit_0_minus_1 = scores[0] - scores[1]
prob_0_beats_1 = sigmoid(logit_0_minus_1)
print("\nPairwise comparison:")
print(f"logit(prefix[0] - prefix[1]) = {logit_0_minus_1:.6f}")
print(f"P(prefix[0] > prefix[1] | non-tie) = {prob_0_beats_1:.4f}")
winner = 0 if scores[0] > scores[1] else 1
print(f"Predicted better prefix: prefix[{winner}]")
if __name__ == "__main__":
main()
Run:
python run_pum.py
Input Template
For each candidate prefix, PUM uses the following template:
<system>
Please reason step by step, and put your final answer within \boxed{}.
<problem>
{problem}
<prefix>
{prefix}
Score:
The last-token hidden state is pooled and passed into the scalar value head.
Output Meaning
The model outputs a scalar score for each prefix.
- Higher score means higher estimated prefix utility.
- For two prefixes A and B, the pairwise logit is:
score(A) - score(B)
A non-tie pairwise preference probability can be computed as:
sigmoid(score(A) - score(B))
The scalar scores are mainly meaningful when comparing prefixes under the same problem and the same input template.
Intended Use
PUM is intended for research on:
- prefix-level reasoning evaluation;
- process supervision for mathematical reasoning;
- preference learning;
- reward modeling;
- best-of-N selection;
- search-time guidance;
- reinforcement learning with prefix-level utility signals.
Limitations
PUM is trained for mathematical reasoning prefix evaluation. It should not be treated as a general-purpose verifier or a fully calibrated correctness model.
The model estimates prefix utility rather than local step correctness. A prefix may look locally plausible but still receive a low score if it is estimated to reduce downstream solving success.
Citation
If you find this model useful, please cite:
@article{zhou2026from,
title = {From Correctness to Utility: Gain-Based Prefix Evaluation for LLM Reasoning},
author = {Yuhang Zhou and Yixin Cao and Guangnan Ye},
journal = {arXiv preprint arXiv:2606.07190},
year = {2026}
}
- Downloads last month
- -
Model tree for zhiqix/PUM
Base model
Qwen/Qwen3-4B-Instruct-2507