BlueSkyXN
Clean labeled DiffusionGemma draft leaks
5b311ae
Raw
History Blame Contribute Delete
7.74 kB
from __future__ import annotations
import re
import shlex
import os
import subprocess
from typing import Any
from src.config import settings
from src.errors import ApiError
from src.prepare import ld_library_path_for
ANSI_RE = re.compile(r"\x1b\[[0-?]*[ -/]*[@-~]")
SPECIAL_TOKEN_RE = re.compile(r"<\|[^|]+?\|>")
CJK_RE = re.compile(r"[\u3400-\u4dbf\u4e00-\u9fff]")
LABELED_DRAFT_RE = re.compile(
r"^\s*(?:[-*]\s*)*(?:\*{1,2})?\s*"
r"(?P<label>Title|Intro(?:duction)?|Body(?:\s+Paragraph)?\s*\d*|Para(?:graph)?\s*\d+|Conclusion)"
r"\s*(?:\([^)]*\))?\s*:\s*(?:\*{1,2}\s*)?(?P<value>.+?)\s*$",
re.IGNORECASE,
)
PLANNING_LEAK_HINT_RE = re.compile(
r"\b(?:Gaokao|Introduction|Body Paragraph|Word count|Structure|Constraints|outline|draft)\b",
re.IGNORECASE,
)
_LOG_LINE_PATTERNS = [
re.compile(r"^llama_.*", re.IGNORECASE),
re.compile(r"^ggml_.*", re.IGNORECASE),
re.compile(r"^main:.*", re.IGNORECASE),
re.compile(r"^system_info:.*", re.IGNORECASE),
re.compile(r"^sampling:.*", re.IGNORECASE),
re.compile(r"^generate:.*", re.IGNORECASE),
re.compile(r"^load_.*", re.IGNORECASE),
re.compile(r"^perf_.*", re.IGNORECASE),
re.compile(r"^total time:.*", re.IGNORECASE),
re.compile(r"^throughput:.*", re.IGNORECASE),
re.compile(r"^\*?\s*User input:.*", re.IGNORECASE),
re.compile(r"^\*?\s*System::.*", re.IGNORECASE),
re.compile(r"^>+$"),
]
def _optional_flag(cmd: list[str], flag: str, value: str | int | float | None) -> None:
if value is not None and str(value) != "":
cmd.extend([flag, str(value)])
def _cjk_density(text: str) -> float:
chars = [ch for ch in text if ch.isalnum() or CJK_RE.match(ch)]
if not chars:
return 0.0
return sum(1 for ch in chars if CJK_RE.match(ch)) / len(chars)
def _strip_labeled_draft_leak(text: str) -> str:
if not PLANNING_LEAK_HINT_RE.search(text):
return text
entries: list[tuple[str, str]] = []
for raw_line in text.splitlines():
match = LABELED_DRAFT_RE.match(raw_line.strip())
if not match:
continue
value = match.group("value").strip().strip("*").strip()
if _cjk_density(value) < 0.35:
continue
label = re.sub(r"\s+", " ", match.group("label").strip().lower())
entries.append((label, value))
if len(entries) < 2:
return text
title_indexes = [idx for idx, (label, _) in enumerate(entries) if label == "title"]
start = title_indexes[-1] if title_indexes else 0
selected = entries[start:]
if len(selected) < 2:
return text
blocks: list[str] = []
for _, value in selected:
if value and (not blocks or value != blocks[-1]):
blocks.append(value)
return "\n\n".join(blocks).strip() or text
def clean_cli_output(stdout: str, prompt: str | None = None) -> str:
text = ANSI_RE.sub("", stdout.replace("\r", "\n")).strip()
if "<|channel>final" in text:
text = text.rsplit("<|channel>final", 1)[-1]
elif "<channel|>" in text:
text = text.rsplit("<channel|>", 1)[-1]
elif "<|channel>thought" in text and "<|channel>" in text:
text = text.rsplit("<|channel>thought", 1)[-1]
# The CLI appends benchmark text after the generated canvas. Keep the model
# response only; detailed performance can be captured separately later.
text = re.split(r"\n\s*total time:", text, maxsplit=1, flags=re.IGNORECASE)[0]
# Remove common progress/log lines. This is intentionally conservative; tune it
# after observing the exact stdout emitted by the DiffusionGemma PR binary.
kept: list[str] = []
for raw_line in text.splitlines():
line = raw_line.strip()
if not line:
kept.append("")
continue
if any(p.match(line) for p in _LOG_LINE_PATTERNS):
continue
kept.append(raw_line)
cleaned = "\n".join(kept).strip()
cleaned = SPECIAL_TOKEN_RE.sub("", cleaned).strip()
if settings.strip_prompt_echo and prompt and cleaned.startswith(prompt):
cleaned = cleaned[len(prompt):].lstrip()
# Common prompt labels may remain when using a plain text prompt.
for prefix in ("Assistant:", "assistant:"):
if cleaned.startswith(prefix):
cleaned = cleaned[len(prefix):].lstrip()
return _strip_labeled_draft_leak(cleaned).strip()
def _run_once(cmd: list[str], prompt: str | None) -> subprocess.CompletedProcess[str]:
env = dict(os.environ)
env["LD_LIBRARY_PATH"] = ld_library_path_for(settings.llama_diffusion_bin)
return subprocess.run(
cmd,
input=prompt,
text=True,
capture_output=True,
timeout=settings.cli_timeout_seconds,
env=env,
)
def _should_retry_with_stdin(proc: subprocess.CompletedProcess[str]) -> bool:
stderr = (proc.stderr or "").lower()
return proc.returncode != 0 and (
"unknown argument" in stderr
or "unrecognized option" in stderr
or "invalid option" in stderr
or "unknown option" in stderr
) and ("-p" in stderr or "prompt" in stderr)
def run_diffusion_cli(model_path: str, prompt: str, params: dict[str, Any]) -> dict[str, Any]:
max_tokens = int(params.get("max_tokens") or settings.default_max_tokens)
max_tokens = max(1, min(max_tokens, settings.max_max_tokens))
max_steps = int(params.get("max_denoising_steps") or settings.diffusion_max_steps)
max_steps = max(1, min(max_steps, 120))
n_gpu_layers = int(params.get("n_gpu_layers") or settings.n_gpu_layers)
visual = bool(params.get("diffusion_visual", settings.diffusion_visual_default))
cmd = [
str(settings.llama_diffusion_bin),
"-m", model_path,
"-ngl", str(n_gpu_layers),
"-cnv",
"-n", str(max_tokens),
]
if visual:
cmd.append("--diffusion-visual")
# DiffusionGemma-specific knobs. Keep defaults close to the Unsloth guidance.
if params.get("diffusion_kv_cache", settings.diffusion_kv_cache):
_optional_flag(cmd, "--diffusion-kv-cache", params.get("diffusion_kv_cache", settings.diffusion_kv_cache))
_optional_flag(cmd, "--diffusion-eb-max-steps", max_steps)
_optional_flag(cmd, "--diffusion-eb-t-max", params.get("diffusion_eb_t_max"))
_optional_flag(cmd, "--diffusion-eb-t-min", params.get("diffusion_eb_t_min"))
_optional_flag(cmd, "--diffusion-eb-entropy-bound", params.get("diffusion_eb_entropy_bound"))
_optional_flag(cmd, "--diffusion-eb-confidence", params.get("diffusion_eb_confidence"))
extra = params.get("extra_cli_args")
if extra:
if isinstance(extra, str):
cmd.extend(shlex.split(extra))
elif isinstance(extra, list):
cmd.extend(str(x) for x in extra)
prompt_mode = str(params.get("prompt_mode") or settings.prompt_mode).lower()
if prompt_mode not in {"auto", "arg", "stdin"}:
prompt_mode = "auto"
if prompt_mode == "stdin":
proc = _run_once(cmd, prompt + "\n")
else:
proc = _run_once(cmd + ["-p", prompt], None)
if prompt_mode == "auto" and _should_retry_with_stdin(proc):
proc = _run_once(cmd, prompt + "\n")
if proc.returncode != 0:
raise ApiError(
"runner_failed",
"llama-diffusion-cli failed. stderr tail: " + proc.stderr[-4000:],
500,
)
content = clean_cli_output(proc.stdout, prompt=prompt)
return {
"model": params.get("model"),
"content": content,
"backend": "llama-diffusion-cli",
"metadata": {
"max_tokens": max_tokens,
"max_denoising_steps": max_steps,
"diffusion_visual": visual,
},
}