""" Parallel load and inference for all 6 models (Baguettotron + 5 Luth). Baguettotron uses EOS-safe formatting: "<|im_end>" (no trailing pipe), stop=["<|im_end>", ""]. """ import warnings from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any import torch from model_config import MODEL_IDS from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig from transformers.utils import logging as hf_logging # Reduce load-time noise (e.g. "lm_head.weight | MISSING" for Qwen3 tied-embedding models) hf_logging.set_verbosity_error() # In-memory cache: model_id -> (model, tokenizer) _model_cache: dict[str, tuple[Any, Any]] = {} _cache_lock = __import__("threading").Lock() # Baguettotron repo_id for EOS quirk handling BAGUETTOTRON_ID = "PleIAs/Baguettotron" def _format_prompt_baguettotron(prompt: str, system_prompt: str = "") -> tuple[str, list[str]]: """ Manual prompt build for Baguettotron. Uses "<|im_end>" (no trailing pipe) per tokenizer; stop=["<|im_end>", ""] for generation. Qwen-style: system (optional) + user + assistant. """ parts: list[str] = [] if system_prompt.strip(): parts.append(f"<|im_start|>system\n{system_prompt.strip()}<|im_end>\n") parts.append(f"<|im_start|>user\n{prompt}<|im_end>\n<|im_start|>assistant\n\n") text = "".join(parts) stop = ["<|im_end>", ""] return text, stop def _format_prompt_luth(prompt: str, tokenizer: Any, system_prompt: str = "") -> tuple[dict[str, Any], list[str] | None]: """Use tokenizer's chat template for Luth models. Supports optional system message.""" messages: list[dict[str, str]] = [] if system_prompt.strip(): messages.append({"role": "system", "content": system_prompt.strip()}) messages.append({"role": "user", "content": prompt}) inputs = tokenizer.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_tensors="pt", return_dict=True, ) return inputs, None # no custom stop for Luth def _get_device() -> str: return "cuda" if torch.cuda.is_available() else "cpu" def _load_model(model_id: str, device: str | None = None) -> tuple[Any, Any]: """Load model and tokenizer; cache by model_id.""" if device is None: device = _get_device() with _cache_lock: if model_id in _model_cache: return _model_cache[model_id] model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype="auto", device_map="auto" if device == "cuda" else device, trust_remote_code=True, ) tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) # Avoid float vs bfloat16 mismatch: on CPU use float32; on CUDA keep autocast model_dtype = next(model.parameters()).dtype if device == "cpu" and model_dtype in (torch.bfloat16, torch.float16): model = model.float() elif str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16): model = model.to(model_dtype) with _cache_lock: _model_cache[model_id] = (model, tokenizer) return model, tokenizer def _generate_one( model_id: str, prompt: str, params: dict[str, Any], device: str = "cuda", system_prompt: str = "", ) -> tuple[str, str]: """Load (or use cached) model, run inference, return (model_id, text).""" model, tokenizer = _load_model(model_id, device) device = next(model.parameters()).device model_dtype = next(model.parameters()).dtype # Clamp temperature/top_p to avoid CUDA assertion (inf/nan in softmax) temp = max(float(params.get("temperature", 0.7)), 0.01) top_p = max(min(float(params.get("top_p", 0.9)), 1.0), 1e-6) gen_kwargs: dict[str, Any] = { "max_new_tokens": int(params.get("max_tokens", 256)), "temperature": temp, "top_p": top_p, "top_k": max(int(params.get("top_k", 40)), 1), "repetition_penalty": float(params.get("repeat_penalty", 1.1)), "do_sample": True, "pad_token_id": tokenizer.eos_token_id or tokenizer.pad_token_id, } if model_id == BAGUETTOTRON_ID: text_prompt, _stop = _format_prompt_baguettotron(prompt, system_prompt) inputs = tokenizer(text_prompt, return_tensors="pt") else: inputs = _format_prompt_luth(prompt, tokenizer, system_prompt)[0] # Move to device (input_ids/attention_mask are int; no dtype cast needed) inputs = {k: v.to(device) for k, v in inputs.items()} def do_generate(kwargs: dict[str, Any], use_autocast: bool = True): if use_autocast and str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16): with torch.amp.autocast(device_type="cuda", dtype=model_dtype): return model.generate(**inputs, **kwargs) return model.generate(**inputs, **kwargs) try: outputs = do_generate(gen_kwargs) except RuntimeError as e: if "expected m1 and m2 to have the same dtype" in str(e) or "float != c10::BFloat16" in str(e): # Qwen3 (e.g. Luth-0.6B/1.7B) can hit float vs bfloat16 in some envs; retry in float32 model.float() outputs = do_generate(gen_kwargs, use_autocast=False) elif "probability tensor contains" in str(e): # Fallback to greedy decoding when sampling yields invalid logits (inf/nan/<0). # Use explicit GenerationConfig without sampling params; suppress "generation flags # are not valid" warning (model config can still merge in temperature/top_p/top_k). fallback_config = GenerationConfig( do_sample=False, max_new_tokens=gen_kwargs["max_new_tokens"], repetition_penalty=gen_kwargs["repetition_penalty"], pad_token_id=gen_kwargs["pad_token_id"], ) with warnings.catch_warnings(): warnings.filterwarnings( "ignore", message=".*generation flags are not valid.*", category=UserWarning, ) outputs = do_generate({"generation_config": fallback_config}) else: raise input_len = inputs["input_ids"].shape[-1] text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True) # Post-process: truncate at stop strings for Baguettotron if model_id == BAGUETTOTRON_ID: for s in ["<|im_end>", ""]: if s in text: text = text.split(s)[0].strip() return model_id, text def run_all( prompt: str, params_by_model: dict[str, dict[str, Any]], device: str | None = None, max_workers: int = 6, system_prompt: str = "", ) -> dict[str, str]: """ Load all 6 models in parallel, run all 6 inferences in parallel. Returns dict {model_id: text}. """ if device is None: device = _get_device() default_params = { "temperature": 0.7, "max_tokens": 256, "top_p": 0.9, "top_k": 40, "repeat_penalty": 1.1, } def task(model_id: str): p = {**default_params, **(params_by_model.get(model_id) or {})} return _generate_one(model_id, prompt, p, device, system_prompt) results: dict[str, str] = {} with ThreadPoolExecutor(max_workers=max_workers) as ex: futures = {ex.submit(task, mid): mid for mid in MODEL_IDS} for fut in as_completed(futures): model_id, text = fut.result() results[model_id] = text return results