from __future__ import annotations from dataclasses import dataclass, replace from typing import Any import torch DEFAULT_REASONING_FILLER = ( "Archived finance notes mention approvals, invoices, compliance dates, and transport budgets across several quarters. " ) DEFAULT_INSTRUCTION_FILLER = ( "Operations guidance references staffing, inventory checks, maintenance windows, and shipping manifests in long planning documents. " ) DEFAULT_RETRIEVAL_HAYSTACK_UNIT = ( "Background memo about permit backlogs, bridge closures, zoning appeals, and archive indexing. " ) DEFAULT_RETRIEVAL_NEEDLE_KEY = "archive code" DEFAULT_RETRIEVAL_NEEDLE_VALUE = "RIVER-58142" DEFAULT_RETRIEVAL_NEEDLE_TEMPLATE = "Important detail: the {needle_key} is {needle_value}. Remember it exactly.\n" DEFAULT_RETRIEVAL_QUESTION_TEMPLATE = ( "Question: What is the {needle_key}? " "Return only the exact value on a single line. " "Do not include analysis, , punctuation, or any extra words.\n" "Answer:" ) @dataclass(frozen=True) class NeedlePromptBuild: input_ids: torch.Tensor attention_mask: torch.Tensor answer_text: str def _resolve_device(harness: Any) -> torch.device: adapter = getattr(harness, "adapter", None) device = getattr(adapter, "device", None) if device is None: raise ValueError("Harness adapter device is unavailable for prompt construction.") return device def _encode_text(tokenizer: Any, text: str) -> list[int]: token_ids = tokenizer(text, add_special_tokens=False)["input_ids"] if not token_ids: raise ValueError(f"text encoded to an empty token sequence: {text!r}") return [int(token_id) for token_id in token_ids] def _repeat_trim_ids(unit_ids: list[int], target_length: int) -> list[int]: if target_length < 0: raise ValueError("target_length must be non-negative") if target_length == 0: return [] if not unit_ids: raise ValueError("unit_ids must be non-empty when target_length is positive") token_ids: list[int] = [] while len(token_ids) < target_length: token_ids.extend(unit_ids) return token_ids[:target_length] def _build_suffix_task_inputs( tokenizer: Any, *, device: torch.device, prompt_length: int, filler_unit: str, suffix_text: str, ) -> tuple[torch.Tensor, torch.Tensor]: if prompt_length <= 0: raise ValueError("prompt_length must be positive") bos_ids = [int(tokenizer.bos_token_id)] if getattr(tokenizer, "bos_token_id", None) is not None else [] filler_ids = _encode_text(tokenizer, filler_unit) suffix_ids = _encode_text(tokenizer, suffix_text) reserved = len(bos_ids) + len(suffix_ids) if reserved >= prompt_length: raise ValueError(f"prompt_length={prompt_length} is too small for reserved suffix of {reserved} tokens") token_ids = bos_ids + _repeat_trim_ids(filler_ids, prompt_length - reserved) + suffix_ids input_ids = torch.tensor([token_ids], dtype=torch.long, device=device) attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) return input_ids, attention_mask def build_needle_prompt_inputs( tokenizer: Any, *, device: torch.device, prompt_length: int, needle_position_fraction: float, haystack_unit: str, needle_key: str, needle_value: str, needle_template: str, question_template: str, ) -> NeedlePromptBuild: if prompt_length <= 0: raise ValueError("prompt_length must be positive") if not 0.0 <= float(needle_position_fraction) <= 1.0: raise ValueError("needle_position_fraction must be in [0, 1]") bos_ids = [int(tokenizer.bos_token_id)] if getattr(tokenizer, "bos_token_id", None) is not None else [] haystack_ids = _encode_text(tokenizer, haystack_unit) needle_ids = _encode_text(tokenizer, needle_template.format(needle_key=needle_key, needle_value=needle_value)) question_ids = _encode_text(tokenizer, question_template.format(needle_key=needle_key)) reserved = len(bos_ids) + len(needle_ids) + len(question_ids) if reserved >= prompt_length: raise ValueError( f"prompt_length={prompt_length} is too small for bos+needle+question payload of {reserved} tokens" ) filler_budget = prompt_length - reserved filler_before_tokens = int(round(filler_budget * float(needle_position_fraction))) filler_before_tokens = max(0, min(filler_budget, filler_before_tokens)) filler_after_tokens = filler_budget - filler_before_tokens token_ids = ( bos_ids + _repeat_trim_ids(haystack_ids, filler_before_tokens) + needle_ids + _repeat_trim_ids(haystack_ids, filler_after_tokens) + question_ids ) if len(token_ids) != prompt_length: raise AssertionError(f"constructed prompt length {len(token_ids)} did not match target {prompt_length}") input_ids = torch.tensor([token_ids], dtype=torch.long, device=device) attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) return NeedlePromptBuild( input_ids=input_ids, attention_mask=attention_mask, answer_text=needle_value, ) def _build_reasoning_inputs( tokenizer: Any, *, device: torch.device, prompt_length: int, ) -> tuple[torch.Tensor, torch.Tensor, str]: answer = "48" suffix = ( "A clerk solves a budget worksheet.\n" "Compute 17 + 26 - 9 + 14.\n" "You may think silently, but the only visible output must be exactly one line.\n" "That line must start with FINAL: followed by the integer answer.\n" "Do not echo the prompt, do not include analysis, and do not output angle brackets.\n" "Answer:\n" "FINAL:" ) input_ids, attention_mask = _build_suffix_task_inputs( tokenizer, device=device, prompt_length=prompt_length, filler_unit=DEFAULT_REASONING_FILLER, suffix_text=suffix, ) return input_ids, attention_mask, answer def _build_instruction_inputs( tokenizer: Any, *, device: torch.device, prompt_length: int, ) -> tuple[torch.Tensor, torch.Tensor, str]: answer = "STATUS: READY\nCOLOR: BLUE" suffix = ( "Follow these instructions exactly.\n" "1. The only visible output must be exactly two lines.\n" "2. First line: STATUS: READY\n" "3. Second line: COLOR: BLUE\n" "4. Do not repeat the prompt.\n" "5. Do not add any extra words, punctuation, explanation, or block.\n" "Answer:" ) input_ids, attention_mask = _build_suffix_task_inputs( tokenizer, device=device, prompt_length=prompt_length, filler_unit=DEFAULT_INSTRUCTION_FILLER, suffix_text=suffix, ) return input_ids, attention_mask, answer def _task_specs( harness: Any, *, prompt_length: int, args: Any, ) -> list[dict[str, Any]]: tokenizer = getattr(harness, "tokenizer", None) if tokenizer is None: raise ValueError("tokenizer is unavailable") device = _resolve_device(harness) retrieval = build_needle_prompt_inputs( tokenizer, device=device, prompt_length=prompt_length, needle_position_fraction=0.5, haystack_unit=DEFAULT_RETRIEVAL_HAYSTACK_UNIT, needle_key=DEFAULT_RETRIEVAL_NEEDLE_KEY, needle_value=DEFAULT_RETRIEVAL_NEEDLE_VALUE, needle_template=DEFAULT_RETRIEVAL_NEEDLE_TEMPLATE, question_template=DEFAULT_RETRIEVAL_QUESTION_TEMPLATE, ) reasoning_ids, reasoning_mask, reasoning_answer = _build_reasoning_inputs( tokenizer, device=device, prompt_length=prompt_length, ) instruction_ids, instruction_mask, instruction_answer = _build_instruction_inputs( tokenizer, device=device, prompt_length=prompt_length, ) return [ { "task_name": "retrieval_passkey", "task_family": "retrieval", "task_variant": "passkey", "task_prompt_preview": "Question: What is the archive code? Answer with the exact value only.", "input_ids": retrieval.input_ids, "attention_mask": retrieval.attention_mask, "decode_steps": int(args.max_new_tokens_retrieval), "expected_answer": retrieval.answer_text, "stop_sequences": (retrieval.answer_text,), }, { "task_name": "reasoning_arithmetic", "task_family": "reasoning", "task_variant": "arithmetic", "task_prompt_preview": "Start with 17. Add 26. Subtract 9. Add 14. What is the final total?", "input_ids": reasoning_ids, "attention_mask": reasoning_mask, "decode_steps": int(args.max_new_tokens_reasoning), "expected_answer": reasoning_answer, "stop_sequences": (f"FINAL: {reasoning_answer}",), }, { "task_name": "instruction_constraints", "task_family": "instruction", "task_variant": "constraints", "task_prompt_preview": "Reply with exactly two lines: STATUS: READY and COLOR: BLUE.", "input_ids": instruction_ids, "attention_mask": instruction_mask, "decode_steps": int(args.max_new_tokens_instruction), "expected_answer": instruction_answer, "stop_sequences": (instruction_answer,), }, ] def _apply_selector_task_context( harness: Any, *, profile: str, task_family: str, task_variant: str, ) -> None: if profile == "dense" or not hasattr(harness, "adapter") or not hasattr(harness.adapter, "dotcache_config"): return prompt_family = None if profile == "exact" else str(task_family) prompt_variant = None if profile == "exact" else str(task_variant) current_config = harness.adapter.dotcache_config if hasattr(current_config, "__dataclass_fields__"): updated_config = replace( current_config, learned_page_selector_prompt_family=prompt_family, learned_page_selector_prompt_variant=prompt_variant, ) else: current_config.learned_page_selector_prompt_family = prompt_family current_config.learned_page_selector_prompt_variant = prompt_variant updated_config = current_config harness.adapter.dotcache_config = updated_config harness.adapter.model_kv_cache.config = updated_config