chatbot / src /streamlit_app.py
khadijaaao's picture
Update src/streamlit_app.py
f7a8ca8 verified
import streamlit as st
import os
# from llama_cpp import Llama # Nous n'utilisons plus llama-cpp-python
# from ctransformers import AutoModelForCausalLM # Pas nécessaire pour l'intégration LangChain
from langchain_community.llms import CTransformers # Correction de la dépréciation
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
import logging
# Configuration du logging pour un meilleur débogage
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# --- Configuration de la page Streamlit ---
st.set_page_config(page_title="Votre Coach RAG", layout="wide")
st.title("Votre Coach Expert")
st.write("Posez une question sur vos documents, et je vous répondrai en me basant sur leur contenu.")
# --- Fonctions de chargement mises en cache ---
# Patch pour un problème de désérialisation
def custom_setstate(self, state):
if "__fields_set__" in state:
del state["__fields_set__"]
self.__dict__.update(state)
Document.__setstate__ = custom_setstate
@st.cache_resource
def load_llm():
"""
Charge le modèle LLM directement depuis le Hub Hugging Face.
La configuration du cache est gérée par les variables d'environnement du Space.
"""
# NOTE : J'ai changé pour un modèle plus petit pour accélérer le premier chargement.
# Vous pouvez remettre "TheBloke/Llama-2-7B-Chat-GGUF" si vous le souhaitez.
model_repo_id = "TheBloke/Llama-2-7B-Chat-GGUF"
model_filename = "llama-2-7b-chat.Q4_K_M.gguf"
with st.spinner(f"Chargement du modèle '{model_filename}' en mémoire... (Ceci peut être long)"):
try:
# Le comportement du cache est maintenant contrôlé par les variables d'environnement
# du Space, et non plus par le code.
llm = CTransformers(
model=model_repo_id,
model_file=model_filename,
model_type="llama",
config={
'max_new_tokens': 1500,
'temperature': 0.7,
'context_length': 4096
}
)
logger.info("Modèle LLM chargé avec succès en mémoire.")
return llm
except Exception as e:
st.error(f"Erreur critique lors du chargement du modèle depuis le Hub : {e}")
logger.error(f"Erreur de chargement du modèle CTransformers : {e}")
st.stop()
@st.cache_resource
def load_retriever(faiss_path, embeddings_path):
"""
Charge le retriever FAISS et le modèle d'embeddings.
"""
with st.spinner("Chargement de la base de connaissances (FAISS) et des embeddings..."):
try:
# Les modèles d'embeddings SONT mis en cache par défaut, mais dans un
# emplacement autorisé par la plateforme.
embeddings_model = HuggingFaceEmbeddings(
model_name=embeddings_path,
model_kwargs={'device': 'cpu'}
)
vectorstore = FAISS.load_local(
faiss_path,
embeddings_model,
allow_dangerous_deserialization=True
)
logger.info("Retriever chargé avec succès.")
return vectorstore.as_retriever(search_kwargs={"k": 5})
except Exception as e:
st.error(f"Erreur lors du chargement du retriever : {e}")
logger.error(f"Erreur de chargement du retriever : {e}")
st.stop()
# --- Chemins d'accès (relatifs à la racine de votre projet) ---
DOSSIER_PROJET = os.path.dirname(__file__)
CHEMIN_INDEX_FAISS = os.path.join(DOSSIER_PROJET, "faiss_index_wize")
CHEMIN_MODELE_EMBEDDINGS = os.path.join(DOSSIER_PROJET, "embedding_model")
if not os.path.exists(CHEMIN_INDEX_FAISS) or not os.path.exists(CHEMIN_MODELE_EMBEDDINGS):
st.error(f"Erreur critique : Les dossiers 'faiss_index_wize' ou 'embedding_model' sont manquants dans votre dépôt.")
st.stop()
# --- Chargement principal ---
try:
llm = load_llm()
retriever = load_retriever(CHEMIN_INDEX_FAISS, CHEMIN_MODELE_EMBEDDINGS)
st.success("🎉 Les modèles sont chargés et prêts !")
except Exception as e:
st.error(f"Une erreur inattendue est survenue lors du chargement des modèles : {e}")
logger.error(f"Erreur de chargement principale : {e}")
st.stop()
# --- Interface de Chat ---
if "messages" not in st.session_state:
st.session_state.messages = []
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
if prompt := st.chat_input("Posez votre question ici..."):
st.session_state.messages.append({"role": "user", "content": prompt})
with st.chat_message("user"):
st.markdown(prompt)
with st.chat_message("assistant"):
with st.spinner("Je réfléchis..."):
try:
docs = retriever.invoke(prompt)
context = "\n".join([doc.dict().get("page_content", "") for doc in docs])
# Création du prompt pour CTransformers
full_prompt = f"System: Vous êtes un coach expert. Répondez à la question en vous basant UNIQUEMENT sur le contexte fourni. Ne mentionnez pas le contexte dans votre réponse.\n\nContexte:\n{context}\n\nQuestion: {prompt}\n\nRéponse:"
# Invocation du LLM
answer = llm.invoke(full_prompt)
st.markdown(answer)
st.session_state.messages.append({"role": "assistant", "content": answer})
except Exception as e:
answer = f"Désolé, une erreur est survenue lors de la génération de la réponse : {e}"
st.error(answer)
logger.error(f"Erreur de génération LLM : {e}")
st.session_state.messages.append({"role": "assistant", "content": answer})