File size: 9,314 Bytes
60a9595
bf6d44e
d76b941
 
2dcb7ad
 
1d79762
2dcb7ad
 
 
60a9595
 
bf6d44e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dcb7ad
 
60a9595
2dcb7ad
60a9595
2dcb7ad
bf6d44e
2dcb7ad
bf6d44e
2dcb7ad
2ad6a17
2dcb7ad
2ad6a17
2dcb7ad
 
1d79762
2dcb7ad
bf6d44e
 
2dcb7ad
2ad6a17
2dcb7ad
 
 
bf6d44e
1d79762
bf6d44e
 
 
 
 
 
 
1d79762
 
 
bf6d44e
d76b941
849364d
 
d76b941
11cacc3
 
bf6d44e
11cacc3
849364d
bf6d44e
849364d
 
bf6d44e
11cacc3
d279e64
d76b941
d279e64
 
 
 
 
 
 
213e916
d279e64
 
 
bf6d44e
d76b941
d279e64
 
 
d76b941
d279e64
 
11cacc3
d279e64
 
bf6d44e
d279e64
849364d
213e916
bf6d44e
213e916
 
bf6d44e
213e916
11cacc3
849364d
bf6d44e
 
 
 
 
 
d76b941
 
849364d
 
bf6d44e
2dcb7ad
 
 
 
 
 
bf6d44e
2dcb7ad
 
 
 
 
 
bf6d44e
 
 
 
 
7471f75
60a9595
 
bf6d44e
7471f75
bf6d44e
d76b941
d279e64
 
 
bf6d44e
d279e64
 
 
bf6d44e
d279e64
bf6d44e
d279e64
bf6d44e
d279e64
 
bf6d44e
d279e64
d76b941
 
 
 
bf6d44e
 
 
d76b941
60a9595
2ad6a17
 
d76b941
2ad6a17
bf6d44e
 
 
 
 
 
 
 
 
2ad6a17
 
bf6d44e
 
 
 
213e916
bf6d44e
213e916
bf6d44e
213e916
7471f75
bf6d44e
2ad6a17
 
 
d76b941
bf6d44e
d76b941
bf6d44e
d76b941
2ad6a17
60a9595
2ad6a17
bf6d44e
d76b941
2ad6a17
bf6d44e
 
2ad6a17
 
 
 
 
213e916
2ad6a17
 
bf6d44e
 
60a9595
 
bf6d44e
1b21789
 
 
bf6d44e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
# hf_backend.py
import time, logging, json
from contextlib import nullcontext
from typing import Any, Dict, AsyncIterable, Tuple

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from backends_base import ChatBackend, ImagesBackend
from config import settings

logger = logging.getLogger(__name__)

# ---------- logging helpers ----------
def _snippet(txt: str, n: int = 800) -> str:
    if not isinstance(txt, str):
        return f"<non-str:{type(txt)}>"
    return txt if len(txt) <= n else txt[:n] + f"... <+{len(txt)-n} chars>"

def _json_snippet(obj: Any, n: int = 800) -> str:
    try:
        s = json.dumps(obj, ensure_ascii=False, indent=2)
    except Exception:
        s = str(obj)
    return _snippet(s, n)


# ---------- HF Spaces imports ----------
try:
    import spaces
    from spaces.zero import client as zero_client
except ImportError:
    spaces, zero_client = None, None

# ---------- Model setup ----------
MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
logger.info(f"[init] MODEL_ID={MODEL_ID}")

tokenizer, load_error = None, None
try:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_ID,
        trust_remote_code=True,
        use_fast=False,
    )
    has_template = hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None)
    logger.info(f"[init] tokenizer loaded. chat_template={'yes' if has_template else 'no'}")
except Exception as e:
    load_error = f"Failed to load tokenizer: {e}"
    logger.exception(load_error)


# ---------- helpers ----------
def _pick_cpu_dtype() -> torch.dtype:
    try:
        if hasattr(torch, "cpu") and hasattr(torch.cpu, "is_bf16_supported") and torch.cpu.is_bf16_supported():
            logger.info("[dtype] CPU BF16 supported -> torch.bfloat16")
            return torch.bfloat16
    except Exception as e:
        logger.warning(f"[dtype] BF16 probe failed: {e}")
    logger.info("[dtype] fallback -> torch.float32")
    return torch.float32


# ---------- global cache ----------
_MODEL_CACHE: Dict[tuple[str, torch.dtype], AutoModelForCausalLM] = {}


