PUM: Prefix Utility Model

This repository contains the released checkpoint of PUM, a Prefix Utility Model for gain-based evaluation of LLM reasoning prefixes.

PUM scores partial reasoning prefixes according to their estimated utility for improving future problem-solving success. Given a math problem and one or more candidate reasoning prefixes, the model assigns a scalar score to each prefix. Higher scores indicate higher estimated prefix utility under the current scorer.

This model is associated with the paper:

From Correctness to Utility: Gain-Based Prefix Evaluation for LLM Reasoning

Model Structure

This repository is not a standard text-generation model. It is a custom scorer composed of:

  1. Base model: Qwen/Qwen3-4B-Instruct-2507
  2. LoRA adapter: backbone/
  3. Scalar MLP value head: value_head.pt

Expected repository structure:

.
├── backbone/
│   ├── adapter_config.json
│   └── adapter_model.safetensors
├── value_head.pt
├── model_args.json
├── train_args.json
├── tokenizer_config.json
├── special_tokens_map.json
├── added_tokens.json
├── chat_template.jinja
├── vocab.json
├── merges.txt
└── README.md

Installation

pip install torch transformers peft huggingface_hub accelerate safetensors

Minimal Usage

Save the following code as run_pum.py and run it directly.

import os
import json
import math
from typing import Dict, Any, List

import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel


REPO_ID = "zhiqix/PUM"
SYSTEM_PROMPT = "Please reason step by step, and put your final answer within \\boxed{}."


def get_hidden_size(config):
    for k in ["hidden_size", "n_embd", "d_model", "dim"]:
        if hasattr(config, k):
            return getattr(config, k)
    raise ValueError("Cannot infer hidden size from model config.")


def build_text(problem: str, prefix: str) -> str:
    problem = "" if problem is None else str(problem)
    prefix = "" if prefix is None else str(prefix)
    return (
        f"<system>\n{SYSTEM_PROMPT}\n\n"
        f"<problem>\n{problem}\n\n"
        f"<prefix>\n{prefix}\n\n"
        f"Score:"
    )


class PUMScorer(nn.Module):
    def __init__(self, backbone: nn.Module, dropout_p: float = 0.1, use_mlp_head: bool = True):
        super().__init__()
        self.backbone = backbone
        if hasattr(self.backbone.config, "use_cache"):
            self.backbone.config.use_cache = False

        hidden_size = get_hidden_size(self.backbone.config)
        self.dropout = nn.Dropout(dropout_p)

        if use_mlp_head:
            self.value_head = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )
        else:
            self.value_head = nn.Linear(hidden_size, 1)

    def _last_non_pad_index(self, attention_mask: torch.Tensor) -> torch.Tensor:
        flipped = attention_mask.flip(dims=[1])
        last_from_end = flipped.long().argmax(dim=1)
        return attention_mask.size(1) - 1 - last_from_end

    @torch.no_grad()
    def score(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True,
            use_cache=False,
        )
        hidden = outputs.hidden_states[-1]
        idx = self._last_non_pad_index(attention_mask)
        pooled = hidden[torch.arange(hidden.size(0), device=hidden.device), idx]
        pooled = self.dropout(pooled)
        return self.value_head(pooled).squeeze(-1)


def load_json(path: str) -> Dict[str, Any]:
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def build_value_head(backbone: nn.Module, dropout_p: float, value_head_state: Dict[str, torch.Tensor]):
    state_keys = set(value_head_state.keys())
    mlp_keys = {"0.weight", "0.bias", "2.weight", "2.bias"}
    linear_keys = {"weight", "bias"}

    if state_keys == mlp_keys:
        scorer = PUMScorer(backbone, dropout_p=dropout_p, use_mlp_head=True)
    elif state_keys == linear_keys:
        scorer = PUMScorer(backbone, dropout_p=dropout_p, use_mlp_head=False)
    else:
        raise ValueError(f"Unrecognized value head state keys: {state_keys}")

    scorer.value_head.load_state_dict(value_head_state)
    return scorer


