import os import json from typing import List, Dict, Any import faiss import numpy as np #from elasticsearch import Elasticsearch from . import base_utils as bu class BaseRetriever: """Interface base para mecanismos de recuperação. A ideia é permitir trocar FAISS por Elasticsearch (ou outro backend) sem mudar o restante da aplicação. Cada implementação deve expor um método `retrieve` que recebe um vetor de consulta (1 x D) e devolve uma lista de metadados de trechos no formato já usado pelo sistema. """ def retrieve(self, query_embedding: np.ndarray, top_k: int) -> List[Dict[str, Any]]: raise NotImplementedError def _load_index_and_metadata_from_config(config: dict): """Carrega índice FAISS e metadata consolidada a partir da config. Mantém a mesma lógica que antes existia em `app/api_server.py`, mas centralizada aqui para poder ser reutilizada por diferentes backends. """ index_path = config["index"].get("index_file", "data/index/faiss.index") metadata_path = config["index"].get("metadata_file", "data/index/metadata.jsonl") if not os.path.exists(index_path) or not os.path.exists(metadata_path): raise FileNotFoundError( "Index or metadata not found. Run scripts/build_index.py first." ) index = faiss.read_index(index_path) metadata: List[Dict[str, Any]] = [] with open(metadata_path, "r", encoding="utf-8") as f: for line in f: if line.strip(): metadata.append(json.loads(line)) return index, metadata class FaissRetriever(BaseRetriever): """Retriever baseado em índice FAISS local. Usa `data/index/faiss.index` e `data/index/metadata.jsonl`, gerados pelos scripts existentes (generate_embeddings + build_index). """ def __init__(self, config: dict) -> None: self.config = config self.index, self.metadata = _load_index_and_metadata_from_config(config) # Mapa de idx global -> metadado, para lookup rápido durante a busca self._meta_by_idx: Dict[int, Dict[str, Any]] = {} for m in self.metadata: idx = m.get("idx") if idx is not None: # Usamos uma cópia simples; o chamador pode depois copiar novamente self._meta_by_idx[int(idx)] = m def retrieve(self, query_embedding: np.ndarray, top_k: int) -> List[Dict[str, Any]]: """Busca vetorial usando FAISS e devolve metadados dos trechos.""" if query_embedding.ndim != 2: raise ValueError("query_embedding must be a 2D array of shape (1, D)") # Busca em FAISS (mesma lógica anterior) scores, indices = self.index.search(query_embedding, top_k) idxs = indices[0].tolist() retrieved: List[Dict[str, Any]] = [] for i in idxs: m = self._meta_by_idx.get(int(i)) if m is not None: item = dict(m) # copiar para não vazar referência mutável # Garantir chaves esperadas para referências item.setdefault("document_authors", []) item.setdefault("publication_year", None) item.setdefault("publication_date", None) retrieved.append(item) return retrieved def list_documents(self) -> List[Dict[str, str]]: """Lista documentos únicos (id + título) com base na metadata carregada.""" docs: Dict[str, str] = {} for m in self.metadata: doc_id = m.get("document_id") if not doc_id: continue titulo = m.get("document_title") or doc_id if doc_id not in docs: docs[doc_id] = titulo documentos_ordenados = [ {"id": doc_id, "title": docs[doc_id]} for doc_id in sorted(docs, key=lambda d: docs[d].lower()) ] return documentos_ordenados def get_retriever(config: dict) -> BaseRetriever: """ Fábrica simples para escolher o backend de recuperação. """ index_type = config.get("index", {}).get("type", "faiss").lower() if index_type == "faiss": return FaissRetriever(config) if index_type == "elasticsearch": return ElasticRetriever(config) # Placeholder para futuras implementações. raise ValueError(f"Index backend '{index_type}' not supported. Use 'faiss' or 'elasticsearch'.") class ElasticRetriever(BaseRetriever): """ Retriever baseado em Elasticsearch (vector search). """ def __init__(self, config: dict) -> None: self.config = config idx_cfg = config.get("index", {}) self.host = idx_cfg.get("host", "http://localhost:9200") self.index_name = idx_cfg.get("index_name", "chatbot-norm") self.vector_field = idx_cfg.get("vector_field", "embedding") self.api_key = idx_cfg.get("api_key") or os.getenv("ELASTIC_API_KEY") self.username = idx_cfg.get("username") self.password = idx_cfg.get("password") # Cliente Elasticsearch (prioriza API key, depois basic_auth, depois sem auth) if self.api_key: self.client = Elasticsearch(self.host, api_key=self.api_key) elif self.username and self.password: self.client = Elasticsearch(self.host, basic_auth=(self.username, self.password)) else: self.client = Elasticsearch(self.host) def retrieve(self, query_embedding: np.ndarray, top_k: int) -> List[Dict[str, Any]]: """Executa busca vetorial k-NN em Elasticsearch.""" if query_embedding.ndim != 2: raise ValueError("query_embedding must be a 2D array of shape (1, D)") query_vec = query_embedding[0].astype(float).tolist() num_candidates = max(top_k * 5, top_k) knn_body = { "field": self.vector_field, "query_vector": query_vec, "k": top_k, "num_candidates": num_candidates, } resp = self.client.search( index=self.index_name, knn=knn_body, size=top_k, _source=[ "idx", "document_id", "document_title", "document_authors", "publication_year", "publication_date", "fragment_id", "content", ], ) hits = resp.get("hits", {}).get("hits", []) retrieved: List[Dict[str, Any]] = [] for h in hits: src = h.get("_source", {}) retrieved.append( { "idx": src.get("idx"), "document_id": src.get("document_id"), "document_title": src.get("document_title"), "document_authors": src.get("document_authors"), "publication_year": src.get("publication_year"), "publication_date": src.get("publication_date"), "fragment_id": src.get("fragment_id"), "content": src.get("content"), } ) return retrieved def list_documents(self) -> List[Dict[str, str]]: """Lista documentos únicos (id + título) a partir do índice ES. Implementação simples via `match_all` limitada a 10k documentos. Para bases muito maiores, seria melhor usar scroll / search_after. """ docs: Dict[str, str] = {} resp = self.client.search( index=self.index_name, query={"match_all": {}}, size=10000, _source=["document_id", "document_title"], ) for h in resp.get("hits", {}).get("hits", []): src = h.get("_source", {}) doc_id = src.get("document_id") if not doc_id: continue titulo = src.get("document_title") or doc_id if doc_id not in docs: docs[doc_id] = titulo documentos_ordenados = [ {"id": doc_id, "title": docs[doc_id]} for doc_id in sorted(docs, key=lambda d: docs[d].lower()) ] return documentos_ordenados