#!/usr/bin/env python3 from __future__ import annotations import argparse import os import sys from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from benchmarks.bench_qwen35_attention_subset_dotcache_serving import _build_dotcache_config # noqa: E402 from dotcache.integrations.qwen35 import Qwen35AttentionSubsetDotCacheHarness # noqa: E402 from engines.live_request import DEFAULT_LIVE_DECODE_STEPS, resolve_live_runtime_settings # noqa: E402 from scripts.space_runner_common import ( # noqa: E402 configure_model_cache_env, decode_generated_text, load_request_from_stdin, print_json, tok_per_sec_from_latency, ) from scripts.space_task_prompts import _apply_selector_task_context, _task_specs # noqa: E402 BACKEND_TRUTH_PROMPT_UNIT = "Cache locality matters for fast decoding." def _build_exact_length_inputs(harness: Qwen35AttentionSubsetDotCacheHarness, *, prompt_unit: str, prompt_length: int): import torch if harness.tokenizer is None: raise ValueError("tokenizer is unavailable for exact-length prompt construction") if prompt_length <= 0: raise ValueError("prompt_length must be positive") tokenizer = harness.tokenizer unit_ids = tokenizer(prompt_unit, add_special_tokens=False)["input_ids"] if not unit_ids: raise ValueError("prompt text tokenized to an empty sequence") token_ids: list[int] = [] if tokenizer.bos_token_id is not None: token_ids.append(int(tokenizer.bos_token_id)) while len(token_ids) < prompt_length: token_ids.extend(int(token_id) for token_id in unit_ids) token_ids = token_ids[:prompt_length] device = harness.adapter.device input_ids = torch.tensor([token_ids], dtype=torch.long, device=device) attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) return input_ids, attention_mask def _build_args_namespace(settings, *, head_dim: int): return argparse.Namespace( model_id=settings.model_id, group_size=32, bits_k=settings.bits_k, bits_v=settings.bits_v, default_mode_k="M0", default_mode_v="M0", key_policy_tier="exact", value_policy_tier="exact", key_mode_override=[], value_mode_override=[], key_layer_sensitivity=[], value_layer_sensitivity=[], key_policy_override=[], value_policy_override=[], quant_scheme_k="affine", quant_scheme_v="affine", escape_dtype="float16", recent_page_escape_dtype="float16", recent_window=settings.recent_window_tokens, execution_recent_window=settings.execution_recent_window_tokens, execution_sink_window=settings.execution_sink_window_tokens, execution_recent_window_layer=[], execution_recent_window_context_layer=[], execution_relevance_top_k=settings.execution_relevance_top_k, execution_relevance_mode="envelope", execution_relevance_top_k_layer=[], execution_relevance_top_k_context_layer=[], execution_full_context_layer=[], execution_disable_grouped_batching_layer=[], execution_recent_old_bonus_window=0, execution_recent_old_bonus_strength=0.0, execution_recent_old_bonus_layer=[], execution_secondary_relevance_mode="", execution_secondary_relevance_top_k=0, execution_secondary_relevance_min_overlap=0.0, execution_secondary_relevance_layer=[], execution_recent_neighbor_rescue_top_k=0, execution_recent_neighbor_rescue_anchor_window=0, execution_recent_neighbor_rescue_min_anchor_pages=0, execution_recent_neighbor_rescue_layer=[], execution_exact_promote_top_k=0, execution_exact_promote_min_margin_threshold=0.0, execution_exact_promote_max_context=0, execution_exact_promote_margin_threshold=0.0, execution_exact_promote_layer=[], execution_exact_promote_union_rescue_top_k=0, execution_grouped_decode_compact=False, execution_grouped_mix_compact=False, execution_grouped_mix_disable_packed_cuda=False, execution_freeze_chunk_budget_during_decode=False, execution_builtin_selector_cache=False, execution_builtin_selector_score_all_pages=False, execution_builtin_selector_candidate_only=False, execution_builtin_selector_score_all_pages_min_candidate_fraction=0.0, execution_value_escape_layer=[], execution_value_escape_mode="M3", execution_value_escape_old_only=False, execution_value_escape_top_k=0, execution_value_escape_prewarm=False, execution_value_escape_prewarm_min_context=0, execution_exact_refine_top_k=0, execution_exact_refine_layer=[], m2_sketch_dim_k=8, m2_center_k=False, m2_segment_count_k=1, m2_adaptive_segments_k=False, m2_adaptive_min_improvement_k=0.1, m2_prefilter_top_k=0, m2_prefilter_min_pages=8, prefer_m4_project_k=False, lut_refine_steps=6, preconditioner="none", precondition_strength=2.0, m1_segment_count_k=1, m1_segment_count_v=1, m1_fallback_to_m0=True, m1_error_threshold=0.35, m1_token_p95_error_threshold=1000000.0, tokens_per_page=settings.tokens_per_page, learned_page_selector_path=settings.learned_page_selector_path, learned_page_selector_prompt_family=settings.learned_page_selector_prompt_family, learned_page_selector_prompt_variant=settings.learned_page_selector_prompt_variant, learned_page_selector_profile=settings.learned_page_selector_profile, learned_page_selector_scope="KV", learned_page_selector_target_candidate="M3/affine/4/float16", learned_page_selector_logit_offset=0.0, prepared_chunk_cache_budget_ratio=None, prepared_chunk_cache_min_bytes=None, prepared_chunk_cache_max_bytes=None, head_dim=head_dim, ) def _task_prompt_inputs(harness: Qwen35AttentionSubsetDotCacheHarness, settings): task_args = argparse.Namespace( max_new_tokens_retrieval=64, max_new_tokens_reasoning=64, max_new_tokens_instruction=32, ) task_specs = _task_specs( harness, prompt_length=settings.context_length, args=task_args, ) for task_spec in task_specs: if task_spec["task_name"] == settings.compact_task_name: if settings.benchmark_variant != "exact": _apply_selector_task_context( harness, profile=settings.benchmark_variant, task_family=str(task_spec["task_family"]), task_variant=str(task_spec["task_variant"]), ) return task_spec raise ValueError(f"Unsupported compact-task replay task: {settings.compact_task_name}") def main() -> int: configure_model_cache_env() request = load_request_from_stdin() settings = resolve_live_runtime_settings( request, decode_steps=int(os.getenv("DOTCACHE_SPACE_LIVE_DECODE_STEPS", str(DEFAULT_LIVE_DECODE_STEPS))), max_live_context=int(os.getenv("DOTCACHE_SPACE_MAX_LIVE_CONTEXT", "4096")), ) from transformers import AutoConfig model_config = AutoConfig.from_pretrained(settings.model_id, trust_remote_code=False) text_config = getattr(model_config, "text_config", model_config) head_dim = int(getattr(text_config, "head_dim", int(text_config.hidden_size) // int(text_config.num_attention_heads))) args = _build_args_namespace(settings, head_dim=head_dim) dotcache_config = _build_dotcache_config(args, head_dim=head_dim) harness = Qwen35AttentionSubsetDotCacheHarness.from_pretrained( settings.model_id, dotcache_config=dotcache_config, backend=os.getenv("DOTCACHE_SPACE_BACKEND", "auto"), device=os.getenv("DOTCACHE_SPACE_DEVICE"), torch_dtype=os.getenv("DOTCACHE_SPACE_TORCH_DTYPE", "float16"), weight_quantization=os.getenv("DOTCACHE_SPACE_WEIGHT_QUANTIZATION", "none"), ) if settings.is_custom_prompt: input_ids, attention_mask = harness.tokenize_prompt(settings.prompt_text) prompt_length = int(input_ids.shape[1]) if prompt_length > settings.context_length: raise ValueError( f"Custom prompt tokenized to {prompt_length} tokens, which exceeds the selected context limit " f"of {settings.context_length}. Increase Context length or shorten the prompt." ) decode_steps = int(settings.decode_steps) record = harness.run_attention_subset_dotcache_serving( input_ids=input_ids, attention_mask=attention_mask, decode_steps=decode_steps, profile_backend=True, ) generated_ids = list(record.get("dotcache_generated_ids") or []) text = decode_generated_text(harness.tokenizer, generated_ids, limit=decode_steps) elif settings.benchmark_suite == "compact_task": task_spec = _task_prompt_inputs(harness, settings) input_ids = task_spec["input_ids"] attention_mask = task_spec["attention_mask"] decode_steps = int(task_spec["decode_steps"]) record = harness.run_attention_subset_dotcache_serving( input_ids=input_ids, attention_mask=attention_mask, decode_steps=decode_steps, profile_backend=True, ) generated_ids = list(record.get("dotcache_generated_ids") or []) text = decode_generated_text(harness.tokenizer, generated_ids, limit=decode_steps) elif settings.use_exact_length_prompt: input_ids, attention_mask = _build_exact_length_inputs( harness, prompt_unit=settings.prompt_text or BACKEND_TRUTH_PROMPT_UNIT, prompt_length=settings.context_length, ) decode_steps = int(settings.decode_steps) record = harness.run_attention_subset_dotcache_serving( input_ids=input_ids, attention_mask=attention_mask, decode_steps=decode_steps, profile_backend=True, ) generated_ids = list(record.get("dotcache_generated_ids") or []) text = decode_generated_text(harness.tokenizer, generated_ids, limit=decode_steps) else: raise ValueError( "Live replay for this benchmark section is not wired in the Space yet. " "Use the preset-backed compare for the valid benchmark row." ) latency = float(record.get("dotcache_decode_ms_per_step") or 0.0) prefill_ms = float(record.get("dotcache_prefill_ms") or 0.0) payload = { "text": text, "tok_per_sec": tok_per_sec_from_latency(latency), "latency_ms_per_token": latency, "kv_bytes": int(record.get("resident_bytes") or record.get("kv_resident_bytes") or 0), "trace": [ {"name": "prefill_ms", "value": prefill_ms, "unit": "ms"}, {"name": "prompt_length", "value": int(record.get("prompt_length") or settings.context_length), "unit": "tokens"}, {"name": "decode_steps", "value": int(record.get("decode_steps") or decode_steps), "unit": "tokens"}, {"name": "benchmark_suite", "value": settings.benchmark_suite, "unit": "label"}, {"name": "benchmark_variant", "value": settings.benchmark_variant, "unit": "label"}, ], } print_json(payload) return 0 if __name__ == "__main__": raise SystemExit(main())