DotCache-Arena / scripts /space_llama_runner.py
DeanoCalver's picture
Add live Llama lane and writable cache fallback
e135040 verified
Raw
History Blame Contribute Delete
6.33 kB
#!/usr/bin/env python3
from __future__ import annotations
import os
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from dotcache.config import DotCacheConfig # noqa: E402
from dotcache.integrations.llama import LlamaDotCacheHarness, resolve_hf_auth_kwargs # noqa: E402
from engines.live_request import ( # noqa: E402
DEFAULT_LIVE_DECODE_STEPS,
resolve_live_runtime_settings,
selective_exact_k_overrides,
)
from scripts.space_runner_common import ( # noqa: E402
configure_model_cache_env,
decode_generated_text,
load_request_from_stdin,
print_json,
tok_per_sec_from_latency,
)
def _build_exact_length_inputs(harness: LlamaDotCacheHarness, *, prompt_unit: str, prompt_length: int):
import torch
if harness.tokenizer is None:
raise ValueError("tokenizer is unavailable for exact-length prompt construction")
if prompt_length <= 0:
raise ValueError("prompt_length must be positive")
tokenizer = harness.tokenizer
unit_ids = tokenizer(prompt_unit, add_special_tokens=False)["input_ids"]
if not unit_ids:
raise ValueError("prompt text tokenized to an empty sequence")
token_ids: list[int] = []
if tokenizer.bos_token_id is not None:
token_ids.append(int(tokenizer.bos_token_id))
while len(token_ids) < prompt_length:
token_ids.extend(int(token_id) for token_id in unit_ids)
token_ids = token_ids[:prompt_length]
device = harness.adapter.device
input_ids = torch.tensor([token_ids], dtype=torch.long, device=device)
attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device)
return input_ids, attention_mask
def _build_dotcache_config(*, settings, head_dim: int) -> DotCacheConfig:
return DotCacheConfig(
head_dim=head_dim,
group_size=32,
bits_k=settings.bits_k,
bits_v=settings.bits_v,
tokens_per_page=settings.tokens_per_page,
default_mode_k="M0",
default_mode_v="M0",
key_mode_overrides=tuple(selective_exact_k_overrides(settings.use_selective_exact_k)),
quant_scheme_k="affine",
quant_scheme_v="affine",
escape_dtype="float16",
recent_page_escape_dtype="float16",
execution_recent_window=settings.recent_window_tokens,
execution_sink_window=settings.sink_window_tokens,
execution_relevance_top_k=settings.shortlist_top_k,
execution_relevance_mode="envelope",
)
def _engine_payload(*, text: str, tok_per_sec: float, latency_ms_per_token: float, kv_bytes: int, prompt_length: int, decode_steps: int):
return {
"text": text,
"tok_per_sec": tok_per_sec,
"latency_ms_per_token": latency_ms_per_token,
"kv_bytes": kv_bytes,
"trace": [
{"name": "prompt_length", "value": int(prompt_length), "unit": "tokens"},
{"name": "decode_steps", "value": int(decode_steps), "unit": "tokens"},
],
}
def main() -> int:
configure_model_cache_env()
request = load_request_from_stdin()
settings = resolve_live_runtime_settings(
request,
decode_steps=int(os.getenv("DOTCACHE_SPACE_LIVE_DECODE_STEPS", str(DEFAULT_LIVE_DECODE_STEPS))),
max_live_context=int(os.getenv("DOTCACHE_SPACE_MAX_LIVE_CONTEXT", "4096")),
)
from transformers import AutoConfig
model_config = AutoConfig.from_pretrained(settings.model_id, **resolve_hf_auth_kwargs())
head_dim = int(getattr(model_config, "head_dim", int(model_config.hidden_size) // int(model_config.num_attention_heads)))
harness = LlamaDotCacheHarness.from_pretrained(
settings.model_id,
_build_dotcache_config(settings=settings, head_dim=head_dim),
backend=os.getenv("DOTCACHE_SPACE_BACKEND", "auto"),
device=os.getenv("DOTCACHE_SPACE_DEVICE"),
torch_dtype=os.getenv("DOTCACHE_SPACE_TORCH_DTYPE", "float16"),
)
if settings.use_exact_length_prompt:
input_ids, attention_mask = _build_exact_length_inputs(
harness,
prompt_unit=settings.prompt_text,
prompt_length=settings.context_length,
)
else:
input_ids, attention_mask = harness.tokenize_prompt(settings.prompt_text)
prompt_length = int(input_ids.shape[1])
if prompt_length > settings.context_length:
raise ValueError(
f"Custom prompt tokenized to {prompt_length} tokens, which exceeds the selected context limit "
f"of {settings.context_length}. Increase Context length or shorten the prompt."
)
record = harness.generate_greedy(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=settings.decode_steps + 1,
profile=False,
)
prompt_length = int(record.get("prompt_length") or input_ids.shape[1])
decode_steps = int(record.get("decode_steps") or settings.decode_steps)
dense_ids = list(record.get("dense_generated_ids") or [])
dotcache_ids = list(record.get("dotcache_generated_ids") or dense_ids)
baseline = _engine_payload(
text=decode_generated_text(harness.tokenizer, dense_ids, limit=settings.decode_steps),
tok_per_sec=tok_per_sec_from_latency(float(record.get("dense_decode_ms_per_step") or 0.0)),
latency_ms_per_token=float(record.get("dense_decode_ms_per_step") or 0.0),
kv_bytes=int(record.get("dense_final_kv_cache_bytes") or 0),
prompt_length=prompt_length,
decode_steps=decode_steps,
)
if str(request.get("mode") or "") == "dense":
candidate = dict(baseline)
else:
candidate = _engine_payload(
text=decode_generated_text(harness.tokenizer, dotcache_ids, limit=settings.decode_steps),
tok_per_sec=tok_per_sec_from_latency(float(record.get("decode_ms_per_step") or 0.0)),
latency_ms_per_token=float(record.get("decode_ms_per_step") or 0.0),
kv_bytes=int(record.get("kv_resident_bytes") or 0),
prompt_length=prompt_length,
decode_steps=decode_steps,
)
print_json({"baseline": baseline, "candidate": candidate})
return 0
if __name__ == "__main__":
raise SystemExit(main())