Spaces:
Paused
Paused
| from __future__ import annotations | |
| from dataclasses import dataclass, replace | |
| from typing import Any | |
| import torch | |
| DEFAULT_REASONING_FILLER = ( | |
| "Archived finance notes mention approvals, invoices, compliance dates, and transport budgets across several quarters. " | |
| ) | |
| DEFAULT_INSTRUCTION_FILLER = ( | |
| "Operations guidance references staffing, inventory checks, maintenance windows, and shipping manifests in long planning documents. " | |
| ) | |
| DEFAULT_RETRIEVAL_HAYSTACK_UNIT = ( | |
| "Background memo about permit backlogs, bridge closures, zoning appeals, and archive indexing. " | |
| ) | |
| DEFAULT_RETRIEVAL_NEEDLE_KEY = "archive code" | |
| DEFAULT_RETRIEVAL_NEEDLE_VALUE = "RIVER-58142" | |
| DEFAULT_RETRIEVAL_NEEDLE_TEMPLATE = "Important detail: the {needle_key} is {needle_value}. Remember it exactly.\n" | |
| DEFAULT_RETRIEVAL_QUESTION_TEMPLATE = ( | |
| "Question: What is the {needle_key}? " | |
| "Return only the exact value on a single line. " | |
| "Do not include analysis, <think>, punctuation, or any extra words.\n" | |
| "Answer:" | |
| ) | |
| class NeedlePromptBuild: | |
| input_ids: torch.Tensor | |
| attention_mask: torch.Tensor | |
| answer_text: str | |
| def _resolve_device(harness: Any) -> torch.device: | |
| adapter = getattr(harness, "adapter", None) | |
| device = getattr(adapter, "device", None) | |
| if device is None: | |
| raise ValueError("Harness adapter device is unavailable for prompt construction.") | |
| return device | |
| def _encode_text(tokenizer: Any, text: str) -> list[int]: | |
| token_ids = tokenizer(text, add_special_tokens=False)["input_ids"] | |
| if not token_ids: | |
| raise ValueError(f"text encoded to an empty token sequence: {text!r}") | |
| return [int(token_id) for token_id in token_ids] | |
| def _repeat_trim_ids(unit_ids: list[int], target_length: int) -> list[int]: | |
| if target_length < 0: | |
| raise ValueError("target_length must be non-negative") | |
| if target_length == 0: | |
| return [] | |
| if not unit_ids: | |
| raise ValueError("unit_ids must be non-empty when target_length is positive") | |
| token_ids: list[int] = [] | |
| while len(token_ids) < target_length: | |
| token_ids.extend(unit_ids) | |
| return token_ids[:target_length] | |
| def _build_suffix_task_inputs( | |
| tokenizer: Any, | |
| *, | |
| device: torch.device, | |
| prompt_length: int, | |
| filler_unit: str, | |
| suffix_text: str, | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| if prompt_length <= 0: | |
| raise ValueError("prompt_length must be positive") | |
| bos_ids = [int(tokenizer.bos_token_id)] if getattr(tokenizer, "bos_token_id", None) is not None else [] | |
| filler_ids = _encode_text(tokenizer, filler_unit) | |
| suffix_ids = _encode_text(tokenizer, suffix_text) | |
| reserved = len(bos_ids) + len(suffix_ids) | |
| if reserved >= prompt_length: | |
| raise ValueError(f"prompt_length={prompt_length} is too small for reserved suffix of {reserved} tokens") | |
| token_ids = bos_ids + _repeat_trim_ids(filler_ids, prompt_length - reserved) + suffix_ids | |
| input_ids = torch.tensor([token_ids], dtype=torch.long, device=device) | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) | |
| return input_ids, attention_mask | |
| def build_needle_prompt_inputs( | |
| tokenizer: Any, | |
| *, | |
| device: torch.device, | |
| prompt_length: int, | |
| needle_position_fraction: float, | |
| haystack_unit: str, | |
| needle_key: str, | |
| needle_value: str, | |
| needle_template: str, | |
| question_template: str, | |
| ) -> NeedlePromptBuild: | |
| if prompt_length <= 0: | |
| raise ValueError("prompt_length must be positive") | |
| if not 0.0 <= float(needle_position_fraction) <= 1.0: | |
| raise ValueError("needle_position_fraction must be in [0, 1]") | |
| bos_ids = [int(tokenizer.bos_token_id)] if getattr(tokenizer, "bos_token_id", None) is not None else [] | |
| haystack_ids = _encode_text(tokenizer, haystack_unit) | |
| needle_ids = _encode_text(tokenizer, needle_template.format(needle_key=needle_key, needle_value=needle_value)) | |
| question_ids = _encode_text(tokenizer, question_template.format(needle_key=needle_key)) | |
| reserved = len(bos_ids) + len(needle_ids) + len(question_ids) | |
| if reserved >= prompt_length: | |
| raise ValueError( | |
| f"prompt_length={prompt_length} is too small for bos+needle+question payload of {reserved} tokens" | |
| ) | |
| filler_budget = prompt_length - reserved | |
| filler_before_tokens = int(round(filler_budget * float(needle_position_fraction))) | |
| filler_before_tokens = max(0, min(filler_budget, filler_before_tokens)) | |
| filler_after_tokens = filler_budget - filler_before_tokens | |
| token_ids = ( | |
| bos_ids | |
| + _repeat_trim_ids(haystack_ids, filler_before_tokens) | |
| + needle_ids | |
| + _repeat_trim_ids(haystack_ids, filler_after_tokens) | |
| + question_ids | |
| ) | |
| if len(token_ids) != prompt_length: | |
| raise AssertionError(f"constructed prompt length {len(token_ids)} did not match target {prompt_length}") | |
| input_ids = torch.tensor([token_ids], dtype=torch.long, device=device) | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) | |
| return NeedlePromptBuild( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| answer_text=needle_value, | |
| ) | |
| def _build_reasoning_inputs( | |
| tokenizer: Any, | |
| *, | |
| device: torch.device, | |
| prompt_length: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor, str]: | |
| answer = "48" | |
| suffix = ( | |
| "A clerk solves a budget worksheet.\n" | |
| "Compute 17 + 26 - 9 + 14.\n" | |
| "You may think silently, but the only visible output must be exactly one line.\n" | |
| "That line must start with FINAL: followed by the integer answer.\n" | |
| "Do not echo the prompt, do not include analysis, and do not output angle brackets.\n" | |
| "Answer:\n" | |
| "FINAL:" | |
| ) | |
| input_ids, attention_mask = _build_suffix_task_inputs( | |
| tokenizer, | |
| device=device, | |
| prompt_length=prompt_length, | |
| filler_unit=DEFAULT_REASONING_FILLER, | |
| suffix_text=suffix, | |
| ) | |
| return input_ids, attention_mask, answer | |
| def _build_instruction_inputs( | |
| tokenizer: Any, | |
| *, | |
| device: torch.device, | |
| prompt_length: int, | |
| ) -> tuple[torch.Tensor, torch.Tensor, str]: | |
| answer = "STATUS: READY\nCOLOR: BLUE" | |
| suffix = ( | |
| "Follow these instructions exactly.\n" | |
| "1. The only visible output must be exactly two lines.\n" | |
| "2. First line: STATUS: READY\n" | |
| "3. Second line: COLOR: BLUE\n" | |
| "4. Do not repeat the prompt.\n" | |
| "5. Do not add any extra words, punctuation, explanation, or <think> block.\n" | |
| "Answer:" | |
| ) | |
| input_ids, attention_mask = _build_suffix_task_inputs( | |
| tokenizer, | |
| device=device, | |
| prompt_length=prompt_length, | |
| filler_unit=DEFAULT_INSTRUCTION_FILLER, | |
| suffix_text=suffix, | |
| ) | |
| return input_ids, attention_mask, answer | |
| def _task_specs( | |
| harness: Any, | |
| *, | |
| prompt_length: int, | |
| args: Any, | |
| ) -> list[dict[str, Any]]: | |
| tokenizer = getattr(harness, "tokenizer", None) | |
| if tokenizer is None: | |
| raise ValueError("tokenizer is unavailable") | |
| device = _resolve_device(harness) | |
| retrieval = build_needle_prompt_inputs( | |
| tokenizer, | |
| device=device, | |
| prompt_length=prompt_length, | |
| needle_position_fraction=0.5, | |
| haystack_unit=DEFAULT_RETRIEVAL_HAYSTACK_UNIT, | |
| needle_key=DEFAULT_RETRIEVAL_NEEDLE_KEY, | |
| needle_value=DEFAULT_RETRIEVAL_NEEDLE_VALUE, | |
| needle_template=DEFAULT_RETRIEVAL_NEEDLE_TEMPLATE, | |
| question_template=DEFAULT_RETRIEVAL_QUESTION_TEMPLATE, | |
| ) | |
| reasoning_ids, reasoning_mask, reasoning_answer = _build_reasoning_inputs( | |
| tokenizer, | |
| device=device, | |
| prompt_length=prompt_length, | |
| ) | |
| instruction_ids, instruction_mask, instruction_answer = _build_instruction_inputs( | |
| tokenizer, | |
| device=device, | |
| prompt_length=prompt_length, | |
| ) | |
| return [ | |
| { | |
| "task_name": "retrieval_passkey", | |
| "task_family": "retrieval", | |
| "task_variant": "passkey", | |
| "task_prompt_preview": "Question: What is the archive code? Answer with the exact value only.", | |
| "input_ids": retrieval.input_ids, | |
| "attention_mask": retrieval.attention_mask, | |
| "decode_steps": int(args.max_new_tokens_retrieval), | |
| "expected_answer": retrieval.answer_text, | |
| "stop_sequences": (retrieval.answer_text,), | |
| }, | |
| { | |
| "task_name": "reasoning_arithmetic", | |
| "task_family": "reasoning", | |
| "task_variant": "arithmetic", | |
| "task_prompt_preview": "Start with 17. Add 26. Subtract 9. Add 14. What is the final total?", | |
| "input_ids": reasoning_ids, | |
| "attention_mask": reasoning_mask, | |
| "decode_steps": int(args.max_new_tokens_reasoning), | |
| "expected_answer": reasoning_answer, | |
| "stop_sequences": (f"FINAL: {reasoning_answer}",), | |
| }, | |
| { | |
| "task_name": "instruction_constraints", | |
| "task_family": "instruction", | |
| "task_variant": "constraints", | |
| "task_prompt_preview": "Reply with exactly two lines: STATUS: READY and COLOR: BLUE.", | |
| "input_ids": instruction_ids, | |
| "attention_mask": instruction_mask, | |
| "decode_steps": int(args.max_new_tokens_instruction), | |
| "expected_answer": instruction_answer, | |
| "stop_sequences": (instruction_answer,), | |
| }, | |
| ] | |
| def _apply_selector_task_context( | |
| harness: Any, | |
| *, | |
| profile: str, | |
| task_family: str, | |
| task_variant: str, | |
| ) -> None: | |
| if profile == "dense" or not hasattr(harness, "adapter") or not hasattr(harness.adapter, "dotcache_config"): | |
| return | |
| prompt_family = None if profile == "exact" else str(task_family) | |
| prompt_variant = None if profile == "exact" else str(task_variant) | |
| current_config = harness.adapter.dotcache_config | |
| if hasattr(current_config, "__dataclass_fields__"): | |
| updated_config = replace( | |
| current_config, | |
| learned_page_selector_prompt_family=prompt_family, | |
| learned_page_selector_prompt_variant=prompt_variant, | |
| ) | |
| else: | |
| current_config.learned_page_selector_prompt_family = prompt_family | |
| current_config.learned_page_selector_prompt_variant = prompt_variant | |
| updated_config = current_config | |
| harness.adapter.dotcache_config = updated_config | |
| harness.adapter.model_kv_cache.config = updated_config | |