Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import re | |
| import shlex | |
| import os | |
| import subprocess | |
| from typing import Any | |
| from src.config import settings | |
| from src.errors import ApiError | |
| from src.prepare import ld_library_path_for | |
| ANSI_RE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]") | |
| SPECIAL_TOKEN_RE = re.compile(r"<\|[^|]+?\|>") | |
| CJK_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff]") | |
| LABELED_DRAFT_RE = re.compile( | |
| r"^\s*(?:[-*]\s*)*(?:\*{1,2})?\s*" | |
| r"(?P<label>Title|Intro(?:duction)?|Body(?:\s+Paragraph)?\s*\d*|Para(?:graph)?\s*\d+|Conclusion)" | |
| r"\s*(?:\([^)]*\))?\s*:\s*(?:\*{1,2}\s*)?(?P<value>.+?)\s*$", | |
| re.IGNORECASE, | |
| ) | |
| PLANNING_LEAK_HINT_RE = re.compile( | |
| r"\b(?:Gaokao|Introduction|Body Paragraph|Word count|Structure|Constraints|outline|draft)\b", | |
| re.IGNORECASE, | |
| ) | |
| _LOG_LINE_PATTERNS = [ | |
| re.compile(r"^llama_.*", re.IGNORECASE), | |
| re.compile(r"^ggml_.*", re.IGNORECASE), | |
| re.compile(r"^main:.*", re.IGNORECASE), | |
| re.compile(r"^system_info:.*", re.IGNORECASE), | |
| re.compile(r"^sampling:.*", re.IGNORECASE), | |
| re.compile(r"^generate:.*", re.IGNORECASE), | |
| re.compile(r"^load_.*", re.IGNORECASE), | |
| re.compile(r"^perf_.*", re.IGNORECASE), | |
| re.compile(r"^total time:.*", re.IGNORECASE), | |
| re.compile(r"^throughput:.*", re.IGNORECASE), | |
| re.compile(r"^\*?\s*User input:.*", re.IGNORECASE), | |
| re.compile(r"^\*?\s*System::.*", re.IGNORECASE), | |
| re.compile(r"^>+$"), | |
| ] | |
| def _optional_flag(cmd: list[str], flag: str, value: str | int | float | None) -> None: | |
| if value is not None and str(value) != "": | |
| cmd.extend([flag, str(value)]) | |
| def _cjk_density(text: str) -> float: | |
| chars = [ch for ch in text if ch.isalnum() or CJK_RE.match(ch)] | |
| if not chars: | |
| return 0.0 | |
| return sum(1 for ch in chars if CJK_RE.match(ch)) / len(chars) | |
| def _strip_labeled_draft_leak(text: str) -> str: | |
| if not PLANNING_LEAK_HINT_RE.search(text): | |
| return text | |
| entries: list[tuple[str, str]] = [] | |
| for raw_line in text.splitlines(): | |
| match = LABELED_DRAFT_RE.match(raw_line.strip()) | |
| if not match: | |
| continue | |
| value = match.group("value").strip().strip("*").strip() | |
| if _cjk_density(value) < 0.35: | |
| continue | |
| label = re.sub(r"\s+", " ", match.group("label").strip().lower()) | |
| entries.append((label, value)) | |
| if len(entries) < 2: | |
| return text | |
| title_indexes = [idx for idx, (label, _) in enumerate(entries) if label == "title"] | |
| start = title_indexes[-1] if title_indexes else 0 | |
| selected = entries[start:] | |
| if len(selected) < 2: | |
| return text | |
| blocks: list[str] = [] | |
| for _, value in selected: | |
| if value and (not blocks or value != blocks[-1]): | |
| blocks.append(value) | |
| return "\n\n".join(blocks).strip() or text | |
| def clean_cli_output(stdout: str, prompt: str | None = None) -> str: | |
| text = ANSI_RE.sub("", stdout.replace("\r", "\n")).strip() | |
| if "<|channel>final" in text: | |
| text = text.rsplit("<|channel>final", 1)[-1] | |
| elif "<channel|>" in text: | |
| text = text.rsplit("<channel|>", 1)[-1] | |
| elif "<|channel>thought" in text and "<|channel>" in text: | |
| text = text.rsplit("<|channel>thought", 1)[-1] | |
| # The CLI appends benchmark text after the generated canvas. Keep the model | |
| # response only; detailed performance can be captured separately later. | |
| text = re.split(r"\n\s*total time:", text, maxsplit=1, flags=re.IGNORECASE)[0] | |
| # Remove common progress/log lines. This is intentionally conservative; tune it | |
| # after observing the exact stdout emitted by the DiffusionGemma PR binary. | |
| kept: list[str] = [] | |
| for raw_line in text.splitlines(): | |
| line = raw_line.strip() | |
| if not line: | |
| kept.append("") | |
| continue | |
| if any(p.match(line) for p in _LOG_LINE_PATTERNS): | |
| continue | |
| kept.append(raw_line) | |
| cleaned = "\n".join(kept).strip() | |
| cleaned = SPECIAL_TOKEN_RE.sub("", cleaned).strip() | |
| if settings.strip_prompt_echo and prompt and cleaned.startswith(prompt): | |
| cleaned = cleaned[len(prompt):].lstrip() | |
| # Common prompt labels may remain when using a plain text prompt. | |
| for prefix in ("Assistant:", "assistant:"): | |
| if cleaned.startswith(prefix): | |
| cleaned = cleaned[len(prefix):].lstrip() | |
| return _strip_labeled_draft_leak(cleaned).strip() | |
| def _run_once(cmd: list[str], prompt: str | None) -> subprocess.CompletedProcess[str]: | |
| env = dict(os.environ) | |
| env["LD_LIBRARY_PATH"] = ld_library_path_for(settings.llama_diffusion_bin) | |
| return subprocess.run( | |
| cmd, | |
| input=prompt, | |
| text=True, | |
| capture_output=True, | |
| timeout=settings.cli_timeout_seconds, | |
| env=env, | |
| ) | |
| def _should_retry_with_stdin(proc: subprocess.CompletedProcess[str]) -> bool: | |
| stderr = (proc.stderr or "").lower() | |
| return proc.returncode != 0 and ( | |
| "unknown argument" in stderr | |
| or "unrecognized option" in stderr | |
| or "invalid option" in stderr | |
| or "unknown option" in stderr | |
| ) and ("-p" in stderr or "prompt" in stderr) | |
| def run_diffusion_cli(model_path: str, prompt: str, params: dict[str, Any]) -> dict[str, Any]: | |
| max_tokens = int(params.get("max_tokens") or settings.default_max_tokens) | |
| max_tokens = max(1, min(max_tokens, settings.max_max_tokens)) | |
| max_steps = int(params.get("max_denoising_steps") or settings.diffusion_max_steps) | |
| max_steps = max(1, min(max_steps, 120)) | |
| n_gpu_layers = int(params.get("n_gpu_layers") or settings.n_gpu_layers) | |
| visual = bool(params.get("diffusion_visual", settings.diffusion_visual_default)) | |
| cmd = [ | |
| str(settings.llama_diffusion_bin), | |
| "-m", model_path, | |
| "-ngl", str(n_gpu_layers), | |
| "-cnv", | |
| "-n", str(max_tokens), | |
| ] | |
| if visual: | |
| cmd.append("--diffusion-visual") | |
| # DiffusionGemma-specific knobs. Keep defaults close to the Unsloth guidance. | |
| if params.get("diffusion_kv_cache", settings.diffusion_kv_cache): | |
| _optional_flag(cmd, "--diffusion-kv-cache", params.get("diffusion_kv_cache", settings.diffusion_kv_cache)) | |
| _optional_flag(cmd, "--diffusion-eb-max-steps", max_steps) | |
| _optional_flag(cmd, "--diffusion-eb-t-max", params.get("diffusion_eb_t_max")) | |
| _optional_flag(cmd, "--diffusion-eb-t-min", params.get("diffusion_eb_t_min")) | |
| _optional_flag(cmd, "--diffusion-eb-entropy-bound", params.get("diffusion_eb_entropy_bound")) | |
| _optional_flag(cmd, "--diffusion-eb-confidence", params.get("diffusion_eb_confidence")) | |
| extra = params.get("extra_cli_args") | |
| if extra: | |
| if isinstance(extra, str): | |
| cmd.extend(shlex.split(extra)) | |
| elif isinstance(extra, list): | |
| cmd.extend(str(x) for x in extra) | |
| prompt_mode = str(params.get("prompt_mode") or settings.prompt_mode).lower() | |
| if prompt_mode not in {"auto", "arg", "stdin"}: | |
| prompt_mode = "auto" | |
| if prompt_mode == "stdin": | |
| proc = _run_once(cmd, prompt + "\n") | |
| else: | |
| proc = _run_once(cmd + ["-p", prompt], None) | |
| if prompt_mode == "auto" and _should_retry_with_stdin(proc): | |
| proc = _run_once(cmd, prompt + "\n") | |
| if proc.returncode != 0: | |
| raise ApiError( | |
| "runner_failed", | |
| "llama-diffusion-cli failed. stderr tail: " + proc.stderr[-4000:], | |
| 500, | |
| ) | |
| content = clean_cli_output(proc.stdout, prompt=prompt) | |
| return { | |
| "model": params.get("model"), | |
| "content": content, | |
| "backend": "llama-diffusion-cli", | |
| "metadata": { | |
| "max_tokens": max_tokens, | |
| "max_denoising_steps": max_steps, | |
| "diffusion_visual": visual, | |
| }, | |
| } | |