Spaces:
Sleeping
Sleeping
File size: 6,948 Bytes
e679b32 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
# 📁 Ficheiro: rag_engine.py
# 👤 Autor: Duarte Grilo – 2201320
# 🧠 Projeto Final – Engenharia Informática
# 📚 Função: Motor de RAG (Retrieval-Augmented Generation)
# Este módulo gere a criação da base vetorial, recuperação semântica de documentos,
# rerank com CrossEncoder, limpeza/tradução do contexto e entrega ao gerador Velvet.
# 📦 Importações de bibliotecas externas e internas
from langchain_chroma import Chroma
from sentence_transformers import CrossEncoder
import unicodedata
from langchain_huggingface import HuggingFaceEmbeddings
from transformers import AutoTokenizer
from langchain_community.document_loaders import DirectoryLoader, TextLoader, PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import re
from config import CHROMA_PATH, EMBEDDING_MODEL, K_SIMILARITY_SEARCH
from models.tradutor_local import traduzir
from config import CHUNK_SIZE, CHUNK_OVERLAP
# 🔧 Inicialização de modelos globais de embeddings e reranker
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL)
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") # usado para rerank dos documentos
# 🧪 Função auxiliar para testar se a tradução está funcional
def testar_traducao():
try:
print("PT ➜ EN:", traduzir("O que é a UML?", origem="pt", destino="en"))
print("EN ➜ PT:", traduzir("What is UML?", origem="en", destino="pt"))
except Exception as erro_trad:
print("⚠️ Erro ao testar tradução:", erro_trad)
# 📦 Aceder à base vetorial Chroma já persistida
def get_chroma_db():
return Chroma(persist_directory=CHROMA_PATH, embedding_function=embeddings)
# 📂 Carrega documentos (TXT ou PDF) de um diretório especificado
def carregar_documentos_por_tipo(path, extensao, loader_cls):
loader = DirectoryLoader(path, glob=f"**/*.{extensao}", loader_cls=loader_cls)
docs = loader.load()
print(f"📄 Documentos .{extensao} encontrados: {len(docs)}")
return docs
# 🏗️ Constrói a base vetorial Chroma a partir dos documentos
def build_chroma_db(documents_path="data/documents"):
print("🚀 A iniciar construção da base de embeddings...")
# Carregar documentos de texto e PDF
docs = (
carregar_documentos_por_tipo(documents_path, "txt", TextLoader) +
carregar_documentos_por_tipo(documents_path, "pdf", lambda path: PyPDFLoader(path))
)
if not docs:
print("❌ Nenhum documento foi encontrado.")
return
# Quebrar documentos em chunks com sobreposição
splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)
chunks = splitter.split_documents(docs)
# Remover duplicados
unique_chunks = list({chunk.page_content: chunk for chunk in chunks}.values())
try:
db = Chroma.from_documents(unique_chunks, embeddings, persist_directory=CHROMA_PATH)
db.persist()
print(f"✅ Embeddings salvos em: {CHROMA_PATH}")
except Exception as erro_db:
print("❌ Erro ao criar a base ChromaDB:", erro_db)
# 🧹 Normaliza o texto: remove acentos, limita tamanho e limpa quebras de linha
def normalize_text(text: str, max_length: int = 1000) -> str:
text = unicodedata.normalize("NFKD", text).encode("ASCII", "ignore").decode("ASCII")
return text[:max_length].replace("\n", " ")
# 🧽 Remove ruído semântico do contexto: números soltos, palavras irrelevantes, etc.
def limpar_ruido_contexto(texto: str) -> str:
linhas = texto.splitlines()
filtradas = [
linha.strip()
for linha in linhas
if len(linha.strip()) > 30 # ignora linhas curtas
and not re.search(r"(FIGURA|FONTE|CONSIDERANDO|EXERC[ÍI]CIO|PAULO COELHO|INTRODUÇÃO)", linha, re.IGNORECASE)
and not re.fullmatch(r"\s*\d+\s*", linha) # ignora números sozinhos
and not re.search(r"\b\d{1,2}\b", linha) # ignora datas ou numerações soltas
]
texto_limpo = " ".join(filtradas)
return re.sub(r"\s{2,}", " ", texto_limpo).strip()
# 🔍 Função principal: recebe a query e retorna o melhor contexto (PT ou EN)
def retrieve_context(query: str, k: int = K_SIMILARITY_SEARCH, return_scores: bool = False, return_raw: bool = False):
debug_ctx = {}
# 🔍 Traduz a pergunta de PT para EN
try:
query_en = traduzir(query, origem="pt", destino="en")
debug_ctx["query_en"] = query_en
except Exception as erro_tq:
print("❌ Erro ao traduzir a query:", erro_tq)
query_en = query # fallback: usa a query original
debug_ctx["query_en"] = query_en
# 🔍 Pesquisa por similaridade no ChromaDB
db = get_chroma_db()
context_docs = db.similarity_search(query_en, k=k)
if not context_docs:
contexto_vazio = "Sem contexto relevante encontrado."
if return_scores and return_raw:
return contexto_vazio, [], "", "", debug_ctx
elif return_scores:
return contexto_vazio, []
else:
return contexto_vazio
# 🧠 Rerank com CrossEncoder
pairs = [(query_en, doc.page_content) for doc in context_docs]
scores = reranker.predict(pairs)
reranked = sorted(zip(context_docs, scores), key=lambda x: x[1], reverse=True)
# 🔹 Juntar os k melhores docs (PT)
top_docs = [doc.page_content for doc, _ in reranked[:k]]
context_pt = "\n\n".join(top_docs)
debug_ctx["context_pt"] = context_pt
# 🔁 Traduzir o contexto para EN (limitado a 3000 chars para evitar erro do modelo)
try:
context_pt_truncado = context_pt[:2000] # Limita para evitar erro de tokens (>512)
context_en = traduzir(context_pt_truncado, origem="pt", destino="en")
debug_ctx["context_en"] = context_en
except Exception as erro_ctx:
print("⚠️ Erro ao traduzir o contexto para EN:", erro_ctx)
context_en = context_pt[:500] # fallback seguro
debug_ctx["context_en"] = context_en
# 🔧 Limpar e normalizar (apenas para debug/métricas)
texto_limpo = limpar_ruido_contexto(context_pt)
texto_normalizado = normalize_text(texto_limpo[:800]) # corta para evitar excesso
debug_ctx["normalizado"] = texto_normalizado
# ✅ Retornar com os diferentes modos
if return_scores and return_raw:
return context_pt, reranked[:k], context_en, texto_normalizado, debug_ctx
elif return_scores:
return context_pt, reranked[:k]
else:
return context_pt
# 🧪 Execução direta para construir ou testar a base Chroma
if __name__ == "__main__":
build_chroma_db()
db = get_chroma_db()
print("🔁 Teste de busca:", db.similarity_search("Teste de contexto", k=1))
|