PHE_Outil_IA / rag_chat_v3.py
clairedhx's picture
Upload folder using huggingface_hub
759fea8 verified
# pip install gradio langchain gpt4all chromadb pypdf tiktoken
# pip install --quiet gradio langchain gpt4all chromadb pypdf tiktoken
# imports
import os
import gradio as gr
from gradio.themes.base import Base
import glob
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.document_loaders import TextLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import GPT4AllEmbeddings
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
import getpass
import json
from tqdm import tqdm
# Import necessary modules
from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough
from typing import Sequence, Any, Dict
from langchain.schema import Document
import time
from functions_rag_chat_v3 import *
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY")
if __name__ == "__main__":
"""
main function
"""
print("Starting program")
start_time = time.time()
# define what LLM to use
use_llm = "mistral"
#use_llm = "phe-v2-gguf"
# define what embedding model to use
from langchain_community.embeddings import HuggingFaceEmbeddings
model_name = "clairedhx/autotrain-v2"
token=os.getenv("hugging_face_token")
model_kwargs = {'device': 'cuda', 'token': token}
encode_kwargs = {'normalize_embeddings': False}
embedding = HuggingFaceEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs
)
#print(embedding)
end_time = time.time()
print(f"Temps d'exécution pour l'initialisation des embeddings: {end_time - start_time} secondes")
# directory to persistently store the vector embedding store
db_directory = '/home/onyxia/phe/scripts/chroma_db'
#test a parir de dataframe pour avoir metadata
from datetime import datetime
import pandas as pd
#start_time = time.time()
#df = pd.read_csv('/home/onyxia/phe/scripts/gestion_base/documents_with_metadata_all_med_21_08_24.csv')
# Conversion de 'date_avis' en année
#df['année'] = pd.to_datetime(df['date_avis'], format='%Y-%m-%d').dt.year
from langchain_community.document_loaders import DataFrameLoader
#loader = DataFrameLoader(df, page_content_column="texte")
#docs= loader.load()
#splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(separators=["\n\n", "\n","."], chunk_size=400, chunk_overlap=150)
#splits = splitter.split_documents(docs)
#print("splits : ", len(splits))
#vectordb= Chroma.from_documents(documents=splits, collection_name="chromemwah", embedding=embedding, persist_directory=db_directory)
#vectordb.persist()
#end_time = time.time()
print(f"Temps d'exécution pour le chargement des documents, split et créer base chromaDB: {end_time - start_time} secondes")
start_time = time.time()
vectordb = Chroma(persist_directory=db_directory, embedding_function=embedding, collection_name="chromemwah")
end_time = time.time()
print(f"Temps d'exécution pour le chargement de la base de données persistante: {end_time - start_time} secondes")
###############################################
#RECUPERATION VANNA AI
###############################################
from dotenv import main
import os
print("Récupération des informations de connection")
start_time = time.time()
# Charger les variables d'environnement à partir du fichier .env
main.load_dotenv()
# Accéder aux variables d'environnement
Hostname = os.getenv("Hostname")
Port = os.getenv("Port")
Database = os.getenv("Database")
Username = os.getenv("Username")
Password = os.getenv("Password")
from vanna.ollama import Ollama
from vanna.chromadb import ChromaDB_VectorStore
class MyVanna(ChromaDB_VectorStore, Ollama):
def __init__(self, config=None):
ChromaDB_VectorStore.__init__(self, config=config)
Ollama.__init__(self, config=config)
vn = MyVanna(config={'model': 'mistral'})
vn.connect_to_postgres(host=Hostname, dbname=Database, user=Username, password=Password, port=Port) # Connect to your database here
vn.train(ddl="""
CREATE TABLE IF NOT EXISTS medicaments (
id SERIAL PRIMARY KEY,
nom VARCHAR(2555) NOT NULL,
nombre_avis INTEGER,
nombre_docs INTEGER,
DCI VARCHAR(2555),
exploitant VARCHAR(2555),
codes_ATC TEXT[],
cip TEXT[]
);
CREATE TABLE IF NOT EXISTS avis (
id SERIAL PRIMARY KEY,
numero_avis VARCHAR(255) NOT NULL,
maladie VARCHAR(255),
aires_therapeutiques TEXT[],
date_avis DATE,
nombre_docs INTEGER,
medicament_id INTEGER REFERENCES medicaments(id),
smr smr_type,
asmr asmr_type
);
CREATE TABLE IF NOT EXISTS documents (
id SERIAL PRIMARY KEY,
titre_doc VARCHAR(300) NOT NULL,
type document_type NOT NULL,
indication VARCHAR(100000),
medicament_id INTEGER REFERENCES medicaments(id),
avis_id INTEGER REFERENCES avis(id),
lien_doc VARCHAR(255),
transcription_ct_associee INTEGER[],
avis_ct_associe INTEGER[],
transcription_ceesp_associee INTEGER[],
avis_ceesp_associe INTEGER[],
questionnaire_associe INTEGER[],
texte TEXT -- Nouveau champ pour stocker le texte extrait
);
""")
import json
# Load the JSON file
with open('/home/onyxia/phe/scripts/modeles/text_to_SQL/entrainement_augmented.json', 'r') as file:
data = json.load(file)
# Train Vanna with the SQL query pairs
for pair in data:
question = pair['question_to_sql']
sql = pair['sql']
vn.train(question=question.strip(), sql=sql.strip())
end_time = time.time()
print(f"Temps d'exécution pour la connexion à la base de données et l'entraînement de Vanna: {end_time - start_time} secondes")
####################################################################
####################################################################
RERANKER_CROSS_ENCODER = "BAAI/bge-reranker-base"
model_hf_cross = HuggingFaceCrossEncoder(model_name=RERANKER_CROSS_ENCODER)
def complete_rag(question, selected_types, year_start, year_end):
"""
The process of retrieval augmented generation
:param question: user query
:return: sources and LLM ouput, generated using retrieved documents
"""
start_time = time.time()
vn.connect_to_postgres(host=Hostname, dbname=Database, user=Username, password=Password, port=Port)
training_data = vn.get_training_data()
print("training_data")
print(training_data)
sous_questions =generate_sous_questions(question)
question_llm, question_sql = sous_questions[0], sous_questions[1]
print("question to sql : ",question_sql)
print("question to llm : ",question_llm)
sql=vn.generate_sql(question=question_sql, allow_llm_to_see_data=True)
print(' \n \n sql : ',sql)
# Récupération des IDs et des liens `lien_med`
result_sql = vn.run_sql(sql)
list_id = result_sql['id'].tolist()
print("\n \n list_id : ",list_id)
if list_id==[]:
print("No documents", "Aucun document pouvant répondre à cette question n'a été trouvé dans la base.")
end_time = time.time()
print(f"Temps d'exécution pour complete_rag [split questions, vanna ai]: {end_time - start_time} secondes")
start_time = time.time()
# Handle the selection of document types
if "tous" in selected_types:
selected_types = ['avis_ct', 'transcription_ct', 'avis_ceesp', 'transcription_ceesp', 'questionnaire']
else:
selected_types = [doc_type for doc_type in selected_types if doc_type != 'tous']
# Convertir les années sélectionnées en entiers
year_min = int(year_start)
year_max = int(year_end)
# La plage d'années sélectionnée est définie par year_min et year_max
years = list(range(year_min, year_max + 1))
# search_kwargs avec le filtre des années
search_kwargs = {
"k": 500,
"filter": {
'$and': [
{'id_doc': {'$in': list_id}},
{'type': {'$in': selected_types}},
{'année': {'$in': years}} # Filtre sur les années sélectionnées
]
}
}
retriever = vectordb.as_retriever(search_kwargs=search_kwargs) #{"k": 500, "filter":{'id_doc': {'$in': list_id},'type': {'$in': ['avis_ct', 'transcription_ct']}}})
compressor = CrossEncoderReranker(model=model_hf_cross, top_n=60)
retrieval_agent_hg_crossencoder = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
from langchain_community.retrievers import BM25Retriever
retrieval_agent_bm25 = BM25Retriever.from_documents(retriever.get_relevant_documents(question_llm), k=60)
from langchain.retrievers import EnsembleRetriever
# initialize the ensemble retriever
ensemble_retriever = EnsembleRetriever(
retrievers=[retrieval_agent_bm25, retrieval_agent_hg_crossencoder], weights=[0.95, 0.05]
)
end_time = time.time()
print(f"Temps d'exécution pour complete_rag [ensemble retriever]: {end_time - start_time} secondes")
start_time = time.time()
print("retriever")
documents = ensemble_retriever.get_relevant_documents(question_llm)
print(len(documents), " chunks retrouvés")
docs_scored=[]
scores=[]
all_scores =[]
for index, doc in enumerate(documents):
# Passer la liste de documents au lieu d'un seul document
output = generate_score(question, context_formatting([doc]), "mistral")
all_scores.append(int(output))
#if(int(output)>1):
#docs_scored.append(doc)
#scores.append(int(output))
#print(len(docs_scored), "retrouvés après scores")
print("All scores : ", all_scores)
# Trier les documents gardés en fonction des scores
#docs_with_scores = list(zip(docs_scored, scores))
#docs_sorted_by_score = sorted(docs_with_scores, key=lambda x: x[1], reverse=True)
#docs_scored_sorted = [doc for doc, score in docs_sorted_by_score]
#scores_sorted = [score for doc, score in docs_sorted_by_score]
# Trier tous les documents en fonction des scores
all_docs_with_scores = list(zip(documents, all_scores))
all_docs_sorted_by_score = sorted(all_docs_with_scores, key=lambda x: x[1], reverse=True)
all_docs_scored_sorted = [doc for doc, score in all_docs_sorted_by_score]
all_scores_sorted = [score for doc, score in all_docs_sorted_by_score]
docs_ejected=[]
scores_ejected=[]
for index, score in enumerate(all_scores_sorted):
if score>2:
docs_scored.append(all_docs_scored_sorted[index])
scores.append(score)
else:
docs_ejected.append(all_docs_scored_sorted[index])
scores_ejected.append(score)
docs_with_scores = list(zip(docs_scored, scores))
end_time = time.time()
print(f"Temps d'exécution pour complete_rag [scoring pertinence]: {end_time - start_time} secondes")
start_time = time.time()
from collections import defaultdict
# Initialisation des variables pour stocker les documents et les scores regroupés
from collections import defaultdict
grouped_documents = defaultdict(list)
grouped_scores = defaultdict(list)
# On suppose que chaque document a une clé 'avis_id' dans ses métadonnées
for doc, score in docs_with_scores:
avis_id = doc.metadata['avis_id'] # Assurez-vous que 'avis_id' est bien dans les métadonnées
grouped_documents[avis_id].append(doc)
grouped_scores[avis_id].append(score)
# Convertir les dictionnaires en listes de listes
documents_regroupes_sorted = []
scores_regroupes_sorted = []
for avis_id in grouped_documents.keys():
# Récupérer les documents et scores pour cet avis_id
docs = grouped_documents[avis_id]
scores = grouped_scores[avis_id]
# Trier les paires (doc, score) en fonction des scores
sorted_pairs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True)
# Séparer les documents et scores après tri
sorted_docs, sorted_scores = zip(*sorted_pairs)
# Ajouter les listes triées aux résultats finaux
documents_regroupes_sorted.append(list(sorted_docs))
scores_regroupes_sorted.append(list(sorted_scores))
# Maintenant, documents_regroupes_sorted et scores_regroupes_sorted sont bien triés
# Afficher le nombre de groupes trouvés
print(f"{len(documents_regroupes_sorted)} avis retrouvés après regroupement des chunks")
end_time = time.time()
print(f"Temps d'exécution pour complete_rag [regroupement chunks par avis]: {end_time - start_time} secondes")
start_time = time.time()
# Appel de la fonction source_formatting avec les scores associés
sources = source_formatting(documents_regroupes_sorted, scores_regroupes_sorted, docs_ejected, scores_ejected)
outputs = ""
outputs_for_last_llm =""
final_docs=documents_regroupes_sorted
for index, doc in enumerate(tqdm(final_docs, desc="question sur chaque chunk - mistral")):
output = generate_2(question_llm, context_formatting(doc), "mistral")
outputs += "Réponse à l'avis numéro " + str(index+1) + " : " + output + "\n\n"
end_time = time.time()
print(f"Temps d'exécution pour complete_rag [boucle mistral question sur chaque avis]: {end_time - start_time} secondes")
start_time = time.time()
output_agreg = generate_agregated_2(outputs, question, "mistral")
synthese = "SYNTHESE : \n\n" +output_agreg + "\n\n\nREPONSE POUR CHAQUE AVIS : \n\n" + outputs
end_time = time.time()
print(f"Temps d'exécution pour complete_rag [question synthese]: {end_time - start_time} secondes")
start_time = time.time()
import psycopg2
# Connexion à la base de données PostgreSQL
conn = psycopg2.connect(host=Hostname, dbname=Database, user=Username, password=Password, port=Port)
cursor = conn.cursor()
# Conversion de la liste en une chaîne compatible SQL
id_string = ','.join(map(str, list_id))
# Requête SQL pour obtenir les nombres uniques
query = f"""
SELECT
COUNT(DISTINCT d.avis_id) AS unique_avis_count,
COUNT(DISTINCT d.medicament_id) AS unique_medicament_count,
COUNT(DISTINCT d.id) AS document_count
FROM
documents d
WHERE
d.id IN ({id_string});
"""
query_lien_meds = f"""
SELECT DISTINCT m.lien_med
FROM documents as d
JOIN medicaments m ON d.medicament_id = m.id
WHERE
d.id IN ({id_string});
"""
# Exécution de la requête
cursor.execute(query)
result = cursor.fetchone()
# Exécution de la requête pour les liens `lien_med`
cursor.execute(query_lien_meds)
result_lien_meds = cursor.fetchall()
# Conversion des résultats de `lien_meds` en une liste
lien_meds = [row[0] for row in result_lien_meds if row[0]] # Évite les valeurs nulles
# Affichage des résultats
unique_avis_count, unique_medicament_count, document_count = result
comptes = (f"Nombre de médicaments concernés par la question : {unique_medicament_count}<br>"
f"Nombre d'avis concernés par la question : {unique_avis_count}<br>"
f"Nombre de documents concernés par la question : {document_count}<br><br>"
"Liens des médicaments concernés :<br>" +
"<br>".join([f"[{lien}]({lien})" for lien in lien_meds])) # Conversion en liens Markdown cliquables avec balises HTML
# Fermeture de la connexion
cursor.close()
conn.close()
end_time = time.time()
print(f"Temps d'exécution pour complete_rag [recupération effectifs]: {end_time - start_time} secondes")
return sources, synthese, comptes
# for web view of prompting
# code below is copied from: https://www.youtube.com/watch?v=JEBDfGqrAUA (Project 2)
with gr.Blocks(theme=Base(), title="Q&A on your data with RAG") as demo:
gr.Markdown("# Q&A sur les documents de la HAS")
# Sélection du type de document
doc_type_selection = gr.CheckboxGroup(
choices=["tous", "avis_ct", "transcription_ct", "avis_ceesp", "transcription_ceesp", "questionnaire"],
label="Sélectionnez les types de documents",
value=["tous"] # Preselect "tous"
)
# Boîte déroulante pour sélectionner l'année de début
year_start_dropdown = gr.Dropdown(
choices=[str(year) for year in range(2000, 2025)], # De 2000 à 2024
value="2000", # Valeur par défaut
label="Sélectionnez l'année de début"
)
# Boîte déroulante pour sélectionner l'année de fin
year_end_dropdown = gr.Dropdown(
choices=[str(year) for year in range(2000, 2025)], # De 2000 à 2024
value="2024", # Valeur par défaut
label="Sélectionnez l'année de fin"
)
textbox = gr.Textbox(label="Question:")
with gr.Row():
button = gr.Button("Entrée", variant="primary")
with gr.Column():
output3 = gr.Markdown(label="Effectifs")
output2 = gr.Textbox(lines=1, max_lines=1000, label="Réponse générée")
output1 = gr.Markdown(label="Sources")
# Mise à jour des inputs pour inclure les deux boîtes déroulantes
button.click(complete_rag, inputs=[textbox, doc_type_selection, year_start_dropdown, year_end_dropdown], outputs=[output1, output2, output3])
demo.launch(share=True)