Spaces:
Running
Running
Ali Hashhash
feat: implement chat API routes, note generation logic, and model loading utilities
3bc6c02 | """ | |
| Singleton model manager for AI inference β Hybrid Architecture. | |
| Supports two inference modes controlled by ``INFERENCE_MODE`` env var: | |
| **"local"** (default) | |
| Loads models on-device for CPU inference on HuggingFace Spaces (16 GB RAM): | |
| 1. Qwen2.5-0.5B-Instruct β causal LM for summarization, chat, keyword extraction. | |
| 2. mDeBERTa-v3-base-xnli β zero-shot classifier for topic categorization. | |
| **"groq"** | |
| Uses the Groq Cloud API (llama-3.3-70b-versatile by default) for all text | |
| generation. Skips loading Qwen & mDeBERTa to save ~3 GB RAM and 30s+ boot. | |
| The mDeBERTa classifier is skipped; categorization falls back to the | |
| keyword-based classifier in ``topic_classifier.py``. | |
| The ``generate_text()`` function is the single public API consumed by | |
| ``note_generator.py``, ``recommender.py``, and ``chat_routes.py``. It | |
| transparently routes to the correct backend based on the active mode. | |
| """ | |
| import os | |
| import threading | |
| from typing import Tuple | |
| from src.utils.config import settings | |
| from src.utils.logger import setup_logger | |
| logger = setup_logger(__name__) | |
| # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| INFERENCE_MODE = settings.inference_mode # "groq" or "local" | |
| QWEN_MODEL_ID = os.environ.get( | |
| "QWEN_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct" | |
| ) | |
| CLASSIFIER_MODEL_ID = os.environ.get( | |
| "CLASSIFIER_MODEL_ID", "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7" | |
| ) | |
| HF_CACHE_DIR = os.path.join(os.getcwd(), "hf_cache") | |
| os.makedirs(HF_CACHE_DIR, exist_ok=True) | |
| # ββ Internal state (module-level singletons) βββββββββββββββββββββββββββββββββ | |
| _qwen_lock = threading.Lock() | |
| _clf_lock = threading.Lock() | |
| _groq_lock = threading.Lock() | |
| _qwen_model = None | |
| _qwen_tokenizer = None | |
| _classifier_pipe = None | |
| _groq_client = None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GROQ BACKEND | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_groq_client(): | |
| """Return a lazily-initialized Groq client singleton.""" | |
| global _groq_client | |
| if _groq_client is not None: | |
| return _groq_client | |
| with _groq_lock: | |
| if _groq_client is not None: | |
| return _groq_client | |
| api_key = settings.groq_api_key | |
| if not api_key: | |
| raise RuntimeError( | |
| "INFERENCE_MODE is set to 'groq' but GROQ_API_KEY is missing. " | |
| "Please set the GROQ_API_KEY environment variable or HF Secret." | |
| ) | |
| from groq import Groq | |
| _groq_client = Groq(api_key=api_key) | |
| logger.info("β Groq client initialized (model: %s).", settings.groq_model) | |
| return _groq_client | |
| def _generate_text_groq( | |
| prompt_messages: list[dict], | |
| *, | |
| max_new_tokens: int = 200, | |
| temperature: float = 1.0, | |
| do_sample: bool = False, | |
| ) -> str: | |
| """Run text generation via the Groq Cloud API. | |
| Maps the local-model calling convention to Groq's OpenAI-compatible | |
| Chat Completions API. Returns the assistant's reply text. | |
| """ | |
| client = get_groq_client() | |
| # Map parameters: Groq uses 'max_tokens' and always "samples" (no greedy toggle). | |
| # For deterministic output, set temperature near 0. | |
| effective_temp = temperature if do_sample else 0.0 | |
| try: | |
| chat_completion = client.chat.completions.create( | |
| model=settings.groq_model, | |
| messages=prompt_messages, | |
| max_tokens=max_new_tokens, | |
| temperature=effective_temp, | |
| ) | |
| reply = chat_completion.choices[0].message.content or "" | |
| return reply.strip() | |
| except Exception as e: | |
| logger.error("β Groq API call failed: %s", e, exc_info=True) | |
| return "" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # LOCAL BACKEND (Qwen + mDeBERTa) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_qwen_model() -> Tuple: | |
| """Return ``(model, tokenizer)`` for Qwen2.5-0.5B-Instruct. | |
| Loads on first call; subsequent calls return the cached objects. | |
| """ | |
| global _qwen_model, _qwen_tokenizer | |
| if _qwen_model is not None: | |
| return _qwen_model, _qwen_tokenizer | |
| with _qwen_lock: | |
| # Double-check after acquiring the lock | |
| if _qwen_model is not None: | |
| return _qwen_model, _qwen_tokenizer | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| logger.info("π€ Loading Qwen model: %s (CPU, float32) β¦", QWEN_MODEL_ID) | |
| _qwen_tokenizer = AutoTokenizer.from_pretrained( | |
| QWEN_MODEL_ID, | |
| cache_dir=HF_CACHE_DIR, | |
| trust_remote_code=True, | |
| ) | |
| _qwen_model = AutoModelForCausalLM.from_pretrained( | |
| QWEN_MODEL_ID, | |
| cache_dir=HF_CACHE_DIR, | |
| torch_dtype=torch.float32, | |
| device_map="cpu", | |
| trust_remote_code=True, | |
| ) | |
| _qwen_model.eval() | |
| logger.info("β Qwen model loaded successfully.") | |
| return _qwen_model, _qwen_tokenizer | |
| def get_classifier_pipeline(): | |
| """Return a zero-shot-classification ``Pipeline`` backed by mDeBERTa. | |
| Loads on first call; subsequent calls return the cached pipeline. | |
| """ | |
| global _classifier_pipe | |
| if _classifier_pipe is not None: | |
| return _classifier_pipe | |
| with _clf_lock: | |
| if _classifier_pipe is not None: | |
| return _classifier_pipe | |
| from transformers import pipeline as hf_pipeline | |
| logger.info( | |
| "π€ Loading zero-shot classifier: %s (CPU) β¦", CLASSIFIER_MODEL_ID | |
| ) | |
| _classifier_pipe = hf_pipeline( | |
| "zero-shot-classification", | |
| model=CLASSIFIER_MODEL_ID, | |
| device=-1, # CPU | |
| cache_dir=HF_CACHE_DIR, | |
| ) | |
| logger.info("β Zero-shot classifier loaded successfully.") | |
| return _classifier_pipe | |
| def _generate_text_local( | |
| prompt_messages: list[dict], | |
| *, | |
| max_new_tokens: int = 200, | |
| temperature: float = 1.0, | |
| do_sample: bool = False, | |
| ) -> str: | |
| """Run text generation via the local Qwen model on CPU.""" | |
| import torch | |
| model, tokenizer = get_qwen_model() | |
| input_text = tokenizer.apply_chat_template( | |
| prompt_messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| inputs = tokenizer( | |
| input_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=2048, | |
| ).to("cpu") | |
| prompt_len = inputs["input_ids"].shape[1] | |
| # Build generation kwargs β only include sampling params when sampling | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "pad_token_id": tokenizer.eos_token_id, | |
| } | |
| if do_sample: | |
| gen_kwargs["do_sample"] = True | |
| gen_kwargs["temperature"] = temperature | |
| # When do_sample=False (greedy), omit temperature/top_p/top_k entirely | |
| with torch.no_grad(): | |
| output_ids = model.generate(**inputs, **gen_kwargs) | |
| # Decode only the newly generated tokens | |
| new_tokens = output_ids[0][prompt_len:] | |
| return tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PUBLIC API β UNIFIED ROUTER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_text( | |
| prompt_messages: list[dict], | |
| *, | |
| max_new_tokens: int = 200, | |
| temperature: float = 1.0, | |
| do_sample: bool = False, | |
| ) -> str: | |
| """High-level helper: run a chat completion and return the reply text. | |
| Transparently routes to either the Groq Cloud API or the local Qwen | |
| model based on the ``INFERENCE_MODE`` setting. All downstream consumers | |
| (note_generator, recommender, chat) call this function β no import | |
| changes needed anywhere else. | |
| Parameters | |
| ---------- | |
| prompt_messages: | |
| List of ``{"role": ..., "content": ...}`` dicts compatible with | |
| both ``tokenizer.apply_chat_template`` and the Groq API. | |
| max_new_tokens: | |
| Cap on generated tokens. | |
| temperature / do_sample: | |
| Sampling config. Defaults to greedy (deterministic). | |
| """ | |
| if INFERENCE_MODE == "groq": | |
| return _generate_text_groq( | |
| prompt_messages, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=do_sample, | |
| ) | |
| else: | |
| return _generate_text_local( | |
| prompt_messages, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=do_sample, | |
| ) | |
| def preload_all_models() -> None: | |
| """Eagerly load AI models at startup based on the active inference mode. | |
| Call this once from the FastAPI lifespan so that the first real | |
| request doesn't trigger a slow cold-load. | |
| - **"groq"** mode: only initializes the Groq client (lightweight). | |
| Skips Qwen & mDeBERTa to save ~3 GB RAM and 30s+ boot time. | |
| - **"local"** mode: loads both Qwen and mDeBERTa as before. | |
| """ | |
| logger.info("β³ Pre-loading AI models (mode: %s) β¦", INFERENCE_MODE) | |
| if INFERENCE_MODE == "groq": | |
| # Only validate that the Groq client can be created | |
| get_groq_client() | |
| logger.info( | |
| "β Groq mode active β skipped local model loading. " | |
| "Using %s via Groq Cloud API.", | |
| settings.groq_model, | |
| ) | |
| else: | |
| # Local mode β load everything on-device | |
| get_qwen_model() | |
| get_classifier_pipeline() | |
| logger.info("β All local AI models loaded and ready.") | |