Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| import yaml | |
| import logging | |
| from typing import List, Dict | |
| from langchain_community.embeddings import HuggingFaceEmbeddings | |
| from langchain_community.vectorstores import FAISS | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class QueryEngine: | |
| def __init__(self, config_path='config.yaml'): | |
| logger.info("Inicializando QueryEngine...") | |
| with open(config_path) as f: | |
| self.config = yaml.safe_load(f) | |
| model_name = self.config.get('embedding_model', 'sentence-transformers/all-MiniLM-L6-v2') | |
| logger.info(f"Modelo: {model_name}") | |
| self.embeddings = HuggingFaceEmbeddings( | |
| model_name=model_name, | |
| model_kwargs={'device': 'cpu'} | |
| ) | |
| faiss_path = self.config.get('faiss_path', '/app/faiss_index') | |
| logger.info(f"Carregando FAISS de: {faiss_path}") | |
| self.vectorstore = FAISS.load_local( | |
| faiss_path, | |
| self.embeddings, | |
| allow_dangerous_deserialization=True | |
| ) | |
| logger.info("✅ QueryEngine pronto!") | |
| def search_by_embedding(self, query: str, top_k: int = 10, return_embeddings: bool = False) -> Dict: | |
| results = self.vectorstore.similarity_search_with_score(query, k=top_k) | |
| formatted = [] | |
| for doc, score in results: | |
| formatted.append({ | |
| 'id': doc.metadata.get('id'), | |
| 'ementa': doc.page_content, | |
| 'score': float(score), | |
| 'metadata': doc.metadata | |
| }) | |
| return { | |
| 'cluster_id': self.config.get('cluster_id'), | |
| 'query': query, | |
| 'total_results': len(formatted), | |
| 'results': formatted | |
| } | |
| def search_by_keywords(self, keywords: List[str], operator: str = 'AND', top_k: int = 20) -> Dict: | |
| query = ' '.join(keywords) | |
| return self.search_by_embedding(query, top_k) | |
| def search_by_ids(self, ids: List[str], return_embeddings: bool = False) -> Dict: | |
| all_docs = self.vectorstore.similarity_search("", k=10000) | |
| results = [] | |
| for doc in all_docs: | |
| if doc.metadata.get('id') in ids: | |
| results.append({ | |
| 'id': doc.metadata.get('id'), | |
| 'ementa': doc.page_content, | |
| 'metadata': doc.metadata | |
| }) | |
| if len(results) >= len(ids): | |
| break | |
| return { | |
| 'cluster_id': self.config.get('cluster_id'), | |
| 'total_results': len(results), | |
| 'results': results | |
| } | |
| def get_cluster_info(self) -> Dict: | |
| return { | |
| 'cluster_id': self.config.get('cluster_id'), | |
| 'chunk_range': [self.config.get('chunk_start'), self.config.get('chunk_end')], | |
| 'embedding_model': self.config.get('embedding_model'), | |
| 'embedding_dim': 384, | |
| 'vector_store': 'FAISS', | |
| 'backend': 'LangChain + CPU', | |
| 'status': 'ready' | |
| } | |