File size: 7,672 Bytes
7c96057
 
 
 
 
1f2f106
7c96057
 
 
 
 
 
1f2f106
ea25b4a
 
 
 
7c96057
 
 
 
 
 
 
 
 
935bdc8
7c96057
 
 
935bdc8
7c96057
935bdc8
 
 
 
 
7c96057
 
 
 
935bdc8
 
 
 
 
 
7c96057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea25b4a
 
 
 
 
 
 
7c96057
 
 
 
 
 
 
 
 
 
 
935bdc8
7c96057
 
 
 
 
3b5014d
 
 
 
 
7c96057
3b5014d
 
 
 
 
7c96057
 
 
 
 
935bdc8
7c96057
 
935bdc8
3b5014d
 
 
7c96057
ea25b4a
 
1d58cce
 
 
 
 
 
 
ea25b4a
 
 
 
 
1f2f106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d58cce
 
7c96057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935bdc8
7c96057
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
935bdc8
7c96057
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Parallel load and inference for all 6 models (Baguettotron + 5 Luth).
Baguettotron uses EOS-safe formatting: "<|im_end>" (no trailing pipe), stop=["<|im_end>", "</think>"].
"""

import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any

import torch

from model_config import MODEL_IDS
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from transformers.utils import logging as hf_logging

# Reduce load-time noise (e.g. "lm_head.weight | MISSING" for Qwen3 tied-embedding models)
hf_logging.set_verbosity_error() 

# In-memory cache: model_id -> (model, tokenizer)
_model_cache: dict[str, tuple[Any, Any]] = {}
_cache_lock = __import__("threading").Lock()

# Baguettotron repo_id for EOS quirk handling
BAGUETTOTRON_ID = "PleIAs/Baguettotron"


def _format_prompt_baguettotron(prompt: str, system_prompt: str = "") -> tuple[str, list[str]]:
    """
    Manual prompt build for Baguettotron. Uses "<|im_end>" (no trailing pipe)
    per tokenizer; stop=["<|im_end>", "</think>"] for generation.
    Qwen-style: system (optional) + user + assistant.
    """
    parts: list[str] = []
    if system_prompt.strip():
        parts.append(f"<|im_start|>system\n{system_prompt.strip()}<|im_end>\n")
    parts.append(f"<|im_start|>user\n{prompt}<|im_end>\n<|im_start|>assistant\n<think>\n")
    text = "".join(parts)
    stop = ["<|im_end>", "</think>"]
    return text, stop


def _format_prompt_luth(prompt: str, tokenizer: Any, system_prompt: str = "") -> tuple[dict[str, Any], list[str] | None]:
    """Use tokenizer's chat template for Luth models. Supports optional system message."""
    messages: list[dict[str, str]] = []
    if system_prompt.strip():
        messages.append({"role": "system", "content": system_prompt.strip()})
    messages.append({"role": "user", "content": prompt})
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=True,
        return_tensors="pt",
        return_dict=True,
    )
    return inputs, None  # no custom stop for Luth


def _get_device() -> str:
    return "cuda" if torch.cuda.is_available() else "cpu"


def _load_model(model_id: str, device: str | None = None) -> tuple[Any, Any]:
    """Load model and tokenizer; cache by model_id."""
    if device is None:
        device = _get_device()
    with _cache_lock:
        if model_id in _model_cache:
            return _model_cache[model_id]

    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype="auto",
        device_map="auto" if device == "cuda" else device,
        trust_remote_code=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    # Avoid float vs bfloat16 mismatch: on CPU use float32; on CUDA keep autocast
    model_dtype = next(model.parameters()).dtype
    if device == "cpu" and model_dtype in (torch.bfloat16, torch.float16):
        model = model.float()
    elif str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16):
        model = model.to(model_dtype)

    with _cache_lock:
        _model_cache[model_id] = (model, tokenizer)

    return model, tokenizer


