khadijaaao commited on
Commit
d649bea
·
verified ·
1 Parent(s): ba29d9c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +127 -27
src/streamlit_app.py CHANGED
@@ -4,6 +4,12 @@ from llama_cpp import Llama
4
  from langchain_community.vectorstores import FAISS
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from huggingface_hub import hf_hub_download
 
 
 
 
 
 
7
 
8
  # --- Configuration de la page Streamlit ---
9
  st.set_page_config(page_title="Votre Coach RAG", layout="wide")
@@ -11,42 +17,110 @@ st.title("Votre Coach Expert")
11
  st.write("Posez une question sur vos documents, et je vous répondrai en me basant sur leur contenu.")
12
 
13
  # --- Fonctions de chargement mises en cache ---
 
 
 
 
 
 
 
 
 
 
 
14
  @st.cache_resource
15
  def load_llm():
 
 
 
 
 
 
16
  model_repo_id = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF"
17
  model_filename = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"
18
 
19
- with st.spinner(f"Téléchargement du modèle '{model_filename}'... (Cette étape est longue et n'a lieu qu'une seule fois)"):
20
- model_path = hf_hub_download(
21
- repo_id=model_repo_id,
22
- filename=model_filename,
23
- cache_dir='/tmp/hf_cache'
24
- )
25
-
 
 
 
 
 
 
26
  with st.spinner("Chargement du modèle LLM en mémoire..."):
27
- llm = Llama(model_path=model_path, n_gpu_layers=0, n_ctx=4096, verbose=False, chat_format="llama-3")
28
- return llm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  @st.cache_resource
31
  def load_retriever(faiss_path, embeddings_path):
32
- with st.spinner("Chargement de la base de connaissances (FAISS)..."):
33
- embeddings_model = HuggingFaceEmbeddings(model_name=embeddings_path, model_kwargs={'device': 'cpu'})
34
- vectorstore = FAISS.load_local(faiss_path, embeddings_model, allow_dangerous_deserialization=True)
35
- return vectorstore.as_retriever(search_kwargs={"k": 5})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- # --- Chemins d'accès (relatifs) ---
38
- # MODIFICATION : On remonte d'un dossier (de 'src' vers la racine) avec '../'
39
  DOSSIER_PROJET = os.path.dirname(__file__)
40
- CHEMIN_INDEX_FAISS = os.path.join(DOSSIER_PROJET, "../faiss_index_wize")
41
- CHEMIN_MODELE_EMBEDDINGS = os.path.join(DOSSIER_PROJET, "../embedding_model")
42
 
43
- # --- Chargement des modèles via Streamlit ---
 
 
 
 
 
 
44
  try:
45
  llm = load_llm()
46
  retriever = load_retriever(CHEMIN_INDEX_FAISS, CHEMIN_MODELE_EMBEDDINGS)
47
- st.success("Les modèles sont chargés et prêts !")
48
  except Exception as e:
49
- st.error(f"Erreur lors du chargement des modèles : {e}")
 
50
  st.stop()
51
 
52
  # --- Initialisation de l'historique de chat ---
@@ -60,19 +134,45 @@ for message in st.session_state.messages:
60
 
61
  # --- Logique de Chat ---
62
  if prompt := st.chat_input("Posez votre question ici..."):
 
63
  st.session_state.messages.append({"role": "user", "content": prompt})
64
  with st.chat_message("user"):
65
  st.markdown(prompt)
66
 
 
67
  with st.chat_message("assistant"):
68
  with st.spinner("Je réfléchis..."):
 
69
  docs = retriever.invoke(prompt)
