fr-on-device / inference.py
Joseph Pollack
final interface improvements
ea25b4a unverified
"""
Parallel load and inference for all 6 models (Baguettotron + 5 Luth).
Baguettotron uses EOS-safe formatting: "<|im_end>" (no trailing pipe), stop=["<|im_end>", "</think>"].
"""
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
import torch
from model_config import MODEL_IDS
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.utils import logging as hf_logging
# Reduce load-time noise (e.g. "lm_head.weight | MISSING" for Qwen3 tied-embedding models)
hf_logging.set_verbosity_error()
# In-memory cache: model_id -> (model, tokenizer)
_model_cache: dict[str, tuple[Any, Any]] = {}
_cache_lock = __import__("threading").Lock()
# Baguettotron repo_id for EOS quirk handling
BAGUETTOTRON_ID = "PleIAs/Baguettotron"
def _format_prompt_baguettotron(prompt: str, system_prompt: str = "") -> tuple[str, list[str]]:
"""
Manual prompt build for Baguettotron. Uses "<|im_end>" (no trailing pipe)
per tokenizer; stop=["<|im_end>", "</think>"] for generation.
Qwen-style: system (optional) + user + assistant.
"""
parts: list[str] = []
if system_prompt.strip():
parts.append(f"<|im_start|>system\n{system_prompt.strip()}<|im_end>\n")
parts.append(f"<|im_start|>user\n{prompt}<|im_end>\n<|im_start|>assistant\n<think>\n")
text = "".join(parts)
stop = ["<|im_end>", "</think>"]
return text, stop
def _format_prompt_luth(prompt: str, tokenizer: Any, system_prompt: str = "") -> tuple[dict[str, Any], list[str] | None]:
"""Use tokenizer's chat template for Luth models. Supports optional system message."""
messages: list[dict[str, str]] = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt.strip()})
messages.append({"role": "user", "content": prompt})
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
)
return inputs, None # no custom stop for Luth
def _get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def _load_model(model_id: str, device: str | None = None) -> tuple[Any, Any]:
"""Load model and tokenizer; cache by model_id."""
if device is None:
device = _get_device()
with _cache_lock:
if model_id in _model_cache:
return _model_cache[model_id]
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto" if device == "cuda" else device,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
# Avoid float vs bfloat16 mismatch: on CPU use float32; on CUDA keep autocast
model_dtype = next(model.parameters()).dtype
if device == "cpu" and model_dtype in (torch.bfloat16, torch.float16):
model = model.float()
elif str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16):
model = model.to(model_dtype)
with _cache_lock:
_model_cache[model_id] = (model, tokenizer)
return model, tokenizer
def _generate_one(
model_id: str,
prompt: str,
params: dict[str, Any],
device: str = "cuda",
system_prompt: str = "",
) -> tuple[str, str]:
"""Load (or use cached) model, run inference, return (model_id, text)."""
model, tokenizer = _load_model(model_id, device)
device = next(model.parameters()).device
model_dtype = next(model.parameters()).dtype
# Clamp temperature/top_p to avoid CUDA assertion (inf/nan in softmax)
temp = max(float(params.get("temperature", 0.7)), 0.01)
top_p = max(min(float(params.get("top_p", 0.9)), 1.0), 1e-6)
gen_kwargs: dict[str, Any] = {
"max_new_tokens": int(params.get("max_tokens", 256)),
"temperature": temp,
"top_p": top_p,
"top_k": max(int(params.get("top_k", 40)), 1),
"repetition_penalty": float(params.get("repeat_penalty", 1.1)),
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id or tokenizer.pad_token_id,
}
if model_id == BAGUETTOTRON_ID:
text_prompt, _stop = _format_prompt_baguettotron(prompt, system_prompt)
inputs = tokenizer(text_prompt, return_tensors="pt")
else:
inputs = _format_prompt_luth(prompt, tokenizer, system_prompt)[0]
# Move to device (input_ids/attention_mask are int; no dtype cast needed)
inputs = {k: v.to(device) for k, v in inputs.items()}
def do_generate(kwargs: dict[str, Any], use_autocast: bool = True):
if use_autocast and str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16):
with torch.amp.autocast(device_type="cuda", dtype=model_dtype):
return model.generate(**inputs, **kwargs)
return model.generate(**inputs, **kwargs)
try:
outputs = do_generate(gen_kwargs)
except RuntimeError as e:
if "expected m1 and m2 to have the same dtype" in str(e) or "float != c10::BFloat16" in str(e):
# Qwen3 (e.g. Luth-0.6B/1.7B) can hit float vs bfloat16 in some envs; retry in float32
model.float()
outputs = do_generate(gen_kwargs, use_autocast=False)
elif "probability tensor contains" in str(e):
# Fallback to greedy decoding when sampling yields invalid logits (inf/nan/<0).
# Use explicit GenerationConfig without sampling params; suppress "generation flags
# are not valid" warning (model config can still merge in temperature/top_p/top_k).
fallback_config = GenerationConfig(
do_sample=False,
max_new_tokens=gen_kwargs["max_new_tokens"],
repetition_penalty=gen_kwargs["repetition_penalty"],
pad_token_id=gen_kwargs["pad_token_id"],
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=".*generation flags are not valid.*",
category=UserWarning,
)
outputs = do_generate({"generation_config": fallback_config})
else:
raise
input_len = inputs["input_ids"].shape[-1]
text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
# Post-process: truncate at stop strings for Baguettotron
if model_id == BAGUETTOTRON_ID:
for s in ["<|im_end>", "</think>"]:
if s in text:
text = text.split(s)[0].strip()
return model_id, text
def run_all(
prompt: str,
params_by_model: dict[str, dict[str, Any]],
device: str | None = None,
max_workers: int = 6,
system_prompt: str = "",
) -> dict[str, str]:
"""
Load all 6 models in parallel, run all 6 inferences in parallel.
Returns dict {model_id: text}.
"""
if device is None:
device = _get_device()
default_params = {
"temperature": 0.7,
"max_tokens": 256,
"top_p": 0.9,
"top_k": 40,
"repeat_penalty": 1.1,
}
def task(model_id: str):
p = {**default_params, **(params_by_model.get(model_id) or {})}
return _generate_one(model_id, prompt, p, device, system_prompt)
results: dict[str, str] = {}
with ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = {ex.submit(task, mid): mid for mid in MODEL_IDS}
for fut in as_completed(futures):
model_id, text = fut.result()
results[model_id] = text
return results