File size: 4,856 Bytes
7840eb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
model_runner.py — Model loading + ZeroGPU inference
The @spaces.GPU decorator is applied lazily so the GPU is only
allocated during actual inference calls, not at startup.
"""

import os
import gc
import torch
import spaces
from threading import Lock
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
from huggingface_hub import snapshot_download
import threading

# ── Global model cache (one model at a time) ──────────────────
_model = None
_tokenizer = None
_current_model_id = None
_lock = Lock()


def get_device():
    if torch.cuda.is_available():
        return "cuda"
    return "cpu"


def load_model(
    model_id: str,
    use_4bit: bool = True,
    use_cpu: bool = False,
):
    """
    Load a model from HuggingFace Hub.
    Unloads the previous model first to free VRAM.
    """
    global _model, _tokenizer, _current_model_id

    with _lock:
        if _current_model_id == model_id:
            return  # Already loaded

        # Unload previous
        _unload()

        device = "cpu" if use_cpu else get_device()

        quant_cfg = None
        if not use_cpu and device == "cuda" and use_4bit:
            quant_cfg = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=True,
            )

        _tokenizer = AutoTokenizer.from_pretrained(
            model_id,
            trust_remote_code=True,
            use_fast=True,
        )
        if _tokenizer.pad_token is None:
            _tokenizer.pad_token = _tokenizer.eos_token

        model_kwargs = dict(
            trust_remote_code=True,
            torch_dtype=torch.float16 if device != "cpu" else torch.float32,
            device_map="auto" if device == "cuda" else None,
        )
        if quant_cfg:
            model_kwargs["quantization_config"] = quant_cfg

        _model = AutoModelForCausalLM.from_pretrained(model_id, **model_kwargs)

        if device == "cpu":
            _model = _model.to(device)

        _model.eval()
        _current_model_id = model_id


def _unload():
    global _model, _tokenizer, _current_model_id
    if _model is not None:
        del _model
        _model = None
    if _tokenizer is not None:
        del _tokenizer
        _tokenizer = None
    _current_model_id = None
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


def is_loaded() -> bool:
    return _model is not None


def current_model() -> str | None:
    return _current_model_id


# ── Inference ─────────────────────────────────────────────────

@spaces.GPU(duration=120)
def generate_stream(
    messages: list[dict],
    max_new_tokens: int = 512,
    temperature: float = 0.7,
    top_p: float = 0.9,
    repetition_penalty: float = 1.1,
    system_prompt: str = "",
):
    """
    Streaming token generator.
    Decorated with @spaces.GPU so GPU is allocated ONLY during this call.
    Yields text chunks as they are generated.
    """
    if _model is None or _tokenizer is None:
        yield "⚠️ Aucun modèle chargé. Veuillez d'abord sélectionner et charger un modèle."
        return

    # Build prompt using chat template if available
    chat_messages = []
    if system_prompt:
        chat_messages.append({"role": "system", "content": system_prompt})
    chat_messages.extend(messages)

    try:
        input_ids = _tokenizer.apply_chat_template(
            chat_messages,
            add_generation_prompt=True,
            return_tensors="pt",
        )
    except Exception:
        # Fallback: simple concatenation
        text = ""
        if system_prompt:
            text += f"System: {system_prompt}\n\n"
        for m in messages:
            role = "Human" if m["role"] == "user" else "Assistant"
            text += f"{role}: {m['content']}\n"
        text += "Assistant:"
        input_ids = _tokenizer(text, return_tensors="pt").input_ids

    device = next(_model.parameters()).device
    input_ids = input_ids.to(device)

    streamer = TextIteratorStreamer(
        _tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
    )

    gen_kwargs = dict(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=temperature > 0,
        streamer=streamer,
        pad_token_id=_tokenizer.eos_token_id,
    )

    thread = threading.Thread(target=_model.generate, kwargs=gen_kwargs)
    thread.start()

    for chunk in streamer:
        yield chunk

    thread.join()