#!/usr/bin/env python3 from __future__ import annotations import os import sys from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[1] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from dotcache.config import DotCacheConfig # noqa: E402 from dotcache.integrations.llama import LlamaDotCacheHarness, resolve_hf_auth_kwargs # noqa: E402 from engines.live_request import ( # noqa: E402 DEFAULT_LIVE_DECODE_STEPS, resolve_live_runtime_settings, selective_exact_k_overrides, ) from scripts.space_runner_common import ( # noqa: E402 configure_model_cache_env, decode_generated_text, load_request_from_stdin, print_json, tok_per_sec_from_latency, ) def _build_exact_length_inputs(harness: LlamaDotCacheHarness, *, prompt_unit: str, prompt_length: int): import torch if harness.tokenizer is None: raise ValueError("tokenizer is unavailable for exact-length prompt construction") if prompt_length <= 0: raise ValueError("prompt_length must be positive") tokenizer = harness.tokenizer unit_ids = tokenizer(prompt_unit, add_special_tokens=False)["input_ids"] if not unit_ids: raise ValueError("prompt text tokenized to an empty sequence") token_ids: list[int] = [] if tokenizer.bos_token_id is not None: token_ids.append(int(tokenizer.bos_token_id)) while len(token_ids) < prompt_length: token_ids.extend(int(token_id) for token_id in unit_ids) token_ids = token_ids[:prompt_length] device = harness.adapter.device input_ids = torch.tensor([token_ids], dtype=torch.long, device=device) attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) return input_ids, attention_mask def _build_dotcache_config(*, settings, head_dim: int) -> DotCacheConfig: return DotCacheConfig( head_dim=head_dim, group_size=32, bits_k=settings.bits_k, bits_v=settings.bits_v, tokens_per_page=settings.tokens_per_page, default_mode_k="M0", default_mode_v="M0", key_mode_overrides=tuple(selective_exact_k_overrides(settings.use_selective_exact_k)), quant_scheme_k="affine", quant_scheme_v="affine", escape_dtype="float16", recent_page_escape_dtype="float16", execution_recent_window=settings.recent_window_tokens, execution_sink_window=settings.sink_window_tokens, execution_relevance_top_k=settings.shortlist_top_k, execution_relevance_mode="envelope", ) def _engine_payload(*, text: str, tok_per_sec: float, latency_ms_per_token: float, kv_bytes: int, prompt_length: int, decode_steps: int): return { "text": text, "tok_per_sec": tok_per_sec, "latency_ms_per_token": latency_ms_per_token, "kv_bytes": kv_bytes, "trace": [ {"name": "prompt_length", "value": int(prompt_length), "unit": "tokens"}, {"name": "decode_steps", "value": int(decode_steps), "unit": "tokens"}, ], } def main() -> int: configure_model_cache_env() request = load_request_from_stdin() settings = resolve_live_runtime_settings( request, decode_steps=int(os.getenv("DOTCACHE_SPACE_LIVE_DECODE_STEPS", str(DEFAULT_LIVE_DECODE_STEPS))), max_live_context=int(os.getenv("DOTCACHE_SPACE_MAX_LIVE_CONTEXT", "4096")), ) from transformers import AutoConfig model_config = AutoConfig.from_pretrained(settings.model_id, **resolve_hf_auth_kwargs()) head_dim = int(getattr(model_config, "head_dim", int(model_config.hidden_size) // int(model_config.num_attention_heads))) harness = LlamaDotCacheHarness.from_pretrained( settings.model_id, _build_dotcache_config(settings=settings, head_dim=head_dim), backend=os.getenv("DOTCACHE_SPACE_BACKEND", "auto"), device=os.getenv("DOTCACHE_SPACE_DEVICE"), torch_dtype=os.getenv("DOTCACHE_SPACE_TORCH_DTYPE", "float16"), ) if settings.use_exact_length_prompt: input_ids, attention_mask = _build_exact_length_inputs( harness, prompt_unit=settings.prompt_text, prompt_length=settings.context_length, ) else: input_ids, attention_mask = harness.tokenize_prompt(settings.prompt_text) prompt_length = int(input_ids.shape[1]) if prompt_length > settings.context_length: raise ValueError( f"Custom prompt tokenized to {prompt_length} tokens, which exceeds the selected context limit " f"of {settings.context_length}. Increase Context length or shorten the prompt." ) record = harness.generate_greedy( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=settings.decode_steps + 1, profile=False, ) prompt_length = int(record.get("prompt_length") or input_ids.shape[1]) decode_steps = int(record.get("decode_steps") or settings.decode_steps) dense_ids = list(record.get("dense_generated_ids") or []) dotcache_ids = list(record.get("dotcache_generated_ids") or dense_ids) baseline = _engine_payload( text=decode_generated_text(harness.tokenizer, dense_ids, limit=settings.decode_steps), tok_per_sec=tok_per_sec_from_latency(float(record.get("dense_decode_ms_per_step") or 0.0)), latency_ms_per_token=float(record.get("dense_decode_ms_per_step") or 0.0), kv_bytes=int(record.get("dense_final_kv_cache_bytes") or 0), prompt_length=prompt_length, decode_steps=decode_steps, ) if str(request.get("mode") or "") == "dense": candidate = dict(baseline) else: candidate = _engine_payload( text=decode_generated_text(harness.tokenizer, dotcache_ids, limit=settings.decode_steps), tok_per_sec=tok_per_sec_from_latency(float(record.get("decode_ms_per_step") or 0.0)), latency_ms_per_token=float(record.get("decode_ms_per_step") or 0.0), kv_bytes=int(record.get("kv_resident_bytes") or 0), prompt_length=prompt_length, decode_steps=decode_steps, ) print_json({"baseline": baseline, "candidate": candidate}) return 0 if __name__ == "__main__": raise SystemExit(main())