Spaces:
Sleeping
Sleeping
| # 📁 Ficheiro: rag_engine.py | |
| # 👤 Autor: Duarte Grilo – 2201320 | |
| # 🧠 Projeto Final – Engenharia Informática | |
| # 📚 Função: Motor de RAG (Retrieval-Augmented Generation) | |
| # Este módulo gere a criação da base vetorial, recuperação semântica de documentos, | |
| # rerank com CrossEncoder, limpeza/tradução do contexto e entrega ao gerador Velvet. | |
| # 📦 Importações de bibliotecas externas e internas | |
| from langchain_chroma import Chroma | |
| from sentence_transformers import CrossEncoder | |
| import unicodedata | |
| from langchain_huggingface import HuggingFaceEmbeddings | |
| from transformers import AutoTokenizer | |
| from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| import re | |
| from config import CHROMA_PATH, EMBEDDING_MODEL, K_SIMILARITY_SEARCH | |
| from models.tradutor_local import traduzir | |
| from config import CHUNK_SIZE, CHUNK_OVERLAP | |
| # 🔧 Inicialização de modelos globais de embeddings e reranker | |
| embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL) | |
| tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL) | |
| reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # usado para rerank dos documentos | |
| # 🧪 Função auxiliar para testar se a tradução está funcional | |
| def testar_traducao(): | |
| try: | |
| print("PT ➜ EN:", traduzir("O que é a UML?", origem="pt", destino="en")) | |
| print("EN ➜ PT:", traduzir("What is UML?", origem="en", destino="pt")) | |
| except Exception as erro_trad: | |
| print("⚠️ Erro ao testar tradução:", erro_trad) | |
| # 📦 Aceder à base vetorial Chroma já persistida | |
| def get_chroma_db(): | |
| return Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings) | |
| # 📂 Carrega documentos (TXT ou PDF) de um diretório especificado | |
| def carregar_documentos_por_tipo(path, extensao, loader_cls): | |
| loader = DirectoryLoader(path, glob=f"**/*.{extensao}", loader_cls=loader_cls) | |
| docs = loader.load() | |
| print(f"📄 Documentos .{extensao} encontrados: {len(docs)}") | |
| return docs | |
| # 🏗️ Constrói a base vetorial Chroma a partir dos documentos | |
| def build_chroma_db(documents_path="data/documents"): | |
| print("🚀 A iniciar construção da base de embeddings...") | |
| # Carregar documentos de texto e PDF | |
| docs = ( | |
| carregar_documentos_por_tipo(documents_path, "txt", TextLoader) + | |
| carregar_documentos_por_tipo(documents_path, "pdf", lambda path: PyPDFLoader(path)) | |
| ) | |
| if not docs: | |
| print("❌ Nenhum documento foi encontrado.") | |
| return | |
| # Quebrar documentos em chunks com sobreposição | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP) | |
| chunks = splitter.split_documents(docs) | |
| # Remover duplicados | |
| unique_chunks = list({chunk.page_content: chunk for chunk in chunks}.values()) | |
| try: | |
| db = Chroma.from_documents(unique_chunks, embeddings, persist_directory=CHROMA_PATH) | |
| db.persist() | |
| print(f"✅ Embeddings salvos em: {CHROMA_PATH}") | |
| except Exception as erro_db: | |
| print("❌ Erro ao criar a base ChromaDB:", erro_db) | |
| # 🧹 Normaliza o texto: remove acentos, limita tamanho e limpa quebras de linha | |
| def normalize_text(text: str, max_length: int = 1000) -> str: | |
| text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII") | |
| return text[:max_length].replace("\n", " ") | |
| # 🧽 Remove ruído semântico do contexto: números soltos, palavras irrelevantes, etc. | |
| def limpar_ruido_contexto(texto: str) -> str: | |
| linhas = texto.splitlines() | |
| filtradas = [ | |
| linha.strip() | |
| for linha in linhas | |
| if len(linha.strip()) > 30 # ignora linhas curtas | |
| and not re.search(r"(FIGURA|FONTE|CONSIDERANDO|EXERC[ÍI]CIO|PAULO COELHO|INTRODUÇÃO)", linha, re.IGNORECASE) | |
| and not re.fullmatch(r"\s*\d+\s*", linha) # ignora números sozinhos | |
| and not re.search(r"\b\d{1,2}\b", linha) # ignora datas ou numerações soltas | |
| ] | |
| texto_limpo = " ".join(filtradas) | |
| return re.sub(r"\s{2,}", " ", texto_limpo).strip() | |
| # 🔍 Função principal: recebe a query e retorna o melhor contexto (PT ou EN) | |
| def retrieve_context(query: str, k: int = K_SIMILARITY_SEARCH, return_scores: bool = False, return_raw: bool = False): | |
| debug_ctx = {} | |
| # 🔍 Traduz a pergunta de PT para EN | |
| try: | |
| query_en = traduzir(query, origem="pt", destino="en") | |
| debug_ctx["query_en"] = query_en | |
| except Exception as erro_tq: | |
| print("❌ Erro ao traduzir a query:", erro_tq) | |
| query_en = query # fallback: usa a query original | |
| debug_ctx["query_en"] = query_en | |
| # 🔍 Pesquisa por similaridade no ChromaDB | |
| db = get_chroma_db() | |
| context_docs = db.similarity_search(query_en, k=k) | |
| if not context_docs: | |
| contexto_vazio = "Sem contexto relevante encontrado." | |
| if return_scores and return_raw: | |
| return contexto_vazio, [], "", "", debug_ctx | |
| elif return_scores: | |
| return contexto_vazio, [] | |
| else: | |
| return contexto_vazio | |
| # 🧠 Rerank com CrossEncoder | |
| pairs = [(query_en, doc.page_content) for doc in context_docs] | |
| scores = reranker.predict(pairs) | |
| reranked = sorted(zip(context_docs, scores), key=lambda x: x[1], reverse=True) | |
| # 🔹 Juntar os k melhores docs (PT) | |
| top_docs = [doc.page_content for doc, _ in reranked[:k]] | |
| context_pt = "\n\n".join(top_docs) | |
| debug_ctx["context_pt"] = context_pt | |
| # 🔁 Traduzir o contexto para EN (limitado a 3000 chars para evitar erro do modelo) | |
| try: | |
| context_pt_truncado = context_pt[:2000] # Limita para evitar erro de tokens (>512) | |
| context_en = traduzir(context_pt_truncado, origem="pt", destino="en") | |
| debug_ctx["context_en"] = context_en | |
| except Exception as erro_ctx: | |
| print("⚠️ Erro ao traduzir o contexto para EN:", erro_ctx) | |
| context_en = context_pt[:500] # fallback seguro | |
| debug_ctx["context_en"] = context_en | |
| # 🔧 Limpar e normalizar (apenas para debug/métricas) | |
| texto_limpo = limpar_ruido_contexto(context_pt) | |
| texto_normalizado = normalize_text(texto_limpo[:800]) # corta para evitar excesso | |
| debug_ctx["normalizado"] = texto_normalizado | |
| # ✅ Retornar com os diferentes modos | |
| if return_scores and return_raw: | |
| return context_pt, reranked[:k], context_en, texto_normalizado, debug_ctx | |
| elif return_scores: | |
| return context_pt, reranked[:k] | |
| else: | |
| return context_pt | |
| # 🧪 Execução direta para construir ou testar a base Chroma | |
| if __name__ == "__main__": | |
| build_chroma_db() | |
| db = get_chroma_db() | |
| print("🔁 Teste de busca:", db.similarity_search("Teste de contexto", k=1)) | |