File size: 4,119 Bytes
1aa136d
 
 
 
 
 
 
 
 
 
 
 
 
 
4302827
1aa136d
 
 
 
903d5ad
1aa136d
4302827
 
 
 
 
 
903d5ad
 
 
 
1aa136d
903d5ad
 
 
 
 
 
1aa136d
 
 
 
 
 
 
 
 
 
 
 
 
 
903d5ad
1aa136d
 
 
 
 
903d5ad
1aa136d
 
903d5ad
 
 
 
 
 
 
4302827
903d5ad
 
 
 
 
1aa136d
903d5ad
1aa136d
 
903d5ad
 
 
 
 
 
4302827
903d5ad
 
 
 
 
1aa136d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903d5ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# -*- coding: utf-8 -*-
"""
Inferencia RAG para Mori usando FAISS + E5 (multilingual).

Este módulo asume que:
- Ya existe un índice FAISS guardado (mori.faiss)
- Ya existe un archivo de metadatos (mori_metas.json)

Solo expone la función:
    retrieve_docs(query: str, k: int = 3) -> list[dict]
"""

import json
from pathlib import Path
import os
import faiss
import numpy as np
import torch
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download

#***************************************************************************
#Setting up variables
#***************************************************************************
# Token privado desde variable de entorno
HF_TOKEN = os.environ.get("HF_TOKEN")

#***************************************************************************
#Loading FAISS Vec DB
#***************************************************************************
DATASET_REPO_ID = "tecuhtli/Mori_FAISS_Full"

# 🔹 Nombres de archivo dentro del dataset
FAISS_FILENAME = "mori.faiss"         # <-- AJUSTA SI TU ARCHIVO SE LLAMA DIFERENTE
METAS_FILENAME = "mori_metas.json"    # idem

model_name = "intfloat/multilingual-e5-base"
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

# Lazy loading (solo en la primera llamada a retrieve_docs)
_rag_model = None
_rag_index = None
_rag_metas = None


def _ensure_rag_loaded(verbose: bool = False):
    """
    Carga perezosa (lazy) del modelo de embeddings, índice FAISS y metadatos.
    Solo se ejecuta la primera vez.
    """
    global _rag_model, _rag_index, _rag_metas

    # Modelo de embeddings
    if _rag_model is None:
        if verbose:
            print(f"[*] Cargando modelo RAG ({model_name}) en {DEVICE}…")
        _rag_model = SentenceTransformer(model_name, device=DEVICE)

    # Índice FAISS
    if _rag_index is None:
        if verbose:
            print(f"[*] Descargando índice FAISS desde dataset: {DATASET_REPO_ID}/{FAISS_FILENAME}…")

        # Descarga el archivo al caché local de HF y devuelve la ruta
        faiss_local_path = hf_hub_download(
            repo_id=DATASET_REPO_ID,
            repo_type="dataset",
            filename=FAISS_FILENAME,
            token=HF_TOKEN,   # ← NECESARIO para repos privados
        )

        if verbose:
            print(f"[*] Leyendo índice FAISS desde {faiss_local_path}…")
        _rag_index = faiss.read_index(str(faiss_local_path))

    # Metadatos
    if _rag_metas is None:
        if verbose:
            print(f"[*] Descargando metadatos desde dataset: {DATASET_REPO_ID}/{METAS_FILENAME}…")

        metas_local_path = hf_hub_download(
            repo_id=DATASET_REPO_ID,
            repo_type="dataset",
            filename=METAS_FILENAME,
            token=HF_TOKEN,
        )

        if verbose:
            print(f"[*] Leyendo metadatos desde {metas_local_path}…")
        with open(metas_local_path, "r", encoding="utf-8") as f:
            _rag_metas = json.load(f)


def retrieve_docs(query: str, k: int = 3, verbose: bool = False):
    """
    Recupera los k documentos más cercanos para una query dada.

    Devuelve una lista de dicts con:
        - score
        - id
        - canonical_term
        - context
        - input
        - output
        - question_type
        - version
        - encoder
    """
    _ensure_rag_loaded(verbose)

    # E5 usa el prefijo "query: " para consultas
    qtext = f"query: {query}"
    q_emb = _rag_model.encode(
        [qtext],
        normalize_embeddings=True,
        convert_to_numpy=True
    ).astype("float32")

    scores, idxs = _rag_index.search(q_emb, k)

    results = []
    for s, i in zip(scores[0], idxs[0]):
        if i == -1:
            continue
        m = _rag_metas[i]
        results.append({
            "score": float(s),
            **m
        })
    return results


if __name__ == "__main__":
    # Pequeña prueba manual
    qs = "¿Para qué sirve un isolation forest?"
    docs = retrieve_docs(qs, k=3, verbose=True)
    for d in docs:
        print(f"[score={d['score']:.3f}] {d['input']} -> {d['output']}")