NuWave / nuwave /bitnet_cpp_client.py
Executor-Tyrant-Framework's picture
Sync from GitHub: 63f15a5ae4e1cea577547d0ed17460441b8f516f
6c53a1d verified
"""
BitnetCppClient β€” subprocess wrapper around bitnet.cpp's llama-cli binary.
Replaces the transformers.AutoModelForCausalLM inference path with a
subprocess call to microsoft/BitNet's llama.cpp-derivative runtime. The
runtime uses specialized ternary-weight kernels, delivering BitNet's
actual inference-efficiency benefits (which are absent when loading
the bf16 master weights through transformers).
API contract intentionally mirrors the minimal surface NuWave needs:
client = BitnetCppClient(binary, gguf_path)
response = client.generate(prompt, max_new_tokens=N, temperature=T, ...)
Subprocess-based so each call is isolated β€” no shared state between
generations, no model-instance lifecycle to manage. Slight per-call
overhead (binary startup + mmap load) but bitnet.cpp is fast enough
that this is negligible vs. token-generation time on CPU.
# ---- Changelog ----
# [2026-04-19] Claude Code (Opus 4.6) β€” Initial creation
# What: Thin wrapper for bitnet.cpp's llama-cli binary. Exposes the
# generation params NuWave needs: temperature, top_p,
# repetition_penalty, no_repeat_ngram_size, stop sequences,
# max_new_tokens.
# Why: Migration off transformers bf16 β€” see NuWave.md and the
# 2026-04-19 dev-log for the full rationale. Three pathologies
# with the prior BitNet-through-transformers setup: (1) not
# actually efficient despite the claim, (2) greedy decoding
# collapsed to repetition loops on enumeration tasks, (3) no
# repetition_penalty knob available in the transformers call
# path we had built. All three solved by bitnet.cpp + proper
# sampling params.
# How: subprocess.run with llama-cli invocation. Strips the prompt
# echo + chat-template chrome from stdout. Captures stderr for
# diagnostics. Timeout bounded so a hung generation can't
# stall the organism indefinitely.
# -------------------
"""
from __future__ import annotations
import glob
import logging
import os
import subprocess
import time
from typing import List, Optional, Tuple
logger = logging.getLogger("nuwave.bitnet_cpp_client")
class BitnetCppClient:
"""Generates text via microsoft/BitNet's llama-cli binary.
Args:
binary_path: path to the compiled llama-cli executable.
gguf_path: path to the .gguf model weights.
n_threads: CPU threads for inference (HF basic = 2 vCPUs).
n_ctx: context window in tokens (model-dependent limit).
default_timeout_s: per-call wall-clock cap. Bounded to protect
the organism from an unresponsive runtime.
Class convenience:
BitnetCppClient.resolve_gguf(dir_path) β€” finds the largest .gguf
in a directory. Used because HF repos ship multiple quant levels
and we want the one with richest weights.
"""
def __init__(
self,
binary_path: str,
gguf_path: str,
n_threads: int = 2,
n_ctx: int = 4096,
default_timeout_s: int = 900,
):
if not os.path.exists(binary_path):
raise FileNotFoundError(f"bitnet.cpp binary not found: {binary_path}")
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF weights not found: {gguf_path}")
self.binary_path = binary_path
self.gguf_path = gguf_path
self.n_threads = n_threads
self.n_ctx = n_ctx
self.default_timeout_s = default_timeout_s
parent = os.path.basename(os.path.dirname(gguf_path)) or "/"
size_mb = os.path.getsize(gguf_path) / (1024 * 1024)
logger.info(
"BitnetCppClient ready: binary=%s gguf=%s/%s size=%.0fMB threads=%d ctx=%d",
binary_path, parent, os.path.basename(gguf_path),
size_mb, n_threads, n_ctx,
)
# Sanity-check the binary β€” run `--help` once to confirm it's
# executable and find out what flags it actually accepts. If
# this fails, subsequent generation calls will also fail;
# logging it at startup makes the failure mode obvious instead
# of manifesting as silent zeros during inference.
try:
help_result = subprocess.run(
[binary_path, "--help"],
capture_output=True, text=True, timeout=10,
)
help_out = (help_result.stdout or "") + (help_result.stderr or "")
# Log first ~500 chars β€” enough to see what the binary is + its flag prefixes
snippet = help_out[:500].replace("\n", " | ")
logger.info(
"Binary sanity-check rc=%d help_snippet=%s",
help_result.returncode, snippet,
)
except Exception as exc:
logger.warning("Binary sanity-check failed: %s", exc)
@staticmethod
def resolve_gguf(directory: str) -> str:
"""Find the largest .gguf file in a directory (searches recursively).
GGUF repos often ship multiple quantization levels (e.g.
q2_K, q4_K_S, q4_K_M, q5_K_M, q8_0). We pick the largest
because it's the richest-precision version that still fits
our memory budget β€” for 1.58-bit models this typically means
the raw ternary weights without further compression.
Searches recursively because setup_env.py and snapshot_download
can both place files in nested directory structures whose exact
layout is not guaranteed stable across versions.
"""
gguf_files = glob.glob(os.path.join(directory, "**", "*.gguf"), recursive=True)
# Also include top-level (glob's ** doesn't match zero dirs on all platforms)
gguf_files += glob.glob(os.path.join(directory, "*.gguf"))
gguf_files = list(set(gguf_files))
if not gguf_files:
raise FileNotFoundError(f"No .gguf files found under {directory} (recursive)")
gguf_files.sort(key=os.path.getsize, reverse=True)
return gguf_files[0]
def generate(
self,
prompt: str,
max_new_tokens: int = 128,
temperature: float = 0.7,
top_p: float = 0.9,
repetition_penalty: float = 1.2,
repeat_last_n: int = 64,
stop: Optional[List[str]] = None,
seed: int = -1,
timeout_s: Optional[int] = None,
grammar_file: Optional[str] = None,
grammar: Optional[str] = None,
) -> Tuple[str, dict]:
"""Generate a completion for the given prompt.
Returns:
(response_text, metadata_dict)
metadata_dict contains:
elapsed_s β€” wall-clock of the subprocess call
returncode β€” llama-cli exit code
raw_stdout β€” full stdout (pre-stripping) for diagnostics
prompt_echo_found β€” whether the prompt was found in stdout
(if False, the runtime output format
may have changed β€” worth investigating)
stderr_tail β€” last 500 chars of stderr (stats/warnings)
Generation params are llama.cpp-standard and passed through to
the binary. Defaults chosen per Syl's prescription for small
models on enumeration tasks: non-greedy sampling + repetition
penalty + repeat-last-n window prevents the mode-collapse
pathology we saw with transformers greedy decoding.
"""
# Flag set verified compatible with this bitnet.cpp fork
# (Eddie-Wang1120/llama.cpp at commit 1f86f058). History of
# removals:
# -no-cnv β€” fork's argparse rejected it; redundant anyway
# (default with -p PROMPT is non-conversational).
# --log-disable β€” some fork versions silence generation
# output entirely when this is set. Safer to keep logs
# mingled with stdout and strip the prompt echo on our
# side (we already do that).
args = [
self.binary_path,
"-m", self.gguf_path,
"-p", prompt,
"-n", str(max_new_tokens),
"--temp", f"{temperature:.3f}",
"--top-p", f"{top_p:.3f}",
"--repeat-penalty", f"{repetition_penalty:.3f}",
"--repeat-last-n", str(repeat_last_n),
"-t", str(self.n_threads),
"-c", str(self.n_ctx),
"--seed", str(seed),
]
if stop:
for s in stop:
# Skip stop sequences that are pure whitespace β€” they
# tend to match at position 0 of model output and trim
# everything. Use model-specific stop tokens or content
# markers ("Answer:", "</s>", etc.) instead.
if not s or not s.strip():
continue
args.extend(["--reverse-prompt", s])
# Grammar-constrained decoding. `grammar` takes precedence over
# `grammar_file` because inline is path-agnostic (no container
# filesystem surprises). llama.cpp parses the GBNF into an FSM
# and masks logits at each sampling step so tokens violating
# the grammar get zero probability. Native C++ β€” no per-token
# Python callback overhead.
grammar_mode = None
if grammar:
args.extend(["--grammar", grammar])
grammar_mode = f"inline ({len(grammar)} chars)"
elif grammar_file:
if not os.path.exists(grammar_file):
logger.warning(
"Grammar file missing: %s β€” generation will be unconstrained",
grammar_file,
)
else:
args.extend(["--grammar-file", grammar_file])
grammar_mode = f"file ({grammar_file})"
# Log the invocation when grammar-constrained β€” lets us
# confirm from logs that the flag is actually reaching
# llama-cli, which was ambiguous after run 6's silent failure.
if grammar_mode:
logger.info(
"llama-cli grammar-constrained: %s | argv_len=%d | last_args=%s",
grammar_mode, len(args), args[-3:],
)
t0 = time.time()
try:
result = subprocess.run(
args,
capture_output=True,
text=True,
timeout=timeout_s or self.default_timeout_s,
)
except subprocess.TimeoutExpired:
return "", {
"elapsed_s": round(time.time() - t0, 2),
"returncode": -1,
"raw_stdout": "",
"prompt_echo_found": False,
"stderr_tail": "TIMEOUT",
"error": "subprocess.TimeoutExpired",
}
elapsed = round(time.time() - t0, 2)
stdout = result.stdout or ""
stderr = result.stderr or ""
# Log ANY non-zero returncode β€” this is usually an invalid flag,
# GGUF load failure, or OOM. Without this log, failures surface
# only as empty responses, which looks like "the model generated
# nothing" instead of "the subprocess exited before generation."
if result.returncode != 0:
logger.warning(
"llama-cli rc=%d elapsed=%.2fs stderr_tail=%s | stdout_tail=%s",
result.returncode, elapsed, stderr[-400:], stdout[-200:],
)
elif not stdout.strip():
# rc=0 but stdout empty is a subtler pathology β€” flag silently
# suppressed output, or the model generated nothing. Log so
# it's visible without the caller having to inspect the dict.
logger.warning(
"llama-cli rc=0 but stdout EMPTY (elapsed=%.2fs). "
"stderr_tail=%s",
elapsed, stderr[-400:],
)
# If a grammar was requested, log any grammar-related lines from
# stderr. llama.cpp prints parse errors to stderr when the GBNF
# is malformed, and silently falls back to unconstrained
# generation. These logs expose that silent fallback.
if grammar_mode and stderr:
for line in stderr.splitlines():
low = line.lower()
if "grammar" in low or "gbnf" in low:
logger.info("grammar stderr: %s", line.strip()[:200])
# Strip the prompt echo β€” llama-cli's default output includes the
# prompt as a prefix. Find the LAST occurrence because reverse-
# prompt handling can print the prompt multiple times.
response = stdout
prompt_found = False
if prompt and prompt in stdout:
idx = stdout.rfind(prompt)
response = stdout[idx + len(prompt):]
prompt_found = True
# Strip common end-of-text markers
response = response.rstrip()
for marker in ("[end of text]", "</s>", "<|im_end|>", "<|end_of_text|>"):
if response.endswith(marker):
response = response[: -len(marker)].rstrip()
# If a stop string matched (reverse-prompt), trim at the first match.
# Skip whitespace-only stops β€” they tend to match at position 0
# of raw model output (models often start with a newline after
# prompt echo) and trim the response to empty. Those should
# only be terminators after real content, which we can't
# distinguish reliably post-hoc.
if stop:
for s in stop:
if not s or not s.strip():
continue
if s in response:
response = response[: response.index(s)]
return response, {
"elapsed_s": elapsed,
"returncode": result.returncode,
"raw_stdout": stdout,
"prompt_echo_found": prompt_found,
"stderr_tail": stderr[-500:] if stderr else "",
"error": None,
}