File size: 10,746 Bytes
568a4d9
ac94e8d
568a4d9
ac94e8d
568a4d9
ac94e8d
 
 
 
568a4d9
ac94e8d
 
 
 
 
 
 
 
 
568a4d9
 
 
 
 
 
ac94e8d
568a4d9
 
 
 
 
ac94e8d
 
568a4d9
600705a
568a4d9
 
 
 
 
 
 
 
 
 
ac94e8d
568a4d9
 
 
 
ac94e8d
 
 
 
 
 
 
3bc6c02
ac94e8d
 
 
 
 
568a4d9
ac94e8d
 
 
568a4d9
ac94e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3bc6c02
ac94e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
568a4d9
 
600705a
568a4d9
 
 
 
 
 
 
 
 
 
 
 
 
ac94e8d
 
 
568a4d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac94e8d
 
568a4d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ac94e8d
568a4d9
 
 
 
 
 
ac94e8d
 
568a4d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420464f
 
 
 
 
 
 
 
 
 
568a4d9
420464f
568a4d9
 
 
 
600705a
 
ac94e8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
600705a
ac94e8d
600705a
 
 
ac94e8d
 
 
 
600705a
ac94e8d
 
 
 
3bc6c02
ac94e8d
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
"""
Singleton model manager for AI inference β€” Hybrid Architecture.

Supports two inference modes controlled by ``INFERENCE_MODE`` env var:

**"local"** (default)
    Loads models on-device for CPU inference on HuggingFace Spaces (16 GB RAM):
    1. Qwen2.5-0.5B-Instruct  – causal LM for summarization, chat, keyword extraction.
    2. mDeBERTa-v3-base-xnli   – zero-shot classifier for topic categorization.

**"groq"**
    Uses the Groq Cloud API (llama-3.3-70b-versatile by default) for all text
    generation. Skips loading Qwen & mDeBERTa to save ~3 GB RAM and 30s+ boot.
    The mDeBERTa classifier is skipped; categorization falls back to the
    keyword-based classifier in ``topic_classifier.py``.

The ``generate_text()`` function is the single public API consumed by
``note_generator.py``, ``recommender.py``, and ``chat_routes.py``. It
transparently routes to the correct backend based on the active mode.
"""

import os
import threading
from typing import Tuple

from src.utils.config import settings
from src.utils.logger import setup_logger

logger = setup_logger(__name__)

# ── Configuration ────────────────────────────────────────────────────────────
INFERENCE_MODE = settings.inference_mode  # "groq" or "local"

QWEN_MODEL_ID = os.environ.get(
    "QWEN_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct"
)
CLASSIFIER_MODEL_ID = os.environ.get(
    "CLASSIFIER_MODEL_ID", "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"
)
HF_CACHE_DIR = os.path.join(os.getcwd(), "hf_cache")
os.makedirs(HF_CACHE_DIR, exist_ok=True)

# ── Internal state (module-level singletons) ─────────────────────────────────
_qwen_lock = threading.Lock()
_clf_lock = threading.Lock()
_groq_lock = threading.Lock()

_qwen_model = None
_qwen_tokenizer = None
_classifier_pipe = None
_groq_client = None


# ═════════════════════════════════════════════════════════════════════════════
# GROQ BACKEND
# ═════════════════════════════════════════════════════════════════════════════

def get_groq_client():
    """Return a lazily-initialized Groq client singleton."""
    global _groq_client

    if _groq_client is not None:
        return _groq_client

    with _groq_lock:
        if _groq_client is not None:
            return _groq_client

        api_key = settings.groq_api_key
        if not api_key:
            raise RuntimeError(
                "INFERENCE_MODE is set to 'groq' but GROQ_API_KEY is missing. "
                "Please set the GROQ_API_KEY environment variable or HF Secret."
            )

        from groq import Groq
        _groq_client = Groq(api_key=api_key)
        logger.info("βœ… Groq client initialized (model: %s).", settings.groq_model)
        return _groq_client


def _generate_text_groq(
    prompt_messages: list[dict],
    *,
    max_new_tokens: int = 200,
    temperature: float = 1.0,
    do_sample: bool = False,
) -> str:
    """Run text generation via the Groq Cloud API.

    Maps the local-model calling convention to Groq's OpenAI-compatible
    Chat Completions API. Returns the assistant's reply text.
    """
    client = get_groq_client()

    # Map parameters: Groq uses 'max_tokens' and always "samples" (no greedy toggle).
    # For deterministic output, set temperature near 0.
    effective_temp = temperature if do_sample else 0.0

    try:
        chat_completion = client.chat.completions.create(
            model=settings.groq_model,
            messages=prompt_messages,
            max_tokens=max_new_tokens,
            temperature=effective_temp,
        )

        reply = chat_completion.choices[0].message.content or ""
        return reply.strip()

    except Exception as e:
        logger.error("❌ Groq API call failed: %s", e, exc_info=True)
        return ""


# ═════════════════════════════════════════════════════════════════════════════
# LOCAL BACKEND (Qwen + mDeBERTa)
# ═════════════════════════════════════════════════════════════════════════════