def _get_model(device: str, dtype: torch.dtype) -> Tuple[AutoModelForCausalLM, torch.dtype]:
    key = (device, dtype)
    if key in _MODEL_CACHE:
        logger.info(f"[cache] hit model for device={device} dtype={dtype}")
        return _MODEL_CACHE[key], dtype

    logger.info(f"[load] begin from_pretrained device={device} dtype={dtype}")
    cfg = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True)
    if hasattr(cfg, "quantization_config"):
        logger.warning("[load] removing quantization_config from config to avoid FP8 path")
        delattr(cfg, "quantization_config")

    eff_dtype = dtype
    try:
        model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            config=cfg,
            torch_dtype=dtype,
            trust_remote_code=True,
            device_map="auto" if device != "cpu" else {"": "cpu"},
            low_cpu_mem_usage=False,
        )
    except Exception as e:
        if device == "cpu" and dtype == torch.bfloat16:
            logger.warning(f"[load] BF16 load failed on CPU ({e}). retry FP32.")
            eff_dtype = torch.float32
            model = AutoModelForCausalLM.from_pretrained(
                MODEL_ID,
                config=cfg,
                torch_dtype=eff_dtype,
                trust_remote_code=True,
                device_map={"": "cpu"},
                low_cpu_mem_usage=False,
            )
        else:
            logger.exception("[load] from_pretrained failed")
            raise

    if device == "cpu":
        logger.info(f"[load] casting all weights to CPU dtype={eff_dtype}")
        model = model.to(device=device, dtype=eff_dtype)
    else:
        logger.info(f"[load] moving model to device={device} (no recast)")
        model = model.to(device=device)

    model.eval()
    try:
        first_dtype = next(model.parameters()).dtype
        logger.info(f"[load] ready. effective_dtype={eff_dtype} first_param_dtype={first_dtype}")
    except Exception:
        logger.info(f"[load] ready. effective_dtype={eff_dtype} (param dtype probe failed)")

    _MODEL_CACHE[(device, eff_dtype)] = model
    return model, eff_dtype


# ---------- Chat Backend ----------
class HFChatBackend(ChatBackend):
    async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
        if load_error:
            raise RuntimeError(load_error)

        messages = request.get("messages", [])
        tools = request.get("tools")
        temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
        max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))

        rid = f"chatcmpl-hf-{int(time.time())}"
        now = int(time.time())

        logger.info(f"[req] rid={rid} temp={temperature} max_tokens={max_tokens} "
                    f"msgs={len(messages)} tools={'yes' if tools else 'no'} "
                    f"spaces={'yes' if spaces else 'no'} cuda={'yes' if torch.cuda.is_available() else 'no'}")

        # X-IP-Token for ZeroGPU
        x_ip_token = request.get("x_ip_token")
        if x_ip_token and zero_client:
            zero_client.HEADERS["X-IP-Token"] = x_ip_token
            logger.info("[req] injected X-IP-Token into ZeroGPU headers")

        # Build prompt
        if hasattr(tokenizer, "apply_chat_template") and getattr(tokenizer, "chat_template", None):
            try:
                prompt = tokenizer.apply_chat_template(
                    messages,
                    tools=tools,
                    tokenize=False,
                    add_generation_prompt=True,
                )
                logger.info(f"[prompt] built via chat_template. len={len(prompt)}\n{_snippet(prompt, 1200)}")
            except Exception as e:
                logger.warning(f"[prompt] chat_template failed -> fallback. err={e}")
                prompt = messages[-1]["content"] if messages else "(empty)"
                logger.info(f"[prompt] fallback content len={len(prompt)}\n{_snippet(prompt, 800)}")
        else:
            prompt = messages[-1]["content"] if messages else "(empty)"
            logger.info(f"[prompt] no template. using last user text len={len(prompt)}\n{_snippet(prompt, 800)}")

        def _run_once(prompt: str, device: str, req_dtype: torch.dtype) -> str:
            model, eff_dtype = _get_model(device, req_dtype)

            inputs = tokenizer(prompt, return_tensors="pt")
            input_ids = inputs["input_ids"]
            logger.info(f"[gen] device={device} dtype={eff_dtype} input_tokens={input_ids.shape[-1]}")

            inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}

            with torch.inference_mode():
                if device != "cpu":
                    autocast_ctx = torch.autocast(device_type=device, dtype=eff_dtype)
                else:
                    autocast_ctx = torch.cpu.amp.autocast(dtype=torch.bfloat16) if eff_dtype == torch.bfloat16 else nullcontext()

                gen_kwargs = dict(
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    do_sample=True,
                    use_cache=True,
                )
                logger.info(f"[gen] kwargs={gen_kwargs}")

                with autocast_ctx:
                    outputs = model.generate(**inputs, **gen_kwargs)

            # Only decode newly generated tokens
            input_len = input_ids.shape[-1]
            generated_ids = outputs[0][input_len:]
            logger.info(f"[gen] new_tokens={generated_ids.shape[-1]}")
            text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
            logger.info(f"[gen] text len={len(text)}\n{_snippet(text, 1200)}")
            return text

        # Dispatch with or without ZeroGPU
        if spaces:
            @spaces.GPU(duration=120)
            def run_once(prompt: str) -> str:
                if torch.cuda.is_available():
                    logger.info("[path] ZeroGPU + CUDA")
                    return _run_once(prompt, device="cuda", req_dtype=torch.float16)
                logger.info("[path] ZeroGPU but no CUDA -> CPU fallback")
                return _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())

            text = run_once(prompt)
        else:
            logger.info("[path] CPU-only runtime")
            text = _run_once(prompt, device="cpu", req_dtype=_pick_cpu_dtype())

        # Emit single OpenAI-style chunk
        chunk = {
            "id": rid,
            "object": "chat.completion.chunk",
            "created": now,
            "model": MODEL_ID,
            "choices": [
                {"index": 0, "delta": {"role": "assistant", "content": text}, "finish_reason": "stop"}
            ],
        }
        logger.info(f"[out] chunk summary -> id={rid} content_len={len(text)}")
        yield chunk


# ---------- Stub Images Backend ----------
class StubImagesBackend(ImagesBackend):
    async def generate_b64(self, request: Dict[str, Any]) -> str:
        logger.warning("Image generation not supported in HF backend.")
        return "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="