def _generate_one(
    model_id: str,
    prompt: str,
    params: dict[str, Any],
    device: str = "cuda",
    system_prompt: str = "",
) -> tuple[str, str]:
    """Load (or use cached) model, run inference, return (model_id, text)."""
    model, tokenizer = _load_model(model_id, device)

    device = next(model.parameters()).device
    model_dtype = next(model.parameters()).dtype

    # Clamp temperature/top_p to avoid CUDA assertion (inf/nan in softmax)
    temp = max(float(params.get("temperature", 0.7)), 0.01)
    top_p = max(min(float(params.get("top_p", 0.9)), 1.0), 1e-6)
    gen_kwargs: dict[str, Any] = {
        "max_new_tokens": int(params.get("max_tokens", 256)),
        "temperature": temp,
        "top_p": top_p,
        "top_k": max(int(params.get("top_k", 40)), 1),
        "repetition_penalty": float(params.get("repeat_penalty", 1.1)),
        "do_sample": True,
        "pad_token_id": tokenizer.eos_token_id or tokenizer.pad_token_id,
    }

    if model_id == BAGUETTOTRON_ID:
        text_prompt, _stop = _format_prompt_baguettotron(prompt, system_prompt)
        inputs = tokenizer(text_prompt, return_tensors="pt")
    else:
        inputs = _format_prompt_luth(prompt, tokenizer, system_prompt)[0]

    # Move to device (input_ids/attention_mask are int; no dtype cast needed)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    def do_generate(kwargs: dict[str, Any], use_autocast: bool = True):
        if use_autocast and str(device).startswith("cuda") and model_dtype in (torch.bfloat16, torch.float16):
            with torch.amp.autocast(device_type="cuda", dtype=model_dtype):
                return model.generate(**inputs, **kwargs)
        return model.generate(**inputs, **kwargs)

    try:
        outputs = do_generate(gen_kwargs)
    except RuntimeError as e:
        if "expected m1 and m2 to have the same dtype" in str(e) or "float != c10::BFloat16" in str(e):
            # Qwen3 (e.g. Luth-0.6B/1.7B) can hit float vs bfloat16 in some envs; retry in float32
            model.float()
            outputs = do_generate(gen_kwargs, use_autocast=False)
        elif "probability tensor contains" in str(e):
            # Fallback to greedy decoding when sampling yields invalid logits (inf/nan/<0).
            # Use explicit GenerationConfig without sampling params; suppress "generation flags
            # are not valid" warning (model config can still merge in temperature/top_p/top_k).
            fallback_config = GenerationConfig(
                do_sample=False,
                max_new_tokens=gen_kwargs["max_new_tokens"],
                repetition_penalty=gen_kwargs["repetition_penalty"],
                pad_token_id=gen_kwargs["pad_token_id"],
            )
            with warnings.catch_warnings():
                warnings.filterwarnings(
                    "ignore",
                    message=".*generation flags are not valid.*",
                    category=UserWarning,
                )
                outputs = do_generate({"generation_config": fallback_config})
        else:
            raise
    input_len = inputs["input_ids"].shape[-1]
    text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=True)

    # Post-process: truncate at stop strings for Baguettotron
    if model_id == BAGUETTOTRON_ID:
        for s in ["<|im_end>", "</think>"]:
            if s in text:
                text = text.split(s)[0].strip()

    return model_id, text


def run_all(
    prompt: str,
    params_by_model: dict[str, dict[str, Any]],
    device: str | None = None,
    max_workers: int = 6,
    system_prompt: str = "",
) -> dict[str, str]:
    """
    Load all 6 models in parallel, run all 6 inferences in parallel.
    Returns dict {model_id: text}.
    """
    if device is None:
        device = _get_device()
    default_params = {
        "temperature": 0.7,
        "max_tokens": 256,
        "top_p": 0.9,
        "top_k": 40,
        "repeat_penalty": 1.1,
    }

    def task(model_id: str):
        p = {**default_params, **(params_by_model.get(model_id) or {})}
        return _generate_one(model_id, prompt, p, device, system_prompt)

    results: dict[str, str] = {}
    with ThreadPoolExecutor(max_workers=max_workers) as ex:
        futures = {ex.submit(task, mid): mid for mid in MODEL_IDS}
        for fut in as_completed(futures):
            model_id, text = fut.result()
            results[model_id] = text

    return results