DocAgentSystem / rag /retrieval.py
RamsesCamas's picture
Initial clean commit for HF Space deployment
d0d2f42
"""
rag/retrieval.py — Clase 6: RAG Avanzado
Técnicas implementadas:
1. BM25 (keyword search)
2. Hybrid Retrieval (BM25 + vector)
3. Multi-query expansion
4. Reciprocal Rank Fusion (RRF)
5. Cross-encoder reranking
6. Context compression
7. Pipeline completo (advanced_rag_query)
LLM: GPT-OSS vía Groq (compatible OpenAI)
Embeddings: sentence-transformers (local)
"""
import json
import os
import re
from dotenv import load_dotenv
from openai import OpenAI
from rank_bm25 import BM25Okapi
from rag.ingestion import Chunk
from rag.vectorstore import SearchResult, search as vector_search
load_dotenv()
# ---------------------------------------------------------------------------
# Configuración LLM (Groq vía OpenAI-compatible API)
# ---------------------------------------------------------------------------
GROQ_MODEL = os.environ.get("GROQ_MODEL", "openai/gpt-oss-120b")
def _get_groq_client() -> OpenAI:
"""Crea un cliente OpenAI apuntando a la API de Groq."""
api_key = os.environ.get("GROQ_API_KEY")
if not api_key:
raise RuntimeError(
"GROQ_API_KEY no está configurada. "
"Exporta la variable de entorno: export GROQ_API_KEY='gsk_...'"
)
return OpenAI(
base_url="https://api.groq.com/openai/v1",
api_key=api_key,
)
_groq_client: OpenAI | None = None
def _get_client() -> OpenAI:
"""Singleton del cliente Groq."""
global _groq_client
if _groq_client is None:
_groq_client = _get_groq_client()
return _groq_client
# ---------------------------------------------------------------------------
# Tracker de uso de tokens (para métricas en compare_rag.py)
# ---------------------------------------------------------------------------
_usage_tracker = {"input_tokens": 0, "output_tokens": 0, "calls": 0}
def reset_usage_tracker():
"""Reinicia el acumulador de tokens."""
global _usage_tracker
_usage_tracker = {"input_tokens": 0, "output_tokens": 0, "calls": 0}
def get_usage():
"""Retorna copia del acumulador actual de tokens."""
return dict(_usage_tracker)
# ---------------------------------------------------------------------------
# Helper: LLM call wrapper
# ---------------------------------------------------------------------------
def call_llm(prompt: str, system: str = "", temperature: float = 0.3) -> str:
"""Llama al LLM (GPT-OSS vía Groq) y retorna el texto de respuesta.
Acumula métricas de uso en el tracker interno.
"""
global _usage_tracker
client = _get_client()
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
response = client.chat.completions.create(
model=GROQ_MODEL,
messages=messages,
temperature=temperature,
)
# Acumular métricas
if response.usage:
_usage_tracker["input_tokens"] += response.usage.prompt_tokens or 0
_usage_tracker["output_tokens"] += response.usage.completion_tokens or 0
_usage_tracker["calls"] += 1
return response.choices[0].message.content
# ---------------------------------------------------------------------------
# 1. BM25Index
# ---------------------------------------------------------------------------
class BM25Index:
"""Índice de búsqueda por keywords usando BM25 (Okapi)."""
def __init__(self, chunks: list[Chunk]) -> None:
self.chunks = chunks
tokenized = [self._tokenize(c.content) for c in chunks]
self.bm25 = BM25Okapi(tokenized)
@staticmethod
def _tokenize(text: str) -> list[str]:
"""Tokenización simple: lowercase + regex \\w+."""
return re.findall(r"\w+", text.lower())
def search(self, query: str, top_k: int = 10) -> list[tuple[Chunk, float]]:
"""Busca por BM25 y retorna lista de (Chunk, score)."""
tokens = self._tokenize(query)
scores = self.bm25.get_scores(tokens)
ranked = sorted(
enumerate(scores), key=lambda x: x[1], reverse=True
)[:top_k]
return [(self.chunks[i], float(s)) for i, s in ranked if s > 0]
# ---------------------------------------------------------------------------
# 2. HybridRetriever
# ---------------------------------------------------------------------------
class HybridRetriever:
"""Combina BM25 (keyword) y vector search (semántico) con score normalizado."""
def __init__(
self, collection, chunks: list[Chunk], alpha: float = 0.5
) -> None:
"""
Args:
collection: Colección de ChromaDB.
chunks: Lista de Chunks indexados.
alpha: Peso del vector search (1-alpha = peso BM25).
"""
self.collection = collection
self.chunks = chunks
self.alpha = alpha
self.bm25 = BM25Index(chunks)
self._chunk_map = {c.chunk_id: c for c in chunks}
def search(self, query: str, top_k: int = 5) -> list[SearchResult]:
"""Búsqueda híbrida con scores normalizados y combinados."""
candidates = top_k * 2
# --- BM25 ---
bm25_results = self.bm25.search(query, top_k=candidates)
bm25_scores: dict[str, float] = {}
if bm25_results:
max_bm25 = max(s for _, s in bm25_results)
if max_bm25 > 0:
for chunk, score in bm25_results:
bm25_scores[chunk.chunk_id] = score / max_bm25
# --- Vector search ---
vec_results = vector_search(self.collection, query, n_results=candidates)
vec_scores: dict[str, float] = {}
vec_content: dict[str, SearchResult] = {}
if vec_results:
max_vec = max(r.score for r in vec_results)
if max_vec > 0:
for r in vec_results:
vec_scores[r.chunk_id] = r.score / max_vec
vec_content[r.chunk_id] = r
# --- Combinar scores ---
all_ids = set(bm25_scores.keys()) | set(vec_scores.keys())
combined: list[tuple[str, float]] = []
for cid in all_ids:
vec_norm = vec_scores.get(cid, 0.0)
bm25_norm = bm25_scores.get(cid, 0.0)
score = self.alpha * vec_norm + (1 - self.alpha) * bm25_norm
combined.append((cid, score))
combined.sort(key=lambda x: x[1], reverse=True)
# --- Construir resultados ---
results = []
for cid, score in combined[:top_k]:
if cid in vec_content:
r = vec_content[cid]
results.append(SearchResult(
content=r.content, metadata=r.metadata,
score=score, chunk_id=cid,
))
elif cid in self._chunk_map:
chunk = self._chunk_map[cid]
results.append(SearchResult(
content=chunk.content, metadata=chunk.metadata,
score=score, chunk_id=cid,
))
return results
# ---------------------------------------------------------------------------
# 3. Multi-query generation
# ---------------------------------------------------------------------------
def generate_multi_queries(original_query: str, n: int = 3) -> list[str]:
"""Genera n reformulaciones de la query usando el LLM.
Siempre incluye la query original como primera entrada.
"""
prompt = (
f"Genera {n} reformulaciones diferentes de la siguiente pregunta. "
"Cada reformulación debe usar vocabulario distinto pero mantener el mismo significado.\n\n"
f"Pregunta original: {original_query}\n\n"
'Responde SOLO con un JSON array de strings. Ejemplo:\n'
'["reformulación 1", "reformulación 2", "reformulación 3"]'
)
try:
response = call_llm(prompt)
match = re.search(r"\[.*\]", response, re.DOTALL)
if match:
queries = json.loads(match.group())
if isinstance(queries, list) and all(isinstance(q, str) for q in queries):
return [original_query] + queries[:n]
except Exception:
pass
return [original_query]
# ---------------------------------------------------------------------------
# 4. Multi-query search
# ---------------------------------------------------------------------------
def multi_query_search(
collection, query: str, n_results: int = 5
) -> list[SearchResult]:
"""Busca con múltiples reformulaciones y acumula scores."""
queries = generate_multi_queries(query)
score_map: dict[str, float] = {}
result_map: dict[str, SearchResult] = {}
for q in queries:
results = vector_search(collection, q, n_results=n_results)
for r in results:
score_map[r.chunk_id] = score_map.get(r.chunk_id, 0.0) + r.score
if r.chunk_id not in result_map:
result_map[r.chunk_id] = r
ranked = sorted(score_map.items(), key=lambda x: x[1], reverse=True)[:n_results]
return [
SearchResult(
content=result_map[cid].content,
metadata=result_map[cid].metadata,
score=score,
chunk_id=cid,
)
for cid, score in ranked
if cid in result_map
]
# ---------------------------------------------------------------------------
# 5. Reciprocal Rank Fusion (RRF)
# ---------------------------------------------------------------------------
def reciprocal_rank_fusion(
result_lists: list[list[SearchResult]], k: int = 60
) -> list[tuple[str, float]]:
"""Fusiona múltiples listas de resultados usando RRF.
Returns:
Lista de (chunk_id, rrf_score) ordenada descendente.
"""
rrf_scores: dict[str, float] = {}
for results in result_lists:
for rank, r in enumerate(results):
rrf_scores[r.chunk_id] = (
rrf_scores.get(r.chunk_id, 0.0) + 1 / (k + rank + 1)
)
return sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True)
# ---------------------------------------------------------------------------
# 6. Reranking con Cross-Encoder
# ---------------------------------------------------------------------------
_cross_encoder = None
def _get_cross_encoder():
"""Carga el cross-encoder (lazy, singleton)."""
global _cross_encoder
if _cross_encoder is None:
print(" Cargando cross-encoder (primera vez puede tardar unos minutos)...")
from sentence_transformers import CrossEncoder
_cross_encoder = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
return _cross_encoder
def rerank(
query: str, results: list[SearchResult], top_k: int = 5
) -> list[SearchResult]:
"""Re-ordena resultados usando un cross-encoder."""
if not results:
return []
model = _get_cross_encoder()
pairs = [(query, r.content) for r in results]
scores = model.predict(pairs)
reranked = []
for r, score in zip(results, scores):
reranked.append(SearchResult(
content=r.content,
metadata=r.metadata,
score=float(score),
chunk_id=r.chunk_id,
))
reranked.sort(key=lambda x: x.score, reverse=True)
return reranked[:top_k]
# ---------------------------------------------------------------------------
# 7. Compresión de contexto (LLM)
# ---------------------------------------------------------------------------
def compress_context(query: str, chunks: list[str]) -> str:
"""Extrae SOLO las oraciones relevantes de los chunks usando el LLM."""
combined = "\n---\n".join(chunks)
prompt = (
f"Pregunta del usuario: {query}\n\n"
f"Contexto recuperado:\n{combined}\n\n"
"Extrae SOLO las oraciones relevantes para responder la pregunta. "
"No inventes información. No parafrasees. "
"Solo selecciona y devuelve las oraciones más relevantes."
)
return call_llm(prompt)
# ---------------------------------------------------------------------------
# 8. Compresión de contexto (reranker, sin LLM)
# ---------------------------------------------------------------------------
def compress_with_reranker(
query: str, chunks: list[str], top_sentences: int = 10
) -> str:
"""Comprime el contexto rerankeando oraciones individuales."""
sentences = []
for chunk in chunks:
for s in chunk.split(". "):
s = s.strip()
if len(s) > 20:
sentences.append(s if s.endswith(".") else s + ".")
if not sentences:
return " ".join(chunks)
model = _get_cross_encoder()
pairs = [(query, s) for s in sentences]
scores = model.predict(pairs)
ranked = sorted(zip(sentences, scores), key=lambda x: x[1], reverse=True)
top = [s for s, _ in ranked[:top_sentences]]
return " ".join(top)
# ---------------------------------------------------------------------------
# 9. Pipeline completo: Advanced RAG Query
# ---------------------------------------------------------------------------
def advanced_rag_query(collection, chunks: list[Chunk], query: str) -> str:
"""Pipeline completo de RAG Avanzado.
1. Multi-query: genera reformulaciones
2. Búsqueda híbrida: por cada query, busca con HybridRetriever
3. Deduplicar resultados por chunk_id
4. Reranking: reordena candidatos únicos con cross-encoder
5. Compresión: extrae solo lo relevante con LLM
6. Genera respuesta final con LLM
"""
# 1. Multi-query
queries = generate_multi_queries(query)
print(f" Multi-query: {len(queries)} queries generadas")
# 2. Búsqueda híbrida por cada query
hybrid = HybridRetriever(collection, chunks)
all_results: list[SearchResult] = []
for q in queries:
results = hybrid.search(q, top_k=5)
all_results.extend(results)
# 3. Deduplicar por chunk_id
seen: set[str] = set()
unique_results: list[SearchResult] = []
for r in all_results:
if r.chunk_id not in seen:
seen.add(r.chunk_id)
unique_results.append(r)
print(f" Candidatos únicos: {len(unique_results)}")
# 4. Reranking
reranked = rerank(query, unique_results, top_k=5)
print(f" Rerankeados: {len(reranked)}")
# 5. Compresión de contexto
chunk_texts = [r.content for r in reranked]
compressed = compress_context(query, chunk_texts)
print(f" Contexto comprimido: {len(compressed)} chars")
# 6. Generar respuesta final
final_prompt = (
"Responde la siguiente pregunta usando SOLO el contexto proporcionado.\n"
'Si no puedes responder con el contexto dado, di "No tengo suficiente información".\n\n'
f"Contexto:\n{compressed}\n\n"
f"Pregunta: {query}\n\n"
"Respuesta:"
)
return call_llm(final_prompt)