Spaces:
Running on T4
Running on T4
| """ | |
| Parallel load and inference for all 6 models (Baguettotron + 5 Luth). | |
| Baguettotron uses EOS-safe formatting: "<|im_end>" (no trailing pipe), stop=["<|im_end>", "</think>"]. | |
| """ | |
| import warnings | |
| from concurrent.futures import ThreadPoolExecutor, as_completed | |
| from typing import Any | |
| import torch | |
| from model_config import MODEL_IDS | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig | |
| from transformers.utils import logging as hf_logging | |
| # Reduce load-time noise (e.g. "lm_head.weight | MISSING" for Qwen3 tied-embedding models) | |
| hf_logging.set_verbosity_error() | |
| # In-memory cache: model_id -> (model, tokenizer) | |
| _model_cache: dict[str, tuple[Any, Any]] = {} | |
| _cache_lock = __import__("threading").Lock() | |
| # Baguettotron repo_id for EOS quirk handling | |
| BAGUETTOTRON_ID = "PleIAs/Baguettotron" | |
| def _format_prompt_baguettotron(prompt: str, system_prompt: str = "") -> tuple[str, list[str]]: | |
| """ | |
| Manual prompt build for Baguettotron. Uses "<|im_end>" (no trailing pipe) | |
| per tokenizer; stop=["<|im_end>", "</think>"] for generation. | |
| Qwen-style: system (optional) + user + assistant. | |
| """ | |
| parts: list[str] = [] | |
| if system_prompt.strip(): | |
| parts.append(f"<|im_start|>system\n{system_prompt.strip()}<|im_end>\n") | |
| parts.append(f"<|im_start|>user\n{prompt}<|im_end>\n<|im_start|>assistant\n<think>\n") | |
| text = "".join(parts) | |
| stop = ["<|im_end>", "</think>"] | |
| return text, stop | |
| def _format_prompt_luth(prompt: str, tokenizer: Any, system_prompt: str = "") -> tuple[dict[str, Any], list[str] | None]: | |
| """Use tokenizer's chat template for Luth models. Supports optional system message.""" | |
| messages: list[dict[str, str]] = [] | |
| if system_prompt.strip(): | |
| messages.append({"role": "system", "content": system_prompt.strip()}) | |
| messages.append({"role": "user", "content": prompt}) | |
| inputs = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_tensors="pt", | |
| return_dict=True, | |
| ) | |
| return inputs, None # no custom stop for Luth | |
| def _get_device() -> str: | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |
| def _load_model(model_id: str, device: str | None = None) -> tuple[Any, Any]: | |
| """Load model and tokenizer; cache by model_id.""" | |
| if device is None: | |
| device = _get_device() | |
| with _cache_lock: | |
| if model_id in _model_cache: | |
| return _model_cache[model_id] | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| torch_dtype="auto", | |
| device_map="auto" if device == "cuda" else device, | |
| trust_remote_code=True, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| # Avoid float vs bfloat16 mismatch: on CPU use float32; on CUDA keep autocast | |
| model_dtype = next(model.parameters()).dtype | |
| if device == "cpu" and model_dtype in (torch.bfloat16, torch.float16): | |
| model = model.float() | |
| elif str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16): | |
| model = model.to(model_dtype) | |
| with _cache_lock: | |
| _model_cache[model_id] = (model, tokenizer) | |
| return model, tokenizer | |
| def _generate_one( | |
| model_id: str, | |
| prompt: str, | |
| params: dict[str, Any], | |
| device: str = "cuda", | |
| system_prompt: str = "", | |
| ) -> tuple[str, str]: | |
| """Load (or use cached) model, run inference, return (model_id, text).""" | |
| model, tokenizer = _load_model(model_id, device) | |
| device = next(model.parameters()).device | |
| model_dtype = next(model.parameters()).dtype | |
| # Clamp temperature/top_p to avoid CUDA assertion (inf/nan in softmax) | |
| temp = max(float(params.get("temperature", 0.7)), 0.01) | |
| top_p = max(min(float(params.get("top_p", 0.9)), 1.0), 1e-6) | |
| gen_kwargs: dict[str, Any] = { | |
| "max_new_tokens": int(params.get("max_tokens", 256)), | |
| "temperature": temp, | |
| "top_p": top_p, | |
| "top_k": max(int(params.get("top_k", 40)), 1), | |
| "repetition_penalty": float(params.get("repeat_penalty", 1.1)), | |
| "do_sample": True, | |
| "pad_token_id": tokenizer.eos_token_id or tokenizer.pad_token_id, | |
| } | |
| if model_id == BAGUETTOTRON_ID: | |
| text_prompt, _stop = _format_prompt_baguettotron(prompt, system_prompt) | |
| inputs = tokenizer(text_prompt, return_tensors="pt") | |
| else: | |
| inputs = _format_prompt_luth(prompt, tokenizer, system_prompt)[0] | |
| # Move to device (input_ids/attention_mask are int; no dtype cast needed) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| def do_generate(kwargs: dict[str, Any], use_autocast: bool = True): | |
| if use_autocast and str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16): | |
| with torch.amp.autocast(device_type="cuda", dtype=model_dtype): | |
| return model.generate(**inputs, **kwargs) | |
| return model.generate(**inputs, **kwargs) | |
| try: | |
| outputs = do_generate(gen_kwargs) | |
| except RuntimeError as e: | |
| if "expected m1 and m2 to have the same dtype" in str(e) or "float != c10::BFloat16" in str(e): | |
| # Qwen3 (e.g. Luth-0.6B/1.7B) can hit float vs bfloat16 in some envs; retry in float32 | |
| model.float() | |
| outputs = do_generate(gen_kwargs, use_autocast=False) | |
| elif "probability tensor contains" in str(e): | |
| # Fallback to greedy decoding when sampling yields invalid logits (inf/nan/<0). | |
| # Use explicit GenerationConfig without sampling params; suppress "generation flags | |
| # are not valid" warning (model config can still merge in temperature/top_p/top_k). | |
| fallback_config = GenerationConfig( | |
| do_sample=False, | |
| max_new_tokens=gen_kwargs["max_new_tokens"], | |
| repetition_penalty=gen_kwargs["repetition_penalty"], | |
| pad_token_id=gen_kwargs["pad_token_id"], | |
| ) | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings( | |
| "ignore", | |
| message=".*generation flags are not valid.*", | |
| category=UserWarning, | |
| ) | |
| outputs = do_generate({"generation_config": fallback_config}) | |
| else: | |
| raise | |
| input_len = inputs["input_ids"].shape[-1] | |
| text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True) | |
| # Post-process: truncate at stop strings for Baguettotron | |
| if model_id == BAGUETTOTRON_ID: | |
| for s in ["<|im_end>", "</think>"]: | |
| if s in text: | |
| text = text.split(s)[0].strip() | |
| return model_id, text | |
| def run_all( | |
| prompt: str, | |
| params_by_model: dict[str, dict[str, Any]], | |
| device: str | None = None, | |
| max_workers: int = 6, | |
| system_prompt: str = "", | |
| ) -> dict[str, str]: | |
| """ | |
| Load all 6 models in parallel, run all 6 inferences in parallel. | |
| Returns dict {model_id: text}. | |
| """ | |
| if device is None: | |
| device = _get_device() | |
| default_params = { | |
| "temperature": 0.7, | |
| "max_tokens": 256, | |
| "top_p": 0.9, | |
| "top_k": 40, | |
| "repeat_penalty": 1.1, | |
| } | |
| def task(model_id: str): | |
| p = {**default_params, **(params_by_model.get(model_id) or {})} | |
| return _generate_one(model_id, prompt, p, device, system_prompt) | |
| results: dict[str, str] = {} | |
| with ThreadPoolExecutor(max_workers=max_workers) as ex: | |
| futures = {ex.submit(task, mid): mid for mid in MODEL_IDS} | |
| for fut in as_completed(futures): | |
| model_id, text = fut.result() | |
| results[model_id] = text | |
| return results | |