DocAgentSystem / rag /cache.py
RamsesCamas's picture
Initial clean commit for HF Space deployment
d0d2f42
"""Módulo de caché semántico para RAG.
Evita llamadas repetidas al LLM almacenando respuestas previas
y buscándolas por similitud semántica con las queries entrantes.
"""
import time
from typing import Callable
import numpy as np
from access_control import User, retrieve_with_access
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Calcula la similitud coseno entre dos vectores."""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
class SemanticCache:
"""Caché semántico que almacena respuestas del LLM y las busca por similitud.
Cada entrada contiene: embedding, query, response, doc_ids, timestamp.
"""
def __init__(
self,
embed_fn: Callable[[str], list[float]],
threshold: float = 0.95,
ttl_seconds: int = 3600,
):
"""Inicializa el caché semántico.
Args:
embed_fn: Función que recibe texto y devuelve un vector de embedding.
threshold: Umbral de similitud coseno para considerar cache hit.
ttl_seconds: Tiempo de vida de las entradas en segundos.
"""
self.embed_fn = embed_fn
self.threshold = threshold
self.ttl_seconds = ttl_seconds
self._entries: list[dict] = []
self._hits = 0
self._misses = 0
def get(self, query: str) -> str | None:
"""Busca una respuesta en caché para la query dada.
Args:
query: Texto de la consulta.
Returns:
Respuesta almacenada si hay cache hit, None si es miss.
"""
query_embedding = np.array(self.embed_fn(query))
now = time.time()
for entry in self._entries:
# Saltar entradas expiradas
if entry["timestamp"] + self.ttl_seconds < now:
continue
sim = _cosine_similarity(query_embedding, np.array(entry["embedding"]))
if sim >= self.threshold:
return entry["response"]
return None
def put(
self,
query: str,
response: str,
doc_ids: list[str] = None,
) -> None:
"""Almacena una nueva entrada en el caché.
Args:
query: Texto de la consulta.
response: Respuesta generada por el LLM.
doc_ids: IDs de los documentos usados para generar la respuesta.
"""
embedding = self.embed_fn(query)
self._entries.append({
"embedding": embedding,
"query": query,
"response": response,
"doc_ids": doc_ids or [],
"timestamp": time.time(),
})
def invalidate_by_doc(self, doc_id: str) -> None:
"""Elimina todas las entradas del caché asociadas a un doc_id.
Args:
doc_id: ID del documento a invalidar.
"""
self._entries = [
entry for entry in self._entries
if doc_id not in entry.get("doc_ids", [])
]
def cleanup_expired(self) -> int:
"""Elimina todas las entradas expiradas del caché.
Returns:
Número de entradas eliminadas.
"""
now = time.time()
before = len(self._entries)
self._entries = [
entry for entry in self._entries
if entry["timestamp"] + self.ttl_seconds >= now
]
removed = before - len(self._entries)
return removed
def hit_rate(self) -> dict:
"""Devuelve estadísticas de hit rate del caché.
Returns:
Diccionario con total_queries, hits, misses y hit_rate_percent.
"""
total = self._hits + self._misses
return {
"total_queries": total,
"hits": self._hits,
"misses": self._misses,
"hit_rate_percent": (self._hits / total * 100) if total > 0 else 0.0,
}
def rag_query(
query: str,
user: User,
vectorstore,
cache: SemanticCache,
llm_fn: Callable[[str], str],
embed_fn: Callable[[str], list[float]],
rerank_fn: Callable = None,
) -> str:
"""Ejecuta una query RAG con caché semántico y control de acceso.
Pipeline:
1. Buscar en caché
2. Si miss, recuperar chunks con control de acceso
3. Opcionalmente re-rankear
4. Generar respuesta con LLM
5. Guardar en caché
Args:
query: Texto de la consulta.
user: Usuario que realiza la consulta.
vectorstore: Colección de ChromaDB.
cache: Instancia de SemanticCache.
llm_fn: Función que recibe un prompt y devuelve texto.
embed_fn: Función de embedding.
rerank_fn: Función opcional de reranking.
Returns:
Respuesta generada.
"""
# 1. Intentar cache hit
cached = cache.get(query)
if cached is not None:
cache._hits += 1
return cached
# 2. Cache miss
cache._misses += 1
# 3. Recuperar chunks con control de acceso
results = retrieve_with_access(query, user, vectorstore)
documents = results.get("documents", [[]])[0]
metadatas = results.get("metadatas", [[]])[0]
if not documents:
response = llm_fn(f"No se encontraron documentos relevantes. Pregunta: {query}")
cache.put(query, response, [])
return response
# 4. Reranking opcional
if rerank_fn is not None:
ranked = rerank_fn(query, documents)
documents = ranked
# 5. Construir prompt y llamar al LLM
context = "\n\n---\n\n".join(documents)
prompt = (
f"Contexto:\n{context}\n\n"
f"Pregunta: {query}\n\n"
f"Responde basándote únicamente en el contexto proporcionado."
)
response = llm_fn(prompt)
# 6. Guardar en caché
doc_ids = list({m.get("doc_id", "") for m in metadatas if m.get("doc_id")})
cache.put(query, response, doc_ids)
return response