DotCache-Arena / scripts /space_dotcache_runner.py
Deano Calver
Fix live space task prompt imports
7cbeabe
Raw
History Blame Contribute Delete
11.7 kB
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import os
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from benchmarks.bench_qwen35_attention_subset_dotcache_serving import _build_dotcache_config # noqa: E402
from dotcache.integrations.qwen35 import Qwen35AttentionSubsetDotCacheHarness # noqa: E402
from engines.live_request import DEFAULT_LIVE_DECODE_STEPS, resolve_live_runtime_settings # noqa: E402
from scripts.space_runner_common import ( # noqa: E402
configure_model_cache_env,
decode_generated_text,
load_request_from_stdin,
print_json,
tok_per_sec_from_latency,
)
from scripts.space_task_prompts import _apply_selector_task_context, _task_specs # noqa: E402
BACKEND_TRUTH_PROMPT_UNIT = "Cache locality matters for fast decoding."
def _build_exact_length_inputs(harness: Qwen35AttentionSubsetDotCacheHarness, *, prompt_unit: str, prompt_length: int):
import torch
if harness.tokenizer is None:
raise ValueError("tokenizer is unavailable for exact-length prompt construction")
if prompt_length <= 0:
raise ValueError("prompt_length must be positive")
tokenizer = harness.tokenizer
unit_ids = tokenizer(prompt_unit, add_special_tokens=False)["input_ids"]
if not unit_ids:
raise ValueError("prompt text tokenized to an empty sequence")
token_ids: list[int] = []
if tokenizer.bos_token_id is not None:
token_ids.append(int(tokenizer.bos_token_id))
while len(token_ids) < prompt_length:
token_ids.extend(int(token_id) for token_id in unit_ids)
token_ids = token_ids[:prompt_length]
device = harness.adapter.device
input_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
return input_ids, attention_mask
def _build_args_namespace(settings, *, head_dim: int):
return argparse.Namespace(
model_id=settings.model_id,
group_size=32,
bits_k=settings.bits_k,
bits_v=settings.bits_v,
default_mode_k="M0",
default_mode_v="M0",
key_policy_tier="exact",
value_policy_tier="exact",
key_mode_override=[],
value_mode_override=[],
key_layer_sensitivity=[],
value_layer_sensitivity=[],
key_policy_override=[],
value_policy_override=[],
quant_scheme_k="affine",
quant_scheme_v="affine",
escape_dtype="float16",
recent_page_escape_dtype="float16",
recent_window=settings.recent_window_tokens,
execution_recent_window=settings.execution_recent_window_tokens,
execution_sink_window=settings.execution_sink_window_tokens,
execution_recent_window_layer=[],
execution_recent_window_context_layer=[],
execution_relevance_top_k=settings.execution_relevance_top_k,
execution_relevance_mode="envelope",
execution_relevance_top_k_layer=[],
execution_relevance_top_k_context_layer=[],
execution_full_context_layer=[],
execution_disable_grouped_batching_layer=[],
execution_recent_old_bonus_window=0,
execution_recent_old_bonus_strength=0.0,
execution_recent_old_bonus_layer=[],
execution_secondary_relevance_mode="",
execution_secondary_relevance_top_k=0,
execution_secondary_relevance_min_overlap=0.0,
execution_secondary_relevance_layer=[],
execution_recent_neighbor_rescue_top_k=0,
execution_recent_neighbor_rescue_anchor_window=0,
execution_recent_neighbor_rescue_min_anchor_pages=0,
execution_recent_neighbor_rescue_layer=[],
execution_exact_promote_top_k=0,
execution_exact_promote_min_margin_threshold=0.0,
execution_exact_promote_max_context=0,
execution_exact_promote_margin_threshold=0.0,
execution_exact_promote_layer=[],
execution_exact_promote_union_rescue_top_k=0,
execution_grouped_decode_compact=False,
execution_grouped_mix_compact=False,
execution_grouped_mix_disable_packed_cuda=False,
execution_freeze_chunk_budget_during_decode=False,
execution_builtin_selector_cache=False,
execution_builtin_selector_score_all_pages=False,
execution_builtin_selector_candidate_only=False,
execution_builtin_selector_score_all_pages_min_candidate_fraction=0.0,
execution_value_escape_layer=[],
execution_value_escape_mode="M3",
execution_value_escape_old_only=False,
execution_value_escape_top_k=0,
execution_value_escape_prewarm=False,
execution_value_escape_prewarm_min_context=0,
execution_exact_refine_top_k=0,
execution_exact_refine_layer=[],
m2_sketch_dim_k=8,
m2_center_k=False,
m2_segment_count_k=1,
m2_adaptive_segments_k=False,
m2_adaptive_min_improvement_k=0.1,
m2_prefilter_top_k=0,
m2_prefilter_min_pages=8,
prefer_m4_project_k=False,
lut_refine_steps=6,
preconditioner="none",
precondition_strength=2.0,
m1_segment_count_k=1,
m1_segment_count_v=1,
m1_fallback_to_m0=True,
m1_error_threshold=0.35,
m1_token_p95_error_threshold=1000000.0,
tokens_per_page=settings.tokens_per_page,
learned_page_selector_path=settings.learned_page_selector_path,
learned_page_selector_prompt_family=settings.learned_page_selector_prompt_family,
learned_page_selector_prompt_variant=settings.learned_page_selector_prompt_variant,
learned_page_selector_profile=settings.learned_page_selector_profile,
learned_page_selector_scope="KV",
learned_page_selector_target_candidate="M3/affine/4/float16",
learned_page_selector_logit_offset=0.0,
prepared_chunk_cache_budget_ratio=None,
prepared_chunk_cache_min_bytes=None,
prepared_chunk_cache_max_bytes=None,
head_dim=head_dim,
)
def _task_prompt_inputs(harness: Qwen35AttentionSubsetDotCacheHarness, settings):
task_args = argparse.Namespace(
max_new_tokens_retrieval=64,
max_new_tokens_reasoning=64,
max_new_tokens_instruction=32,
)
task_specs = _task_specs(
harness,
prompt_length=settings.context_length,
args=task_args,
)
for task_spec in task_specs:
if task_spec["task_name"] == settings.compact_task_name:
if settings.benchmark_variant != "exact":
_apply_selector_task_context(
harness,
profile=settings.benchmark_variant,
task_family=str(task_spec["task_family"]),
task_variant=str(task_spec["task_variant"]),
)
return task_spec
raise ValueError(f"Unsupported compact-task replay task: {settings.compact_task_name}")
def main() -> int:
configure_model_cache_env()
request = load_request_from_stdin()
settings = resolve_live_runtime_settings(
request,
decode_steps=int(os.getenv("DOTCACHE_SPACE_LIVE_DECODE_STEPS", str(DEFAULT_LIVE_DECODE_STEPS))),
max_live_context=int(os.getenv("DOTCACHE_SPACE_MAX_LIVE_CONTEXT", "4096")),
)
from transformers import AutoConfig
model_config = AutoConfig.from_pretrained(settings.model_id, trust_remote_code=False)
text_config = getattr(model_config, "text_config", model_config)
head_dim = int(getattr(text_config, "head_dim", int(text_config.hidden_size) // int(text_config.num_attention_heads)))
args = _build_args_namespace(settings, head_dim=head_dim)
dotcache_config = _build_dotcache_config(args, head_dim=head_dim)
harness = Qwen35AttentionSubsetDotCacheHarness.from_pretrained(
settings.model_id,
dotcache_config=dotcache_config,
backend=os.getenv("DOTCACHE_SPACE_BACKEND", "auto"),
device=os.getenv("DOTCACHE_SPACE_DEVICE"),
torch_dtype=os.getenv("DOTCACHE_SPACE_TORCH_DTYPE", "float16"),
weight_quantization=os.getenv("DOTCACHE_SPACE_WEIGHT_QUANTIZATION", "none"),
)
if settings.is_custom_prompt:
input_ids, attention_mask = harness.tokenize_prompt(settings.prompt_text)
prompt_length = int(input_ids.shape[1])
if prompt_length > settings.context_length:
raise ValueError(
f"Custom prompt tokenized to {prompt_length} tokens, which exceeds the selected context limit "
f"of {settings.context_length}. Increase Context length or shorten the prompt."
)
decode_steps = int(settings.decode_steps)
record = harness.run_attention_subset_dotcache_serving(
input_ids=input_ids,
attention_mask=attention_mask,
decode_steps=decode_steps,
profile_backend=True,
)
generated_ids = list(record.get("dotcache_generated_ids") or [])
text = decode_generated_text(harness.tokenizer, generated_ids, limit=decode_steps)
elif settings.benchmark_suite == "compact_task":
task_spec = _task_prompt_inputs(harness, settings)
input_ids = task_spec["input_ids"]
attention_mask = task_spec["attention_mask"]
decode_steps = int(task_spec["decode_steps"])
record = harness.run_attention_subset_dotcache_serving(
input_ids=input_ids,
attention_mask=attention_mask,
decode_steps=decode_steps,
profile_backend=True,
)
generated_ids = list(record.get("dotcache_generated_ids") or [])
text = decode_generated_text(harness.tokenizer, generated_ids, limit=decode_steps)
elif settings.use_exact_length_prompt:
input_ids, attention_mask = _build_exact_length_inputs(
harness,
prompt_unit=settings.prompt_text or BACKEND_TRUTH_PROMPT_UNIT,
prompt_length=settings.context_length,
)
decode_steps = int(settings.decode_steps)
record = harness.run_attention_subset_dotcache_serving(
input_ids=input_ids,
attention_mask=attention_mask,
decode_steps=decode_steps,
profile_backend=True,
)
generated_ids = list(record.get("dotcache_generated_ids") or [])
text = decode_generated_text(harness.tokenizer, generated_ids, limit=decode_steps)
else:
raise ValueError(
"Live replay for this benchmark section is not wired in the Space yet. "
"Use the preset-backed compare for the valid benchmark row."
)
latency = float(record.get("dotcache_decode_ms_per_step") or 0.0)
prefill_ms = float(record.get("dotcache_prefill_ms") or 0.0)
payload = {
"text": text,
"tok_per_sec": tok_per_sec_from_latency(latency),
"latency_ms_per_token": latency,
"kv_bytes": int(record.get("resident_bytes") or record.get("kv_resident_bytes") or 0),
"trace": [
{"name": "prefill_ms", "value": prefill_ms, "unit": "ms"},
{"name": "prompt_length", "value": int(record.get("prompt_length") or settings.context_length), "unit": "tokens"},
{"name": "decode_steps", "value": int(record.get("decode_steps") or decode_steps), "unit": "tokens"},
{"name": "benchmark_suite", "value": settings.benchmark_suite, "unit": "label"},
{"name": "benchmark_variant", "value": settings.benchmark_variant, "unit": "label"},
],
}
print_json(payload)
return 0
if __name__ == "__main__":
raise SystemExit(main())