def get_qwen_model() -> Tuple:
    """Return ``(model, tokenizer)`` for Qwen2.5-0.5B-Instruct.

    Loads on first call; subsequent calls return the cached objects.
    """
    global _qwen_model, _qwen_tokenizer

    if _qwen_model is not None:
        return _qwen_model, _qwen_tokenizer

    with _qwen_lock:
        # Double-check after acquiring the lock
        if _qwen_model is not None:
            return _qwen_model, _qwen_tokenizer

        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer

        logger.info("πŸ€– Loading Qwen model: %s (CPU, float32) …", QWEN_MODEL_ID)

        _qwen_tokenizer = AutoTokenizer.from_pretrained(
            QWEN_MODEL_ID,
            cache_dir=HF_CACHE_DIR,
            trust_remote_code=True,
        )
        _qwen_model = AutoModelForCausalLM.from_pretrained(
            QWEN_MODEL_ID,
            cache_dir=HF_CACHE_DIR,
            torch_dtype=torch.float32,
            device_map="cpu",
            trust_remote_code=True,
        )
        _qwen_model.eval()

        logger.info("βœ… Qwen model loaded successfully.")
        return _qwen_model, _qwen_tokenizer


def get_classifier_pipeline():
    """Return a zero-shot-classification ``Pipeline`` backed by mDeBERTa.

    Loads on first call; subsequent calls return the cached pipeline.
    """
    global _classifier_pipe

    if _classifier_pipe is not None:
        return _classifier_pipe

    with _clf_lock:
        if _classifier_pipe is not None:
            return _classifier_pipe

        from transformers import pipeline as hf_pipeline

        logger.info(
            "πŸ€– Loading zero-shot classifier: %s (CPU) …", CLASSIFIER_MODEL_ID
        )

        _classifier_pipe = hf_pipeline(
            "zero-shot-classification",
            model=CLASSIFIER_MODEL_ID,
            device=-1,  # CPU
            cache_dir=HF_CACHE_DIR,
        )

        logger.info("βœ… Zero-shot classifier loaded successfully.")
        return _classifier_pipe


def _generate_text_local(
    prompt_messages: list[dict],
    *,
    max_new_tokens: int = 200,
    temperature: float = 1.0,
    do_sample: bool = False,
) -> str:
    """Run text generation via the local Qwen model on CPU."""
    import torch

    model, tokenizer = get_qwen_model()

    input_text = tokenizer.apply_chat_template(
        prompt_messages,
        tokenize=False,
        add_generation_prompt=True,
    )

    inputs = tokenizer(
        input_text,
        return_tensors="pt",
        truncation=True,
        max_length=2048,
    ).to("cpu")

    prompt_len = inputs["input_ids"].shape[1]

    # Build generation kwargs β€” only include sampling params when sampling
    gen_kwargs = {
        "max_new_tokens": max_new_tokens,
        "pad_token_id": tokenizer.eos_token_id,
    }
    if do_sample:
        gen_kwargs["do_sample"] = True
        gen_kwargs["temperature"] = temperature
    # When do_sample=False (greedy), omit temperature/top_p/top_k entirely

    with torch.no_grad():
        output_ids = model.generate(**inputs, **gen_kwargs)

    # Decode only the newly generated tokens
    new_tokens = output_ids[0][prompt_len:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()


# ═════════════════════════════════════════════════════════════════════════════
# PUBLIC API β€” UNIFIED ROUTER
# ═════════════════════════════════════════════════════════════════════════════

def generate_text(
    prompt_messages: list[dict],
    *,
    max_new_tokens: int = 200,
    temperature: float = 1.0,
    do_sample: bool = False,
) -> str:
    """High-level helper: run a chat completion and return the reply text.

    Transparently routes to either the Groq Cloud API or the local Qwen
    model based on the ``INFERENCE_MODE`` setting. All downstream consumers
    (note_generator, recommender, chat) call this function β€” no import
    changes needed anywhere else.

    Parameters
    ----------
    prompt_messages:
        List of ``{"role": ..., "content": ...}`` dicts compatible with
        both ``tokenizer.apply_chat_template`` and the Groq API.
    max_new_tokens:
        Cap on generated tokens.
    temperature / do_sample:
        Sampling config.  Defaults to greedy (deterministic).
    """
    if INFERENCE_MODE == "groq":
        return _generate_text_groq(
            prompt_messages,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
        )
    else:
        return _generate_text_local(
            prompt_messages,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            do_sample=do_sample,
        )


def preload_all_models() -> None:
    """Eagerly load AI models at startup based on the active inference mode.

    Call this once from the FastAPI lifespan so that the first real
    request doesn't trigger a slow cold-load.

    - **"groq"** mode: only initializes the Groq client (lightweight).
      Skips Qwen & mDeBERTa to save ~3 GB RAM and 30s+ boot time.
    - **"local"** mode: loads both Qwen and mDeBERTa as before.
    """
    logger.info("⏳ Pre-loading AI models (mode: %s) …", INFERENCE_MODE)

    if INFERENCE_MODE == "groq":
        # Only validate that the Groq client can be created
        get_groq_client()
        logger.info(
            "βœ… Groq mode active β€” skipped local model loading. "
            "Using %s via Groq Cloud API.",
            settings.groq_model,
        )
    else:
        # Local mode β€” load everything on-device
        get_qwen_model()
        get_classifier_pipeline()
        logger.info("βœ… All local AI models loaded and ready.")