AIdea-Server / src /utils /model_loader.py
Ali Hashhash
feat: implement chat API routes, note generation logic, and model loading utilities
3bc6c02
Raw
History Blame Contribute Delete
10.7 kB
"""
Singleton model manager for AI inference β€” Hybrid Architecture.
Supports two inference modes controlled by ``INFERENCE_MODE`` env var:
**"local"** (default)
Loads models on-device for CPU inference on HuggingFace Spaces (16 GB RAM):
1. Qwen2.5-0.5B-Instruct – causal LM for summarization, chat, keyword extraction.
2. mDeBERTa-v3-base-xnli – zero-shot classifier for topic categorization.
**"groq"**
Uses the Groq Cloud API (llama-3.3-70b-versatile by default) for all text
generation. Skips loading Qwen & mDeBERTa to save ~3 GB RAM and 30s+ boot.
The mDeBERTa classifier is skipped; categorization falls back to the
keyword-based classifier in ``topic_classifier.py``.
The ``generate_text()`` function is the single public API consumed by
``note_generator.py``, ``recommender.py``, and ``chat_routes.py``. It
transparently routes to the correct backend based on the active mode.
"""
import os
import threading
from typing import Tuple
from src.utils.config import settings
from src.utils.logger import setup_logger
logger = setup_logger(__name__)
# ── Configuration ────────────────────────────────────────────────────────────
INFERENCE_MODE = settings.inference_mode # "groq" or "local"
QWEN_MODEL_ID = os.environ.get(
"QWEN_MODEL_ID", "Qwen/Qwen2.5-0.5B-Instruct"
)
CLASSIFIER_MODEL_ID = os.environ.get(
"CLASSIFIER_MODEL_ID", "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7"
)
HF_CACHE_DIR = os.path.join(os.getcwd(), "hf_cache")
os.makedirs(HF_CACHE_DIR, exist_ok=True)
# ── Internal state (module-level singletons) ─────────────────────────────────
_qwen_lock = threading.Lock()
_clf_lock = threading.Lock()
_groq_lock = threading.Lock()
_qwen_model = None
_qwen_tokenizer = None
_classifier_pipe = None
_groq_client = None
# ═════════════════════════════════════════════════════════════════════════════
# GROQ BACKEND
# ═════════════════════════════════════════════════════════════════════════════
def get_groq_client():
"""Return a lazily-initialized Groq client singleton."""
global _groq_client
if _groq_client is not None:
return _groq_client
with _groq_lock:
if _groq_client is not None:
return _groq_client
api_key = settings.groq_api_key
if not api_key:
raise RuntimeError(
"INFERENCE_MODE is set to 'groq' but GROQ_API_KEY is missing. "
"Please set the GROQ_API_KEY environment variable or HF Secret."
)
from groq import Groq
_groq_client = Groq(api_key=api_key)
logger.info("βœ… Groq client initialized (model: %s).", settings.groq_model)
return _groq_client
def _generate_text_groq(
prompt_messages: list[dict],
*,
max_new_tokens: int = 200,
temperature: float = 1.0,
do_sample: bool = False,
) -> str:
"""Run text generation via the Groq Cloud API.
Maps the local-model calling convention to Groq's OpenAI-compatible
Chat Completions API. Returns the assistant's reply text.
"""
client = get_groq_client()
# Map parameters: Groq uses 'max_tokens' and always "samples" (no greedy toggle).
# For deterministic output, set temperature near 0.
effective_temp = temperature if do_sample else 0.0
try:
chat_completion = client.chat.completions.create(
model=settings.groq_model,
messages=prompt_messages,
max_tokens=max_new_tokens,
temperature=effective_temp,
)
reply = chat_completion.choices[0].message.content or ""
return reply.strip()
except Exception as e:
logger.error("❌ Groq API call failed: %s", e, exc_info=True)
return ""
# ═════════════════════════════════════════════════════════════════════════════
# LOCAL BACKEND (Qwen + mDeBERTa)
# ═════════════════════════════════════════════════════════════════════════════
def get_qwen_model() -> Tuple:
"""Return ``(model, tokenizer)`` for Qwen2.5-0.5B-Instruct.
Loads on first call; subsequent calls return the cached objects.
"""
global _qwen_model, _qwen_tokenizer
if _qwen_model is not None:
return _qwen_model, _qwen_tokenizer
with _qwen_lock:
# Double-check after acquiring the lock
if _qwen_model is not None:
return _qwen_model, _qwen_tokenizer
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
logger.info("πŸ€– Loading Qwen model: %s (CPU, float32) …", QWEN_MODEL_ID)
_qwen_tokenizer = AutoTokenizer.from_pretrained(
QWEN_MODEL_ID,
cache_dir=HF_CACHE_DIR,
trust_remote_code=True,
)
_qwen_model = AutoModelForCausalLM.from_pretrained(
QWEN_MODEL_ID,
cache_dir=HF_CACHE_DIR,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
)
_qwen_model.eval()
logger.info("βœ… Qwen model loaded successfully.")
return _qwen_model, _qwen_tokenizer
def get_classifier_pipeline():
"""Return a zero-shot-classification ``Pipeline`` backed by mDeBERTa.
Loads on first call; subsequent calls return the cached pipeline.
"""
global _classifier_pipe
if _classifier_pipe is not None:
return _classifier_pipe
with _clf_lock:
if _classifier_pipe is not None:
return _classifier_pipe
from transformers import pipeline as hf_pipeline
logger.info(
"πŸ€– Loading zero-shot classifier: %s (CPU) …", CLASSIFIER_MODEL_ID
)
_classifier_pipe = hf_pipeline(
"zero-shot-classification",
model=CLASSIFIER_MODEL_ID,
device=-1, # CPU
cache_dir=HF_CACHE_DIR,
)
logger.info("βœ… Zero-shot classifier loaded successfully.")
return _classifier_pipe
def _generate_text_local(
prompt_messages: list[dict],
*,
max_new_tokens: int = 200,
temperature: float = 1.0,
do_sample: bool = False,
) -> str:
"""Run text generation via the local Qwen model on CPU."""
import torch
model, tokenizer = get_qwen_model()
input_text = tokenizer.apply_chat_template(
prompt_messages,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(
input_text,
return_tensors="pt",
truncation=True,
max_length=2048,
).to("cpu")
prompt_len = inputs["input_ids"].shape[1]
# Build generation kwargs β€” only include sampling params when sampling
gen_kwargs = {
"max_new_tokens": max_new_tokens,
"pad_token_id": tokenizer.eos_token_id,
}
if do_sample:
gen_kwargs["do_sample"] = True
gen_kwargs["temperature"] = temperature
# When do_sample=False (greedy), omit temperature/top_p/top_k entirely
with torch.no_grad():
output_ids = model.generate(**inputs, **gen_kwargs)
# Decode only the newly generated tokens
new_tokens = output_ids[0][prompt_len:]
return tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
# ═════════════════════════════════════════════════════════════════════════════
# PUBLIC API β€” UNIFIED ROUTER
# ═════════════════════════════════════════════════════════════════════════════
def generate_text(
prompt_messages: list[dict],
*,
max_new_tokens: int = 200,
temperature: float = 1.0,
do_sample: bool = False,
) -> str:
"""High-level helper: run a chat completion and return the reply text.
Transparently routes to either the Groq Cloud API or the local Qwen
model based on the ``INFERENCE_MODE`` setting. All downstream consumers
(note_generator, recommender, chat) call this function β€” no import
changes needed anywhere else.
Parameters
----------
prompt_messages:
List of ``{"role": ..., "content": ...}`` dicts compatible with
both ``tokenizer.apply_chat_template`` and the Groq API.
max_new_tokens:
Cap on generated tokens.
temperature / do_sample:
Sampling config. Defaults to greedy (deterministic).
"""
if INFERENCE_MODE == "groq":
return _generate_text_groq(
prompt_messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
)
else:
return _generate_text_local(
prompt_messages,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=do_sample,
)
def preload_all_models() -> None:
"""Eagerly load AI models at startup based on the active inference mode.
Call this once from the FastAPI lifespan so that the first real
request doesn't trigger a slow cold-load.
- **"groq"** mode: only initializes the Groq client (lightweight).
Skips Qwen & mDeBERTa to save ~3 GB RAM and 30s+ boot time.
- **"local"** mode: loads both Qwen and mDeBERTa as before.
"""
logger.info("⏳ Pre-loading AI models (mode: %s) …", INFERENCE_MODE)
if INFERENCE_MODE == "groq":
# Only validate that the Groq client can be created
get_groq_client()
logger.info(
"βœ… Groq mode active β€” skipped local model loading. "
"Using %s via Groq Cloud API.",
settings.groq_model,
)
else:
# Local mode β€” load everything on-device
get_qwen_model()
get_classifier_pipeline()
logger.info("βœ… All local AI models loaded and ready.")