""" rag.py — Lógica central do RAG (Retrieval-Augmented Generation) Fluxo: 1. Indexação : documento → chunks → embeddings → armazenados em memória 2. Recuperação: pergunta → embedding → busca por similaridade (cosine) 3. Geração : contexto recuperado + pergunta → LLM (chat) → resposta """ import re import math from typing import List, Tuple from huggingface_hub import InferenceClient # --------------------------------------------------------------------------- # Configurações # --------------------------------------------------------------------------- EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # Cerebras: rápido, gratuito via HF token, sem necessidade de aceitar termos GENERATION_MODEL = "meta-llama/Llama-3.1-8B-Instruct" GENERATION_PROVIDER = "cerebras" # --------------------------------------------------------------------------- # 1. CHUNKING # --------------------------------------------------------------------------- def chunk_text(text: str, chunk_size: int = 300, overlap: int = 50) -> List[str]: """Divide o texto em chunks com sobreposição para manter contexto.""" text = re.sub(r"\s+", " ", text).strip() words = text.split() chunks, start = [], 0 while start < len(words): chunks.append(" ".join(words[start:start + chunk_size])) start += chunk_size - overlap return chunks # --------------------------------------------------------------------------- # 2. EMBEDDINGS — usa hf-inference (ótimo para embeddings) # --------------------------------------------------------------------------- def get_embeddings(texts: List[str], hf_token: str) -> List[List[float]]: """Gera embeddings via HF Inference API.""" client = InferenceClient(provider="hf-inference", token=hf_token) embeddings = client.feature_extraction(texts, model=EMBEDDING_MODEL) return [list(map(float, vec)) for vec in embeddings] # --------------------------------------------------------------------------- # 3. SIMILARIDADE # --------------------------------------------------------------------------- def cosine_similarity(vec_a: List[float], vec_b: List[float]) -> float: dot = sum(a * b for a, b in zip(vec_a, vec_b)) mag_a = math.sqrt(sum(a * a for a in vec_a)) mag_b = math.sqrt(sum(b * b for b in vec_b)) return 0.0 if mag_a == 0 or mag_b == 0 else dot / (mag_a * mag_b) def retrieve( query_embedding: List[float], chunk_embeddings: List[List[float]], chunks: List[str], top_k: int = 3, ) -> List[Tuple[str, float]]: scores = [ (chunk, cosine_similarity(query_embedding, emb)) for chunk, emb in zip(chunks, chunk_embeddings) ] scores.sort(key=lambda x: x[1], reverse=True) return scores[:top_k] # --------------------------------------------------------------------------- # 4. GERAÇÃO — Cerebras (rápido, gratuito, sem restrições de aceite) # --------------------------------------------------------------------------- def generate_answer( question: str, context_chunks: List[Tuple[str, float]], hf_token: str, max_tokens: int = 512, ) -> str: context = "\n\n".join( f"[Trecho {i+1} | score={score:.2f}]\n{text}" for i, (text, score) in enumerate(context_chunks) ) system_msg = ( "Você é um assistente prestativo. " "Responda à pergunta usando APENAS as informações fornecidas no contexto. " "Se a resposta não estiver no contexto, diga: " "'Não encontrei essa informação no documento.' " "Responda sempre em português." ) user_msg = f"### CONTEXTO:\n{context}\n\n### PERGUNTA:\n{question}" client = InferenceClient(provider=GENERATION_PROVIDER, token=hf_token) response = client.chat_completion( model=GENERATION_MODEL, messages=[ {"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}, ], max_tokens=max_tokens, temperature=0.3, ) return response.choices[0].message.content.strip() # --------------------------------------------------------------------------- # 5. PIPELINE COMPLETO # --------------------------------------------------------------------------- class SimpleRAG: """ Encapsula todo o pipeline RAG de forma simples e didática. Uso: rag = SimpleRAG(hf_token="hf_...") rag.index(meu_texto) resposta, trechos = rag.query("Qual é o tema principal?") """ def __init__(self, hf_token: str): self.hf_token = hf_token self.chunks: List[str] = [] self.embeddings: List[List[float]] = [] self.indexed = False def index(self, text: str, chunk_size: int = 300, overlap: int = 50) -> int: self.chunks = chunk_text(text, chunk_size, overlap) self.embeddings = get_embeddings(self.chunks, self.hf_token) self.indexed = True return len(self.chunks) def query(self, question: str, top_k: int = 3) -> Tuple[str, List[Tuple[str, float]]]: if not self.indexed: raise RuntimeError("Nenhum documento indexado. Chame .index() primeiro.") query_emb = get_embeddings([question], self.hf_token)[0] retrieved = retrieve(query_emb, self.embeddings, self.chunks, top_k) answer = generate_answer(question, retrieved, self.hf_token) return answer, retrieved