File size: 3,720 Bytes
60a9595
 
7471f75
2dcb7ad
 
 
 
 
 
60a9595
 
2dcb7ad
 
60a9595
2dcb7ad
60a9595
2dcb7ad
60a9595
2dcb7ad
60a9595
2dcb7ad
7471f75
2dcb7ad
 
 
 
60a9595
2dcb7ad
 
 
 
 
 
 
 
60a9595
2dcb7ad
 
 
 
 
 
 
 
 
 
 
 
 
60a9595
 
 
 
7471f75
60a9595
 
 
7471f75
60a9595
 
 
 
 
 
 
7471f75
60a9595
7471f75
 
 
 
 
 
 
 
 
60a9595
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b21789
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# hf_backend.py
import time, logging
from typing import Any, Dict, AsyncIterable

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from backends_base import ChatBackend, ImagesBackend
from config import settings

logger = logging.getLogger(__name__)

try:
    import spaces
    from spaces.zero import client as zero_client
except ImportError:
    spaces, zero_client = None, None

# --- Model setup (CPU-safe load, real inference on GPU only) ---
MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct"
logger.info(f"Preloading tokenizer for {MODEL_ID} on CPU (ZeroGPU safe)...")

tokenizer, model, load_error = None, None, None
try:
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_ID,
        torch_dtype=torch.float32,   # dummy dtype for CPU preload
        trust_remote_code=True,
    )
    model.eval()
except Exception as e:
    load_error = f"Failed to load model/tokenizer: {e}"
    logger.exception(load_error)


# ---------------- Chat Backend ----------------
class HFChatBackend(ChatBackend):
    async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]:
        if load_error:
            raise RuntimeError(load_error)

        messages = request.get("messages", [])
        prompt = messages[-1]["content"] if messages else "(empty)"
        temperature = float(request.get("temperature", settings.LlmTemp or 0.7))
        max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512))

        rid = f"chatcmpl-hf-{int(time.time())}"
        now = int(time.time())

        if not spaces:
            raise RuntimeError("ZeroGPU (spaces) is required but not available!")

        # --- Inject X-IP-Token into global headers ---
        x_ip_token = request.get("x_ip_token")
        if x_ip_token and zero_client:
            zero_client.HEADERS["X-IP-Token"] = x_ip_token
            logger.debug("Injected X-IP-Token into ZeroGPU headers")

        # --- Define the GPU-only inference function ---
        @spaces.GPU(duration=120)
        def run_once(prompt: str) -> str:
            device = "cuda"   # force CUDA
            dtype = torch.float16

            model.to(device=device, dtype=dtype).eval()
            inputs = tokenizer(prompt, return_tensors="pt").to(device)

            with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype):
                outputs = model.generate(
                    **inputs,
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    do_sample=True,
                )
            return tokenizer.decode(outputs[0], skip_special_tokens=True)

        try:
            text = run_once(prompt)
            yield {
                "id": rid,
                "object": "chat.completion.chunk",
                "created": now,
                "model": MODEL_ID,
                "choices": [
                    {"index": 0, "delta": {"content": text}, "finish_reason": "stop"}
                ],
            }
        except Exception:
            logger.exception("HF inference failed")
            raise


# ---------------- Stub Images Backend ----------------
class StubImagesBackend(ImagesBackend):
    """
    Stub backend for images since HFChatBackend is text-only.
    Returns a transparent 1x1 PNG placeholder.
    """
    async def generate_b64(self, request: Dict[str, Any]) -> str:
        logger.warning("Image generation not supported in HF backend.")
        return (
            "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR4nGP4BwQACfsD/etCJH0AAAAASUVORK5CYII="
        )