DotCache-Arena / scripts /space_task_prompts.py
Deano Calver
Fix live space task prompt imports
7cbeabe
Raw
History Blame Contribute Delete
10.6 kB
from __future__ import annotations
from dataclasses import dataclass, replace
from typing import Any
import torch
DEFAULT_REASONING_FILLER = (
"Archived finance notes mention approvals, invoices, compliance dates, and transport budgets across several quarters. "
)
DEFAULT_INSTRUCTION_FILLER = (
"Operations guidance references staffing, inventory checks, maintenance windows, and shipping manifests in long planning documents. "
)
DEFAULT_RETRIEVAL_HAYSTACK_UNIT = (
"Background memo about permit backlogs, bridge closures, zoning appeals, and archive indexing. "
)
DEFAULT_RETRIEVAL_NEEDLE_KEY = "archive code"
DEFAULT_RETRIEVAL_NEEDLE_VALUE = "RIVER-58142"
DEFAULT_RETRIEVAL_NEEDLE_TEMPLATE = "Important detail: the {needle_key} is {needle_value}. Remember it exactly.\n"
DEFAULT_RETRIEVAL_QUESTION_TEMPLATE = (
"Question: What is the {needle_key}? "
"Return only the exact value on a single line. "
"Do not include analysis, <think>, punctuation, or any extra words.\n"
"Answer:"
)
@dataclass(frozen=True)
class NeedlePromptBuild:
input_ids: torch.Tensor
attention_mask: torch.Tensor
answer_text: str
def _resolve_device(harness: Any) -> torch.device:
adapter = getattr(harness, "adapter", None)
device = getattr(adapter, "device", None)
if device is None:
raise ValueError("Harness adapter device is unavailable for prompt construction.")
return device
def _encode_text(tokenizer: Any, text: str) -> list[int]:
token_ids = tokenizer(text, add_special_tokens=False)["input_ids"]
if not token_ids:
raise ValueError(f"text encoded to an empty token sequence: {text!r}")
return [int(token_id) for token_id in token_ids]
def _repeat_trim_ids(unit_ids: list[int], target_length: int) -> list[int]:
if target_length < 0:
raise ValueError("target_length must be non-negative")
if target_length == 0:
return []
if not unit_ids:
raise ValueError("unit_ids must be non-empty when target_length is positive")
token_ids: list[int] = []
while len(token_ids) < target_length:
token_ids.extend(unit_ids)
return token_ids[:target_length]
def _build_suffix_task_inputs(
tokenizer: Any,
*,
device: torch.device,
prompt_length: int,
filler_unit: str,
suffix_text: str,
) -> tuple[torch.Tensor, torch.Tensor]:
if prompt_length <= 0:
raise ValueError("prompt_length must be positive")
bos_ids = [int(tokenizer.bos_token_id)] if getattr(tokenizer, "bos_token_id", None) is not None else []
filler_ids = _encode_text(tokenizer, filler_unit)
suffix_ids = _encode_text(tokenizer, suffix_text)
reserved = len(bos_ids) + len(suffix_ids)
if reserved >= prompt_length:
raise ValueError(f"prompt_length={prompt_length} is too small for reserved suffix of {reserved} tokens")
token_ids = bos_ids + _repeat_trim_ids(filler_ids, prompt_length - reserved) + suffix_ids
input_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
return input_ids, attention_mask
def build_needle_prompt_inputs(
tokenizer: Any,
*,
device: torch.device,
prompt_length: int,
needle_position_fraction: float,
haystack_unit: str,
needle_key: str,
needle_value: str,
needle_template: str,
question_template: str,
) -> NeedlePromptBuild:
if prompt_length <= 0:
raise ValueError("prompt_length must be positive")
if not 0.0 <= float(needle_position_fraction) <= 1.0:
raise ValueError("needle_position_fraction must be in [0, 1]")
bos_ids = [int(tokenizer.bos_token_id)] if getattr(tokenizer, "bos_token_id", None) is not None else []
haystack_ids = _encode_text(tokenizer, haystack_unit)
needle_ids = _encode_text(tokenizer, needle_template.format(needle_key=needle_key, needle_value=needle_value))
question_ids = _encode_text(tokenizer, question_template.format(needle_key=needle_key))
reserved = len(bos_ids) + len(needle_ids) + len(question_ids)
if reserved >= prompt_length:
raise ValueError(
f"prompt_length={prompt_length} is too small for bos+needle+question payload of {reserved} tokens"
)
filler_budget = prompt_length - reserved
filler_before_tokens = int(round(filler_budget * float(needle_position_fraction)))
filler_before_tokens = max(0, min(filler_budget, filler_before_tokens))
filler_after_tokens = filler_budget - filler_before_tokens
token_ids = (
bos_ids
+ _repeat_trim_ids(haystack_ids, filler_before_tokens)
+ needle_ids
+ _repeat_trim_ids(haystack_ids, filler_after_tokens)
+ question_ids
)
if len(token_ids) != prompt_length:
raise AssertionError(f"constructed prompt length {len(token_ids)} did not match target {prompt_length}")
input_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
return NeedlePromptBuild(
input_ids=input_ids,
attention_mask=attention_mask,
answer_text=needle_value,
)
def _build_reasoning_inputs(
tokenizer: Any,
*,
device: torch.device,
prompt_length: int,
) -> tuple[torch.Tensor, torch.Tensor, str]:
answer = "48"
suffix = (
"A clerk solves a budget worksheet.\n"
"Compute 17 + 26 - 9 + 14.\n"
"You may think silently, but the only visible output must be exactly one line.\n"
"That line must start with FINAL: followed by the integer answer.\n"
"Do not echo the prompt, do not include analysis, and do not output angle brackets.\n"
"Answer:\n"
"FINAL:"
)
input_ids, attention_mask = _build_suffix_task_inputs(
tokenizer,
device=device,
prompt_length=prompt_length,
filler_unit=DEFAULT_REASONING_FILLER,
suffix_text=suffix,
)
return input_ids, attention_mask, answer
def _build_instruction_inputs(
tokenizer: Any,
*,
device: torch.device,
prompt_length: int,
) -> tuple[torch.Tensor, torch.Tensor, str]:
answer = "STATUS: READY\nCOLOR: BLUE"
suffix = (
"Follow these instructions exactly.\n"
"1. The only visible output must be exactly two lines.\n"
"2. First line: STATUS: READY\n"
"3. Second line: COLOR: BLUE\n"
"4. Do not repeat the prompt.\n"
"5. Do not add any extra words, punctuation, explanation, or <think> block.\n"
"Answer:"
)
input_ids, attention_mask = _build_suffix_task_inputs(
tokenizer,
device=device,
prompt_length=prompt_length,
filler_unit=DEFAULT_INSTRUCTION_FILLER,
suffix_text=suffix,
)
return input_ids, attention_mask, answer
def _task_specs(
harness: Any,
*,
prompt_length: int,
args: Any,
) -> list[dict[str, Any]]:
tokenizer = getattr(harness, "tokenizer", None)
if tokenizer is None:
raise ValueError("tokenizer is unavailable")
device = _resolve_device(harness)
retrieval = build_needle_prompt_inputs(
tokenizer,
device=device,
prompt_length=prompt_length,
needle_position_fraction=0.5,
haystack_unit=DEFAULT_RETRIEVAL_HAYSTACK_UNIT,
needle_key=DEFAULT_RETRIEVAL_NEEDLE_KEY,
needle_value=DEFAULT_RETRIEVAL_NEEDLE_VALUE,
needle_template=DEFAULT_RETRIEVAL_NEEDLE_TEMPLATE,
question_template=DEFAULT_RETRIEVAL_QUESTION_TEMPLATE,
)
reasoning_ids, reasoning_mask, reasoning_answer = _build_reasoning_inputs(
tokenizer,
device=device,
prompt_length=prompt_length,
)
instruction_ids, instruction_mask, instruction_answer = _build_instruction_inputs(
tokenizer,
device=device,
prompt_length=prompt_length,
)
return [
{
"task_name": "retrieval_passkey",
"task_family": "retrieval",
"task_variant": "passkey",
"task_prompt_preview": "Question: What is the archive code? Answer with the exact value only.",
"input_ids": retrieval.input_ids,
"attention_mask": retrieval.attention_mask,
"decode_steps": int(args.max_new_tokens_retrieval),
"expected_answer": retrieval.answer_text,
"stop_sequences": (retrieval.answer_text,),
},
{
"task_name": "reasoning_arithmetic",
"task_family": "reasoning",
"task_variant": "arithmetic",
"task_prompt_preview": "Start with 17. Add 26. Subtract 9. Add 14. What is the final total?",
"input_ids": reasoning_ids,
"attention_mask": reasoning_mask,
"decode_steps": int(args.max_new_tokens_reasoning),
"expected_answer": reasoning_answer,
"stop_sequences": (f"FINAL: {reasoning_answer}",),
},
{
"task_name": "instruction_constraints",
"task_family": "instruction",
"task_variant": "constraints",
"task_prompt_preview": "Reply with exactly two lines: STATUS: READY and COLOR: BLUE.",
"input_ids": instruction_ids,
"attention_mask": instruction_mask,
"decode_steps": int(args.max_new_tokens_instruction),
"expected_answer": instruction_answer,
"stop_sequences": (instruction_answer,),
},
]
def _apply_selector_task_context(
harness: Any,
*,
profile: str,
task_family: str,
task_variant: str,
) -> None:
if profile == "dense" or not hasattr(harness, "adapter") or not hasattr(harness.adapter, "dotcache_config"):
return
prompt_family = None if profile == "exact" else str(task_family)
prompt_variant = None if profile == "exact" else str(task_variant)
current_config = harness.adapter.dotcache_config
if hasattr(current_config, "__dataclass_fields__"):
updated_config = replace(
current_config,
learned_page_selector_prompt_family=prompt_family,
learned_page_selector_prompt_variant=prompt_variant,
)
else:
current_config.learned_page_selector_prompt_family = prompt_family
current_config.learned_page_selector_prompt_variant = prompt_variant
updated_config = current_config
harness.adapter.dotcache_config = updated_config
harness.adapter.model_kv_cache.config = updated_config