Spaces:
Sleeping
Sleeping
| """ | |
| rag/retrieval.py — Clase 6: RAG Avanzado | |
| Técnicas implementadas: | |
| 1. BM25 (keyword search) | |
| 2. Hybrid Retrieval (BM25 + vector) | |
| 3. Multi-query expansion | |
| 4. Reciprocal Rank Fusion (RRF) | |
| 5. Cross-encoder reranking | |
| 6. Context compression | |
| 7. Pipeline completo (advanced_rag_query) | |
| LLM: GPT-OSS vía Groq (compatible OpenAI) | |
| Embeddings: sentence-transformers (local) | |
| """ | |
| import json | |
| import os | |
| import re | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| from rank_bm25 import BM25Okapi | |
| from rag.ingestion import Chunk | |
| from rag.vectorstore import SearchResult, search as vector_search | |
| load_dotenv() | |
| # --------------------------------------------------------------------------- | |
| # Configuración LLM (Groq vía OpenAI-compatible API) | |
| # --------------------------------------------------------------------------- | |
| GROQ_MODEL = os.environ.get("GROQ_MODEL", "openai/gpt-oss-120b") | |
| def _get_groq_client() -> OpenAI: | |
| """Crea un cliente OpenAI apuntando a la API de Groq.""" | |
| api_key = os.environ.get("GROQ_API_KEY") | |
| if not api_key: | |
| raise RuntimeError( | |
| "GROQ_API_KEY no está configurada. " | |
| "Exporta la variable de entorno: export GROQ_API_KEY='gsk_...'" | |
| ) | |
| return OpenAI( | |
| base_url="https://api.groq.com/openai/v1", | |
| api_key=api_key, | |
| ) | |
| _groq_client: OpenAI | None = None | |
| def _get_client() -> OpenAI: | |
| """Singleton del cliente Groq.""" | |
| global _groq_client | |
| if _groq_client is None: | |
| _groq_client = _get_groq_client() | |
| return _groq_client | |
| # --------------------------------------------------------------------------- | |
| # Tracker de uso de tokens (para métricas en compare_rag.py) | |
| # --------------------------------------------------------------------------- | |
| _usage_tracker = {"input_tokens": 0, "output_tokens": 0, "calls": 0} | |
| def reset_usage_tracker(): | |
| """Reinicia el acumulador de tokens.""" | |
| global _usage_tracker | |
| _usage_tracker = {"input_tokens": 0, "output_tokens": 0, "calls": 0} | |
| def get_usage(): | |
| """Retorna copia del acumulador actual de tokens.""" | |
| return dict(_usage_tracker) | |
| # --------------------------------------------------------------------------- | |
| # Helper: LLM call wrapper | |
| # --------------------------------------------------------------------------- | |
| def call_llm(prompt: str, system: str = "", temperature: float = 0.3) -> str: | |
| """Llama al LLM (GPT-OSS vía Groq) y retorna el texto de respuesta. | |
| Acumula métricas de uso en el tracker interno. | |
| """ | |
| global _usage_tracker | |
| client = _get_client() | |
| messages = [] | |
| if system: | |
| messages.append({"role": "system", "content": system}) | |
| messages.append({"role": "user", "content": prompt}) | |
| response = client.chat.completions.create( | |
| model=GROQ_MODEL, | |
| messages=messages, | |
| temperature=temperature, | |
| ) | |
| # Acumular métricas | |
| if response.usage: | |
| _usage_tracker["input_tokens"] += response.usage.prompt_tokens or 0 | |
| _usage_tracker["output_tokens"] += response.usage.completion_tokens or 0 | |
| _usage_tracker["calls"] += 1 | |
| return response.choices[0].message.content | |
| # --------------------------------------------------------------------------- | |
| # 1. BM25Index | |
| # --------------------------------------------------------------------------- | |
| class BM25Index: | |
| """Índice de búsqueda por keywords usando BM25 (Okapi).""" | |
| def __init__(self, chunks: list[Chunk]) -> None: | |
| self.chunks = chunks | |
| tokenized = [self._tokenize(c.content) for c in chunks] | |
| self.bm25 = BM25Okapi(tokenized) | |
| def _tokenize(text: str) -> list[str]: | |
| """Tokenización simple: lowercase + regex \\w+.""" | |
| return re.findall(r"\w+", text.lower()) | |
| def search(self, query: str, top_k: int = 10) -> list[tuple[Chunk, float]]: | |
| """Busca por BM25 y retorna lista de (Chunk, score).""" | |
| tokens = self._tokenize(query) | |
| scores = self.bm25.get_scores(tokens) | |
| ranked = sorted( | |
| enumerate(scores), key=lambda x: x[1], reverse=True | |
| )[:top_k] | |
| return [(self.chunks[i], float(s)) for i, s in ranked if s > 0] | |
| # --------------------------------------------------------------------------- | |
| # 2. HybridRetriever | |
| # --------------------------------------------------------------------------- | |
| class HybridRetriever: | |
| """Combina BM25 (keyword) y vector search (semántico) con score normalizado.""" | |
| def __init__( | |
| self, collection, chunks: list[Chunk], alpha: float = 0.5 | |
| ) -> None: | |
| """ | |
| Args: | |
| collection: Colección de ChromaDB. | |
| chunks: Lista de Chunks indexados. | |
| alpha: Peso del vector search (1-alpha = peso BM25). | |
| """ | |
| self.collection = collection | |
| self.chunks = chunks | |
| self.alpha = alpha | |
| self.bm25 = BM25Index(chunks) | |
| self._chunk_map = {c.chunk_id: c for c in chunks} | |
| def search(self, query: str, top_k: int = 5) -> list[SearchResult]: | |
| """Búsqueda híbrida con scores normalizados y combinados.""" | |
| candidates = top_k * 2 | |
| # --- BM25 --- | |
| bm25_results = self.bm25.search(query, top_k=candidates) | |
| bm25_scores: dict[str, float] = {} | |
| if bm25_results: | |
| max_bm25 = max(s for _, s in bm25_results) | |
| if max_bm25 > 0: | |
| for chunk, score in bm25_results: | |
| bm25_scores[chunk.chunk_id] = score / max_bm25 | |
| # --- Vector search --- | |
| vec_results = vector_search(self.collection, query, n_results=candidates) | |
| vec_scores: dict[str, float] = {} | |
| vec_content: dict[str, SearchResult] = {} | |
| if vec_results: | |
| max_vec = max(r.score for r in vec_results) | |
| if max_vec > 0: | |
| for r in vec_results: | |
| vec_scores[r.chunk_id] = r.score / max_vec | |
| vec_content[r.chunk_id] = r | |
| # --- Combinar scores --- | |
| all_ids = set(bm25_scores.keys()) | set(vec_scores.keys()) | |
| combined: list[tuple[str, float]] = [] | |
| for cid in all_ids: | |
| vec_norm = vec_scores.get(cid, 0.0) | |
| bm25_norm = bm25_scores.get(cid, 0.0) | |
| score = self.alpha * vec_norm + (1 - self.alpha) * bm25_norm | |
| combined.append((cid, score)) | |
| combined.sort(key=lambda x: x[1], reverse=True) | |
| # --- Construir resultados --- | |
| results = [] | |
| for cid, score in combined[:top_k]: | |
| if cid in vec_content: | |
| r = vec_content[cid] | |
| results.append(SearchResult( | |
| content=r.content, metadata=r.metadata, | |
| score=score, chunk_id=cid, | |
| )) | |
| elif cid in self._chunk_map: | |
| chunk = self._chunk_map[cid] | |
| results.append(SearchResult( | |
| content=chunk.content, metadata=chunk.metadata, | |
| score=score, chunk_id=cid, | |
| )) | |
| return results | |
| # --------------------------------------------------------------------------- | |
| # 3. Multi-query generation | |
| # --------------------------------------------------------------------------- | |
| def generate_multi_queries(original_query: str, n: int = 3) -> list[str]: | |
| """Genera n reformulaciones de la query usando el LLM. | |
| Siempre incluye la query original como primera entrada. | |
| """ | |
| prompt = ( | |
| f"Genera {n} reformulaciones diferentes de la siguiente pregunta. " | |
| "Cada reformulación debe usar vocabulario distinto pero mantener el mismo significado.\n\n" | |
| f"Pregunta original: {original_query}\n\n" | |
| 'Responde SOLO con un JSON array de strings. Ejemplo:\n' | |
| '["reformulación 1", "reformulación 2", "reformulación 3"]' | |
| ) | |
| try: | |
| response = call_llm(prompt) | |
| match = re.search(r"\[.*\]", response, re.DOTALL) | |
| if match: | |
| queries = json.loads(match.group()) | |
| if isinstance(queries, list) and all(isinstance(q, str) for q in queries): | |
| return [original_query] + queries[:n] | |
| except Exception: | |
| pass | |
| return [original_query] | |
| # --------------------------------------------------------------------------- | |
| # 4. Multi-query search | |
| # --------------------------------------------------------------------------- | |
| def multi_query_search( | |
| collection, query: str, n_results: int = 5 | |
| ) -> list[SearchResult]: | |
| """Busca con múltiples reformulaciones y acumula scores.""" | |
| queries = generate_multi_queries(query) | |
| score_map: dict[str, float] = {} | |
| result_map: dict[str, SearchResult] = {} | |
| for q in queries: | |
| results = vector_search(collection, q, n_results=n_results) | |
| for r in results: | |
| score_map[r.chunk_id] = score_map.get(r.chunk_id, 0.0) + r.score | |
| if r.chunk_id not in result_map: | |
| result_map[r.chunk_id] = r | |
| ranked = sorted(score_map.items(), key=lambda x: x[1], reverse=True)[:n_results] | |
| return [ | |
| SearchResult( | |
| content=result_map[cid].content, | |
| metadata=result_map[cid].metadata, | |
| score=score, | |
| chunk_id=cid, | |
| ) | |
| for cid, score in ranked | |
| if cid in result_map | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # 5. Reciprocal Rank Fusion (RRF) | |
| # --------------------------------------------------------------------------- | |
| def reciprocal_rank_fusion( | |
| result_lists: list[list[SearchResult]], k: int = 60 | |
| ) -> list[tuple[str, float]]: | |
| """Fusiona múltiples listas de resultados usando RRF. | |
| Returns: | |
| Lista de (chunk_id, rrf_score) ordenada descendente. | |
| """ | |
| rrf_scores: dict[str, float] = {} | |
| for results in result_lists: | |
| for rank, r in enumerate(results): | |
| rrf_scores[r.chunk_id] = ( | |
| rrf_scores.get(r.chunk_id, 0.0) + 1 / (k + rank + 1) | |
| ) | |
| return sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) | |
| # --------------------------------------------------------------------------- | |
| # 6. Reranking con Cross-Encoder | |
| # --------------------------------------------------------------------------- | |
| _cross_encoder = None | |
| def _get_cross_encoder(): | |
| """Carga el cross-encoder (lazy, singleton).""" | |
| global _cross_encoder | |
| if _cross_encoder is None: | |
| print(" Cargando cross-encoder (primera vez puede tardar unos minutos)...") | |
| from sentence_transformers import CrossEncoder | |
| _cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2") | |
| return _cross_encoder | |
| def rerank( | |
| query: str, results: list[SearchResult], top_k: int = 5 | |
| ) -> list[SearchResult]: | |
| """Re-ordena resultados usando un cross-encoder.""" | |
| if not results: | |
| return [] | |
| model = _get_cross_encoder() | |
| pairs = [(query, r.content) for r in results] | |
| scores = model.predict(pairs) | |
| reranked = [] | |
| for r, score in zip(results, scores): | |
| reranked.append(SearchResult( | |
| content=r.content, | |
| metadata=r.metadata, | |
| score=float(score), | |
| chunk_id=r.chunk_id, | |
| )) | |
| reranked.sort(key=lambda x: x.score, reverse=True) | |
| return reranked[:top_k] | |
| # --------------------------------------------------------------------------- | |
| # 7. Compresión de contexto (LLM) | |
| # --------------------------------------------------------------------------- | |
| def compress_context(query: str, chunks: list[str]) -> str: | |
| """Extrae SOLO las oraciones relevantes de los chunks usando el LLM.""" | |
| combined = "\n---\n".join(chunks) | |
| prompt = ( | |
| f"Pregunta del usuario: {query}\n\n" | |
| f"Contexto recuperado:\n{combined}\n\n" | |
| "Extrae SOLO las oraciones relevantes para responder la pregunta. " | |
| "No inventes información. No parafrasees. " | |
| "Solo selecciona y devuelve las oraciones más relevantes." | |
| ) | |
| return call_llm(prompt) | |
| # --------------------------------------------------------------------------- | |
| # 8. Compresión de contexto (reranker, sin LLM) | |
| # --------------------------------------------------------------------------- | |
| def compress_with_reranker( | |
| query: str, chunks: list[str], top_sentences: int = 10 | |
| ) -> str: | |
| """Comprime el contexto rerankeando oraciones individuales.""" | |
| sentences = [] | |
| for chunk in chunks: | |
| for s in chunk.split(". "): | |
| s = s.strip() | |
| if len(s) > 20: | |
| sentences.append(s if s.endswith(".") else s + ".") | |
| if not sentences: | |
| return " ".join(chunks) | |
| model = _get_cross_encoder() | |
| pairs = [(query, s) for s in sentences] | |
| scores = model.predict(pairs) | |
| ranked = sorted(zip(sentences, scores), key=lambda x: x[1], reverse=True) | |
| top = [s for s, _ in ranked[:top_sentences]] | |
| return " ".join(top) | |
| # --------------------------------------------------------------------------- | |
| # 9. Pipeline completo: Advanced RAG Query | |
| # --------------------------------------------------------------------------- | |
| def advanced_rag_query(collection, chunks: list[Chunk], query: str) -> str: | |
| """Pipeline completo de RAG Avanzado. | |
| 1. Multi-query: genera reformulaciones | |
| 2. Búsqueda híbrida: por cada query, busca con HybridRetriever | |
| 3. Deduplicar resultados por chunk_id | |
| 4. Reranking: reordena candidatos únicos con cross-encoder | |
| 5. Compresión: extrae solo lo relevante con LLM | |
| 6. Genera respuesta final con LLM | |
| """ | |
| # 1. Multi-query | |
| queries = generate_multi_queries(query) | |
| print(f" Multi-query: {len(queries)} queries generadas") | |
| # 2. Búsqueda híbrida por cada query | |
| hybrid = HybridRetriever(collection, chunks) | |
| all_results: list[SearchResult] = [] | |
| for q in queries: | |
| results = hybrid.search(q, top_k=5) | |
| all_results.extend(results) | |
| # 3. Deduplicar por chunk_id | |
| seen: set[str] = set() | |
| unique_results: list[SearchResult] = [] | |
| for r in all_results: | |
| if r.chunk_id not in seen: | |
| seen.add(r.chunk_id) | |
| unique_results.append(r) | |
| print(f" Candidatos únicos: {len(unique_results)}") | |
| # 4. Reranking | |
| reranked = rerank(query, unique_results, top_k=5) | |
| print(f" Rerankeados: {len(reranked)}") | |
| # 5. Compresión de contexto | |
| chunk_texts = [r.content for r in reranked] | |
| compressed = compress_context(query, chunk_texts) | |
| print(f" Contexto comprimido: {len(compressed)} chars") | |
| # 6. Generar respuesta final | |
| final_prompt = ( | |
| "Responde la siguiente pregunta usando SOLO el contexto proporcionado.\n" | |
| 'Si no puedes responder con el contexto dado, di "No tengo suficiente información".\n\n' | |
| f"Contexto:\n{compressed}\n\n" | |
| f"Pregunta: {query}\n\n" | |
| "Respuesta:" | |
| ) | |
| return call_llm(final_prompt) | |