Spaces:
Runtime error
Runtime error
File size: 4,856 Bytes
7840eb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 | """
model_runner.py — Model loading + ZeroGPU inference
The @spaces.GPU decorator is applied lazily so the GPU is only
allocated during actual inference calls, not at startup.
"""
import os
import gc
import torch
import spaces
from threading import Lock
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
TextIteratorStreamer,
BitsAndBytesConfig,
)
from huggingface_hub import snapshot_download
import threading
# ── Global model cache (one model at a time) ──────────────────
_model = None
_tokenizer = None
_current_model_id = None
_lock = Lock()
def get_device():
if torch.cuda.is_available():
return "cuda"
return "cpu"
def load_model(
model_id: str,
use_4bit: bool = True,
use_cpu: bool = False,
):
"""
Load a model from HuggingFace Hub.
Unloads the previous model first to free VRAM.
"""
global _model, _tokenizer, _current_model_id
with _lock:
if _current_model_id == model_id:
return # Already loaded
# Unload previous
_unload()
device = "cpu" if use_cpu else get_device()
quant_cfg = None
if not use_cpu and device == "cuda" and use_4bit:
quant_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
_tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True,
use_fast=True,
)
if _tokenizer.pad_token is None:
_tokenizer.pad_token = _tokenizer.eos_token
model_kwargs = dict(
trust_remote_code=True,
torch_dtype=torch.float16 if device != "cpu" else torch.float32,
device_map="auto" if device == "cuda" else None,
)
if quant_cfg:
model_kwargs["quantization_config"] = quant_cfg
_model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)
if device == "cpu":
_model = _model.to(device)
_model.eval()
_current_model_id = model_id
def _unload():
global _model, _tokenizer, _current_model_id
if _model is not None:
del _model
_model = None
if _tokenizer is not None:
del _tokenizer
_tokenizer = None
_current_model_id = None
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
def is_loaded() -> bool:
return _model is not None
def current_model() -> str | None:
return _current_model_id
# ── Inference ─────────────────────────────────────────────────
@spaces.GPU(duration=120)
def generate_stream(
messages: list[dict],
max_new_tokens: int = 512,
temperature: float = 0.7,
top_p: float = 0.9,
repetition_penalty: float = 1.1,
system_prompt: str = "",
):
"""
Streaming token generator.
Decorated with @spaces.GPU so GPU is allocated ONLY during this call.
Yields text chunks as they are generated.
"""
if _model is None or _tokenizer is None:
yield "⚠️ Aucun modèle chargé. Veuillez d'abord sélectionner et charger un modèle."
return
# Build prompt using chat template if available
chat_messages = []
if system_prompt:
chat_messages.append({"role": "system", "content": system_prompt})
chat_messages.extend(messages)
try:
input_ids = _tokenizer.apply_chat_template(
chat_messages,
add_generation_prompt=True,
return_tensors="pt",
)
except Exception:
# Fallback: simple concatenation
text = ""
if system_prompt:
text += f"System: {system_prompt}\n\n"
for m in messages:
role = "Human" if m["role"] == "user" else "Assistant"
text += f"{role}: {m['content']}\n"
text += "Assistant:"
input_ids = _tokenizer(text, return_tensors="pt").input_ids
device = next(_model.parameters()).device
input_ids = input_ids.to(device)
streamer = TextIteratorStreamer(
_tokenizer,
skip_prompt=True,
skip_special_tokens=True,
)
gen_kwargs = dict(
input_ids=input_ids,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=temperature > 0,
streamer=streamer,
pad_token_id=_tokenizer.eos_token_id,
)
thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
thread.start()
for chunk in streamer:
yield chunk
thread.join()
|