rag / rag.py
FabioSantos's picture
Upload 4 files
2f8b4e1 verified
"""
rag.py — Lógica central do RAG (Retrieval-Augmented Generation)
Fluxo:
1. Indexação : documento → chunks → embeddings → armazenados em memória
2. Recuperação: pergunta → embedding → busca por similaridade (cosine)
3. Geração : contexto recuperado + pergunta → LLM (chat) → resposta
"""
import re
import math
from typing import List, Tuple
from huggingface_hub import InferenceClient
# ---------------------------------------------------------------------------
# Configurações
# ---------------------------------------------------------------------------
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
# Cerebras: rápido, gratuito via HF token, sem necessidade de aceitar termos
GENERATION_MODEL = "meta-llama/Llama-3.1-8B-Instruct"
GENERATION_PROVIDER = "cerebras"
# ---------------------------------------------------------------------------
# 1. CHUNKING
# ---------------------------------------------------------------------------
def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]:
"""Divide o texto em chunks com sobreposição para manter contexto."""
text = re.sub(r"\s+", " ", text).strip()
words = text.split()
chunks, start = [], 0
while start < len(words):
chunks.append(" ".join(words[start:start + chunk_size]))
start += chunk_size - overlap
return chunks
# ---------------------------------------------------------------------------
# 2. EMBEDDINGS — usa hf-inference (ótimo para embeddings)
# ---------------------------------------------------------------------------
def get_embeddings(texts: List[str], hf_token: str) -> List[List[float]]:
"""Gera embeddings via HF Inference API."""
client = InferenceClient(provider="hf-inference", token=hf_token)
embeddings = client.feature_extraction(texts, model=EMBEDDING_MODEL)
return [list(map(float, vec)) for vec in embeddings]
# ---------------------------------------------------------------------------
# 3. SIMILARIDADE
# ---------------------------------------------------------------------------
def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float:
dot = sum(a * b for a, b in zip(vec_a, vec_b))
mag_a = math.sqrt(sum(a * a for a in vec_a))
mag_b = math.sqrt(sum(b * b for b in vec_b))
return 0.0 if mag_a == 0 or mag_b == 0 else dot / (mag_a * mag_b)
def retrieve(
query_embedding: List[float],
chunk_embeddings: List[List[float]],
chunks: List[str],
top_k: int = 3,
) -> List[Tuple[str, float]]:
scores = [
(chunk, cosine_similarity(query_embedding, emb))
for chunk, emb in zip(chunks, chunk_embeddings)
]
scores.sort(key=lambda x: x[1], reverse=True)
return scores[:top_k]
# ---------------------------------------------------------------------------
# 4. GERAÇÃO — Cerebras (rápido, gratuito, sem restrições de aceite)
# ---------------------------------------------------------------------------
def generate_answer(
question: str,
context_chunks: List[Tuple[str, float]],
hf_token: str,
max_tokens: int = 512,
) -> str:
context = "\n\n".join(
f"[Trecho {i+1} | score={score:.2f}]\n{text}"
for i, (text, score) in enumerate(context_chunks)
)
system_msg = (
"Você é um assistente prestativo. "
"Responda à pergunta usando APENAS as informações fornecidas no contexto. "
"Se a resposta não estiver no contexto, diga: "
"'Não encontrei essa informação no documento.' "
"Responda sempre em português."
)
user_msg = f"### CONTEXTO:\n{context}\n\n### PERGUNTA:\n{question}"
client = InferenceClient(provider=GENERATION_PROVIDER, token=hf_token)
response = client.chat_completion(
model=GENERATION_MODEL,
messages=[
{"role": "system", "content": system_msg},
{"role": "user", "content": user_msg},
],
max_tokens=max_tokens,
temperature=0.3,
)
return response.choices[0].message.content.strip()
# ---------------------------------------------------------------------------
# 5. PIPELINE COMPLETO
# ---------------------------------------------------------------------------
class SimpleRAG:
"""
Encapsula todo o pipeline RAG de forma simples e didática.
Uso:
rag = SimpleRAG(hf_token="hf_...")
rag.index(meu_texto)
resposta, trechos = rag.query("Qual é o tema principal?")
"""
def __init__(self, hf_token: str):
self.hf_token = hf_token
self.chunks: List[str] = []
self.embeddings: List[List[float]] = []
self.indexed = False
def index(self, text: str, chunk_size: int = 300, overlap: int = 50) -> int:
self.chunks = chunk_text(text, chunk_size, overlap)
self.embeddings = get_embeddings(self.chunks, self.hf_token)
self.indexed = True
return len(self.chunks)
def query(self, question: str, top_k: int = 3) -> Tuple[str, List[Tuple[str, float]]]:
if not self.indexed:
raise RuntimeError("Nenhum documento indexado. Chame .index() primeiro.")
query_emb = get_embeddings([question], self.hf_token)[0]
retrieved = retrieve(query_emb, self.embeddings, self.chunks, top_k)
answer = generate_answer(question, retrieved, self.hf_token)
return answer, retrieved