beta-NORM / utils /retrievers.py
GitHub Actions
Sync from GitHub master
92145af
import os
import json
from typing import List, Dict, Any
import faiss
import numpy as np
#from elasticsearch import Elasticsearch
from . import base_utils as bu
class BaseRetriever:
"""Interface base para mecanismos de recuperação.
A ideia é permitir trocar FAISS por Elasticsearch (ou outro backend)
sem mudar o restante da aplicação. Cada implementação deve expor um
método `retrieve` que recebe um vetor de consulta (1 x D) e devolve
uma lista de metadados de trechos no formato já usado pelo sistema.
"""
def retrieve(self, query_embedding: np.ndarray, top_k: int) -> List[Dict[str, Any]]:
raise NotImplementedError
def _load_index_and_metadata_from_config(config: dict):
"""Carrega índice FAISS e metadata consolidada a partir da config.
Mantém a mesma lógica que antes existia em `app/api_server.py`, mas
centralizada aqui para poder ser reutilizada por diferentes backends.
"""
index_path = config["index"].get("index_file", "data/index/faiss.index")
metadata_path = config["index"].get("metadata_file", "data/index/metadata.jsonl")
if not os.path.exists(index_path) or not os.path.exists(metadata_path):
raise FileNotFoundError(
"Index or metadata not found. Run scripts/build_index.py first."
)
index = faiss.read_index(index_path)
metadata: List[Dict[str, Any]] = []
with open(metadata_path, "r", encoding="utf-8") as f:
for line in f:
if line.strip():
metadata.append(json.loads(line))
return index, metadata
class FaissRetriever(BaseRetriever):
"""Retriever baseado em índice FAISS local.
Usa `data/index/faiss.index` e `data/index/metadata.jsonl`, gerados
pelos scripts existentes (generate_embeddings + build_index).
"""
def __init__(self, config: dict) -> None:
self.config = config
self.index, self.metadata = _load_index_and_metadata_from_config(config)
# Mapa de idx global -> metadado, para lookup rápido durante a busca
self._meta_by_idx: Dict[int, Dict[str, Any]] = {}
for m in self.metadata:
idx = m.get("idx")
if idx is not None:
# Usamos uma cópia simples; o chamador pode depois copiar novamente
self._meta_by_idx[int(idx)] = m
def retrieve(self, query_embedding: np.ndarray, top_k: int) -> List[Dict[str, Any]]:
"""Busca vetorial usando FAISS e devolve metadados dos trechos."""
if query_embedding.ndim != 2:
raise ValueError("query_embedding must be a 2D array of shape (1, D)")
# Busca em FAISS (mesma lógica anterior)
scores, indices = self.index.search(query_embedding, top_k)
idxs = indices[0].tolist()
retrieved: List[Dict[str, Any]] = []
for i in idxs:
m = self._meta_by_idx.get(int(i))
if m is not None:
item = dict(m) # copiar para não vazar referência mutável
# Garantir chaves esperadas para referências
item.setdefault("document_authors", [])
item.setdefault("publication_year", None)
item.setdefault("publication_date", None)
retrieved.append(item)
return retrieved
def list_documents(self) -> List[Dict[str, str]]:
"""Lista documentos únicos (id + título) com base na metadata carregada."""
docs: Dict[str, str] = {}
for m in self.metadata:
doc_id = m.get("document_id")
if not doc_id:
continue
titulo = m.get("document_title") or doc_id
if doc_id not in docs:
docs[doc_id] = titulo
documentos_ordenados = [
{"id": doc_id, "title": docs[doc_id]}
for doc_id in sorted(docs, key=lambda d: docs[d].lower())
]
return documentos_ordenados
def get_retriever(config: dict) -> BaseRetriever:
"""
Fábrica simples para escolher o backend de recuperação.
"""
index_type = config.get("index", {}).get("type", "faiss").lower()
if index_type == "faiss":
return FaissRetriever(config)
if index_type == "elasticsearch":
return ElasticRetriever(config)
# Placeholder para futuras implementações.
raise ValueError(f"Index backend '{index_type}' not supported. Use 'faiss' or 'elasticsearch'.")
class ElasticRetriever(BaseRetriever):
"""
Retriever baseado em Elasticsearch (vector search).
"""
def __init__(self, config: dict) -> None:
self.config = config
idx_cfg = config.get("index", {})
self.host = idx_cfg.get("host", "http://localhost:9200")
self.index_name = idx_cfg.get("index_name", "chatbot-norm")
self.vector_field = idx_cfg.get("vector_field", "embedding")
self.api_key = idx_cfg.get("api_key") or os.getenv("ELASTIC_API_KEY")
self.username = idx_cfg.get("username")
self.password = idx_cfg.get("password")
# Cliente Elasticsearch (prioriza API key, depois basic_auth, depois sem auth)
if self.api_key:
self.client = Elasticsearch(self.host, api_key=self.api_key)
elif self.username and self.password:
self.client = Elasticsearch(self.host, basic_auth=(self.username, self.password))
else:
self.client = Elasticsearch(self.host)
def retrieve(self, query_embedding: np.ndarray, top_k: int) -> List[Dict[str, Any]]:
"""Executa busca vetorial k-NN em Elasticsearch."""
if query_embedding.ndim != 2:
raise ValueError("query_embedding must be a 2D array of shape (1, D)")
query_vec = query_embedding[0].astype(float).tolist()
num_candidates = max(top_k * 5, top_k)
knn_body = {
"field": self.vector_field,
"query_vector": query_vec,
"k": top_k,
"num_candidates": num_candidates,
}
resp = self.client.search(
index=self.index_name,
knn=knn_body,
size=top_k,
_source=[
"idx",
"document_id",
"document_title",
"document_authors",
"publication_year",
"publication_date",
"fragment_id",
"content",
],
)
hits = resp.get("hits", {}).get("hits", [])
retrieved: List[Dict[str, Any]] = []
for h in hits:
src = h.get("_source", {})
retrieved.append(
{
"idx": src.get("idx"),
"document_id": src.get("document_id"),
"document_title": src.get("document_title"),
"document_authors": src.get("document_authors"),
"publication_year": src.get("publication_year"),
"publication_date": src.get("publication_date"),
"fragment_id": src.get("fragment_id"),
"content": src.get("content"),
}
)
return retrieved
def list_documents(self) -> List[Dict[str, str]]:
"""Lista documentos únicos (id + título) a partir do índice ES.
Implementação simples via `match_all` limitada a 10k documentos.
Para bases muito maiores, seria melhor usar scroll / search_after.
"""
docs: Dict[str, str] = {}
resp = self.client.search(
index=self.index_name,
query={"match_all": {}},
size=10000,
_source=["document_id", "document_title"],
)
for h in resp.get("hits", {}).get("hits", []):
src = h.get("_source", {})
doc_id = src.get("document_id")
if not doc_id:
continue
titulo = src.get("document_title") or doc_id
if doc_id not in docs:
docs[doc_id] = titulo
documentos_ordenados = [
{"id": doc_id, "title": docs[doc_id]}
for doc_id in sorted(docs, key=lambda d: docs[d].lower())
]
return documentos_ordenados