Spaces:
Running
Running
| # src/resources.py | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Optional, List, Dict, Any | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from src.config import ( | |
| DB_DIR, | |
| EMBED_MODEL, | |
| LLM_MODEL_PATH, | |
| LLM_N_CTX, | |
| LLM_N_THREADS, | |
| LLM_N_BATCH, | |
| ) | |
| # -------------------- | |
| # Lazy singletons | |
| # -------------------- | |
| _VS: Optional[FAISS] = None | |
| # LLM local (fallback) | |
| _LLM_LOCAL = None | |
| # Groq client (primary when GROQ_API_KEY is set) | |
| _GROQ_CLIENT = None | |
| # -------------------- | |
| # Helpers | |
| # -------------------- | |
| def _assert_vectorstore_files(db_dir: Path) -> None: | |
| if not db_dir.exists() or not db_dir.is_dir(): | |
| raise RuntimeError( | |
| f"Vectorstore introuvable : {db_dir}\n" | |
| "Attendu : un dossier contenant un index FAISS (ex: index.faiss, index.pkl)." | |
| ) | |
| faiss_file = db_dir / "index.faiss" | |
| pkl_file = db_dir / "index.pkl" | |
| if not faiss_file.exists() or not pkl_file.exists(): | |
| raise RuntimeError( | |
| f"Vectorstore incomplet dans {db_dir}\n" | |
| f"Fichiers attendus : {faiss_file.name} et {pkl_file.name}" | |
| ) | |
| def _assert_llm_file(model_path: Path) -> None: | |
| if not model_path.exists() or not model_path.is_file(): | |
| raise RuntimeError( | |
| f"Modèle GGUF introuvable : {model_path}\n" | |
| "Assure-toi que app.py a bien téléchargé/copier le modèle dans models/ " | |
| "ou que LLM_MODEL_PATH pointe vers un fichier GGUF valide." | |
| ) | |
| def is_groq_enabled() -> bool: | |
| """Groq est actif si une clé est définie.""" | |
| return bool(os.environ.get("GROQ_API_KEY", "").strip()) | |
| def _get_groq_settings() -> Dict[str, Any]: | |
| """Récupère les paramètres Groq depuis les variables d'environnement.""" | |
| return { | |
| "model": os.environ.get("GROQ_MODEL", "llama-3.1-8b-instant"), | |
| "temperature": float(os.environ.get("GROQ_TEMPERATURE", "0.1")), | |
| "max_tokens_summary": int(os.environ.get("GROQ_MAX_TOKENS_SUMMARY", "120")), | |
| "max_tokens_qa": int(os.environ.get("GROQ_MAX_TOKENS_QA", "220")), | |
| } | |
| # -------------------- | |
| # Vectorstore (FAISS) | |
| # -------------------- | |
| def get_vectorstore() -> FAISS: | |
| """ | |
| Charge FAISS + embeddings UNE fois (lazy-loading). | |
| IMPORTANT : coûteux (CPU + I/O). N'appelle que si nécessaire. | |
| """ | |
| global _VS | |
| if _VS is not None: | |
| return _VS | |
| db_dir = Path(DB_DIR) | |
| _assert_vectorstore_files(db_dir) | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL) | |
| _VS = FAISS.load_local( | |
| str(db_dir), | |
| embeddings, | |
| allow_dangerous_deserialization=True, | |
| ) | |
| return _VS | |
| # -------------------- | |
| # LLM local (fallback) | |
| # -------------------- | |
| def get_llm_local(): | |
| """ | |
| Charge le modèle GGUF UNE fois (fallback uniquement). | |
| Si Groq est activé, tu n'es pas censé l'appeler dans SUMMARY/QA. | |
| """ | |
| global _LLM_LOCAL | |
| if _LLM_LOCAL is not None: | |
| return _LLM_LOCAL | |
| # Import ici pour éviter de charger llama_cpp inutilement si Groq est utilisé | |
| from llama_cpp import Llama | |
| model_path = Path(LLM_MODEL_PATH) | |
| _assert_llm_file(model_path) | |
| _LLM_LOCAL = Llama( | |
| model_path=str(model_path), | |
| n_ctx=int(LLM_N_CTX), | |
| n_threads=int(LLM_N_THREADS), | |
| n_batch=int(LLM_N_BATCH), | |
| verbose=False, | |
| ) | |
| return _LLM_LOCAL | |
| # -------------------- | |
| # Groq client | |
| # -------------------- | |
| def get_groq_client(): | |
| """ | |
| Instancie le client Groq UNE fois. | |
| Utilise GROQ_API_KEY depuis l'environnement. | |
| """ | |
| global _GROQ_CLIENT | |
| if _GROQ_CLIENT is not None: | |
| return _GROQ_CLIENT | |
| # Import ici pour ne pas dépendre du package si on veut fallback local | |
| from groq import Groq # type: ignore | |
| # Le SDK lit GROQ_API_KEY automatiquement (ou via Groq(api_key=...)) | |
| _GROQ_CLIENT = Groq(api_key=os.environ.get("GROQ_API_KEY")) | |
| return _GROQ_CLIENT | |
| # -------------------- | |
| # Unified chat generation | |
| # -------------------- | |
| def generate_chat( | |
| messages: List[Dict[str, str]], | |
| *, | |
| max_tokens: int, | |
| temperature: float, | |
| ) -> str: | |
| """ | |
| Génère une réponse à partir de messages de chat. | |
| - Si GROQ_API_KEY est défini : utilise Groq (rapide). | |
| - Sinon : fallback llama.cpp local. | |
| messages format: | |
| [{"role": "system"|"user"|"assistant", "content": "..."}] | |
| """ | |
| if is_groq_enabled(): | |
| settings = _get_groq_settings() | |
| client = get_groq_client() | |
| resp = client.chat.completions.create( | |
| model=settings["model"], | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return (resp.choices[0].message.content or "").strip() | |
| # Fallback local llama.cpp | |
| llm = get_llm_local() | |
| out = llm.create_chat_completion( | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| ) | |
| return out["choices"][0]["message"]["content"].strip() | |
| def groq_max_tokens_for(mode: str) -> int: | |
| """ | |
| Helper pratique : renvoie la valeur max_tokens recommandée selon le mode. | |
| mode : "summary" ou "qa" | |
| """ | |
| s = _get_groq_settings() | |
| if mode.lower().startswith("sum"): | |
| return int(s["max_tokens_summary"]) | |
| return int(s["max_tokens_qa"]) | |