def load_pum(repo_id: str = REPO_ID):
    ckpt_dir = snapshot_download(repo_id)

    model_args = load_json(os.path.join(ckpt_dir, "model_args.json"))
    model_name_or_path = model_args["model_name_or_path"]
    trust_remote_code = model_args.get("trust_remote_code", True)
    dropout_p = model_args.get("dropout", 0.1)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32

    tokenizer_source = (
        ckpt_dir
        if os.path.exists(os.path.join(ckpt_dir, "tokenizer_config.json"))
        else model_name_or_path
    )

    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_source,
        trust_remote_code=trust_remote_code,
        use_fast=False,
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    base_model = AutoModelForCausalLM.from_pretrained(
        model_name_or_path,
        trust_remote_code=trust_remote_code,
        torch_dtype=dtype,
        low_cpu_mem_usage=True,
    )

    backbone = PeftModel.from_pretrained(
        base_model,
        os.path.join(ckpt_dir, "backbone"),
    )

    state = torch.load(os.path.join(ckpt_dir, "value_head.pt"), map_location="cpu")
    value_head_state = state["value_head"] if isinstance(state, dict) and "value_head" in state else state

    scorer = build_value_head(
        backbone=backbone,
        dropout_p=dropout_p,
        value_head_state=value_head_state,
    )

    scorer.to(device=device, dtype=dtype)
    scorer.eval()

    return scorer, tokenizer, device, ckpt_dir


@torch.no_grad()
def score_prefixes(
    scorer: PUMScorer,
    tokenizer,
    device: torch.device,
    problem: str,
    prefixes: List[str],
    max_length: int = 4096,
):
    texts = [build_text(problem, prefix) for prefix in prefixes]
    batch = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    batch = {k: v.to(device) for k, v in batch.items()}

    scores = scorer.score(batch["input_ids"], batch["attention_mask"])
    return scores.detach().float().cpu().tolist()


def sigmoid(x: float) -> float:
    return 1.0 / (1.0 + math.exp(-x))


def main():
    scorer, tokenizer, device, ckpt_dir = load_pum(REPO_ID)

    problem = "If $\\det \\mathbf{A} = 4$ and $\\det \\mathbf{B} = -5,$ then find $\\det(\\mathbf{A}\\mathbf{B}).$"

    prefixes = [
        "Recall that the determinant of a product satisfies $\\det(AB)=\\det(A)\\det(B)$. Substituting the values gives $4\\cdot(-5)$.",
        "Recall that the determinant of a product satisfies $\\det(AB)=\\det(A)+\\det(B)$. Substituting the values gives $4+(-5)$."
    ]

    scores = score_prefixes(
        scorer=scorer,
        tokenizer=tokenizer,
        device=device,
        problem=problem,
        prefixes=prefixes,
        max_length=4096,
    )

    print("Downloaded checkpoint:", ckpt_dir)
    print("\nPrefix scores:")
    for i, score in enumerate(scores):
        print(f"prefix[{i}] score = {score:.6f}")

    logit_0_minus_1 = scores[0] - scores[1]
    prob_0_beats_1 = sigmoid(logit_0_minus_1)

    print("\nPairwise comparison:")
    print(f"logit(prefix[0] - prefix[1]) = {logit_0_minus_1:.6f}")
    print(f"P(prefix[0] > prefix[1] | non-tie) = {prob_0_beats_1:.4f}")

    winner = 0 if scores[0] > scores[1] else 1
    print(f"Predicted better prefix: prefix[{winner}]")


if __name__ == "__main__":
    main()

Run:

python run_pum.py

Input Template

For each candidate prefix, PUM uses the following template:

<system>
Please reason step by step, and put your final answer within \boxed{}.

<problem>
{problem}

<prefix>
{prefix}

Score:

The last-token hidden state is pooled and passed into the scalar value head.

Output Meaning

The model outputs a scalar score for each prefix.

  • Higher score means higher estimated prefix utility.
  • For two prefixes A and B, the pairwise logit is:
score(A) - score(B)

A non-tie pairwise preference probability can be computed as:

sigmoid(score(A) - score(B))

The scalar scores are mainly meaningful when comparing prefixes under the same problem and the same input template.

Intended Use

PUM is intended for research on:

  • prefix-level reasoning evaluation;
  • process supervision for mathematical reasoning;
  • preference learning;
  • reward modeling;
  • best-of-N selection;
  • search-time guidance;
  • reinforcement learning with prefix-level utility signals.

Limitations

PUM is trained for mathematical reasoning prefix evaluation. It should not be treated as a general-purpose verifier or a fully calibrated correctness model.

The model estimates prefix utility rather than local step correctness. A prefix may look locally plausible but still receive a low score if it is estimated to reduce downstream solving success.

Citation

If you find this model useful, please cite:

@article{zhou2026from,
  title = {From Correctness to Utility: Gain-Based Prefix Evaluation for LLM Reasoning},
  author = {Yuhang Zhou and Yixin Cao and Guangnan Ye},
  journal = {arXiv preprint arXiv:2606.07190},
  year = {2026}
}
Downloads last month
-
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for zhiqix/PUM

Adapter
(5498)
this model

Paper for zhiqix/PUM