Spaces:
Runtime error
Runtime error
| # pip install gradio langchain gpt4all chromadb pypdf tiktoken | |
| # pip install --quiet gradio langchain gpt4all chromadb pypdf tiktoken | |
| # imports | |
| import os | |
| import gradio as gr | |
| from gradio.themes.base import Base | |
| import glob | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import DirectoryLoader | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.document_loaders import TextLoader | |
| from langchain_community.vectorstores import Chroma | |
| from langchain_community.embeddings import GPT4AllEmbeddings | |
| from langchain_community.chat_models import ChatOllama | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| import getpass | |
| import json | |
| from tqdm import tqdm | |
| # Import necessary modules | |
| from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriever | |
| from langchain.retrievers.document_compressors import CrossEncoderReranker | |
| from langchain_community.cross_encoders import HuggingFaceCrossEncoder | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnableLambda, RunnableParallel, RunnablePassthrough | |
| from typing import Sequence, Any, Dict | |
| from langchain.schema import Document | |
| import time | |
| from functions_rag_chat_v3 import * | |
| os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
| os.environ["LANGCHAIN_API_KEY"] = os.getenv("LANGCHAIN_API_KEY") | |
| if __name__ == "__main__": | |
| """ | |
| main function | |
| """ | |
| print("Starting program") | |
| start_time = time.time() | |
| # define what LLM to use | |
| use_llm = "mistral" | |
| #use_llm = "phe-v2-gguf" | |
| # define what embedding model to use | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| model_name = "clairedhx/autotrain-v2" | |
| token=os.getenv("hugging_face_token") | |
| model_kwargs = {'device': 'cuda', 'token': token} | |
| encode_kwargs = {'normalize_embeddings': False} | |
| embedding = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs | |
| ) | |
| #print(embedding) | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour l'initialisation des embeddings: {end_time - start_time} secondes") | |
| # directory to persistently store the vector embedding store | |
| db_directory = '/home/onyxia/phe/scripts/chroma_db' | |
| #test a parir de dataframe pour avoir metadata | |
| from datetime import datetime | |
| import pandas as pd | |
| #start_time = time.time() | |
| #df = pd.read_csv('/home/onyxia/phe/scripts/gestion_base/documents_with_metadata_all_med_21_08_24.csv') | |
| # Conversion de 'date_avis' en année | |
| #df['année'] = pd.to_datetime(df['date_avis'], format='%Y-%m-%d').dt.year | |
| from langchain_community.document_loaders import DataFrameLoader | |
| #loader = DataFrameLoader(df, page_content_column="texte") | |
| #docs= loader.load() | |
| #splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(separators=["\n\n", "\n","."], chunk_size=400, chunk_overlap=150) | |
| #splits = splitter.split_documents(docs) | |
| #print("splits : ", len(splits)) | |
| #vectordb= Chroma.from_documents(documents=splits, collection_name="chromemwah", embedding=embedding, persist_directory=db_directory) | |
| #vectordb.persist() | |
| #end_time = time.time() | |
| print(f"Temps d'exécution pour le chargement des documents, split et créer base chromaDB: {end_time - start_time} secondes") | |
| start_time = time.time() | |
| vectordb = Chroma(persist_directory=db_directory, embedding_function=embedding, collection_name="chromemwah") | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour le chargement de la base de données persistante: {end_time - start_time} secondes") | |
| ############################################### | |
| #RECUPERATION VANNA AI | |
| ############################################### | |
| from dotenv import main | |
| import os | |
| print("Récupération des informations de connection") | |
| start_time = time.time() | |
| # Charger les variables d'environnement à partir du fichier .env | |
| main.load_dotenv() | |
| # Accéder aux variables d'environnement | |
| Hostname = os.getenv("Hostname") | |
| Port = os.getenv("Port") | |
| Database = os.getenv("Database") | |
| Username = os.getenv("Username") | |
| Password = os.getenv("Password") | |
| from vanna.ollama import Ollama | |
| from vanna.chromadb import ChromaDB_VectorStore | |
| class MyVanna(ChromaDB_VectorStore, Ollama): | |
| def __init__(self, config=None): | |
| ChromaDB_VectorStore.__init__(self, config=config) | |
| Ollama.__init__(self, config=config) | |
| vn = MyVanna(config={'model': 'mistral'}) | |
| vn.connect_to_postgres(host=Hostname, dbname=Database, user=Username, password=Password, port=Port) # Connect to your database here | |
| vn.train(ddl=""" | |
| CREATE TABLE IF NOT EXISTS medicaments ( | |
| id SERIAL PRIMARY KEY, | |
| nom VARCHAR(2555) NOT NULL, | |
| nombre_avis INTEGER, | |
| nombre_docs INTEGER, | |
| DCI VARCHAR(2555), | |
| exploitant VARCHAR(2555), | |
| codes_ATC TEXT[], | |
| cip TEXT[] | |
| ); | |
| CREATE TABLE IF NOT EXISTS avis ( | |
| id SERIAL PRIMARY KEY, | |
| numero_avis VARCHAR(255) NOT NULL, | |
| maladie VARCHAR(255), | |
| aires_therapeutiques TEXT[], | |
| date_avis DATE, | |
| nombre_docs INTEGER, | |
| medicament_id INTEGER REFERENCES medicaments(id), | |
| smr smr_type, | |
| asmr asmr_type | |
| ); | |
| CREATE TABLE IF NOT EXISTS documents ( | |
| id SERIAL PRIMARY KEY, | |
| titre_doc VARCHAR(300) NOT NULL, | |
| type document_type NOT NULL, | |
| indication VARCHAR(100000), | |
| medicament_id INTEGER REFERENCES medicaments(id), | |
| avis_id INTEGER REFERENCES avis(id), | |
| lien_doc VARCHAR(255), | |
| transcription_ct_associee INTEGER[], | |
| avis_ct_associe INTEGER[], | |
| transcription_ceesp_associee INTEGER[], | |
| avis_ceesp_associe INTEGER[], | |
| questionnaire_associe INTEGER[], | |
| texte TEXT -- Nouveau champ pour stocker le texte extrait | |
| ); | |
| """) | |
| import json | |
| # Load the JSON file | |
| with open('/home/onyxia/phe/scripts/modeles/text_to_SQL/entrainement_augmented.json', 'r') as file: | |
| data = json.load(file) | |
| # Train Vanna with the SQL query pairs | |
| for pair in data: | |
| question = pair['question_to_sql'] | |
| sql = pair['sql'] | |
| vn.train(question=question.strip(), sql=sql.strip()) | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour la connexion à la base de données et l'entraînement de Vanna: {end_time - start_time} secondes") | |
| #################################################################### | |
| #################################################################### | |
| RERANKER_CROSS_ENCODER = "BAAI/bge-reranker-base" | |
| model_hf_cross = HuggingFaceCrossEncoder(model_name=RERANKER_CROSS_ENCODER) | |
| def complete_rag(question, selected_types, year_start, year_end): | |
| """ | |
| The process of retrieval augmented generation | |
| :param question: user query | |
| :return: sources and LLM ouput, generated using retrieved documents | |
| """ | |
| start_time = time.time() | |
| vn.connect_to_postgres(host=Hostname, dbname=Database, user=Username, password=Password, port=Port) | |
| training_data = vn.get_training_data() | |
| print("training_data") | |
| print(training_data) | |
| sous_questions =generate_sous_questions(question) | |
| question_llm, question_sql = sous_questions[0], sous_questions[1] | |
| print("question to sql : ",question_sql) | |
| print("question to llm : ",question_llm) | |
| sql=vn.generate_sql(question=question_sql, allow_llm_to_see_data=True) | |
| print(' \n \n sql : ',sql) | |
| # Récupération des IDs et des liens `lien_med` | |
| result_sql = vn.run_sql(sql) | |
| list_id = result_sql['id'].tolist() | |
| print("\n \n list_id : ",list_id) | |
| if list_id==[]: | |
| print("No documents", "Aucun document pouvant répondre à cette question n'a été trouvé dans la base.") | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour complete_rag [split questions, vanna ai]: {end_time - start_time} secondes") | |
| start_time = time.time() | |
| # Handle the selection of document types | |
| if "tous" in selected_types: | |
| selected_types = ['avis_ct', 'transcription_ct', 'avis_ceesp', 'transcription_ceesp', 'questionnaire'] | |
| else: | |
| selected_types = [doc_type for doc_type in selected_types if doc_type != 'tous'] | |
| # Convertir les années sélectionnées en entiers | |
| year_min = int(year_start) | |
| year_max = int(year_end) | |
| # La plage d'années sélectionnée est définie par year_min et year_max | |
| years = list(range(year_min, year_max + 1)) | |
| # search_kwargs avec le filtre des années | |
| search_kwargs = { | |
| "k": 500, | |
| "filter": { | |
| '$and': [ | |
| {'id_doc': {'$in': list_id}}, | |
| {'type': {'$in': selected_types}}, | |
| {'année': {'$in': years}} # Filtre sur les années sélectionnées | |
| ] | |
| } | |
| } | |
| retriever = vectordb.as_retriever(search_kwargs=search_kwargs) #{"k": 500, "filter":{'id_doc': {'$in': list_id},'type': {'$in': ['avis_ct', 'transcription_ct']}}}) | |
| compressor = CrossEncoderReranker(model=model_hf_cross, top_n=60) | |
| retrieval_agent_hg_crossencoder = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever) | |
| from langchain_community.retrievers import BM25Retriever | |
| retrieval_agent_bm25 = BM25Retriever.from_documents(retriever.get_relevant_documents(question_llm), k=60) | |
| from langchain.retrievers import EnsembleRetriever | |
| # initialize the ensemble retriever | |
| ensemble_retriever = EnsembleRetriever( | |
| retrievers=[retrieval_agent_bm25, retrieval_agent_hg_crossencoder], weights=[0.95, 0.05] | |
| ) | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour complete_rag [ensemble retriever]: {end_time - start_time} secondes") | |
| start_time = time.time() | |
| print("retriever") | |
| documents = ensemble_retriever.get_relevant_documents(question_llm) | |
| print(len(documents), " chunks retrouvés") | |
| docs_scored=[] | |
| scores=[] | |
| all_scores =[] | |
| for index, doc in enumerate(documents): | |
| # Passer la liste de documents au lieu d'un seul document | |
| output = generate_score(question, context_formatting([doc]), "mistral") | |
| all_scores.append(int(output)) | |
| #if(int(output)>1): | |
| #docs_scored.append(doc) | |
| #scores.append(int(output)) | |
| #print(len(docs_scored), "retrouvés après scores") | |
| print("All scores : ", all_scores) | |
| # Trier les documents gardés en fonction des scores | |
| #docs_with_scores = list(zip(docs_scored, scores)) | |
| #docs_sorted_by_score = sorted(docs_with_scores, key=lambda x: x[1], reverse=True) | |
| #docs_scored_sorted = [doc for doc, score in docs_sorted_by_score] | |
| #scores_sorted = [score for doc, score in docs_sorted_by_score] | |
| # Trier tous les documents en fonction des scores | |
| all_docs_with_scores = list(zip(documents, all_scores)) | |
| all_docs_sorted_by_score = sorted(all_docs_with_scores, key=lambda x: x[1], reverse=True) | |
| all_docs_scored_sorted = [doc for doc, score in all_docs_sorted_by_score] | |
| all_scores_sorted = [score for doc, score in all_docs_sorted_by_score] | |
| docs_ejected=[] | |
| scores_ejected=[] | |
| for index, score in enumerate(all_scores_sorted): | |
| if score>2: | |
| docs_scored.append(all_docs_scored_sorted[index]) | |
| scores.append(score) | |
| else: | |
| docs_ejected.append(all_docs_scored_sorted[index]) | |
| scores_ejected.append(score) | |
| docs_with_scores = list(zip(docs_scored, scores)) | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour complete_rag [scoring pertinence]: {end_time - start_time} secondes") | |
| start_time = time.time() | |
| from collections import defaultdict | |
| # Initialisation des variables pour stocker les documents et les scores regroupés | |
| from collections import defaultdict | |
| grouped_documents = defaultdict(list) | |
| grouped_scores = defaultdict(list) | |
| # On suppose que chaque document a une clé 'avis_id' dans ses métadonnées | |
| for doc, score in docs_with_scores: | |
| avis_id = doc.metadata['avis_id'] # Assurez-vous que 'avis_id' est bien dans les métadonnées | |
| grouped_documents[avis_id].append(doc) | |
| grouped_scores[avis_id].append(score) | |
| # Convertir les dictionnaires en listes de listes | |
| documents_regroupes_sorted = [] | |
| scores_regroupes_sorted = [] | |
| for avis_id in grouped_documents.keys(): | |
| # Récupérer les documents et scores pour cet avis_id | |
| docs = grouped_documents[avis_id] | |
| scores = grouped_scores[avis_id] | |
| # Trier les paires (doc, score) en fonction des scores | |
| sorted_pairs = sorted(zip(docs, scores), key=lambda x: x[1], reverse=True) | |
| # Séparer les documents et scores après tri | |
| sorted_docs, sorted_scores = zip(*sorted_pairs) | |
| # Ajouter les listes triées aux résultats finaux | |
| documents_regroupes_sorted.append(list(sorted_docs)) | |
| scores_regroupes_sorted.append(list(sorted_scores)) | |
| # Maintenant, documents_regroupes_sorted et scores_regroupes_sorted sont bien triés | |
| # Afficher le nombre de groupes trouvés | |
| print(f"{len(documents_regroupes_sorted)} avis retrouvés après regroupement des chunks") | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour complete_rag [regroupement chunks par avis]: {end_time - start_time} secondes") | |
| start_time = time.time() | |
| # Appel de la fonction source_formatting avec les scores associés | |
| sources = source_formatting(documents_regroupes_sorted, scores_regroupes_sorted, docs_ejected, scores_ejected) | |
| outputs = "" | |
| outputs_for_last_llm ="" | |
| final_docs=documents_regroupes_sorted | |
| for index, doc in enumerate(tqdm(final_docs, desc="question sur chaque chunk - mistral")): | |
| output = generate_2(question_llm, context_formatting(doc), "mistral") | |
| outputs += "Réponse à l'avis numéro " + str(index+1) + " : " + output + "\n\n" | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour complete_rag [boucle mistral question sur chaque avis]: {end_time - start_time} secondes") | |
| start_time = time.time() | |
| output_agreg = generate_agregated_2(outputs, question, "mistral") | |
| synthese = "SYNTHESE : \n\n" +output_agreg + "\n\n\nREPONSE POUR CHAQUE AVIS : \n\n" + outputs | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour complete_rag [question synthese]: {end_time - start_time} secondes") | |
| start_time = time.time() | |
| import psycopg2 | |
| # Connexion à la base de données PostgreSQL | |
| conn = psycopg2.connect(host=Hostname, dbname=Database, user=Username, password=Password, port=Port) | |
| cursor = conn.cursor() | |
| # Conversion de la liste en une chaîne compatible SQL | |
| id_string = ','.join(map(str, list_id)) | |
| # Requête SQL pour obtenir les nombres uniques | |
| query = f""" | |
| SELECT | |
| COUNT(DISTINCT d.avis_id) AS unique_avis_count, | |
| COUNT(DISTINCT d.medicament_id) AS unique_medicament_count, | |
| COUNT(DISTINCT d.id) AS document_count | |
| FROM | |
| documents d | |
| WHERE | |
| d.id IN ({id_string}); | |
| """ | |
| query_lien_meds = f""" | |
| SELECT DISTINCT m.lien_med | |
| FROM documents as d | |
| JOIN medicaments m ON d.medicament_id = m.id | |
| WHERE | |
| d.id IN ({id_string}); | |
| """ | |
| # Exécution de la requête | |
| cursor.execute(query) | |
| result = cursor.fetchone() | |
| # Exécution de la requête pour les liens `lien_med` | |
| cursor.execute(query_lien_meds) | |
| result_lien_meds = cursor.fetchall() | |
| # Conversion des résultats de `lien_meds` en une liste | |
| lien_meds = [row[0] for row in result_lien_meds if row[0]] # Évite les valeurs nulles | |
| # Affichage des résultats | |
| unique_avis_count, unique_medicament_count, document_count = result | |
| comptes = (f"Nombre de médicaments concernés par la question : {unique_medicament_count}<br>" | |
| f"Nombre d'avis concernés par la question : {unique_avis_count}<br>" | |
| f"Nombre de documents concernés par la question : {document_count}<br><br>" | |
| "Liens des médicaments concernés :<br>" + | |
| "<br>".join([f"[{lien}]({lien})" for lien in lien_meds])) # Conversion en liens Markdown cliquables avec balises HTML | |
| # Fermeture de la connexion | |
| cursor.close() | |
| conn.close() | |
| end_time = time.time() | |
| print(f"Temps d'exécution pour complete_rag [recupération effectifs]: {end_time - start_time} secondes") | |
| return sources, synthese, comptes | |
| # for web view of prompting | |
| # code below is copied from: https://www.youtube.com/watch?v=JEBDfGqrAUA (Project 2) | |
| with gr.Blocks(theme=Base(), title="Q&A on your data with RAG") as demo: | |
| gr.Markdown("# Q&A sur les documents de la HAS") | |
| # Sélection du type de document | |
| doc_type_selection = gr.CheckboxGroup( | |
| choices=["tous", "avis_ct", "transcription_ct", "avis_ceesp", "transcription_ceesp", "questionnaire"], | |
| label="Sélectionnez les types de documents", | |
| value=["tous"] # Preselect "tous" | |
| ) | |
| # Boîte déroulante pour sélectionner l'année de début | |
| year_start_dropdown = gr.Dropdown( | |
| choices=[str(year) for year in range(2000, 2025)], # De 2000 à 2024 | |
| value="2000", # Valeur par défaut | |
| label="Sélectionnez l'année de début" | |
| ) | |
| # Boîte déroulante pour sélectionner l'année de fin | |
| year_end_dropdown = gr.Dropdown( | |
| choices=[str(year) for year in range(2000, 2025)], # De 2000 à 2024 | |
| value="2024", # Valeur par défaut | |
| label="Sélectionnez l'année de fin" | |
| ) | |
| textbox = gr.Textbox(label="Question:") | |
| with gr.Row(): | |
| button = gr.Button("Entrée", variant="primary") | |
| with gr.Column(): | |
| output3 = gr.Markdown(label="Effectifs") | |
| output2 = gr.Textbox(lines=1, max_lines=1000, label="Réponse générée") | |
| output1 = gr.Markdown(label="Sources") | |
| # Mise à jour des inputs pour inclure les deux boîtes déroulantes | |
| button.click(complete_rag, inputs=[textbox, doc_type_selection, year_start_dropdown, year_end_dropdown], outputs=[output1, output2, output3]) | |
| demo.launch(share=True) | |