Legal-Chatbot / core /retrieval.py
msi
first commit
f7b069f
import re
from sklearn.metrics.pairwise import cosine_similarity
from core.config import collection, client_embedding, embedding_model
def get_query_embedding(query: str):
"""Retourne l'embedding de la requête"""
response = client_embedding.embeddings.create(input=query, model=embedding_model)
return response.data[0].embedding
def extract_article_number(query: str):
"""Extrait le numéro d'article explicitement mentionné"""
match = re.search(r'article\s*(\w+)', query, re.IGNORECASE)
if match:
word = match.group(1).lower()
if word == "premier":
return "Article premier"
elif word.isdigit():
return f"Article {word}"
return None
def find_relevant_articles(query: str, threshold: float = 0.8, max_articles: int = 10):
"""Trouve les articles les plus similaires à la requête"""
article_num = extract_article_number(query)
if article_num:
doc = collection.find_one({"article_num": article_num})
if doc:
return [(doc, 1.0)]
query_vector = get_query_embedding(query)
similarities = []
for doc in collection.find():
article_vector = doc.get("embedding2")
if article_vector:
sim = cosine_similarity([query_vector], [article_vector])[0][0]
if sim >= threshold:
similarities.append((doc, sim))
similarities.sort(key=lambda x: x[1], reverse=True)
return similarities[:max_articles]