File size: 5,985 Bytes
9e62b52
1c73b90
4dd18be
 
 
1c73b90
 
d649bea
 
 
 
 
 
9e62b52
1c73b90
 
 
 
 
3773f69
ce2c90d
965812d
d649bea
 
 
 
 
 
1c73b90
 
d649bea
4dd18be
 
d649bea
4dd18be
 
 
16e5505
027fe20
4dd18be
d649bea
4dd18be
 
16e5505
 
 
 
 
 
 
 
 
d649bea
16e5505
d649bea
 
16e5505
 
d649bea
 
1c73b90
 
d649bea
 
 
 
 
4dd18be
 
d649bea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c73b90
d649bea
ba29d9c
d649bea
 
1c73b90
d649bea
4dd18be
d649bea
 
ce2c90d
1c73b90
 
 
d649bea
1c73b90
d649bea
 
1c73b90
 
ce2c90d
1c73b90
 
 
 
 
 
 
40a3ddb
1c73b90
 
 
 
 
 
d649bea
027fe20
f7a8ca8
16e5505
4dd18be
16e5505
 
 
027fe20
16e5505
027fe20
 
d649bea
 
 
 
027fe20
ce2c90d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import streamlit as st
import os
# from llama_cpp import Llama # Nous n'utilisons plus llama-cpp-python
# from ctransformers import AutoModelForCausalLM # Pas nécessaire pour l'intégration LangChain
from langchain_community.llms import CTransformers # Correction de la dépréciation
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
import logging

# Configuration du logging pour un meilleur débogage
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# --- Configuration de la page Streamlit ---
st.set_page_config(page_title="Votre Coach RAG", layout="wide")
st.title("Votre Coach Expert")
st.write("Posez une question sur vos documents, et je vous répondrai en me basant sur leur contenu.")

# --- Fonctions de chargement mises en cache ---

# Patch pour un problème de désérialisation
def custom_setstate(self, state):
    if "__fields_set__" in state:
        del state["__fields_set__"]
    self.__dict__.update(state)
Document.__setstate__ = custom_setstate

@st.cache_resource
def load_llm():
    """
    Charge le modèle LLM directement depuis le Hub Hugging Face.
    La configuration du cache est gérée par les variables d'environnement du Space.
    """
    # NOTE : J'ai changé pour un modèle plus petit pour accélérer le premier chargement.
    # Vous pouvez remettre "TheBloke/Llama-2-7B-Chat-GGUF" si vous le souhaitez.
    model_repo_id = "TheBloke/Llama-2-7B-Chat-GGUF" 
    model_filename = "llama-2-7b-chat.Q4_K_M.gguf"
    
    with st.spinner(f"Chargement du modèle '{model_filename}' en mémoire... (Ceci peut être long)"):
        try:
            # Le comportement du cache est maintenant contrôlé par les variables d'environnement
            # du Space, et non plus par le code.
            llm = CTransformers(
                model=model_repo_id,
                model_file=model_filename,
                model_type="llama",
                config={
                    'max_new_tokens': 1500,
                    'temperature': 0.7,
                    'context_length': 4096
                }
            )
            logger.info("Modèle LLM chargé avec succès en mémoire.")
            return llm
        except Exception as e:
            st.error(f"Erreur critique lors du chargement du modèle depuis le Hub : {e}")
            logger.error(f"Erreur de chargement du modèle CTransformers : {e}")
            st.stop()

@st.cache_resource
def load_retriever(faiss_path, embeddings_path):
    """
    Charge le retriever FAISS et le modèle d'embeddings.
    """
    with st.spinner("Chargement de la base de connaissances (FAISS) et des embeddings..."):
        try:
            # Les modèles d'embeddings SONT mis en cache par défaut, mais dans un
            # emplacement autorisé par la plateforme.
            embeddings_model = HuggingFaceEmbeddings(
                model_name=embeddings_path,
                model_kwargs={'device': 'cpu'}
            )
            vectorstore = FAISS.load_local(
                faiss_path,
                embeddings_model,
                allow_dangerous_deserialization=True
            )
            logger.info("Retriever chargé avec succès.")
            return vectorstore.as_retriever(search_kwargs={"k": 5})
        except Exception as e:
            st.error(f"Erreur lors du chargement du retriever : {e}")
            logger.error(f"Erreur de chargement du retriever : {e}")
            st.stop()

# --- Chemins d'accès (relatifs à la racine de votre projet) ---
DOSSIER_PROJET = os.path.dirname(__file__)
CHEMIN_INDEX_FAISS = os.path.join(DOSSIER_PROJET, "faiss_index_wize")
CHEMIN_MODELE_EMBEDDINGS = os.path.join(DOSSIER_PROJET, "embedding_model")

if not os.path.exists(CHEMIN_INDEX_FAISS) or not os.path.exists(CHEMIN_MODELE_EMBEDDINGS):
    st.error(f"Erreur critique : Les dossiers 'faiss_index_wize' ou 'embedding_model' sont manquants dans votre dépôt.")
    st.stop()

# --- Chargement principal ---
try:
    llm = load_llm()
    retriever = load_retriever(CHEMIN_INDEX_FAISS, CHEMIN_MODELE_EMBEDDINGS)
    st.success("🎉 Les modèles sont chargés et prêts !")
except Exception as e:
    st.error(f"Une erreur inattendue est survenue lors du chargement des modèles : {e}")
    logger.error(f"Erreur de chargement principale : {e}")
    st.stop()

# --- Interface de Chat ---
if "messages" not in st.session_state:
    st.session_state.messages = []

for message in st.session_state.messages:
    with st.chat_message(message["role"]):
        st.markdown(message["content"])

if prompt := st.chat_input("Posez votre question ici..."):
    st.session_state.messages.append({"role": "user", "content": prompt})
    with st.chat_message("user"):
        st.markdown(prompt)

    with st.chat_message("assistant"):
        with st.spinner("Je réfléchis..."):
            try:
                docs = retriever.invoke(prompt)
                context = "\n".join([doc.dict().get("page_content", "") for doc in docs])
                # Création du prompt pour CTransformers
                full_prompt = f"System: Vous êtes un coach expert. Répondez à la question en vous basant UNIQUEMENT sur le contexte fourni. Ne mentionnez pas le contexte dans votre réponse.\n\nContexte:\n{context}\n\nQuestion: {prompt}\n\nRéponse:"
                
                # Invocation du LLM
                answer = llm.invoke(full_prompt)
                
                st.markdown(answer)
                st.session_state.messages.append({"role": "assistant", "content": answer})

            except Exception as e:
                answer = f"Désolé, une erreur est survenue lors de la génération de la réponse : {e}"
                st.error(answer)
                logger.error(f"Erreur de génération LLM : {e}")
                st.session_state.messages.append({"role": "assistant", "content": answer})