""" Singleton model manager for AI inference — Hybrid Architecture. Supports two inference modes controlled by ``INFERENCE_MODE`` env var: **"local"** (default) Loads models on-device for CPU inference on HuggingFace Spaces (16 GB RAM): 1. Qwen2.5-0.5B-Instruct – causal LM for summarization, chat, keyword extraction. 2. mDeBERTa-v3-base-xnli – zero-shot classifier for topic categorization. **"groq"** Uses the Groq Cloud API (llama-3.3-70b-versatile by default) for all text generation. Skips loading Qwen & mDeBERTa to save ~3 GB RAM and 30s+ boot. The mDeBERTa classifier is skipped; categorization falls back to the keyword-based classifier in ``topic_classifier.py``. The ``generate_text()`` function is the single public API consumed by ``note_generator.py``, ``recommender.py``, and ``chat_routes.py``. It transparently routes to the correct backend based on the active mode. """ import os import threading from typing import Tuple from src.utils.config import settings from src.utils.logger import setup_logger logger = setup_logger(__name__) # ── Configuration ──────────────────────────────────────────────────────────── INFERENCE_MODE = settings.inference_mode # "groq" or "local" QWEN_MODEL_ID = os.environ.get( "QWEN_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct" ) CLASSIFIER_MODEL_ID = os.environ.get( "CLASSIFIER_MODEL_ID", "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" ) HF_CACHE_DIR = os.path.join(os.getcwd(), "hf_cache") os.makedirs(HF_CACHE_DIR, exist_ok=True) # ── Internal state (module-level singletons) ───────────────────────────────── _qwen_lock = threading.Lock() _clf_lock = threading.Lock() _groq_lock = threading.Lock() _qwen_model = None _qwen_tokenizer = None _classifier_pipe = None _groq_client = None # ═════════════════════════════════════════════════════════════════════════════ # GROQ BACKEND # ═════════════════════════════════════════════════════════════════════════════ def get_groq_client(): """Return a lazily-initialized Groq client singleton.""" global _groq_client if _groq_client is not None: return _groq_client with _groq_lock: if _groq_client is not None: return _groq_client api_key = settings.groq_api_key if not api_key: raise RuntimeError( "INFERENCE_MODE is set to 'groq' but GROQ_API_KEY is missing. " "Please set the GROQ_API_KEY environment variable or HF Secret." ) from groq import Groq _groq_client = Groq(api_key=api_key) logger.info("✅ Groq client initialized (model: %s).", settings.groq_model) return _groq_client def _generate_text_groq( prompt_messages: list[dict], *, max_new_tokens: int = 200, temperature: float = 1.0, do_sample: bool = False, ) -> str: """Run text generation via the Groq Cloud API. Maps the local-model calling convention to Groq's OpenAI-compatible Chat Completions API. Returns the assistant's reply text. """ client = get_groq_client() # Map parameters: Groq uses 'max_tokens' and always "samples" (no greedy toggle). # For deterministic output, set temperature near 0. effective_temp = temperature if do_sample else 0.0 try: chat_completion = client.chat.completions.create( model=settings.groq_model, messages=prompt_messages, max_tokens=max_new_tokens, temperature=effective_temp, ) reply = chat_completion.choices[0].message.content or "" return reply.strip() except Exception as e: logger.error("❌ Groq API call failed: %s", e, exc_info=True) return "" # ═════════════════════════════════════════════════════════════════════════════ # LOCAL BACKEND (Qwen + mDeBERTa) # ═════════════════════════════════════════════════════════════════════════════ def get_qwen_model() -> Tuple: """Return ``(model, tokenizer)`` for Qwen2.5-0.5B-Instruct. Loads on first call; subsequent calls return the cached objects. """ global _qwen_model, _qwen_tokenizer if _qwen_model is not None: return _qwen_model, _qwen_tokenizer with _qwen_lock: # Double-check after acquiring the lock if _qwen_model is not None: return _qwen_model, _qwen_tokenizer import torch from transformers import AutoModelForCausalLM, AutoTokenizer logger.info("🤖 Loading Qwen model: %s (CPU, float32) …", QWEN_MODEL_ID) _qwen_tokenizer = AutoTokenizer.from_pretrained( QWEN_MODEL_ID, cache_dir=HF_CACHE_DIR, trust_remote_code=True, ) _qwen_model = AutoModelForCausalLM.from_pretrained( QWEN_MODEL_ID, cache_dir=HF_CACHE_DIR, torch_dtype=torch.float32, device_map="cpu", trust_remote_code=True, ) _qwen_model.eval() logger.info("✅ Qwen model loaded successfully.") return _qwen_model, _qwen_tokenizer def get_classifier_pipeline(): """Return a zero-shot-classification ``Pipeline`` backed by mDeBERTa. Loads on first call; subsequent calls return the cached pipeline. """ global _classifier_pipe if _classifier_pipe is not None: return _classifier_pipe with _clf_lock: if _classifier_pipe is not None: return _classifier_pipe from transformers import pipeline as hf_pipeline logger.info( "🤖 Loading zero-shot classifier: %s (CPU) …", CLASSIFIER_MODEL_ID ) _classifier_pipe = hf_pipeline( "zero-shot-classification", model=CLASSIFIER_MODEL_ID, device=-1, # CPU cache_dir=HF_CACHE_DIR, ) logger.info("✅ Zero-shot classifier loaded successfully.") return _classifier_pipe def _generate_text_local( prompt_messages: list[dict], *, max_new_tokens: int = 200, temperature: float = 1.0, do_sample: bool = False, ) -> str: """Run text generation via the local Qwen model on CPU.""" import torch model, tokenizer = get_qwen_model() input_text = tokenizer.apply_chat_template( prompt_messages, tokenize=False, add_generation_prompt=True, ) inputs = tokenizer( input_text, return_tensors="pt", truncation=True, max_length=2048, ).to("cpu") prompt_len = inputs["input_ids"].shape[1] # Build generation kwargs — only include sampling params when sampling gen_kwargs = { "max_new_tokens": max_new_tokens, "pad_token_id": tokenizer.eos_token_id, } if do_sample: gen_kwargs["do_sample"] = True gen_kwargs["temperature"] = temperature # When do_sample=False (greedy), omit temperature/top_p/top_k entirely with torch.no_grad(): output_ids = model.generate(**inputs, **gen_kwargs) # Decode only the newly generated tokens new_tokens = output_ids[0][prompt_len:] return tokenizer.decode(new_tokens, skip_special_tokens=True).strip() # ═════════════════════════════════════════════════════════════════════════════ # PUBLIC API — UNIFIED ROUTER # ═════════════════════════════════════════════════════════════════════════════ def generate_text( prompt_messages: list[dict], *, max_new_tokens: int = 200, temperature: float = 1.0, do_sample: bool = False, ) -> str: """High-level helper: run a chat completion and return the reply text. Transparently routes to either the Groq Cloud API or the local Qwen model based on the ``INFERENCE_MODE`` setting. All downstream consumers (note_generator, recommender, chat) call this function — no import changes needed anywhere else. Parameters ---------- prompt_messages: List of ``{"role": ..., "content": ...}`` dicts compatible with both ``tokenizer.apply_chat_template`` and the Groq API. max_new_tokens: Cap on generated tokens. temperature / do_sample: Sampling config. Defaults to greedy (deterministic). """ if INFERENCE_MODE == "groq": return _generate_text_groq( prompt_messages, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, ) else: return _generate_text_local( prompt_messages, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=do_sample, ) def preload_all_models() -> None: """Eagerly load AI models at startup based on the active inference mode. Call this once from the FastAPI lifespan so that the first real request doesn't trigger a slow cold-load. - **"groq"** mode: only initializes the Groq client (lightweight). Skips Qwen & mDeBERTa to save ~3 GB RAM and 30s+ boot time. - **"local"** mode: loads both Qwen and mDeBERTa as before. """ logger.info("⏳ Pre-loading AI models (mode: %s) …", INFERENCE_MODE) if INFERENCE_MODE == "groq": # Only validate that the Groq client can be created get_groq_client() logger.info( "✅ Groq mode active — skipped local model loading. " "Using %s via Groq Cloud API.", settings.groq_model, ) else: # Local mode — load everything on-device get_qwen_model() get_classifier_pipeline() logger.info("✅ All local AI models loaded and ready.")