AttrLLM / loader /perplexity.py
Qingpeng Kong
clean initial state
3e72399
import math
import os
import logging
from typing import Callable, List, Dict, Any, Optional
import numpy as np
import torch
import torch.nn.functional as F
from .models import BaseModel
logger = logging.getLogger(__name__)
# Signature: (question, context, answer_tag, reference_text) -> full_prompt_str
PromptFormatter = Callable[[str, str, str, str], str]
def default_prompt_formatter(
question: str,
context: str,
answer_tag: str,
reference_text: str,
) -> str:
"""
Basic prompt layout used by default.
You can override this via the `prompt_formatter` argument if a
specific model needs a different template (e.g., chat template).
"""
return f"{question}\nContext: {context}\n{answer_tag} {reference_text}"
def _join_prefix_continuation(prefix: str, continuation: str) -> str:
"""Join prefix and continuation with a single space when needed."""
if not prefix:
return continuation
if not continuation:
return prefix
if prefix[-1].isspace() or continuation[0].isspace():
return prefix + continuation
return prefix + " " + continuation
def _continuation_prompt_formatter(
_question: str,
context: str,
_answer_tag: str,
reference_text: str,
) -> str:
"""Prompt formatter that concatenates context and continuation text."""
return _join_prefix_continuation(context, reference_text)
def score_continuation(
model: BaseModel,
prefix: str,
continuation: str,
*,
max_new_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""
Compute teacher-forced logprob and perplexity of `continuation` given `prefix`.
Uses vLLM prompt logprobs when available.
"""
if not continuation:
return {
"avg_logprob": float("-inf"),
"perplexity": float("inf"),
"per_token": [],
"target_len": 0,
"sequence_logprobs": [],
}
if hasattr(model, "llm"):
result = score_reference_autoregressive_vllm(
model,
question="",
masked_inputs=[prefix],
reference_text=continuation,
answer_tag="",
prompt_formatter=_continuation_prompt_formatter,
max_new_tokens=max_new_tokens,
)
token_logprobs = result.get("token_logprobs", [])
per_token = token_logprobs[0] if token_logprobs else []
target_len = int(result.get("target_len", 0) or 0)
seq_logprobs = result.get("sequence_logprobs", [])
if seq_logprobs:
total_logprob = float(seq_logprobs[0])
else:
avg_lp = float(result.get("avg_logprob", float("-inf")))
total_logprob = avg_lp * target_len if target_len else float("-inf")
if target_len > 0 and math.isfinite(total_logprob):
avg_nll = -total_logprob / target_len
ppl = math.exp(avg_nll)
else:
avg_nll = float("inf")
ppl = float("inf")
return {
"avg_logprob": result.get("avg_logprob", float("-inf")),
"perplexity": result.get("perplexity", float("inf")),
"total_logprob": total_logprob,
"avg_nll": avg_nll,
"per_token": per_token,
"target_len": target_len,
"sequence_logprobs": seq_logprobs,
}
raise RuntimeError(
"score_continuation requires a VLLMModel with .llm "
"(use loader.get_model_vllm inside the vLLM container)."
)
# --- Real (teacher-forced) scorer using Hugging Face generate() ---
def score_reference_autoregressive_hf(
model, # HF AutoModelForCausalLM (on CUDA or CPU)
tokenizer, # matching HF tokenizer
question: str, # e.g., "Count the number of r's in strawberry."
masked_inputs: List[str], # batch of masked contexts (strings)
reference_token_ids: List[int], # tokenized reference answer ids
*,
answer_tag: str = "Answer:",
prompt_formatter: PromptFormatter = default_prompt_formatter,
) -> Dict[str, Any]:
"""
Teacher-forced next-token log-probs of the reference answer across a batch of masked contexts.
Returns:
{
"avg_logprob": float, # mean over tokens, then mean over batch
"perplexity": float, # exp(-avg_logprob)
"target_len": int, # number of next-token steps (len(ref_ids)-1)
"sequence_logprobs": List[float], # total logprob per masked input (len= batch)
"token_logprobs": List[List[float]] # per-token logprobs per masked input
}
"""
# --- ensure tokenizer/model have a pad token ---
if tokenizer.pad_token is None:
if tokenizer.eos_token is not None:
tokenizer.pad_token = tokenizer.eos_token
else:
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
# only call if model is an HF CausalLM with embeddings
try:
model.resize_token_embeddings(len(tokenizer))
except Exception:
pass
tokenizer.padding_side = "left"
try:
model.config.pad_token_id = tokenizer.pad_token_id
except Exception:
pass
B = len(masked_inputs)
# We will accumulate total log-prob of the reference per masked input
batch_seq_logprobs = np.zeros(B, dtype=np.float64)
T = max(0, len(reference_token_ids) - 1) # number of next-token predictions
if T == 0:
return {
"avg_logprob": float("-inf"),
"perplexity": float("inf"),
"target_len": 0,
"sequence_logprobs": batch_seq_logprobs.tolist(),
}
for j in range(T):
# Reference prefix up to and including token j
prefix_ids = reference_token_ids[: j + 1]
# Decode the entire prefix to keep spacing/special-tokens consistent
prefix_str = tokenizer.decode(prefix_ids, skip_special_tokens=False)
# Build prompts for this step
prompts = [
prompt_formatter(question, ctx, answer_tag, prefix_str)
for ctx in masked_inputs
]
inputs = tokenizer(
prompts, return_tensors="pt", padding=True, truncation=True
).to(model.device)
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=1, # predict exactly one next token
output_scores=True,
return_dict_in_generate=True,
pad_token_id=tokenizer.eos_token_id,
)
# out["scores"] is a list (len = #new tokens = 1) of logits [batch, vocab]
logits = torch.stack(out["scores"]).swapaxes(0, 1)[:, 0, :] # (B, V)
logprobs = F.log_softmax(logits, dim=-1)
target_next_id = reference_token_ids[j + 1]
step_lp = logprobs[:, target_next_id].detach().cpu().numpy() # (B,)
batch_seq_logprobs += step_lp
# clean up to keep memory low
del inputs, out, logits, logprobs
# Average over tokens, then average across the batch
avg_lp = float((batch_seq_logprobs / T).mean())
ppl = math.exp(-avg_lp)
return {
"avg_logprob": avg_lp,
"perplexity": ppl,
"target_len": T,
"sequence_logprobs": batch_seq_logprobs.tolist(),
}
def score_reference_autoregressive_vllm(
vllm_model, # your loader.VLLMModel instance
question: str,
masked_inputs: List[str],
reference_text: str, # e.g., " The answer is 3."
*,
answer_tag: str = "Answer:",
prompt_formatter: PromptFormatter = default_prompt_formatter,
top_k: int = 5,
debug: bool = False,
debug_index: int = 0,
max_new_tokens: Optional[int] = None,
) -> Dict[str, Any]:
"""
vLLM-native teacher-forced perplexity over a *given* reference answer.
Batching behavior:
- Build one full prompt per masked context.
- Send ALL prompts to vLLM in a single `llm.generate` call.
- vLLM returns prompt-level logprobs for each token position.
- For each prompt, slice out the part corresponding to the answer
and sum the logprobs of those answer tokens.
If debug=True, we print detailed info for the example at `debug_index`:
- Full prompt string
- Base vs full token lengths
- Each answer token + its logprob and running perplexity.
"""
# 1) Get underlying vLLM engine & tokenizer
assert hasattr(vllm_model, "llm"), f"Expected VLLMModel wrapper, got {type(vllm_model)}"
llm = vllm_model.llm
try:
tokenizer = llm.get_tokenizer()
except Exception as e:
raise RuntimeError("Could not access vLLM tokenizer; check vLLM version / API") from e
B = len(masked_inputs)
if B == 0:
return {
"avg_logprob": float("-inf"),
"perplexity": float("inf"),
"target_len": 0,
"sequence_logprobs": [],
"token_logprobs": [],
}
# 2) Build all prompts and tokenizations in a batch
full_prompts: List[str] = []
base_ids_list: List[List[int]] = []
full_ids_list: List[List[int]] = []
ref_ids_list: List[List[int]] = []
ref_token_count: Optional[int] = None
for ctx in masked_inputs:
# base prompt: no answer text
base_prompt = prompt_formatter(question, ctx, answer_tag, "")
# full prompt: includes gold answer text
full_prompt = prompt_formatter(question, ctx, answer_tag, reference_text)
base_ids = tokenizer(base_prompt, add_special_tokens=False)["input_ids"]
full_ids = tokenizer(full_prompt, add_special_tokens=False)["input_ids"]
# Safety check
if not full_ids or len(full_ids) <= len(base_ids):
base_ids_list.append(base_ids)
full_ids_list.append(full_ids)
ref_ids_list.append([])
full_prompts.append(full_prompt)
continue
ref_ids = full_ids[len(base_ids):] # tokens of the answer region
if ref_token_count is None:
ref_token_count = len(ref_ids)
else:
# make all answers share a common length (minimum across examples)
ref_token_count = min(ref_token_count, len(ref_ids))
base_ids_list.append(base_ids)
full_ids_list.append(full_ids)
ref_ids_list.append(ref_ids)
full_prompts.append(full_prompt)
# If no valid ref tokens at all, bail out
if ref_token_count is None or ref_token_count == 0:
return {
"avg_logprob": float("-inf"),
"perplexity": float("inf"),
"target_len": 0,
"sequence_logprobs": [float("-inf")] * B,
"token_logprobs": [[] for _ in range(B)],
}
# Truncate all ref_ids to the common length
ref_len = ref_token_count
ref_ids_list = [ref_ids[:ref_len] for ref_ids in ref_ids_list]
# 3) Call vLLM *once* on all full prompts (batched)
from vllm import SamplingParams
sp = SamplingParams(
max_tokens=max_new_tokens or 1, # keep tiny generation; default minimal
temperature=0.0,
top_p=1.0,
logprobs=top_k, # store top-k logprobs per position
prompt_logprobs=1, # request logprobs for prompt tokens
)
results = llm.generate(full_prompts, sp)
# 4) For each prompt, sum the logprobs over the answer segment
per_seq_logprobs: List[float] = []
per_seq_token_logprobs: List[List[float]] = []
def to_float(v):
# v might be a Logprob object or a raw float
return float(getattr(v, "logprob", v))
for idx_ex, (base_ids, full_ids, ref_ids) in enumerate(
zip(base_ids_list, full_ids_list, ref_ids_list)
):
if not ref_ids or not full_ids:
per_seq_logprobs.append(float("-inf"))
per_seq_token_logprobs.append([])
continue
req_out = results[idx_ex]
prompt_logprobs = getattr(req_out, "prompt_logprobs", None)
if prompt_logprobs is None:
raise RuntimeError(
"req_out.prompt_logprobs is None. Adjust to match your vLLM version."
)
# Align lengths: vLLM may add BOS token
if len(prompt_logprobs) == len(full_ids) + 1:
prompt_logprobs = prompt_logprobs[1:]
elif len(prompt_logprobs) != len(full_ids):
L = min(len(prompt_logprobs), len(full_ids))
prompt_logprobs = prompt_logprobs[:L]
full_ids = full_ids[:L]
seq_lp = 0.0
token_logprobs: List[float] = []
# Optional debug header
if debug and idx_ex == debug_index:
print("\n=== DEBUG vLLM PPL example", idx_ex, "===")
print("Full prompt:\n", full_prompts[idx_ex])
print("\nBase token length:", len(base_ids))
print("Full token length:", len(full_ids))
print("Answer token length (ref_len):", ref_len)
print("\nAnswer tokens and logprobs:")
for offset, token_id in enumerate(ref_ids[:ref_len]):
pos = len(base_ids) + offset # index in full sequence
if pos >= len(prompt_logprobs):
token_lp = -20.0
seq_lp += token_lp
token_logprobs.append(token_lp)
if debug and idx_ex == debug_index:
tok_str = tokenizer.convert_ids_to_tokens([token_id])[0]
print(f" pos={pos:<3} token={tok_str!r:>10} logprob={token_lp: .4f} (OUT OF RANGE)")
continue
cand_dict = prompt_logprobs[pos] or {}
if token_id in cand_dict:
token_lp = to_float(cand_dict[token_id])
else:
if cand_dict:
floor = min(to_float(v) for v in cand_dict.values())
token_lp = floor - 5.0
else:
token_lp = -20.0
seq_lp += token_lp
token_logprobs.append(token_lp)
if debug and idx_ex == debug_index:
tok_str = tokenizer.convert_ids_to_tokens([token_id])[0]
avg_lp_so_far = seq_lp / (offset + 1)
ppl_so_far = math.exp(-avg_lp_so_far)
print(
f" pos={pos:<3} token={tok_str!r:>10} "
f"logprob={token_lp: .4f} "
f"cum_avg_lp={avg_lp_so_far: .4f} "
f"cum_ppl={ppl_so_far: .4f}"
)
per_seq_logprobs.append(seq_lp)
per_seq_token_logprobs.append(token_logprobs)
# 5) Aggregate across examples
arr = np.array(per_seq_logprobs, dtype=np.float64)
T = int(ref_len)
if T == 0:
return {
"avg_logprob": float("-inf"),
"perplexity": float("inf"),
"target_len": 0,
"sequence_logprobs": per_seq_logprobs,
"token_logprobs": per_seq_token_logprobs,
}
avg_lp = float((arr / T).mean())
ppl = math.exp(-avg_lp)
return {
"avg_logprob": avg_lp,
"perplexity": ppl,
"target_len": T,
"sequence_logprobs": per_seq_logprobs,
"token_logprobs": per_seq_token_logprobs,
}
def score_continuation_batch(
model: BaseModel,
prefixes: List[str],
continuation: str,
*,
max_new_tokens: Optional[int] = None,
batch_size: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Batched variant of score_continuation. Returns list aligned to prefixes.
"""
if not prefixes:
return []
B = batch_size or int(os.getenv("VLLM_SCORE_BATCH_SIZE",
os.getenv("ATTRLLM_VLLM_BATCH_SIZE", "128")))
B = max(1, int(B))
outputs: List[Optional[Dict[str, Any]]] = [None] * len(prefixes)
def _failed_output() -> Dict[str, Any]:
return {
"avg_logprob": float("-inf"),
"perplexity": float("inf"),
"total_logprob": float("-inf"),
"avg_nll": float("inf"),
"per_token": [],
"target_len": 0,
"sequence_logprobs": [float("-inf")],
}
def _score_chunk(chunk_prefixes: List[str], offset: int, chunk_batch_size: int) -> None:
if not chunk_prefixes:
return
try:
res = score_reference_autoregressive_vllm(
model,
question="",
masked_inputs=chunk_prefixes,
reference_text=continuation,
answer_tag="",
prompt_formatter=_continuation_prompt_formatter,
max_new_tokens=max_new_tokens,
)
seq_lp = res.get("sequence_logprobs", [])
per_token = res.get("token_logprobs", [])
target_len = int(res.get("target_len", 0) or 0)
if len(seq_lp) != len(chunk_prefixes):
raise RuntimeError(
"vLLM returned mismatched sequence_logprobs length: "
f"expected={len(chunk_prefixes)} got={len(seq_lp)}"
)
for i, lp in enumerate(seq_lp):
idx = offset + i
total_logprob = float(lp)
if target_len > 0 and math.isfinite(total_logprob):
avg_nll = -total_logprob / target_len
ppl = math.exp(avg_nll)
else:
avg_nll = float("inf")
ppl = float("inf")
outputs[idx] = {
"avg_logprob": res.get("avg_logprob", float("-inf")),
"perplexity": res.get("perplexity", ppl),
"total_logprob": total_logprob,
"avg_nll": avg_nll,
"per_token": per_token[i] if i < len(per_token) else [],
"target_len": target_len,
"sequence_logprobs": [total_logprob],
}
except Exception as exc:
# vLLM occasionally asserts in large-batch prompt_logprob scoring.
# Back off to smaller chunks instead of failing the whole request.
n = len(chunk_prefixes)
if n == 1:
logger.warning(
"score_continuation_batch failed for single prefix; returning -inf fallback. error=%s",
exc,
)
outputs[offset] = _failed_output()
return
next_batch_size = max(1, min(chunk_batch_size // 2, n // 2))
logger.warning(
"[BACKOFF_ACTIVE] score_continuation_batch chunk failed; retrying smaller chunks. "
"chunk_size=%d batch_size=%d next_batch_size=%d error=%s",
n,
chunk_batch_size,
next_batch_size,
exc,
)
for sub_start in range(0, n, next_batch_size):
sub_chunk = chunk_prefixes[sub_start : sub_start + next_batch_size]
_score_chunk(sub_chunk, offset + sub_start, next_batch_size)
for start in range(0, len(prefixes), B):
chunk = prefixes[start : start + B]
_score_chunk(chunk, start, B)
# Guarantee alignment even in worst-case failures.
for i, item in enumerate(outputs):
if item is None:
outputs[i] = _failed_output()
return outputs # type: ignore[return-value]