"""Módulo de caché semántico para RAG. Evita llamadas repetidas al LLM almacenando respuestas previas y buscándolas por similitud semántica con las queries entrantes. """ import time from typing import Callable import numpy as np from access_control import User, retrieve_with_access def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: """Calcula la similitud coseno entre dos vectores.""" norm_a = np.linalg.norm(a) norm_b = np.linalg.norm(b) if norm_a == 0 or norm_b == 0: return 0.0 return float(np.dot(a, b) / (norm_a * norm_b)) class SemanticCache: """Caché semántico que almacena respuestas del LLM y las busca por similitud. Cada entrada contiene: embedding, query, response, doc_ids, timestamp. """ def __init__( self, embed_fn: Callable[[str], list[float]], threshold: float = 0.95, ttl_seconds: int = 3600, ): """Inicializa el caché semántico. Args: embed_fn: Función que recibe texto y devuelve un vector de embedding. threshold: Umbral de similitud coseno para considerar cache hit. ttl_seconds: Tiempo de vida de las entradas en segundos. """ self.embed_fn = embed_fn self.threshold = threshold self.ttl_seconds = ttl_seconds self._entries: list[dict] = [] self._hits = 0 self._misses = 0 def get(self, query: str) -> str | None: """Busca una respuesta en caché para la query dada. Args: query: Texto de la consulta. Returns: Respuesta almacenada si hay cache hit, None si es miss. """ query_embedding = np.array(self.embed_fn(query)) now = time.time() for entry in self._entries: # Saltar entradas expiradas if entry["timestamp"] + self.ttl_seconds < now: continue sim = _cosine_similarity(query_embedding, np.array(entry["embedding"])) if sim >= self.threshold: return entry["response"] return None def put( self, query: str, response: str, doc_ids: list[str] = None, ) -> None: """Almacena una nueva entrada en el caché. Args: query: Texto de la consulta. response: Respuesta generada por el LLM. doc_ids: IDs de los documentos usados para generar la respuesta. """ embedding = self.embed_fn(query) self._entries.append({ "embedding": embedding, "query": query, "response": response, "doc_ids": doc_ids or [], "timestamp": time.time(), }) def invalidate_by_doc(self, doc_id: str) -> None: """Elimina todas las entradas del caché asociadas a un doc_id. Args: doc_id: ID del documento a invalidar. """ self._entries = [ entry for entry in self._entries if doc_id not in entry.get("doc_ids", []) ] def cleanup_expired(self) -> int: """Elimina todas las entradas expiradas del caché. Returns: Número de entradas eliminadas. """ now = time.time() before = len(self._entries) self._entries = [ entry for entry in self._entries if entry["timestamp"] + self.ttl_seconds >= now ] removed = before - len(self._entries) return removed def hit_rate(self) -> dict: """Devuelve estadísticas de hit rate del caché. Returns: Diccionario con total_queries, hits, misses y hit_rate_percent. """ total = self._hits + self._misses return { "total_queries": total, "hits": self._hits, "misses": self._misses, "hit_rate_percent": (self._hits / total * 100) if total > 0 else 0.0, } def rag_query( query: str, user: User, vectorstore, cache: SemanticCache, llm_fn: Callable[[str], str], embed_fn: Callable[[str], list[float]], rerank_fn: Callable = None, ) -> str: """Ejecuta una query RAG con caché semántico y control de acceso. Pipeline: 1. Buscar en caché 2. Si miss, recuperar chunks con control de acceso 3. Opcionalmente re-rankear 4. Generar respuesta con LLM 5. Guardar en caché Args: query: Texto de la consulta. user: Usuario que realiza la consulta. vectorstore: Colección de ChromaDB. cache: Instancia de SemanticCache. llm_fn: Función que recibe un prompt y devuelve texto. embed_fn: Función de embedding. rerank_fn: Función opcional de reranking. Returns: Respuesta generada. """ # 1. Intentar cache hit cached = cache.get(query) if cached is not None: cache._hits += 1 return cached # 2. Cache miss cache._misses += 1 # 3. Recuperar chunks con control de acceso results = retrieve_with_access(query, user, vectorstore) documents = results.get("documents", [[]])[0] metadatas = results.get("metadatas", [[]])[0] if not documents: response = llm_fn(f"No se encontraron documentos relevantes. Pregunta: {query}") cache.put(query, response, []) return response # 4. Reranking opcional if rerank_fn is not None: ranked = rerank_fn(query, documents) documents = ranked # 5. Construir prompt y llamar al LLM context = "\n\n---\n\n".join(documents) prompt = ( f"Contexto:\n{context}\n\n" f"Pregunta: {query}\n\n" f"Responde basándote únicamente en el contexto proporcionado." ) response = llm_fn(prompt) # 6. Guardar en caché doc_ids = list({m.get("doc_id", "") for m in metadatas if m.get("doc_id")}) cache.put(query, response, doc_ids) return response