|
|
from doc_processor import DocProcessor |
|
|
from contextual_doc_processor import ContextualDocProcessor |
|
|
from embedding_retriever import EmbeddingRetriever |
|
|
from langchain_community.retrievers import BM25Retriever |
|
|
import numpy as np |
|
|
from huggingface_hub import InferenceClient |
|
|
import torch |
|
|
|
|
|
import json |
|
|
import os |
|
|
|
|
|
def get_list_dir(DATA_DIR): |
|
|
list_dir = os.listdir(DATA_DIR) |
|
|
list_dir = [d for d in list_dir if os.path.isdir(os.path.join(DATA_DIR, d))] |
|
|
return list_dir |
|
|
|
|
|
def get_data_paths(DATA_DIR, list_dir): |
|
|
list_files = [] |
|
|
for d in list_dir: |
|
|
start_path = os.path.join(DATA_DIR, d) |
|
|
filenames = os.listdir(start_path) |
|
|
filenames = [f for f in filenames if f.endswith(".json")] |
|
|
paths = [os.path.join(start_path, f) for f in filenames] |
|
|
list_files += paths |
|
|
return list_files |
|
|
|
|
|
def get_chunks(DATA_DIR, list_dir, use_context_wtcontext, use_context, PATH_SAVE_CHUNKS, PATH_SAVE_CONTEXT_CHUNKS): |
|
|
chunks = [] |
|
|
LIST_FILES = get_data_paths(DATA_DIR, list_dir) |
|
|
if use_context_wtcontext: |
|
|
doc_process = ContextualDocProcessor(LIST_FILES, PATH_SAVE_CONTEXT_CHUNKS) |
|
|
doc_process.process_data() |
|
|
chunks = doc_process.chunks |
|
|
doc_process = DocProcessor(LIST_FILES, PATH_SAVE_CHUNKS) |
|
|
doc_process.process_data() |
|
|
chunks += doc_process.chunks |
|
|
else: |
|
|
if use_context: |
|
|
doc_process = ContextualDocProcessor(LIST_FILES, PATH_SAVE_CONTEXT_CHUNKS) |
|
|
else: |
|
|
doc_process = DocProcessor(LIST_FILES, PATH_SAVE_CHUNKS) |
|
|
doc_process.process_data() |
|
|
chunks = doc_process.chunks |
|
|
|
|
|
return chunks |
|
|
|
|
|
def process_data(DATA_DIR, PATH_SAVE_CHUNKS, PATH_SAVE_CONTEXT_CHUNKS, use_context_wtcontext, use_context): |
|
|
list_dir = get_list_dir(DATA_DIR) |
|
|
chunks = get_chunks(DATA_DIR, list_dir, use_context_wtcontext, use_context, PATH_SAVE_CHUNKS, PATH_SAVE_CONTEXT_CHUNKS) |
|
|
return list_dir, chunks |
|
|
|
|
|
|
|
|
def add_embedding_retriever(embedding_models, embedding_model_name, path_idx, chunks, device): |
|
|
embedding_models[path_idx] = EmbeddingRetriever(embedding_model_name, path_idx, chunks, device) |
|
|
return embedding_models |
|
|
|
|
|
def add_BM25_retriever(chunks, TOP_K): |
|
|
return BM25Retriever.from_documents(chunks, k=TOP_K) |
|
|
|
|
|
def process_retrievers(embedding_model_names, chunks, TOP_K, use_context, use_context_wtcontext, PATH_IDX, PATH_IDX_CONTEXT, PATH_IDX_CONTEXT_AND_WT, device): |
|
|
embedding_models = {} |
|
|
if embedding_model_names: |
|
|
path_idx = PATH_IDX_CONTEXT_AND_WT if use_context_wtcontext else PATH_IDX_CONTEXT if use_context else PATH_IDX |
|
|
for embedding_model_name in embedding_model_names: |
|
|
embedding_models = add_embedding_retriever(embedding_models, embedding_model_name, path_idx, device) |
|
|
BM25_retriever = add_BM25_retriever(chunks, TOP_K) |
|
|
return embedding_models, BM25_retriever |
|
|
|
|
|
class Agent: |
|
|
def __init__(self, list_dir, chunks, embedding_models, BM25_retriever, TOP_K, |
|
|
reformulation=False, use_HyDE=False, use_HyDE_cut=False, ask_again=False): |
|
|
|
|
|
|
|
|
self.embedding_models = embedding_models |
|
|
self.BM25_retriever = BM25_retriever |
|
|
self.ranks = {} |
|
|
self.current_query = [] if reformulation else "" |
|
|
self.reformulation = reformulation |
|
|
self.use_HyDE = use_HyDE |
|
|
self.use_HyDE_cut = use_HyDE_cut |
|
|
self.ask_again = ask_again |
|
|
self.history = [] |
|
|
|
|
|
|
|
|
|
|
|
self.list_dir = list_dir |
|
|
self.TOP_K = TOP_K |
|
|
self.chunks = chunks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def select_sub_part_of_data(self, query): |
|
|
list_dept = "\n".join(f"- {dept}" for dept in self.list_dir) |
|
|
prompt = f"""Voici une liste de secteur d'activité: |
|
|
{list_dept} |
|
|
|
|
|
Quel est le secteur d'activité qui correspond à la question suivante : {query} |
|
|
Répondre uniquement avec le nom du secteur d'activité. |
|
|
""" |
|
|
|
|
|
response = InferenceClient().chat.completions.create( |
|
|
model="deepseek-ai/DeepSeek-V3-0324", |
|
|
messages=[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt |
|
|
} |
|
|
], |
|
|
) |
|
|
|
|
|
dept = response.choices[0].message.content |
|
|
return dept |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ask_a_question(self, query, nb_reformulation=5): |
|
|
self.ranks = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.reformulation and not self.use_HyDE and not self.use_HyDE_cut: |
|
|
self.current_query = query |
|
|
else: |
|
|
self.current_query = [query] |
|
|
if self.reformulation: |
|
|
self.reformulation_of_the_query(nb_reformulation) |
|
|
if self.use_HyDE: |
|
|
doc_hyde = self.generate_HyDE(query) |
|
|
self.current_query += [doc_hyde] |
|
|
if self.use_HyDE_cut: |
|
|
doc_hyde = self.generate_HyDE_cut(query) |
|
|
self.current_query += [doc_hyde] |
|
|
|
|
|
|
|
|
|
|
|
def generate_HyDE(self, query): |
|
|
prompt = f"Voila une question d'un utilisateur: '{query}'\n\nGénère une réponse hypothétique en français à cette question qui pourrait être présente sur un site d'information avec 1000 caratères maximum." |
|
|
response = InferenceClient().chat.completions.create( |
|
|
model="deepseek-ai/DeepSeek-V3-0324", |
|
|
messages=[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt |
|
|
} |
|
|
], |
|
|
) |
|
|
|
|
|
doc_gen = response.choices[0].message.content |
|
|
return doc_gen |
|
|
|
|
|
def generate_HyDE_cut(self, query): |
|
|
prompt = f"Voila une question d'un utilisateur: '{query}'\n\nGénère une réponse hypothétique en français à cette question qui pourrait être présente sur un site d'information." |
|
|
response = InferenceClient().chat.completions.create( |
|
|
model="deepseek-ai/DeepSeek-V3-0324", |
|
|
messages=[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt |
|
|
} |
|
|
], |
|
|
) |
|
|
|
|
|
doc_gen = response.choices[0].message.content[:1000] |
|
|
return doc_gen |
|
|
|
|
|
def reformulation_of_the_query(self, nb_reformulation=5): |
|
|
prompt = f"Reformule en français {nb_reformulation} fois la question '{self.current_query[0]}'" |
|
|
response = InferenceClient().chat.completions.create( |
|
|
model="deepseek-ai/DeepSeek-V3-0324", |
|
|
messages=[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt |
|
|
} |
|
|
], |
|
|
) |
|
|
|
|
|
queries = response.choices[0].message.content.split('\n') |
|
|
queries = [q for q in queries if "?" in q and not self.current_query[0] in q] |
|
|
self.current_query += queries |
|
|
|
|
|
def retrieve_data_from_embeddings(self): |
|
|
r = {} |
|
|
if self.reformulation or self.use_HyDE or self.use_HyDE_cut: |
|
|
for i, query in enumerate(self.current_query): |
|
|
for name_model, retriever in self.embedding_models.items(): |
|
|
r[str(i)+"-"+name_model] = retriever.retrieve_data(query, self.TOP_K) |
|
|
else: |
|
|
for name_model, retriever in self.embedding_models.items(): |
|
|
r[name_model] = retriever.retrieve_data(self.current_query, self.TOP_K) |
|
|
self.ranks.update(r) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def retrieve_data_from_BM25(self): |
|
|
if self.reformulation or self.use_HyDE or self.use_HyDE_cut: |
|
|
for i, query in enumerate(self.current_query): |
|
|
self.retrieve_query_from_BM25(query, str(i)+"-BM25") |
|
|
else: |
|
|
self.retrieve_query_from_BM25(self.current_query, "BM25") |
|
|
|
|
|
|
|
|
def retrieve_query_from_BM25(self, query, rank_name): |
|
|
top_k_docs = self.BM25_retriever.invoke(query) |
|
|
|
|
|
|
|
|
idx_bm25 = self.get_idx_from_lists(self.chunks, top_k_docs) |
|
|
self.ranks[rank_name] = np.array(idx_bm25) |
|
|
|
|
|
def get_idx_from_lists(self, main_list, sublist): |
|
|
idx_list = [] |
|
|
for sl in sublist: |
|
|
if sl in main_list: |
|
|
idx_list += [main_list.index(sl)] |
|
|
return idx_list |
|
|
|
|
|
def RRF(self, k=60): |
|
|
idx_score = {} |
|
|
for idx_list in self.ranks.values(): |
|
|
for rank, idx in enumerate(idx_list): |
|
|
if idx not in idx_score: |
|
|
idx_score[idx] = 1 / (k + rank) |
|
|
else: |
|
|
idx_score[idx] += 1 / (k + rank) |
|
|
|
|
|
return sorted(idx_score.items(), key=lambda x:x[1], reverse=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_chunks_from_rank(self, idx_list): |
|
|
chunks = [] |
|
|
for e in idx_list: |
|
|
id, score = e |
|
|
chunks += [self.chunks[id]] |
|
|
return chunks |
|
|
|
|
|
def create_prompt_with_context(self, rank): |
|
|
chunks_list = self.get_chunks_from_rank(rank) |
|
|
page_contents = [chunk.page_content + "\n\nSource:" + chunk.metadata['source'] for chunk in chunks_list] |
|
|
context = "\n\n".join(page_contents) |
|
|
query = self.current_query[0] if self.reformulation or self.use_HyDE or self.use_HyDE_cut else self.current_query |
|
|
prompt = f"Répondez à la question suivante en utilisant le contexte ci-dessous:\n\nContexte:\n{context}\n\nQuestion: {query}" |
|
|
self.history += [{ |
|
|
"role": "user", |
|
|
"content": prompt, |
|
|
}] |
|
|
return prompt, page_contents |
|
|
|
|
|
def ask_agent(self, cpt=1): |
|
|
query = self.history[-1]['content'] |
|
|
response = InferenceClient().chat_completion( |
|
|
model="deepseek-ai/DeepSeek-V3-0324", |
|
|
messages = self.history, |
|
|
temperature = 0, |
|
|
seed = 0 |
|
|
) |
|
|
reply = response.choices[0].message.content |
|
|
self.history += [{ |
|
|
"role": "assistant", |
|
|
"content": reply, |
|
|
}] |
|
|
if self.ask_again: |
|
|
good_reply = self.ask_again_agent(cpt) |
|
|
if not good_reply: |
|
|
start = "Peux-tu donner une meilleur réponse à cette question:" |
|
|
if not query.startswith(start): |
|
|
query = f"Peux-tu donner une meilleur réponse à cette question: \n\n {query}" |
|
|
self.ask_a_question(query) |
|
|
self.history += [{ |
|
|
"role": "user", |
|
|
"content": query, |
|
|
}] |
|
|
reply = self.ask_agent(cpt+1) |
|
|
|
|
|
return reply |
|
|
|
|
|
def ask_again_agent(self, cpt=1): |
|
|
prompt = f"Est ce que la réponse donnée est satisfaisante ? Répondre uniquement par Oui ou Non." |
|
|
self.history += [{ |
|
|
"role": "user", |
|
|
"content": prompt, |
|
|
}] |
|
|
response = InferenceClient().chat.completions.create( |
|
|
model="deepseek-ai/DeepSeek-V3-0324", |
|
|
messages = self.history, |
|
|
options = {'temperature': 0} |
|
|
) |
|
|
reply = response.choices[0].message.content |
|
|
self.history += [{ |
|
|
"role": "assistant", |
|
|
"content": reply, |
|
|
}] |
|
|
|
|
|
if "Oui" in reply or cpt > 4: |
|
|
return True |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_url_from_paths(self, paths): |
|
|
url = [] |
|
|
for p in paths: |
|
|
with open(p) as json_file: |
|
|
data = json.load(json_file) |
|
|
url += [data['url']] |
|
|
return url |
|
|
|
|
|
def get_a_reply(self, query): |
|
|
self.ask_a_question(query) |
|
|
|
|
|
self.retrieve_data_from_BM25() |
|
|
|
|
|
rank = self.RRF() |
|
|
prompt, chunks = self.create_prompt_with_context(rank) |
|
|
|
|
|
reply = self.ask_agent() |
|
|
|
|
|
self.retrieve_query_from_BM25(reply, 'source') |
|
|
sources_BM25 = [self.chunks[r].metadata['source'] for r in self.ranks['source']] |
|
|
sources = list({s for s in sources_BM25 if sources_BM25.count(s) > 1}) |
|
|
if not sources: |
|
|
sources = [sources_BM25[0]] |
|
|
sources = self.get_url_from_paths(sources) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return (reply, sources) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
DATA_DIR = "data_websites" |
|
|
PATH_SAVE_CHUNKS = "chunks_saved.json" |
|
|
PATH_SAVE_CONTEXT = "chunks_with_context.json" |
|
|
PATH_IDX = "index_faiss_data_sh" |
|
|
PATH_IDX_CONTEXT = "index_faiss_context_sh" |
|
|
PATH_IDX_CONTEXT_AND_WT= "index_faiss_context_and_wt_sh" |
|
|
|
|
|
embedding_model_names = [] |
|
|
|
|
|
agent_name = "Geotrend/distilbert-base-en-fr-cased" |
|
|
TOP_K = 10 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
query = "Comment se déroule la formation \"Accompagner les personnes âgées, comprendre le vieillissement et ses conséquences\" ?" |
|
|
|
|
|
use_context = False |
|
|
|
|
|
reformulation = False |
|
|
|
|
|
use_HyDE = False |
|
|
|
|
|
use_HyDE_cut = False |
|
|
use_context_wtcontext= True |
|
|
|
|
|
|
|
|
ask_again = False |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
|
|
|
list_dir, chunks = process_data(DATA_DIR, PATH_SAVE_CHUNKS, PATH_SAVE_CONTEXT, use_context_wtcontext, use_context) |
|
|
embedding_models, BM25_retriever = process_retrievers(embedding_model_names, chunks, TOP_K, use_context, |
|
|
use_context_wtcontext, PATH_IDX, PATH_IDX_CONTEXT, |
|
|
PATH_IDX_CONTEXT_AND_WT, device) |
|
|
|
|
|
agent = Agent(list_dir, chunks, embedding_models, BM25_retriever, TOP_K, reformulation, |
|
|
use_HyDE, use_HyDE_cut, ask_again) |
|
|
|
|
|
agent.ask_a_question(query) |
|
|
|
|
|
agent.retrieve_data_from_embeddings() |
|
|
agent.retrieve_data_from_BM25() |
|
|
|
|
|
print(agent.ranks) |
|
|
for name, lr in agent.ranks.items(): |
|
|
print("*******",name,"*******") |
|
|
for r in lr: |
|
|
print(r) |
|
|
|
|
|
|
|
|
|
|
|
print(agent.chunks[r]) |
|
|
|
|
|
rank = agent.RRF() |
|
|
prompt, chunks = agent.create_prompt_with_context(rank) |
|
|
|
|
|
reply = agent.ask_agent() |
|
|
print("***************") |
|
|
for diag in agent.history: |
|
|
print(diag) |
|
|
|
|
|
print("***************") |
|
|
print(reply) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nRank retrievers:") |
|
|
for chunk in chunks[:10]: |
|
|
print(f"Chunk: {chunk.split('Source:')[-1]}") |
|
|
|
|
|
agent.retrieve_query_from_BM25(reply, 'source') |
|
|
|
|
|
print("\nBM25 rerank:") |
|
|
for r in agent.ranks['source']: |
|
|
print(agent.chunks[r].metadata['source']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|