| |
| import time, logging, os |
| from typing import Any, Dict, AsyncIterable |
|
|
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| from backends_base import ChatBackend, ImagesBackend |
| from config import settings |
|
|
| try: |
| import spaces |
| from spaces.zero.client import SpaceZeroClient |
| except ImportError: |
| spaces, SpaceZeroClient = None, None |
|
|
| logger = logging.getLogger(__name__) |
|
|
| MODEL_ID = settings.LlmHFModelID or "Qwen/Qwen2.5-1.5B-Instruct" |
| logger.info(f"Loading {MODEL_ID} on CPU at startup (ZeroGPU safe)...") |
|
|
| tokenizer, model, load_error = None, None, None |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True, use_fast=False) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, |
| torch_dtype=torch.float32, |
| trust_remote_code=True, |
| ) |
| model.eval() |
| except Exception as e: |
| load_error = f"Failed to load model/tokenizer: {e}" |
| logger.exception(load_error) |
|
|
|
|
| def pick_device() -> str: |
| if torch.cuda.is_available(): |
| return "cuda" |
| if getattr(torch.backends, "mps", None) and torch.backends.mps.is_available(): |
| return "mps" |
| return "cpu" |
|
|
| def pick_dtype(device: str) -> torch.dtype: |
| if device == "cuda": |
| major, _ = torch.cuda.get_device_capability() |
| return torch.bfloat16 if major >= 8 else torch.float16 |
| if device == "mps": |
| return torch.float16 |
| return torch.float32 |
|
|
|
|
| class HFChatBackend(ChatBackend): |
| async def stream(self, request: Dict[str, Any]) -> AsyncIterable[Dict[str, Any]]: |
| if load_error: |
| raise RuntimeError(load_error) |
|
|
| messages = request.get("messages", []) |
| prompt = messages[-1]["content"] if messages else "(empty)" |
| temperature = float(request.get("temperature", settings.LlmTemp or 0.7)) |
| max_tokens = int(request.get("max_tokens", settings.LlmOpenAICtxSize or 512)) |
|
|
| rid = f"chatcmpl-hf-{int(time.time())}" |
| now = int(time.time()) |
|
|
| |
| x_ip_token = request.get("x_ip_token") |
| headers = {} |
| if x_ip_token: |
| headers["X-IP-Token"] = x_ip_token |
| logger.info("Using X-IP-Token from request for ZeroGPU attribution") |
|
|
| def _gpu_inference_fn(prompt: str) -> str: |
| device = pick_device() |
| dtype = pick_dtype(device) |
| model.to(device=device, dtype=dtype).eval() |
|
|
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| with torch.inference_mode(), torch.autocast(device_type=device, dtype=dtype): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=True, |
| ) |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| if spaces and SpaceZeroClient: |
| |
| client = SpaceZeroClient(headers=headers or None) |
| try: |
| text = await client.run(_gpu_inference_fn, args=[prompt], duration=120) |
| except Exception: |
| logger.exception("HF inference (ZeroGPU) failed") |
| raise |
| else: |
| |
| inputs = tokenizer(prompt, return_tensors="pt") |
| with torch.inference_mode(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_tokens, |
| temperature=temperature, |
| do_sample=True, |
| ) |
| text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| yield { |
| "id": rid, |
| "object": "chat.completion.chunk", |
| "created": now, |
| "model": MODEL_ID, |
| "choices": [ |
| {"index": 0, "delta": {"content": text}, "finish_reason": "stop"} |
| ], |
| } |
|
|