Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import os | |
| import sys | |
| from pathlib import Path | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from dotcache.integrations.qwen35 import Qwen35TextHarness # noqa: E402 | |
| from engines.live_request import DEFAULT_LIVE_DECODE_STEPS, resolve_live_runtime_settings # noqa: E402 | |
| from scripts.space_runner_common import ( # noqa: E402 | |
| configure_model_cache_env, | |
| decode_generated_text, | |
| load_request_from_stdin, | |
| print_json, | |
| tok_per_sec_from_latency, | |
| ) | |
| from scripts.space_task_prompts import _task_specs # noqa: E402 | |
| def _build_exact_length_inputs(harness: Qwen35TextHarness, *, prompt_unit: str, prompt_length: int): | |
| import torch | |
| if harness.tokenizer is None: | |
| raise ValueError("tokenizer is unavailable for exact-length prompt construction") | |
| if prompt_length <= 0: | |
| raise ValueError("prompt_length must be positive") | |
| tokenizer = harness.tokenizer | |
| unit_ids = tokenizer(prompt_unit, add_special_tokens=False)["input_ids"] | |
| if not unit_ids: | |
| raise ValueError("prompt text tokenized to an empty sequence") | |
| token_ids: list[int] = [] | |
| if tokenizer.bos_token_id is not None: | |
| token_ids.append(int(tokenizer.bos_token_id)) | |
| while len(token_ids) < prompt_length: | |
| token_ids.extend(int(token_id) for token_id in unit_ids) | |
| token_ids = token_ids[:prompt_length] | |
| device = harness.adapter.device | |
| input_ids = torch.tensor([token_ids], dtype=torch.long, device=device) | |
| attention_mask = torch.ones_like(input_ids, dtype=torch.long, device=device) | |
| return input_ids, attention_mask | |
| def _task_prompt_inputs(harness: Qwen35TextHarness, settings): | |
| task_args = argparse.Namespace( | |
| max_new_tokens_retrieval=64, | |
| max_new_tokens_reasoning=64, | |
| max_new_tokens_instruction=32, | |
| ) | |
| task_specs = _task_specs( | |
| harness, | |
| prompt_length=settings.context_length, | |
| args=task_args, | |
| ) | |
| for task_spec in task_specs: | |
| if task_spec["task_name"] == settings.compact_task_name: | |
| return task_spec | |
| raise ValueError(f"Unsupported compact-task replay task: {settings.compact_task_name}") | |
| def main() -> int: | |
| configure_model_cache_env() | |
| request = load_request_from_stdin() | |
| settings = resolve_live_runtime_settings( | |
| request, | |
| decode_steps=int(os.getenv("DOTCACHE_SPACE_LIVE_DECODE_STEPS", str(DEFAULT_LIVE_DECODE_STEPS))), | |
| max_live_context=int(os.getenv("DOTCACHE_SPACE_MAX_LIVE_CONTEXT", "4096")), | |
| ) | |
| harness = Qwen35TextHarness.from_pretrained( | |
| settings.model_id, | |
| device=os.getenv("DOTCACHE_SPACE_DEVICE"), | |
| torch_dtype=os.getenv("DOTCACHE_SPACE_TORCH_DTYPE", "float16"), | |
| weight_quantization=os.getenv("DOTCACHE_SPACE_WEIGHT_QUANTIZATION", "none"), | |
| ) | |
| if settings.is_custom_prompt: | |
| input_ids, attention_mask = harness.tokenize_prompt(settings.prompt_text) | |
| prompt_length = int(input_ids.shape[1]) | |
| if prompt_length > settings.context_length: | |
| raise ValueError( | |
| f"Custom prompt tokenized to {prompt_length} tokens, which exceeds the selected context limit " | |
| f"of {settings.context_length}. Increase Context length or shorten the prompt." | |
| ) | |
| decode_steps = int(settings.decode_steps) | |
| elif settings.benchmark_suite == "compact_task": | |
| task_spec = _task_prompt_inputs(harness, settings) | |
| input_ids = task_spec["input_ids"] | |
| attention_mask = task_spec["attention_mask"] | |
| decode_steps = int(task_spec["decode_steps"]) | |
| elif settings.use_exact_length_prompt: | |
| input_ids, attention_mask = _build_exact_length_inputs( | |
| harness, | |
| prompt_unit=settings.prompt_text, | |
| prompt_length=settings.context_length, | |
| ) | |
| decode_steps = int(settings.decode_steps) | |
| else: | |
| raise ValueError( | |
| "Live replay for this benchmark section is not wired in the Space yet. " | |
| "Use the preset-backed compare for the valid benchmark row." | |
| ) | |
| record = harness.generate_greedy( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| max_new_tokens=decode_steps + 1, | |
| ) | |
| generated_ids = list(record.get("dense_generated_ids") or []) | |
| text = decode_generated_text(harness.tokenizer, generated_ids, limit=decode_steps) | |
| latency = float(record.get("dense_decode_ms_per_step") or 0.0) | |
| prefill_ms = float(record.get("prefill_ms") or 0.0) | |
| payload = { | |
| "text": text, | |
| "tok_per_sec": tok_per_sec_from_latency(latency), | |
| "latency_ms_per_token": latency, | |
| "kv_bytes": int(record.get("dense_final_cache_bytes") or 0), | |
| "trace": [ | |
| {"name": "prefill_ms", "value": prefill_ms, "unit": "ms"}, | |
| {"name": "prompt_length", "value": int(record.get("prompt_length") or input_ids.shape[1]), "unit": "tokens"}, | |
| { | |
| "name": "decode_steps", | |
| "value": int(record.get("decode_steps") or decode_steps), | |
| "unit": "tokens", | |
| }, | |
| ], | |
| } | |
| print_json(payload) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |