Mori_Bot / Mori_Technical_RAGwithFAISS.py
tecuhtli's picture
Update Mori_Technical_RAGwithFAISS.py
4302827 verified
# -*- coding: utf-8 -*-
"""
Inferencia RAG para Mori usando FAISS + E5 (multilingual).
Este módulo asume que:
- Ya existe un índice FAISS guardado (mori.faiss)
- Ya existe un archivo de metadatos (mori_metas.json)
Solo expone la función:
retrieve_docs(query: str, k: int = 3) -> list[dict]
"""
import json
from pathlib import Path
import os
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
#***************************************************************************
#Setting up variables
#***************************************************************************
# Token privado desde variable de entorno
HF_TOKEN = os.environ.get("HF_TOKEN")
#***************************************************************************
#Loading FAISS Vec DB
#***************************************************************************
DATASET_REPO_ID = "tecuhtli/Mori_FAISS_Full"
# 🔹 Nombres de archivo dentro del dataset
FAISS_FILENAME = "mori.faiss" # <-- AJUSTA SI TU ARCHIVO SE LLAMA DIFERENTE
METAS_FILENAME = "mori_metas.json" # idem
model_name = "intfloat/multilingual-e5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Lazy loading (solo en la primera llamada a retrieve_docs)
_rag_model = None
_rag_index = None
_rag_metas = None
def _ensure_rag_loaded(verbose: bool = False):
"""
Carga perezosa (lazy) del modelo de embeddings, índice FAISS y metadatos.
Solo se ejecuta la primera vez.
"""
global _rag_model, _rag_index, _rag_metas
# Modelo de embeddings
if _rag_model is None:
if verbose:
print(f"[*] Cargando modelo RAG ({model_name}) en {DEVICE}…")
_rag_model = SentenceTransformer(model_name, device=DEVICE)
# Índice FAISS
if _rag_index is None:
if verbose:
print(f"[*] Descargando índice FAISS desde dataset: {DATASET_REPO_ID}/{FAISS_FILENAME}…")
# Descarga el archivo al caché local de HF y devuelve la ruta
faiss_local_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
filename=FAISS_FILENAME,
token=HF_TOKEN, # ← NECESARIO para repos privados
)
if verbose:
print(f"[*] Leyendo índice FAISS desde {faiss_local_path}…")
_rag_index = faiss.read_index(str(faiss_local_path))
# Metadatos
if _rag_metas is None:
if verbose:
print(f"[*] Descargando metadatos desde dataset: {DATASET_REPO_ID}/{METAS_FILENAME}…")
metas_local_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
filename=METAS_FILENAME,
token=HF_TOKEN,
)
if verbose:
print(f"[*] Leyendo metadatos desde {metas_local_path}…")
with open(metas_local_path, "r", encoding="utf-8") as f:
_rag_metas = json.load(f)
def retrieve_docs(query: str, k: int = 3, verbose: bool = False):
"""
Recupera los k documentos más cercanos para una query dada.
Devuelve una lista de dicts con:
- score
- id
- canonical_term
- context
- input
- output
- question_type
- version
- encoder
"""
_ensure_rag_loaded(verbose)
# E5 usa el prefijo "query: " para consultas
qtext = f"query: {query}"
q_emb = _rag_model.encode(
[qtext],
normalize_embeddings=True,
convert_to_numpy=True
).astype("float32")
scores, idxs = _rag_index.search(q_emb, k)
results = []
for s, i in zip(scores[0], idxs[0]):
if i == -1:
continue
m = _rag_metas[i]
results.append({
"score": float(s),
**m
})
return results
if __name__ == "__main__":
# Pequeña prueba manual
qs = "¿Para qué sirve un isolation forest?"
docs = retrieve_docs(qs, k=3, verbose=True)
for d in docs:
print(f"[score={d['score']:.3f}] {d['input']} -> {d['output']}")