Spaces:
Sleeping
Sleeping
File size: 7,672 Bytes
7c96057 1f2f106 7c96057 1f2f106 ea25b4a 7c96057 935bdc8 7c96057 935bdc8 7c96057 935bdc8 7c96057 935bdc8 7c96057 ea25b4a 7c96057 935bdc8 7c96057 3b5014d 7c96057 3b5014d 7c96057 935bdc8 7c96057 935bdc8 3b5014d 7c96057 ea25b4a 1d58cce ea25b4a 1f2f106 1d58cce 7c96057 935bdc8 7c96057 935bdc8 7c96057 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | """
Parallel load and inference for all 6 models (Baguettotron + 5 Luth).
Baguettotron uses EOS-safe formatting: "<|im_end>" (no trailing pipe), stop=["<|im_end>", "</think>"].
"""
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any
import torch
from model_config import MODEL_IDS
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.utils import logging as hf_logging
# Reduce load-time noise (e.g. "lm_head.weight | MISSING" for Qwen3 tied-embedding models)
hf_logging.set_verbosity_error()
# In-memory cache: model_id -> (model, tokenizer)
_model_cache: dict[str, tuple[Any, Any]] = {}
_cache_lock = __import__("threading").Lock()
# Baguettotron repo_id for EOS quirk handling
BAGUETTOTRON_ID = "PleIAs/Baguettotron"
def _format_prompt_baguettotron(prompt: str, system_prompt: str = "") -> tuple[str, list[str]]:
"""
Manual prompt build for Baguettotron. Uses "<|im_end>" (no trailing pipe)
per tokenizer; stop=["<|im_end>", "</think>"] for generation.
Qwen-style: system (optional) + user + assistant.
"""
parts: list[str] = []
if system_prompt.strip():
parts.append(f"<|im_start|>system\n{system_prompt.strip()}<|im_end>\n")
parts.append(f"<|im_start|>user\n{prompt}<|im_end>\n<|im_start|>assistant\n<think>\n")
text = "".join(parts)
stop = ["<|im_end>", "</think>"]
return text, stop
def _format_prompt_luth(prompt: str, tokenizer: Any, system_prompt: str = "") -> tuple[dict[str, Any], list[str] | None]:
"""Use tokenizer's chat template for Luth models. Supports optional system message."""
messages: list[dict[str, str]] = []
if system_prompt.strip():
messages.append({"role": "system", "content": system_prompt.strip()})
messages.append({"role": "user", "content": prompt})
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
)
return inputs, None # no custom stop for Luth
def _get_device() -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
def _load_model(model_id: str, device: str | None = None) -> tuple[Any, Any]:
"""Load model and tokenizer; cache by model_id."""
if device is None:
device = _get_device()
with _cache_lock:
if model_id in _model_cache:
return _model_cache[model_id]
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype="auto",
device_map="auto" if device == "cuda" else device,
trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
# Avoid float vs bfloat16 mismatch: on CPU use float32; on CUDA keep autocast
model_dtype = next(model.parameters()).dtype
if device == "cpu" and model_dtype in (torch.bfloat16, torch.float16):
model = model.float()
elif str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16):
model = model.to(model_dtype)
with _cache_lock:
_model_cache[model_id] = (model, tokenizer)
return model, tokenizer
def _generate_one(
model_id: str,
prompt: str,
params: dict[str, Any],
device: str = "cuda",
system_prompt: str = "",
) -> tuple[str, str]:
"""Load (or use cached) model, run inference, return (model_id, text)."""
model, tokenizer = _load_model(model_id, device)
device = next(model.parameters()).device
model_dtype = next(model.parameters()).dtype
# Clamp temperature/top_p to avoid CUDA assertion (inf/nan in softmax)
temp = max(float(params.get("temperature", 0.7)), 0.01)
top_p = max(min(float(params.get("top_p", 0.9)), 1.0), 1e-6)
gen_kwargs: dict[str, Any] = {
"max_new_tokens": int(params.get("max_tokens", 256)),
"temperature": temp,
"top_p": top_p,
"top_k": max(int(params.get("top_k", 40)), 1),
"repetition_penalty": float(params.get("repeat_penalty", 1.1)),
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id or tokenizer.pad_token_id,
}
if model_id == BAGUETTOTRON_ID:
text_prompt, _stop = _format_prompt_baguettotron(prompt, system_prompt)
inputs = tokenizer(text_prompt, return_tensors="pt")
else:
inputs = _format_prompt_luth(prompt, tokenizer, system_prompt)[0]
# Move to device (input_ids/attention_mask are int; no dtype cast needed)
inputs = {k: v.to(device) for k, v in inputs.items()}
def do_generate(kwargs: dict[str, Any], use_autocast: bool = True):
if use_autocast and str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16):
with torch.amp.autocast(device_type="cuda", dtype=model_dtype):
return model.generate(**inputs, **kwargs)
return model.generate(**inputs, **kwargs)
try:
outputs = do_generate(gen_kwargs)
except RuntimeError as e:
if "expected m1 and m2 to have the same dtype" in str(e) or "float != c10::BFloat16" in str(e):
# Qwen3 (e.g. Luth-0.6B/1.7B) can hit float vs bfloat16 in some envs; retry in float32
model.float()
outputs = do_generate(gen_kwargs, use_autocast=False)
elif "probability tensor contains" in str(e):
# Fallback to greedy decoding when sampling yields invalid logits (inf/nan/<0).
# Use explicit GenerationConfig without sampling params; suppress "generation flags
# are not valid" warning (model config can still merge in temperature/top_p/top_k).
fallback_config = GenerationConfig(
do_sample=False,
max_new_tokens=gen_kwargs["max_new_tokens"],
repetition_penalty=gen_kwargs["repetition_penalty"],
pad_token_id=gen_kwargs["pad_token_id"],
)
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=".*generation flags are not valid.*",
category=UserWarning,
)
outputs = do_generate({"generation_config": fallback_config})
else:
raise
input_len = inputs["input_ids"].shape[-1]
text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)
# Post-process: truncate at stop strings for Baguettotron
if model_id == BAGUETTOTRON_ID:
for s in ["<|im_end>", "</think>"]:
if s in text:
text = text.split(s)[0].strip()
return model_id, text
def run_all(
prompt: str,
params_by_model: dict[str, dict[str, Any]],
device: str | None = None,
max_workers: int = 6,
system_prompt: str = "",
) -> dict[str, str]:
"""
Load all 6 models in parallel, run all 6 inferences in parallel.
Returns dict {model_id: text}.
"""
if device is None:
device = _get_device()
default_params = {
"temperature": 0.7,
"max_tokens": 256,
"top_p": 0.9,
"top_k": 40,
"repeat_penalty": 1.1,
}
def task(model_id: str):
p = {**default_params, **(params_by_model.get(model_id) or {})}
return _generate_one(model_id, prompt, p, device, system_prompt)
results: dict[str, str] = {}
with ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = {ex.submit(task, mid): mid for mid in MODEL_IDS}
for fut in as_completed(futures):
model_id, text = fut.result()
results[model_id] = text
return results
|