File size: 4,119 Bytes
1aa136d 4302827 1aa136d 903d5ad 1aa136d 4302827 903d5ad 1aa136d 903d5ad 1aa136d 903d5ad 1aa136d 903d5ad 1aa136d 903d5ad 4302827 903d5ad 1aa136d 903d5ad 1aa136d 903d5ad 4302827 903d5ad 1aa136d 903d5ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# -*- coding: utf-8 -*-
"""
Inferencia RAG para Mori usando FAISS + E5 (multilingual).
Este módulo asume que:
- Ya existe un índice FAISS guardado (mori.faiss)
- Ya existe un archivo de metadatos (mori_metas.json)
Solo expone la función:
retrieve_docs(query: str, k: int = 3) -> list[dict]
"""
import json
from pathlib import Path
import os
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
#***************************************************************************
#Setting up variables
#***************************************************************************
# Token privado desde variable de entorno
HF_TOKEN = os.environ.get("HF_TOKEN")
#***************************************************************************
#Loading FAISS Vec DB
#***************************************************************************
DATASET_REPO_ID = "tecuhtli/Mori_FAISS_Full"
# 🔹 Nombres de archivo dentro del dataset
FAISS_FILENAME = "mori.faiss" # <-- AJUSTA SI TU ARCHIVO SE LLAMA DIFERENTE
METAS_FILENAME = "mori_metas.json" # idem
model_name = "intfloat/multilingual-e5-base"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Lazy loading (solo en la primera llamada a retrieve_docs)
_rag_model = None
_rag_index = None
_rag_metas = None
def _ensure_rag_loaded(verbose: bool = False):
"""
Carga perezosa (lazy) del modelo de embeddings, índice FAISS y metadatos.
Solo se ejecuta la primera vez.
"""
global _rag_model, _rag_index, _rag_metas
# Modelo de embeddings
if _rag_model is None:
if verbose:
print(f"[*] Cargando modelo RAG ({model_name}) en {DEVICE}…")
_rag_model = SentenceTransformer(model_name, device=DEVICE)
# Índice FAISS
if _rag_index is None:
if verbose:
print(f"[*] Descargando índice FAISS desde dataset: {DATASET_REPO_ID}/{FAISS_FILENAME}…")
# Descarga el archivo al caché local de HF y devuelve la ruta
faiss_local_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
filename=FAISS_FILENAME,
token=HF_TOKEN, # ← NECESARIO para repos privados
)
if verbose:
print(f"[*] Leyendo índice FAISS desde {faiss_local_path}…")
_rag_index = faiss.read_index(str(faiss_local_path))
# Metadatos
if _rag_metas is None:
if verbose:
print(f"[*] Descargando metadatos desde dataset: {DATASET_REPO_ID}/{METAS_FILENAME}…")
metas_local_path = hf_hub_download(
repo_id=DATASET_REPO_ID,
repo_type="dataset",
filename=METAS_FILENAME,
token=HF_TOKEN,
)
if verbose:
print(f"[*] Leyendo metadatos desde {metas_local_path}…")
with open(metas_local_path, "r", encoding="utf-8") as f:
_rag_metas = json.load(f)
def retrieve_docs(query: str, k: int = 3, verbose: bool = False):
"""
Recupera los k documentos más cercanos para una query dada.
Devuelve una lista de dicts con:
- score
- id
- canonical_term
- context
- input
- output
- question_type
- version
- encoder
"""
_ensure_rag_loaded(verbose)
# E5 usa el prefijo "query: " para consultas
qtext = f"query: {query}"
q_emb = _rag_model.encode(
[qtext],
normalize_embeddings=True,
convert_to_numpy=True
).astype("float32")
scores, idxs = _rag_index.search(q_emb, k)
results = []
for s, i in zip(scores[0], idxs[0]):
if i == -1:
continue
m = _rag_metas[i]
results.append({
"score": float(s),
**m
})
return results
if __name__ == "__main__":
# Pequeña prueba manual
qs = "¿Para qué sirve un isolation forest?"
docs = retrieve_docs(qs, k=3, verbose=True)
for d in docs:
print(f"[score={d['score']:.3f}] {d['input']} -> {d['output']}") |