code-education-rag / src /resources.py
FabIndy's picture
Switch to Groq-only LLM, remove GGUF dependency, speed up build and inference
56a777c
# src/resources.py
from __future__ import annotations
import os
from pathlib import Path
from typing import Optional, List, Dict, Any
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings
from src.config import (
DB_DIR,
EMBED_MODEL,
LLM_MODEL_PATH,
LLM_N_CTX,
LLM_N_THREADS,
LLM_N_BATCH,
)
# --------------------
# Lazy singletons
# --------------------
_VS: Optional[FAISS] = None
# LLM local (fallback)
_LLM_LOCAL = None
# Groq client (primary when GROQ_API_KEY is set)
_GROQ_CLIENT = None
# --------------------
# Helpers
# --------------------
def _assert_vectorstore_files(db_dir: Path) -> None:
if not db_dir.exists() or not db_dir.is_dir():
raise RuntimeError(
f"Vectorstore introuvable : {db_dir}\n"
"Attendu : un dossier contenant un index FAISS (ex: index.faiss, index.pkl)."
)
faiss_file = db_dir / "index.faiss"
pkl_file = db_dir / "index.pkl"
if not faiss_file.exists() or not pkl_file.exists():
raise RuntimeError(
f"Vectorstore incomplet dans {db_dir}\n"
f"Fichiers attendus : {faiss_file.name} et {pkl_file.name}"
)
def _assert_llm_file(model_path: Path) -> None:
if not model_path.exists() or not model_path.is_file():
raise RuntimeError(
f"Modèle GGUF introuvable : {model_path}\n"
"Assure-toi que app.py a bien téléchargé/copier le modèle dans models/ "
"ou que LLM_MODEL_PATH pointe vers un fichier GGUF valide."
)
def is_groq_enabled() -> bool:
"""Groq est actif si une clé est définie."""
return bool(os.environ.get("GROQ_API_KEY", "").strip())
def _get_groq_settings() -> Dict[str, Any]:
"""Récupère les paramètres Groq depuis les variables d'environnement."""
return {
"model": os.environ.get("GROQ_MODEL", "llama-3.1-8b-instant"),
"temperature": float(os.environ.get("GROQ_TEMPERATURE", "0.1")),
"max_tokens_summary": int(os.environ.get("GROQ_MAX_TOKENS_SUMMARY", "120")),
"max_tokens_qa": int(os.environ.get("GROQ_MAX_TOKENS_QA", "220")),
}
# --------------------
# Vectorstore (FAISS)
# --------------------
def get_vectorstore() -> FAISS:
"""
Charge FAISS + embeddings UNE fois (lazy-loading).
IMPORTANT : coûteux (CPU + I/O). N'appelle que si nécessaire.
"""
global _VS
if _VS is not None:
return _VS
db_dir = Path(DB_DIR)
_assert_vectorstore_files(db_dir)
embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)
_VS = FAISS.load_local(
str(db_dir),
embeddings,
allow_dangerous_deserialization=True,
)
return _VS
# --------------------
# LLM local (fallback)
# --------------------
def get_llm_local():
"""
Charge le modèle GGUF UNE fois (fallback uniquement).
Si Groq est activé, tu n'es pas censé l'appeler dans SUMMARY/QA.
"""
global _LLM_LOCAL
if _LLM_LOCAL is not None:
return _LLM_LOCAL
# Import ici pour éviter de charger llama_cpp inutilement si Groq est utilisé
from llama_cpp import Llama
model_path = Path(LLM_MODEL_PATH)
_assert_llm_file(model_path)
_LLM_LOCAL = Llama(
model_path=str(model_path),
n_ctx=int(LLM_N_CTX),
n_threads=int(LLM_N_THREADS),
n_batch=int(LLM_N_BATCH),
verbose=False,
)
return _LLM_LOCAL
# --------------------
# Groq client
# --------------------
def get_groq_client():
"""
Instancie le client Groq UNE fois.
Utilise GROQ_API_KEY depuis l'environnement.
"""
global _GROQ_CLIENT
if _GROQ_CLIENT is not None:
return _GROQ_CLIENT
# Import ici pour ne pas dépendre du package si on veut fallback local
from groq import Groq # type: ignore
# Le SDK lit GROQ_API_KEY automatiquement (ou via Groq(api_key=...))
_GROQ_CLIENT = Groq(api_key=os.environ.get("GROQ_API_KEY"))
return _GROQ_CLIENT
# --------------------
# Unified chat generation
# --------------------
def generate_chat(
messages: List[Dict[str, str]],
*,
max_tokens: int,
temperature: float,
) -> str:
"""
Génère une réponse à partir de messages de chat.
- Si GROQ_API_KEY est défini : utilise Groq (rapide).
- Sinon : fallback llama.cpp local.
messages format:
[{"role": "system"|"user"|"assistant", "content": "..."}]
"""
if is_groq_enabled():
settings = _get_groq_settings()
client = get_groq_client()
resp = client.chat.completions.create(
model=settings["model"],
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return (resp.choices[0].message.content or "").strip()
# Fallback local llama.cpp
llm = get_llm_local()
out = llm.create_chat_completion(
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
)
return out["choices"][0]["message"]["content"].strip()
def groq_max_tokens_for(mode: str) -> int:
"""
Helper pratique : renvoie la valeur max_tokens recommandée selon le mode.
mode : "summary" ou "qa"
"""
s = _get_groq_settings()
if mode.lower().startswith("sum"):
return int(s["max_tokens_summary"])
return int(s["max_tokens_qa"])