70
- context = "\n".join([doc.page_content for doc in docs])
71
- system_prompt = "Vous êtes Un coach expert. Répondez à la question en vous basant uniquement sur le contexte fourni."
72
- full_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{system_prompt}\nContexte : {context}<|eot_id|><|start_header_id|>user<|end_header_id|>\nQuestion : {prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>"""
 
73
 
74
- response = llm(full_prompt, max_tokens=1500, stop=["<|eot_id|>"], echo=False)
75
- answer = response['choices'][0]['text'].strip()
76
- st.markdown(answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- st.session_state.messages.append({"role": "assistant", "content": answer})
 
 
 
4
  from langchain_community.vectorstores import FAISS
5
  from langchain_community.embeddings import HuggingFaceEmbeddings
6
  from huggingface_hub import hf_hub_download
7
+ from langchain.docstore.document import Document
8
+ import logging
9
+
10
+ # Configuration du logging pour un meilleur débogage
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
 
14
  # --- Configuration de la page Streamlit ---
15
  st.set_page_config(page_title="Votre Coach RAG", layout="wide")
 
17
  st.write("Posez une question sur vos documents, et je vous répondrai en me basant sur leur contenu.")
18
 
19
  # --- Fonctions de chargement mises en cache ---
20
+ # @st.cache_resource est CRUCIAL pour que Streamlit ne recharge pas les modèles à chaque interaction
21
+
22
+ # Patch pour un problème de désérialisation avec Langchain et une version spécifique de Pydantic
23
+ # Cela peut être nécessaire dans certains environnements.
24
+ def custom_setstate(self, state):
25
+ if "__fields_set__" in state:
26
+ del state["__fields_set__"]
27
+ self.__dict__.update(state)
28
+
29
+ Document.__setstate__ = custom_setstate
30
+
31
  @st.cache_resource
32
  def load_llm():
33
+ """
34
+ Charge le modèle LLM depuis le Hub Hugging Face.
35
+ Cette fonction est mise en cache pour n'être exécutée qu'une seule fois.
36
+ """
37
+ # Identifiants pour le modèle sur le Hub Hugging Face
38
+ # Il est recommandé d'utiliser un modèle quantifié (GGUF) pour un bon équilibre performance/taille.
39
  model_repo_id = "QuantFactory/Meta-Llama-3-8B-Instruct-GGUF"
40
  model_filename = "Meta-Llama-3-8B-Instruct.Q4_K_M.gguf"
41
 
42
+ with st.spinner(f"Téléchargement du modèle '{model_filename}' depuis le Hub... (Cette étape peut être longue et n'a lieu qu'au premier démarrage)"):
43
+ try:
44
+ # Télécharge le fichier s'il n'est pas dans le cache de Hugging Face et retourne son chemin local
45
+ model_path = hf_hub_download(
46
+ repo_id=model_repo_id,
47
+ filename=model_filename
48
+ )
49
+ logger.info(f"Modèle téléchargé avec succès : {model_path}")
50
+ except Exception as e:
51
+ st.error(f"Erreur lors du téléchargement du modèle depuis le Hub : {e}")
52
+ logger.error(f"Erreur de téléchargement du modèle : {e}")
53
+ st.stop()
54
+
55
  with st.spinner("Chargement du modèle LLM en mémoire..."):
56
+ try:
57
+ # On utilise le CPU car les Spaces gratuits n'ont pas de GPU.
58
+ # n_gpu_layers=0 garantit l'utilisation du CPU.
59
+ llm = Llama(
60
+ model_path=model_path,
61
+ n_gpu_layers=0,
62
+ n_ctx=4096, # Augmenter si nécessaire pour des contextes plus longs
63
+ verbose=False,
64
+ chat_format="llama-3"
65
+ )
66
+ logger.info("Modèle LLM chargé avec succès.")
67
+ return llm
68
+ except Exception as e:
69
+ st.error(f"Erreur lors du chargement du modèle Llama : {e}")
70
+ logger.error(f"Erreur de chargement Llama : {e}")
71
+ st.stop()
72
+
73
 
74
  @st.cache_resource
75
  def load_retriever(faiss_path, embeddings_path):
76
+ """
77
+ Charge le retriever FAISS et le modèle d'embeddings.
78
+ Cette fonction est également mise en cache.
79
+ """
80
+ with st.spinner("Chargement de la base de connaissances (FAISS) et des embeddings..."):
81
+ try:
82
+ # Spécifier 'cpu' car nous n'avons pas de GPU disponible.
83
+ embeddings_model = HuggingFaceEmbeddings(
84
+ model_name=embeddings_path,
85
+ model_kwargs={'device': 'cpu'}
86
+ )
87
+
88
+ # Charger l'index FAISS local
89
+ # allow_dangerous_deserialization est nécessaire pour les index créés avec des versions plus anciennes de langchain.
90
+ vectorstore = FAISS.load_local(
91
+ faiss_path,
92
+ embeddings_model,
93
+ allow_dangerous_deserialization=True
94
+ )
95
+
96
+ # Créer un retriever qui retournera les 5 documents les plus pertinents.
97
+ logger.info("Retriever chargé avec succès.")
98
+ return vectorstore.as_retriever(search_kwargs={"k": 5})
99
+ except Exception as e:
100
+ st.error(f"Erreur lors du chargement du retriever : {e}")
101
+ logger.error(f"Erreur de chargement du retriever : {e}")
102
+ st.stop()
103
 
104
+ # --- Chemins d'accès (relatifs à la racine de votre projet) ---
105
+ # Assurez-vous que ces dossiers sont bien à la racine de votre Space Hugging Face.
106
  DOSSIER_PROJET = os.path.dirname(__file__)
107
+ CHEMIN_INDEX_FAISS = os.path.join(DOSSIER_PROJET, "faiss_index_wize")
108
+ CHEMIN_MODELE_EMBEDDINGS = os.path.join(DOSSIER_PROJET, "embedding_model")
109
 
110
+ # --- Vérification de l'existence des dossiers locaux ---
111
+ if not os.path.exists(CHEMIN_INDEX_FAISS) or not os.path.exists(CHEMIN_MODELE_EMBEDDINGS):
112
+ st.error(f"Erreur critique : Les dossiers 'faiss_index_wize' ou 'embedding_model' sont manquants. Assurez-vous de les avoir téléversés à la racine de votre Space.")
113
+ st.stop()
114
+
115
+
116
+ # --- Chargement principal des modèles via Streamlit ---
117
  try:
118
  llm = load_llm()
119
  retriever = load_retriever(CHEMIN_INDEX_FAISS, CHEMIN_MODELE_EMBEDDINGS)
120
+ st.success("🎉 Les modèles sont chargés et prêts !")
121
  except Exception as e:
122
+ st.error(f"Une erreur inattendue est survenue lors du chargement des modèles : {e}")
123
+ logger.error(f"Erreur de chargement principale : {e}")
124
  st.stop()
125
 
126
  # --- Initialisation de l'historique de chat ---
 
134
 
135
  # --- Logique de Chat ---
136
  if prompt := st.chat_input("Posez votre question ici..."):
137
+ # Ajouter et afficher le message de l'utilisateur
138
  st.session_state.messages.append({"role": "user", "content": prompt})
139
  with st.chat_message("user"):
140
  st.markdown(prompt)
141
 
142
+ # Préparer et afficher la réponse de l'assistant
143
  with st.chat_message("assistant"):
144
  with st.spinner("Je réfléchis..."):
145
+ # 1. Récupérer le contexte pertinent depuis la base de connaissances
146
  docs = retriever.invoke(prompt)
147
+ context = "\n\n".join([doc.page_content for doc in docs])
148
+
149
+ # 2. Créer le prompt complet pour le LLM avec le contexte
150
+ system_prompt = "Vous êtes un coach expert. Répondez à la question en vous basant UNIQUEMENT sur le contexte fourni. Ne mentionnez pas le contexte dans votre réponse."
151
 
152
+ # Utilisation du template de chat Llama 3
153
+ messages_for_llm = [
154
+ {"role": "system", "content": f"{system_prompt}\n\nContexte:\n{context}"},
155
+ {"role": "user", "content": prompt}
156
+ ]
157
+
158
+ # 3. Générer la réponse
159
+ try:
160
+ # Utiliser la méthode create_chat_completion pour un format de chat structuré
161
+ response_stream = llm.create_chat_completion_stream(
162
+ messages=messages_for_llm,
163
+ max_tokens=1500,
164
+ temperature=0.7,
165
+ stop=["<|eot_id|>", "<|end_of_text|>"] # Tokens d'arrêt pour Llama 3
166
+ )
167
+
168
+ # Utiliser st.write_stream pour afficher la réponse en streaming
169
+ answer = st.write_stream(token['choices'][0]['delta'].get('content', '') for token in response_stream)
170
+
171
+ except Exception as e:
172
+ answer = f"Désolé, une erreur est survenue lors de la génération de la réponse : {e}"
173
+ st.error(answer)
174
+ logger.error(f"Erreur de génération LLM : {e}")
175
 
176
+ # Ajouter la réponse complète de l'assistant à l'historique
177
+ st.session_state.messages.append({"role": "assistant", "content": answer})
178
+