import math import os import logging from typing import Callable, List, Dict, Any, Optional import numpy as np import torch import torch.nn.functional as F from .models import BaseModel logger = logging.getLogger(__name__) # Signature: (question, context, answer_tag, reference_text) -> full_prompt_str PromptFormatter = Callable[[str, str, str, str], str] def default_prompt_formatter( question: str, context: str, answer_tag: str, reference_text: str, ) -> str: """ Basic prompt layout used by default. You can override this via the `prompt_formatter` argument if a specific model needs a different template (e.g., chat template). """ return f"{question}\nContext: {context}\n{answer_tag} {reference_text}" def _join_prefix_continuation(prefix: str, continuation: str) -> str: """Join prefix and continuation with a single space when needed.""" if not prefix: return continuation if not continuation: return prefix if prefix[-1].isspace() or continuation[0].isspace(): return prefix + continuation return prefix + " " + continuation def _continuation_prompt_formatter( _question: str, context: str, _answer_tag: str, reference_text: str, ) -> str: """Prompt formatter that concatenates context and continuation text.""" return _join_prefix_continuation(context, reference_text) def score_continuation( model: BaseModel, prefix: str, continuation: str, *, max_new_tokens: Optional[int] = None, ) -> Dict[str, Any]: """ Compute teacher-forced logprob and perplexity of `continuation` given `prefix`. Uses vLLM prompt logprobs when available. """ if not continuation: return { "avg_logprob": float("-inf"), "perplexity": float("inf"), "per_token": [], "target_len": 0, "sequence_logprobs": [], } if hasattr(model, "llm"): result = score_reference_autoregressive_vllm( model, question="", masked_inputs=[prefix], reference_text=continuation, answer_tag="", prompt_formatter=_continuation_prompt_formatter, max_new_tokens=max_new_tokens, ) token_logprobs = result.get("token_logprobs", []) per_token = token_logprobs[0] if token_logprobs else [] target_len = int(result.get("target_len", 0) or 0) seq_logprobs = result.get("sequence_logprobs", []) if seq_logprobs: total_logprob = float(seq_logprobs[0]) else: avg_lp = float(result.get("avg_logprob", float("-inf"))) total_logprob = avg_lp * target_len if target_len else float("-inf") if target_len > 0 and math.isfinite(total_logprob): avg_nll = -total_logprob / target_len ppl = math.exp(avg_nll) else: avg_nll = float("inf") ppl = float("inf") return { "avg_logprob": result.get("avg_logprob", float("-inf")), "perplexity": result.get("perplexity", float("inf")), "total_logprob": total_logprob, "avg_nll": avg_nll, "per_token": per_token, "target_len": target_len, "sequence_logprobs": seq_logprobs, } raise RuntimeError( "score_continuation requires a VLLMModel with .llm " "(use loader.get_model_vllm inside the vLLM container)." ) # --- Real (teacher-forced) scorer using Hugging Face generate() --- def score_reference_autoregressive_hf( model, # HF AutoModelForCausalLM (on CUDA or CPU) tokenizer, # matching HF tokenizer question: str, # e.g., "Count the number of r's in strawberry." masked_inputs: List[str], # batch of masked contexts (strings) reference_token_ids: List[int], # tokenized reference answer ids *, answer_tag: str = "Answer:", prompt_formatter: PromptFormatter = default_prompt_formatter, ) -> Dict[str, Any]: """ Teacher-forced next-token log-probs of the reference answer across a batch of masked contexts. Returns: { "avg_logprob": float, # mean over tokens, then mean over batch "perplexity": float, # exp(-avg_logprob) "target_len": int, # number of next-token steps (len(ref_ids)-1) "sequence_logprobs": List[float], # total logprob per masked input (len= batch) "token_logprobs": List[List[float]] # per-token logprobs per masked input } """ # --- ensure tokenizer/model have a pad token --- if tokenizer.pad_token is None: if tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token else: tokenizer.add_special_tokens({"pad_token": "[PAD]"}) # only call if model is an HF CausalLM with embeddings try: model.resize_token_embeddings(len(tokenizer)) except Exception: pass tokenizer.padding_side = "left" try: model.config.pad_token_id = tokenizer.pad_token_id except Exception: pass B = len(masked_inputs) # We will accumulate total log-prob of the reference per masked input batch_seq_logprobs = np.zeros(B, dtype=np.float64) T = max(0, len(reference_token_ids) - 1) # number of next-token predictions if T == 0: return { "avg_logprob": float("-inf"), "perplexity": float("inf"), "target_len": 0, "sequence_logprobs": batch_seq_logprobs.tolist(), } for j in range(T): # Reference prefix up to and including token j prefix_ids = reference_token_ids[: j + 1] # Decode the entire prefix to keep spacing/special-tokens consistent prefix_str = tokenizer.decode(prefix_ids, skip_special_tokens=False) # Build prompts for this step prompts = [ prompt_formatter(question, ctx, answer_tag, prefix_str) for ctx in masked_inputs ] inputs = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True ).to(model.device) with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=1, # predict exactly one next token output_scores=True, return_dict_in_generate=True, pad_token_id=tokenizer.eos_token_id, ) # out["scores"] is a list (len = #new tokens = 1) of logits [batch, vocab] logits = torch.stack(out["scores"]).swapaxes(0, 1)[:, 0, :] # (B, V) logprobs = F.log_softmax(logits, dim=-1) target_next_id = reference_token_ids[j + 1] step_lp = logprobs[:, target_next_id].detach().cpu().numpy() # (B,) batch_seq_logprobs += step_lp # clean up to keep memory low del inputs, out, logits, logprobs # Average over tokens, then average across the batch avg_lp = float((batch_seq_logprobs / T).mean()) ppl = math.exp(-avg_lp) return { "avg_logprob": avg_lp, "perplexity": ppl, "target_len": T, "sequence_logprobs": batch_seq_logprobs.tolist(), } def score_reference_autoregressive_vllm( vllm_model, # your loader.VLLMModel instance question: str, masked_inputs: List[str], reference_text: str, # e.g., " The answer is 3." *, answer_tag: str = "Answer:", prompt_formatter: PromptFormatter = default_prompt_formatter, top_k: int = 5, debug: bool = False, debug_index: int = 0, max_new_tokens: Optional[int] = None, ) -> Dict[str, Any]: """ vLLM-native teacher-forced perplexity over a *given* reference answer. Batching behavior: - Build one full prompt per masked context. - Send ALL prompts to vLLM in a single `llm.generate` call. - vLLM returns prompt-level logprobs for each token position. - For each prompt, slice out the part corresponding to the answer and sum the logprobs of those answer tokens. If debug=True, we print detailed info for the example at `debug_index`: - Full prompt string - Base vs full token lengths - Each answer token + its logprob and running perplexity. """ # 1) Get underlying vLLM engine & tokenizer assert hasattr(vllm_model, "llm"), f"Expected VLLMModel wrapper, got {type(vllm_model)}" llm = vllm_model.llm try: tokenizer = llm.get_tokenizer() except Exception as e: raise RuntimeError("Could not access vLLM tokenizer; check vLLM version / API") from e B = len(masked_inputs) if B == 0: return { "avg_logprob": float("-inf"), "perplexity": float("inf"), "target_len": 0, "sequence_logprobs": [], "token_logprobs": [], } # 2) Build all prompts and tokenizations in a batch full_prompts: List[str] = [] base_ids_list: List[List[int]] = [] full_ids_list: List[List[int]] = [] ref_ids_list: List[List[int]] = [] ref_token_count: Optional[int] = None for ctx in masked_inputs: # base prompt: no answer text base_prompt = prompt_formatter(question, ctx, answer_tag, "") # full prompt: includes gold answer text full_prompt = prompt_formatter(question, ctx, answer_tag, reference_text) base_ids = tokenizer(base_prompt, add_special_tokens=False)["input_ids"] full_ids = tokenizer(full_prompt, add_special_tokens=False)["input_ids"] # Safety check if not full_ids or len(full_ids) <= len(base_ids): base_ids_list.append(base_ids) full_ids_list.append(full_ids) ref_ids_list.append([]) full_prompts.append(full_prompt) continue ref_ids = full_ids[len(base_ids):] # tokens of the answer region if ref_token_count is None: ref_token_count = len(ref_ids) else: # make all answers share a common length (minimum across examples) ref_token_count = min(ref_token_count, len(ref_ids)) base_ids_list.append(base_ids) full_ids_list.append(full_ids) ref_ids_list.append(ref_ids) full_prompts.append(full_prompt) # If no valid ref tokens at all, bail out if ref_token_count is None or ref_token_count == 0: return { "avg_logprob": float("-inf"), "perplexity": float("inf"), "target_len": 0, "sequence_logprobs": [float("-inf")] * B, "token_logprobs": [[] for _ in range(B)], } # Truncate all ref_ids to the common length ref_len = ref_token_count ref_ids_list = [ref_ids[:ref_len] for ref_ids in ref_ids_list] # 3) Call vLLM *once* on all full prompts (batched) from vllm import SamplingParams sp = SamplingParams( max_tokens=max_new_tokens or 1, # keep tiny generation; default minimal temperature=0.0, top_p=1.0, logprobs=top_k, # store top-k logprobs per position prompt_logprobs=1, # request logprobs for prompt tokens ) results = llm.generate(full_prompts, sp) # 4) For each prompt, sum the logprobs over the answer segment per_seq_logprobs: List[float] = [] per_seq_token_logprobs: List[List[float]] = [] def to_float(v): # v might be a Logprob object or a raw float return float(getattr(v, "logprob", v)) for idx_ex, (base_ids, full_ids, ref_ids) in enumerate( zip(base_ids_list, full_ids_list, ref_ids_list) ): if not ref_ids or not full_ids: per_seq_logprobs.append(float("-inf")) per_seq_token_logprobs.append([]) continue req_out = results[idx_ex] prompt_logprobs = getattr(req_out, "prompt_logprobs", None) if prompt_logprobs is None: raise RuntimeError( "req_out.prompt_logprobs is None. Adjust to match your vLLM version." ) # Align lengths: vLLM may add BOS token if len(prompt_logprobs) == len(full_ids) + 1: prompt_logprobs = prompt_logprobs[1:] elif len(prompt_logprobs) != len(full_ids): L = min(len(prompt_logprobs), len(full_ids)) prompt_logprobs = prompt_logprobs[:L] full_ids = full_ids[:L] seq_lp = 0.0 token_logprobs: List[float] = [] # Optional debug header if debug and idx_ex == debug_index: print("\n=== DEBUG vLLM PPL example", idx_ex, "===") print("Full prompt:\n", full_prompts[idx_ex]) print("\nBase token length:", len(base_ids)) print("Full token length:", len(full_ids)) print("Answer token length (ref_len):", ref_len) print("\nAnswer tokens and logprobs:") for offset, token_id in enumerate(ref_ids[:ref_len]): pos = len(base_ids) + offset # index in full sequence if pos >= len(prompt_logprobs): token_lp = -20.0 seq_lp += token_lp token_logprobs.append(token_lp) if debug and idx_ex == debug_index: tok_str = tokenizer.convert_ids_to_tokens([token_id])[0] print(f" pos={pos:<3} token={tok_str!r:>10} logprob={token_lp: .4f} (OUT OF RANGE)") continue cand_dict = prompt_logprobs[pos] or {} if token_id in cand_dict: token_lp = to_float(cand_dict[token_id]) else: if cand_dict: floor = min(to_float(v) for v in cand_dict.values()) token_lp = floor - 5.0 else: token_lp = -20.0 seq_lp += token_lp token_logprobs.append(token_lp) if debug and idx_ex == debug_index: tok_str = tokenizer.convert_ids_to_tokens([token_id])[0] avg_lp_so_far = seq_lp / (offset + 1) ppl_so_far = math.exp(-avg_lp_so_far) print( f" pos={pos:<3} token={tok_str!r:>10} " f"logprob={token_lp: .4f} " f"cum_avg_lp={avg_lp_so_far: .4f} " f"cum_ppl={ppl_so_far: .4f}" ) per_seq_logprobs.append(seq_lp) per_seq_token_logprobs.append(token_logprobs) # 5) Aggregate across examples arr = np.array(per_seq_logprobs, dtype=np.float64) T = int(ref_len) if T == 0: return { "avg_logprob": float("-inf"), "perplexity": float("inf"), "target_len": 0, "sequence_logprobs": per_seq_logprobs, "token_logprobs": per_seq_token_logprobs, } avg_lp = float((arr / T).mean()) ppl = math.exp(-avg_lp) return { "avg_logprob": avg_lp, "perplexity": ppl, "target_len": T, "sequence_logprobs": per_seq_logprobs, "token_logprobs": per_seq_token_logprobs, } def score_continuation_batch( model: BaseModel, prefixes: List[str], continuation: str, *, max_new_tokens: Optional[int] = None, batch_size: Optional[int] = None, ) -> List[Dict[str, Any]]: """ Batched variant of score_continuation. Returns list aligned to prefixes. """ if not prefixes: return [] B = batch_size or int(os.getenv("VLLM_SCORE_BATCH_SIZE", os.getenv("ATTRLLM_VLLM_BATCH_SIZE", "128"))) B = max(1, int(B)) outputs: List[Optional[Dict[str, Any]]] = [None] * len(prefixes) def _failed_output() -> Dict[str, Any]: return { "avg_logprob": float("-inf"), "perplexity": float("inf"), "total_logprob": float("-inf"), "avg_nll": float("inf"), "per_token": [], "target_len": 0, "sequence_logprobs": [float("-inf")], } def _score_chunk(chunk_prefixes: List[str], offset: int, chunk_batch_size: int) -> None: if not chunk_prefixes: return try: res = score_reference_autoregressive_vllm( model, question="", masked_inputs=chunk_prefixes, reference_text=continuation, answer_tag="", prompt_formatter=_continuation_prompt_formatter, max_new_tokens=max_new_tokens, ) seq_lp = res.get("sequence_logprobs", []) per_token = res.get("token_logprobs", []) target_len = int(res.get("target_len", 0) or 0) if len(seq_lp) != len(chunk_prefixes): raise RuntimeError( "vLLM returned mismatched sequence_logprobs length: " f"expected={len(chunk_prefixes)} got={len(seq_lp)}" ) for i, lp in enumerate(seq_lp): idx = offset + i total_logprob = float(lp) if target_len > 0 and math.isfinite(total_logprob): avg_nll = -total_logprob / target_len ppl = math.exp(avg_nll) else: avg_nll = float("inf") ppl = float("inf") outputs[idx] = { "avg_logprob": res.get("avg_logprob", float("-inf")), "perplexity": res.get("perplexity", ppl), "total_logprob": total_logprob, "avg_nll": avg_nll, "per_token": per_token[i] if i < len(per_token) else [], "target_len": target_len, "sequence_logprobs": [total_logprob], } except Exception as exc: # vLLM occasionally asserts in large-batch prompt_logprob scoring. # Back off to smaller chunks instead of failing the whole request. n = len(chunk_prefixes) if n == 1: logger.warning( "score_continuation_batch failed for single prefix; returning -inf fallback. error=%s", exc, ) outputs[offset] = _failed_output() return next_batch_size = max(1, min(chunk_batch_size // 2, n // 2)) logger.warning( "[BACKOFF_ACTIVE] score_continuation_batch chunk failed; retrying smaller chunks. " "chunk_size=%d batch_size=%d next_batch_size=%d error=%s", n, chunk_batch_size, next_batch_size, exc, ) for sub_start in range(0, n, next_batch_size): sub_chunk = chunk_prefixes[sub_start : sub_start + next_batch_size] _score_chunk(sub_chunk, offset + sub_start, next_batch_size) for start in range(0, len(prefixes), B): chunk = prefixes[start : start + B] _score_chunk(chunk, start, B) # Guarantee alignment even in worst-case failures. for i, item in enumerate(outputs): if item is None: outputs[i] = _failed_output() return outputs # type: ignore[return-value]