| import math |
| import os |
| import logging |
| from typing import Callable, List, Dict, Any, Optional |
|
|
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
|
|
| from .models import BaseModel |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| PromptFormatter = Callable[[str, str, str, str], str] |
|
|
|
|
| def default_prompt_formatter( |
| question: str, |
| context: str, |
| answer_tag: str, |
| reference_text: str, |
| ) -> str: |
| """ |
| Basic prompt layout used by default. |
| |
| You can override this via the `prompt_formatter` argument if a |
| specific model needs a different template (e.g., chat template). |
| """ |
| return f"{question}\nContext: {context}\n{answer_tag} {reference_text}" |
|
|
|
|
| def _join_prefix_continuation(prefix: str, continuation: str) -> str: |
| """Join prefix and continuation with a single space when needed.""" |
| if not prefix: |
| return continuation |
| if not continuation: |
| return prefix |
| if prefix[-1].isspace() or continuation[0].isspace(): |
| return prefix + continuation |
| return prefix + " " + continuation |
|
|
|
|
| def _continuation_prompt_formatter( |
| _question: str, |
| context: str, |
| _answer_tag: str, |
| reference_text: str, |
| ) -> str: |
| """Prompt formatter that concatenates context and continuation text.""" |
| return _join_prefix_continuation(context, reference_text) |
|
|
|
|
| def score_continuation( |
| model: BaseModel, |
| prefix: str, |
| continuation: str, |
| *, |
| max_new_tokens: Optional[int] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Compute teacher-forced logprob and perplexity of `continuation` given `prefix`. |
| Uses vLLM prompt logprobs when available. |
| """ |
| if not continuation: |
| return { |
| "avg_logprob": float("-inf"), |
| "perplexity": float("inf"), |
| "per_token": [], |
| "target_len": 0, |
| "sequence_logprobs": [], |
| } |
|
|
| if hasattr(model, "llm"): |
| result = score_reference_autoregressive_vllm( |
| model, |
| question="", |
| masked_inputs=[prefix], |
| reference_text=continuation, |
| answer_tag="", |
| prompt_formatter=_continuation_prompt_formatter, |
| max_new_tokens=max_new_tokens, |
| ) |
| token_logprobs = result.get("token_logprobs", []) |
| per_token = token_logprobs[0] if token_logprobs else [] |
| target_len = int(result.get("target_len", 0) or 0) |
| seq_logprobs = result.get("sequence_logprobs", []) |
| if seq_logprobs: |
| total_logprob = float(seq_logprobs[0]) |
| else: |
| avg_lp = float(result.get("avg_logprob", float("-inf"))) |
| total_logprob = avg_lp * target_len if target_len else float("-inf") |
| if target_len > 0 and math.isfinite(total_logprob): |
| avg_nll = -total_logprob / target_len |
| ppl = math.exp(avg_nll) |
| else: |
| avg_nll = float("inf") |
| ppl = float("inf") |
| return { |
| "avg_logprob": result.get("avg_logprob", float("-inf")), |
| "perplexity": result.get("perplexity", float("inf")), |
| "total_logprob": total_logprob, |
| "avg_nll": avg_nll, |
| "per_token": per_token, |
| "target_len": target_len, |
| "sequence_logprobs": seq_logprobs, |
| } |
|
|
| raise RuntimeError( |
| "score_continuation requires a VLLMModel with .llm " |
| "(use loader.get_model_vllm inside the vLLM container)." |
| ) |
|
|
| |
| def score_reference_autoregressive_hf( |
| model, |
| tokenizer, |
| question: str, |
| masked_inputs: List[str], |
| reference_token_ids: List[int], |
| *, |
| answer_tag: str = "Answer:", |
| prompt_formatter: PromptFormatter = default_prompt_formatter, |
| ) -> Dict[str, Any]: |
| """ |
| Teacher-forced next-token log-probs of the reference answer across a batch of masked contexts. |
| Returns: |
| { |
| "avg_logprob": float, # mean over tokens, then mean over batch |
| "perplexity": float, # exp(-avg_logprob) |
| "target_len": int, # number of next-token steps (len(ref_ids)-1) |
| "sequence_logprobs": List[float], # total logprob per masked input (len= batch) |
| "token_logprobs": List[List[float]] # per-token logprobs per masked input |
| } |
| """ |
| |
| if tokenizer.pad_token is None: |
| if tokenizer.eos_token is not None: |
| tokenizer.pad_token = tokenizer.eos_token |
| else: |
| tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
| |
| try: |
| model.resize_token_embeddings(len(tokenizer)) |
| except Exception: |
| pass |
| tokenizer.padding_side = "left" |
| try: |
| model.config.pad_token_id = tokenizer.pad_token_id |
| except Exception: |
| pass |
| B = len(masked_inputs) |
| |
| batch_seq_logprobs = np.zeros(B, dtype=np.float64) |
|
|
| T = max(0, len(reference_token_ids) - 1) |
| if T == 0: |
| return { |
| "avg_logprob": float("-inf"), |
| "perplexity": float("inf"), |
| "target_len": 0, |
| "sequence_logprobs": batch_seq_logprobs.tolist(), |
| } |
|
|
| for j in range(T): |
| |
| prefix_ids = reference_token_ids[: j + 1] |
| |
| prefix_str = tokenizer.decode(prefix_ids, skip_special_tokens=False) |
|
|
| |
| prompts = [ |
| prompt_formatter(question, ctx, answer_tag, prefix_str) |
| for ctx in masked_inputs |
| ] |
|
|
| inputs = tokenizer( |
| prompts, return_tensors="pt", padding=True, truncation=True |
| ).to(model.device) |
|
|
| with torch.inference_mode(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=1, |
| output_scores=True, |
| return_dict_in_generate=True, |
| pad_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| |
| logits = torch.stack(out["scores"]).swapaxes(0, 1)[:, 0, :] |
| logprobs = F.log_softmax(logits, dim=-1) |
|
|
| target_next_id = reference_token_ids[j + 1] |
| step_lp = logprobs[:, target_next_id].detach().cpu().numpy() |
|
|
| batch_seq_logprobs += step_lp |
|
|
| |
| del inputs, out, logits, logprobs |
|
|
| |
| avg_lp = float((batch_seq_logprobs / T).mean()) |
| ppl = math.exp(-avg_lp) |
|
|
| return { |
| "avg_logprob": avg_lp, |
| "perplexity": ppl, |
| "target_len": T, |
| "sequence_logprobs": batch_seq_logprobs.tolist(), |
| } |
|
|
| def score_reference_autoregressive_vllm( |
| vllm_model, |
| question: str, |
| masked_inputs: List[str], |
| reference_text: str, |
| *, |
| answer_tag: str = "Answer:", |
| prompt_formatter: PromptFormatter = default_prompt_formatter, |
| top_k: int = 5, |
| debug: bool = False, |
| debug_index: int = 0, |
| max_new_tokens: Optional[int] = None, |
| ) -> Dict[str, Any]: |
| """ |
| vLLM-native teacher-forced perplexity over a *given* reference answer. |
| |
| Batching behavior: |
| - Build one full prompt per masked context. |
| - Send ALL prompts to vLLM in a single `llm.generate` call. |
| - vLLM returns prompt-level logprobs for each token position. |
| - For each prompt, slice out the part corresponding to the answer |
| and sum the logprobs of those answer tokens. |
| |
| If debug=True, we print detailed info for the example at `debug_index`: |
| - Full prompt string |
| - Base vs full token lengths |
| - Each answer token + its logprob and running perplexity. |
| """ |
| |
| assert hasattr(vllm_model, "llm"), f"Expected VLLMModel wrapper, got {type(vllm_model)}" |
|
|
| llm = vllm_model.llm |
| try: |
| tokenizer = llm.get_tokenizer() |
| except Exception as e: |
| raise RuntimeError("Could not access vLLM tokenizer; check vLLM version / API") from e |
|
|
| B = len(masked_inputs) |
| if B == 0: |
| return { |
| "avg_logprob": float("-inf"), |
| "perplexity": float("inf"), |
| "target_len": 0, |
| "sequence_logprobs": [], |
| "token_logprobs": [], |
| } |
|
|
| |
| full_prompts: List[str] = [] |
| base_ids_list: List[List[int]] = [] |
| full_ids_list: List[List[int]] = [] |
| ref_ids_list: List[List[int]] = [] |
|
|
| ref_token_count: Optional[int] = None |
|
|
| for ctx in masked_inputs: |
| |
| base_prompt = prompt_formatter(question, ctx, answer_tag, "") |
| |
| full_prompt = prompt_formatter(question, ctx, answer_tag, reference_text) |
|
|
| base_ids = tokenizer(base_prompt, add_special_tokens=False)["input_ids"] |
| full_ids = tokenizer(full_prompt, add_special_tokens=False)["input_ids"] |
|
|
| |
| if not full_ids or len(full_ids) <= len(base_ids): |
| base_ids_list.append(base_ids) |
| full_ids_list.append(full_ids) |
| ref_ids_list.append([]) |
| full_prompts.append(full_prompt) |
| continue |
|
|
| ref_ids = full_ids[len(base_ids):] |
|
|
| if ref_token_count is None: |
| ref_token_count = len(ref_ids) |
| else: |
| |
| ref_token_count = min(ref_token_count, len(ref_ids)) |
|
|
| base_ids_list.append(base_ids) |
| full_ids_list.append(full_ids) |
| ref_ids_list.append(ref_ids) |
| full_prompts.append(full_prompt) |
|
|
| |
| if ref_token_count is None or ref_token_count == 0: |
| return { |
| "avg_logprob": float("-inf"), |
| "perplexity": float("inf"), |
| "target_len": 0, |
| "sequence_logprobs": [float("-inf")] * B, |
| "token_logprobs": [[] for _ in range(B)], |
| } |
|
|
| |
| ref_len = ref_token_count |
| ref_ids_list = [ref_ids[:ref_len] for ref_ids in ref_ids_list] |
|
|
| |
| from vllm import SamplingParams |
|
|
| sp = SamplingParams( |
| max_tokens=max_new_tokens or 1, |
| temperature=0.0, |
| top_p=1.0, |
| logprobs=top_k, |
| prompt_logprobs=1, |
| ) |
|
|
| results = llm.generate(full_prompts, sp) |
|
|
| |
| per_seq_logprobs: List[float] = [] |
| per_seq_token_logprobs: List[List[float]] = [] |
|
|
| def to_float(v): |
| |
| return float(getattr(v, "logprob", v)) |
|
|
| for idx_ex, (base_ids, full_ids, ref_ids) in enumerate( |
| zip(base_ids_list, full_ids_list, ref_ids_list) |
| ): |
| if not ref_ids or not full_ids: |
| per_seq_logprobs.append(float("-inf")) |
| per_seq_token_logprobs.append([]) |
| continue |
|
|
| req_out = results[idx_ex] |
| prompt_logprobs = getattr(req_out, "prompt_logprobs", None) |
| if prompt_logprobs is None: |
| raise RuntimeError( |
| "req_out.prompt_logprobs is None. Adjust to match your vLLM version." |
| ) |
|
|
| |
| if len(prompt_logprobs) == len(full_ids) + 1: |
| prompt_logprobs = prompt_logprobs[1:] |
| elif len(prompt_logprobs) != len(full_ids): |
| L = min(len(prompt_logprobs), len(full_ids)) |
| prompt_logprobs = prompt_logprobs[:L] |
| full_ids = full_ids[:L] |
|
|
| seq_lp = 0.0 |
| token_logprobs: List[float] = [] |
|
|
| |
| if debug and idx_ex == debug_index: |
| print("\n=== DEBUG vLLM PPL example", idx_ex, "===") |
| print("Full prompt:\n", full_prompts[idx_ex]) |
| print("\nBase token length:", len(base_ids)) |
| print("Full token length:", len(full_ids)) |
| print("Answer token length (ref_len):", ref_len) |
| print("\nAnswer tokens and logprobs:") |
|
|
| for offset, token_id in enumerate(ref_ids[:ref_len]): |
| pos = len(base_ids) + offset |
|
|
| if pos >= len(prompt_logprobs): |
| token_lp = -20.0 |
| seq_lp += token_lp |
| token_logprobs.append(token_lp) |
| if debug and idx_ex == debug_index: |
| tok_str = tokenizer.convert_ids_to_tokens([token_id])[0] |
| print(f" pos={pos:<3} token={tok_str!r:>10} logprob={token_lp: .4f} (OUT OF RANGE)") |
| continue |
|
|
| cand_dict = prompt_logprobs[pos] or {} |
|
|
| if token_id in cand_dict: |
| token_lp = to_float(cand_dict[token_id]) |
| else: |
| if cand_dict: |
| floor = min(to_float(v) for v in cand_dict.values()) |
| token_lp = floor - 5.0 |
| else: |
| token_lp = -20.0 |
|
|
| seq_lp += token_lp |
| token_logprobs.append(token_lp) |
|
|
| if debug and idx_ex == debug_index: |
| tok_str = tokenizer.convert_ids_to_tokens([token_id])[0] |
| avg_lp_so_far = seq_lp / (offset + 1) |
| ppl_so_far = math.exp(-avg_lp_so_far) |
| print( |
| f" pos={pos:<3} token={tok_str!r:>10} " |
| f"logprob={token_lp: .4f} " |
| f"cum_avg_lp={avg_lp_so_far: .4f} " |
| f"cum_ppl={ppl_so_far: .4f}" |
| ) |
|
|
| per_seq_logprobs.append(seq_lp) |
| per_seq_token_logprobs.append(token_logprobs) |
|
|
| |
| arr = np.array(per_seq_logprobs, dtype=np.float64) |
| T = int(ref_len) |
|
|
| if T == 0: |
| return { |
| "avg_logprob": float("-inf"), |
| "perplexity": float("inf"), |
| "target_len": 0, |
| "sequence_logprobs": per_seq_logprobs, |
| "token_logprobs": per_seq_token_logprobs, |
| } |
|
|
| avg_lp = float((arr / T).mean()) |
| ppl = math.exp(-avg_lp) |
|
|
| return { |
| "avg_logprob": avg_lp, |
| "perplexity": ppl, |
| "target_len": T, |
| "sequence_logprobs": per_seq_logprobs, |
| "token_logprobs": per_seq_token_logprobs, |
| } |
|
|
|
|
| def score_continuation_batch( |
| model: BaseModel, |
| prefixes: List[str], |
| continuation: str, |
| *, |
| max_new_tokens: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| ) -> List[Dict[str, Any]]: |
| """ |
| Batched variant of score_continuation. Returns list aligned to prefixes. |
| """ |
| if not prefixes: |
| return [] |
|
|
| B = batch_size or int(os.getenv("VLLM_SCORE_BATCH_SIZE", |
| os.getenv("ATTRLLM_VLLM_BATCH_SIZE", "128"))) |
| B = max(1, int(B)) |
| outputs: List[Optional[Dict[str, Any]]] = [None] * len(prefixes) |
|
|
| def _failed_output() -> Dict[str, Any]: |
| return { |
| "avg_logprob": float("-inf"), |
| "perplexity": float("inf"), |
| "total_logprob": float("-inf"), |
| "avg_nll": float("inf"), |
| "per_token": [], |
| "target_len": 0, |
| "sequence_logprobs": [float("-inf")], |
| } |
|
|
| def _score_chunk(chunk_prefixes: List[str], offset: int, chunk_batch_size: int) -> None: |
| if not chunk_prefixes: |
| return |
| try: |
| res = score_reference_autoregressive_vllm( |
| model, |
| question="", |
| masked_inputs=chunk_prefixes, |
| reference_text=continuation, |
| answer_tag="", |
| prompt_formatter=_continuation_prompt_formatter, |
| max_new_tokens=max_new_tokens, |
| ) |
| seq_lp = res.get("sequence_logprobs", []) |
| per_token = res.get("token_logprobs", []) |
| target_len = int(res.get("target_len", 0) or 0) |
|
|
| if len(seq_lp) != len(chunk_prefixes): |
| raise RuntimeError( |
| "vLLM returned mismatched sequence_logprobs length: " |
| f"expected={len(chunk_prefixes)} got={len(seq_lp)}" |
| ) |
|
|
| for i, lp in enumerate(seq_lp): |
| idx = offset + i |
| total_logprob = float(lp) |
| if target_len > 0 and math.isfinite(total_logprob): |
| avg_nll = -total_logprob / target_len |
| ppl = math.exp(avg_nll) |
| else: |
| avg_nll = float("inf") |
| ppl = float("inf") |
| outputs[idx] = { |
| "avg_logprob": res.get("avg_logprob", float("-inf")), |
| "perplexity": res.get("perplexity", ppl), |
| "total_logprob": total_logprob, |
| "avg_nll": avg_nll, |
| "per_token": per_token[i] if i < len(per_token) else [], |
| "target_len": target_len, |
| "sequence_logprobs": [total_logprob], |
| } |
| except Exception as exc: |
| |
| |
| n = len(chunk_prefixes) |
| if n == 1: |
| logger.warning( |
| "score_continuation_batch failed for single prefix; returning -inf fallback. error=%s", |
| exc, |
| ) |
| outputs[offset] = _failed_output() |
| return |
|
|
| next_batch_size = max(1, min(chunk_batch_size // 2, n // 2)) |
| logger.warning( |
| "[BACKOFF_ACTIVE] score_continuation_batch chunk failed; retrying smaller chunks. " |
| "chunk_size=%d batch_size=%d next_batch_size=%d error=%s", |
| n, |
| chunk_batch_size, |
| next_batch_size, |
| exc, |
| ) |
| for sub_start in range(0, n, next_batch_size): |
| sub_chunk = chunk_prefixes[sub_start : sub_start + next_batch_size] |
| _score_chunk(sub_chunk, offset + sub_start, next_batch_size) |
|
|
| for start in range(0, len(prefixes), B): |
| chunk = prefixes[start : start + B] |
| _score_chunk(chunk, start, B) |
|
|
| |
| for i, item in enumerate(outputs): |
| if item is None: |
| outputs[i] = _failed_output() |
|
|
| return outputs |
|
|