RAGHospital / rag_engine.py
MaxtheGenius's picture
Upload 6 files
bc96d14 verified
"""
Motor RAG de la práctica final (RA3): embeddings, recuperación por similitud coseno y generación
con Pleias-RAG-350M.
Autor: Maxime Pol Marcet.
Soy **Maxime Pol Marcet**, y este módulo lo he escrito **entero yo**: es el núcleo técnico
del chatbot y concentra toda la lógica que el enunciado (**MOPT DAW2**) exige con nombres
de función y flujo determinados.
Al **inicio del script** cargo una sola vez lo que prescribe la práctica: el modelo de
embeddings `SentenceTransformer("MongoDB/mdbr-leaf-ir")`, el tokenizer y el modelo causal
`PleIAs/Pleias-RAG-350M`, y los textos desde **`documents.json`**. He precalculado y
normalizado los embeddings de los documentos para no repetir trabajo en cada consulta; la
similitud la obtengo por **coseno** entre el vector de la pregunta y los de la base,
ordenando de mayor a menor puntuación y aplicando **umbral** y **top_k** tal como indica el
enunciado oficial.
He implementado **`recuperar_documentos`**, **`generar_respuesta`** y **`preguntar`** con
las firmas obligatorias; el **prompt** lo construyo en inglés, con el formato literal del
MOPT, porque el modelo y los documentos de ejemplo están en esa lengua y así evito
desajustes en la generación.
Los detalles de generación (`max_new_tokens`, penalizaciones) los ajusté para un LM
relativamente pequeño; son decisiones mías de ingeniería, no del enunciado mínimo.
"""
import json
from pathlib import Path
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import DEFAULT_TOP_K, DEFAULT_UMBRAL
_BASE_DIR = Path(__file__).resolve().parent
with open(_BASE_DIR / "documents.json", encoding="utf-8") as _f:
_raw_docs = json.load(_f)
_document_texts = [_raw_docs[k] for k in sorted(_raw_docs.keys())]
embedding_model = SentenceTransformer("MongoDB/mdbr-leaf-ir")
_doc_embeddings = embedding_model.encode(
_document_texts,
convert_to_numpy=True,
show_progress_bar=False,
)
_norms = np.linalg.norm(_doc_embeddings, axis=1, keepdims=True)
_norms[_norms == 0] = 1.0
_doc_embeddings_norm = _doc_embeddings / _norms
tokenizer = AutoTokenizer.from_pretrained("PleIAs/Pleias-RAG-350M")
lm_model = AutoModelForCausalLM.from_pretrained("PleIAs/Pleias-RAG-350M")
lm_model.eval()
_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lm_model.to(_device)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
_EOS_ID = tokenizer.eos_token_id
_MAX_CONTEXT_TOKENS = 2048
def _safe_tokenize(prompt: str):
ml = getattr(tokenizer, "model_max_length", None)
cap = _MAX_CONTEXT_TOKENS
if isinstance(ml, int) and 0 < ml < 100_000:
cap = min(cap, ml)
return tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=cap,
).to(_device)
def _embed_consulta(consulta):
q_emb = embedding_model.encode([consulta], convert_to_numpy=True)[0]
q_n = np.linalg.norm(q_emb)
if q_n > 0:
q_emb = q_emb / q_n
return q_emb
def _retrieve_ranked(consulta, top_k, umbral):
q_emb = _embed_consulta(consulta)
sims = _doc_embeddings_norm @ q_emb
order = np.argsort(-sims)
out = []
for idx in order:
idx = int(idx)
s = float(sims[idx])
if s >= umbral and len(out) < top_k:
out.append((s, _document_texts[idx]))
return out
def recuperar_documentos(
consulta: str,
top_k: int = DEFAULT_TOP_K,
umbral: float = DEFAULT_UMBRAL,
) -> list[str]:
"""
MOPT DAW2: recupera textos por similitud coseno (embeddings precargados), orden descendente,
filtrando por umbral y limitando a top_k. Los valores por defecto coinciden con `config.py`.
"""
ranked = _retrieve_ranked(consulta, int(top_k), float(umbral))
return [text for _, text in ranked]
def _build_prompt(consulta: str, documentos_recuperados: list[str]) -> str:
context = " ".join(documentos_recuperados)
return (
"Answer the question based only on the context provided\n"
f"Context: {context}\n"
f"Question: {consulta}\n"
"Answer:"
)
def _generate_kwargs(do_sample: bool = False, temperature: float = 0.7):
kw = dict(
max_new_tokens=128,
do_sample=do_sample,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=_EOS_ID,
repetition_penalty=1.18,
no_repeat_ngram_size=4,
)
if do_sample:
kw["temperature"] = max(0.01, min(1.5, float(temperature)))
kw["top_p"] = 0.92
return kw
def generar_respuesta(consulta: str, documentos_recuperados: list[str]) -> str:
"""
MOPT DAW2: concatena documentos con espacios, aplica el prompt indicado y genera con el LM.
"""
prompt = _build_prompt(consulta, documentos_recuperados)
inputs = _safe_tokenize(prompt)
input_len = inputs["input_ids"].shape[1]
gen_kw = _generate_kwargs(do_sample=False, temperature=0.7)
with torch.inference_mode():
out = lm_model.generate(**inputs, **gen_kw)
generated_ids = out[0][input_len:]
return tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
def preguntar(consulta: str, top_k: int = DEFAULT_TOP_K, umbral: float = DEFAULT_UMBRAL) -> str:
"""
MOPT DAW2: `recuperar_documentos` + `generar_respuesta`; devuelve solo el texto generado.
"""
documentos_recuperados = recuperar_documentos(consulta, top_k=top_k, umbral=umbral)
return generar_respuesta(consulta, documentos_recuperados)