Spaces:
Runtime error
Runtime error
| """ | |
| model_runner.py — Model loading + ZeroGPU inference | |
| The @spaces.GPU decorator is applied lazily so the GPU is only | |
| allocated during actual inference calls, not at startup. | |
| """ | |
| import os | |
| import gc | |
| import torch | |
| import spaces | |
| from threading import Lock | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForCausalLM, | |
| TextIteratorStreamer, | |
| BitsAndBytesConfig, | |
| ) | |
| from huggingface_hub import snapshot_download | |
| import threading | |
| # ── Global model cache (one model at a time) ────────────────── | |
| _model = None | |
| _tokenizer = None | |
| _current_model_id = None | |
| _lock = Lock() | |
| def get_device(): | |
| if torch.cuda.is_available(): | |
| return "cuda" | |
| return "cpu" | |
| def load_model( | |
| model_id: str, | |
| use_4bit: bool = True, | |
| use_cpu: bool = False, | |
| ): | |
| """ | |
| Load a model from HuggingFace Hub. | |
| Unloads the previous model first to free VRAM. | |
| """ | |
| global _model, _tokenizer, _current_model_id | |
| with _lock: | |
| if _current_model_id == model_id: | |
| return # Already loaded | |
| # Unload previous | |
| _unload() | |
| device = "cpu" if use_cpu else get_device() | |
| quant_cfg = None | |
| if not use_cpu and device == "cuda" and use_4bit: | |
| quant_cfg = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| _tokenizer = AutoTokenizer.from_pretrained( | |
| model_id, | |
| trust_remote_code=True, | |
| use_fast=True, | |
| ) | |
| if _tokenizer.pad_token is None: | |
| _tokenizer.pad_token = _tokenizer.eos_token | |
| model_kwargs = dict( | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if device != "cpu" else torch.float32, | |
| device_map="auto" if device == "cuda" else None, | |
| ) | |
| if quant_cfg: | |
| model_kwargs["quantization_config"] = quant_cfg | |
| _model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs) | |
| if device == "cpu": | |
| _model = _model.to(device) | |
| _model.eval() | |
| _current_model_id = model_id | |
| def _unload(): | |
| global _model, _tokenizer, _current_model_id | |
| if _model is not None: | |
| del _model | |
| _model = None | |
| if _tokenizer is not None: | |
| del _tokenizer | |
| _tokenizer = None | |
| _current_model_id = None | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| def is_loaded() -> bool: | |
| return _model is not None | |
| def current_model() -> str | None: | |
| return _current_model_id | |
| # ── Inference ───────────────────────────────────────────────── | |
| def generate_stream( | |
| messages: list[dict], | |
| max_new_tokens: int = 512, | |
| temperature: float = 0.7, | |
| top_p: float = 0.9, | |
| repetition_penalty: float = 1.1, | |
| system_prompt: str = "", | |
| ): | |
| """ | |
| Streaming token generator. | |
| Decorated with @spaces.GPU so GPU is allocated ONLY during this call. | |
| Yields text chunks as they are generated. | |
| """ | |
| if _model is None or _tokenizer is None: | |
| yield "⚠️ Aucun modèle chargé. Veuillez d'abord sélectionner et charger un modèle." | |
| return | |
| # Build prompt using chat template if available | |
| chat_messages = [] | |
| if system_prompt: | |
| chat_messages.append({"role": "system", "content": system_prompt}) | |
| chat_messages.extend(messages) | |
| try: | |
| input_ids = _tokenizer.apply_chat_template( | |
| chat_messages, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ) | |
| except Exception: | |
| # Fallback: simple concatenation | |
| text = "" | |
| if system_prompt: | |
| text += f"System: {system_prompt}\n\n" | |
| for m in messages: | |
| role = "Human" if m["role"] == "user" else "Assistant" | |
| text += f"{role}: {m['content']}\n" | |
| text += "Assistant:" | |
| input_ids = _tokenizer(text, return_tensors="pt").input_ids | |
| device = next(_model.parameters()).device | |
| input_ids = input_ids.to(device) | |
| streamer = TextIteratorStreamer( | |
| _tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| ) | |
| gen_kwargs = dict( | |
| input_ids=input_ids, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| repetition_penalty=repetition_penalty, | |
| do_sample=temperature > 0, | |
| streamer=streamer, | |
| pad_token_id=_tokenizer.eos_token_id, | |
| ) | |
| thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs) | |
| thread.start() | |
| for chunk in streamer: | |
| yield chunk | |
| thread.join() | |