python-rag-chatbot / rag_engine.py
DiegoPerezPonce's picture
Upload 4 files
c9af776 verified
import json
import torch
from sentence_transformers import SentenceTransformer, util
from transformers import AutoTokenizer, AutoModelForCausalLM
from sklearn.metrics.pairwise import cosine_similarity
# --- Carga de Modelos y Datos ---
# Modelo de embeddings
embedding_model = SentenceTransformer("MongoDB/mdbr-leaf-ir")
# Modelo de lenguaje y tokenizer
model_name = "PleIAs/Pleias-RAG-350M"
tokenizer = AutoTokenizer.from_pretrained(model_name)
llm_model = AutoModelForCausalLM.from_pretrained(model_name)
# Cargar documentos
with open("documents.json", "r", encoding="utf-8") as f:
docs_data = json.load(f)
# Extraemos solo el texto de los documentos
docs_texts = list(docs_data.values())
# Precalcular embeddings de los documentos (una sola vez)
docs_embeddings = embedding_model.encode(docs_texts)
def recuperar_documentos(consulta, top_k=2, umbral=0.4):
"""Recupera los documentos más similares a la consulta."""
# 1. Calcular embedding de la consulta
query_embedding = embedding_model.encode([consulta])
# 2. Calcular similitud del coseno
similitudes = cosine_similarity(query_embedding, docs_embeddings)[0]
# 3. Emparejar textos con sus similitudes y ordenar
docs_con_similitud = sorted(
zip(docs_texts, similitudes),
key=lambda x: x[1],
reverse=True
)
# 4. Filtrar por umbral y top_k
seleccionados = []
for texto, sim in docs_con_similitud:
if sim >= umbral and len(seleccionados) < top_k:
seleccionados.append(texto)
return seleccionados
def generar_respuesta(consulta, documentos_recuperados):
"""Genera una respuesta usando el contexto inyectado."""
# 1. Concatenar documentos
contexto = " ".join(documentos_recuperados)
# 2. Construir el prompt (formato exacto pedido)
prompt = f"Answer the question based only on the context provided\nContext: {contexto}\nQuestion: {consulta}\nAnswer:"
# 3. Generar respuesta
inputs = tokenizer(prompt, return_tensors="pt")
outputs = llm_model.generate(**inputs, max_new_tokens=150)
respuesta_completa = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extraer solo la parte después de "Answer:"
respuesta = respuesta_completa.split("Answer:")[-1].strip()
return respuesta
def preguntar(consulta, top_k=2, umbral=0.4):
"""Función de alto nivel que une recuperación y generación."""
docs = recuperar_documentos(consulta, top_k, umbral)
if not docs:
return "I'm sorry, I couldn't find relevant information in the knowledge base."
respuesta = generar_respuesta(consulta, docs)
return respuesta