File size: 5,592 Bytes
456aba5
 
 
56a777c
feddcd9
56a777c
feddcd9
456aba5
 
 
feddcd9
 
 
 
 
 
 
 
456aba5
 
56a777c
 
 
456aba5
 
56a777c
 
feddcd9
56a777c
 
 
 
 
 
 
feddcd9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56a777c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
456aba5
feddcd9
 
56a777c
feddcd9
456aba5
feddcd9
 
 
 
 
 
 
 
 
 
 
 
 
456aba5
 
feddcd9
56a777c
 
 
 
feddcd9
56a777c
 
feddcd9
56a777c
 
 
 
 
 
feddcd9
 
 
 
56a777c
feddcd9
 
 
 
56a777c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
feddcd9
56a777c
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
# src/resources.py
from __future__ import annotations

import os
from pathlib import Path
from typing import Optional, List, Dict, Any

from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings

from src.config import (
    DB_DIR,
    EMBED_MODEL,
    LLM_MODEL_PATH,
    LLM_N_CTX,
    LLM_N_THREADS,
    LLM_N_BATCH,
)


# --------------------
# Lazy singletons
# --------------------
_VS: Optional[FAISS] = None

# LLM local (fallback)
_LLM_LOCAL = None

# Groq client (primary when GROQ_API_KEY is set)
_GROQ_CLIENT = None


# --------------------
# Helpers
# --------------------
def _assert_vectorstore_files(db_dir: Path) -> None:
    if not db_dir.exists() or not db_dir.is_dir():
        raise RuntimeError(
            f"Vectorstore introuvable : {db_dir}\n"
            "Attendu : un dossier contenant un index FAISS (ex: index.faiss, index.pkl)."
        )

    faiss_file = db_dir / "index.faiss"
    pkl_file = db_dir / "index.pkl"
    if not faiss_file.exists() or not pkl_file.exists():
        raise RuntimeError(
            f"Vectorstore incomplet dans {db_dir}\n"
            f"Fichiers attendus : {faiss_file.name} et {pkl_file.name}"
        )


def _assert_llm_file(model_path: Path) -> None:
    if not model_path.exists() or not model_path.is_file():
        raise RuntimeError(
            f"Modèle GGUF introuvable : {model_path}\n"
            "Assure-toi que app.py a bien téléchargé/copier le modèle dans models/ "
            "ou que LLM_MODEL_PATH pointe vers un fichier GGUF valide."
        )


def is_groq_enabled() -> bool:
    """Groq est actif si une clé est définie."""
    return bool(os.environ.get("GROQ_API_KEY", "").strip())


def _get_groq_settings() -> Dict[str, Any]:
    """Récupère les paramètres Groq depuis les variables d'environnement."""
    return {
        "model": os.environ.get("GROQ_MODEL", "llama-3.1-8b-instant"),
        "temperature": float(os.environ.get("GROQ_TEMPERATURE", "0.1")),
        "max_tokens_summary": int(os.environ.get("GROQ_MAX_TOKENS_SUMMARY", "120")),
        "max_tokens_qa": int(os.environ.get("GROQ_MAX_TOKENS_QA", "220")),
    }


# --------------------
# Vectorstore (FAISS)
# --------------------
def get_vectorstore() -> FAISS:
    """

    Charge FAISS + embeddings UNE fois (lazy-loading).

    IMPORTANT : coûteux (CPU + I/O). N'appelle que si nécessaire.

    """
    global _VS
    if _VS is not None:
        return _VS

    db_dir = Path(DB_DIR)
    _assert_vectorstore_files(db_dir)

    embeddings = HuggingFaceEmbeddings(model_name=EMBED_MODEL)

    _VS = FAISS.load_local(
        str(db_dir),
        embeddings,
        allow_dangerous_deserialization=True,
    )
    return _VS


# --------------------
# LLM local (fallback)
# --------------------
def get_llm_local():
    """

    Charge le modèle GGUF UNE fois (fallback uniquement).

    Si Groq est activé, tu n'es pas censé l'appeler dans SUMMARY/QA.

    """
    global _LLM_LOCAL
    if _LLM_LOCAL is not None:
        return _LLM_LOCAL

    # Import ici pour éviter de charger llama_cpp inutilement si Groq est utilisé
    from llama_cpp import Llama

    model_path = Path(LLM_MODEL_PATH)
    _assert_llm_file(model_path)

    _LLM_LOCAL = Llama(
        model_path=str(model_path),
        n_ctx=int(LLM_N_CTX),
        n_threads=int(LLM_N_THREADS),
        n_batch=int(LLM_N_BATCH),
        verbose=False,
    )
    return _LLM_LOCAL


# --------------------
# Groq client
# --------------------
def get_groq_client():
    """

    Instancie le client Groq UNE fois.

    Utilise GROQ_API_KEY depuis l'environnement.

    """
    global _GROQ_CLIENT
    if _GROQ_CLIENT is not None:
        return _GROQ_CLIENT

    # Import ici pour ne pas dépendre du package si on veut fallback local
    from groq import Groq  # type: ignore

    # Le SDK lit GROQ_API_KEY automatiquement (ou via Groq(api_key=...))
    _GROQ_CLIENT = Groq(api_key=os.environ.get("GROQ_API_KEY"))
    return _GROQ_CLIENT


# --------------------
# Unified chat generation
# --------------------
def generate_chat(

    messages: List[Dict[str, str]],

    *,

    max_tokens: int,

    temperature: float,

) -> str:
    """

    Génère une réponse à partir de messages de chat.



    - Si GROQ_API_KEY est défini : utilise Groq (rapide).

    - Sinon : fallback llama.cpp local.



    messages format:

      [{"role": "system"|"user"|"assistant", "content": "..."}]

    """
    if is_groq_enabled():
        settings = _get_groq_settings()
        client = get_groq_client()

        resp = client.chat.completions.create(
            model=settings["model"],
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
        )
        return (resp.choices[0].message.content or "").strip()

    # Fallback local llama.cpp
    llm = get_llm_local()
    out = llm.create_chat_completion(
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens,
    )
    return out["choices"][0]["message"]["content"].strip()


def groq_max_tokens_for(mode: str) -> int:
    """

    Helper pratique : renvoie la valeur max_tokens recommandée selon le mode.

    mode : "summary" ou "qa"

    """
    s = _get_groq_settings()
    if mode.lower().startswith("sum"):
        return int(s["max_tokens_summary"])
    return int(s["max_tokens_qa"])