Spaces:
Sleeping
Sleeping
File size: 5,977 Bytes
d0d2f42 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | """Módulo de caché semántico para RAG.
Evita llamadas repetidas al LLM almacenando respuestas previas
y buscándolas por similitud semántica con las queries entrantes.
"""
import time
from typing import Callable
import numpy as np
from access_control import User, retrieve_with_access
def _cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Calcula la similitud coseno entre dos vectores."""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0 or norm_b == 0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))
class SemanticCache:
"""Caché semántico que almacena respuestas del LLM y las busca por similitud.
Cada entrada contiene: embedding, query, response, doc_ids, timestamp.
"""
def __init__(
self,
embed_fn: Callable[[str], list[float]],
threshold: float = 0.95,
ttl_seconds: int = 3600,
):
"""Inicializa el caché semántico.
Args:
embed_fn: Función que recibe texto y devuelve un vector de embedding.
threshold: Umbral de similitud coseno para considerar cache hit.
ttl_seconds: Tiempo de vida de las entradas en segundos.
"""
self.embed_fn = embed_fn
self.threshold = threshold
self.ttl_seconds = ttl_seconds
self._entries: list[dict] = []
self._hits = 0
self._misses = 0
def get(self, query: str) -> str | None:
"""Busca una respuesta en caché para la query dada.
Args:
query: Texto de la consulta.
Returns:
Respuesta almacenada si hay cache hit, None si es miss.
"""
query_embedding = np.array(self.embed_fn(query))
now = time.time()
for entry in self._entries:
# Saltar entradas expiradas
if entry["timestamp"] + self.ttl_seconds < now:
continue
sim = _cosine_similarity(query_embedding, np.array(entry["embedding"]))
if sim >= self.threshold:
return entry["response"]
return None
def put(
self,
query: str,
response: str,
doc_ids: list[str] = None,
) -> None:
"""Almacena una nueva entrada en el caché.
Args:
query: Texto de la consulta.
response: Respuesta generada por el LLM.
doc_ids: IDs de los documentos usados para generar la respuesta.
"""
embedding = self.embed_fn(query)
self._entries.append({
"embedding": embedding,
"query": query,
"response": response,
"doc_ids": doc_ids or [],
"timestamp": time.time(),
})
def invalidate_by_doc(self, doc_id: str) -> None:
"""Elimina todas las entradas del caché asociadas a un doc_id.
Args:
doc_id: ID del documento a invalidar.
"""
self._entries = [
entry for entry in self._entries
if doc_id not in entry.get("doc_ids", [])
]
def cleanup_expired(self) -> int:
"""Elimina todas las entradas expiradas del caché.
Returns:
Número de entradas eliminadas.
"""
now = time.time()
before = len(self._entries)
self._entries = [
entry for entry in self._entries
if entry["timestamp"] + self.ttl_seconds >= now
]
removed = before - len(self._entries)
return removed
def hit_rate(self) -> dict:
"""Devuelve estadísticas de hit rate del caché.
Returns:
Diccionario con total_queries, hits, misses y hit_rate_percent.
"""
total = self._hits + self._misses
return {
"total_queries": total,
"hits": self._hits,
"misses": self._misses,
"hit_rate_percent": (self._hits / total * 100) if total > 0 else 0.0,
}
def rag_query(
query: str,
user: User,
vectorstore,
cache: SemanticCache,
llm_fn: Callable[[str], str],
embed_fn: Callable[[str], list[float]],
rerank_fn: Callable = None,
) -> str:
"""Ejecuta una query RAG con caché semántico y control de acceso.
Pipeline:
1. Buscar en caché
2. Si miss, recuperar chunks con control de acceso
3. Opcionalmente re-rankear
4. Generar respuesta con LLM
5. Guardar en caché
Args:
query: Texto de la consulta.
user: Usuario que realiza la consulta.
vectorstore: Colección de ChromaDB.
cache: Instancia de SemanticCache.
llm_fn: Función que recibe un prompt y devuelve texto.
embed_fn: Función de embedding.
rerank_fn: Función opcional de reranking.
Returns:
Respuesta generada.
"""
# 1. Intentar cache hit
cached = cache.get(query)
if cached is not None:
cache._hits += 1
return cached
# 2. Cache miss
cache._misses += 1
# 3. Recuperar chunks con control de acceso
results = retrieve_with_access(query, user, vectorstore)
documents = results.get("documents", [[]])[0]
metadatas = results.get("metadatas", [[]])[0]
if not documents:
response = llm_fn(f"No se encontraron documentos relevantes. Pregunta: {query}")
cache.put(query, response, [])
return response
# 4. Reranking opcional
if rerank_fn is not None:
ranked = rerank_fn(query, documents)
documents = ranked
# 5. Construir prompt y llamar al LLM
context = "\n\n---\n\n".join(documents)
prompt = (
f"Contexto:\n{context}\n\n"
f"Pregunta: {query}\n\n"
f"Responde basándote únicamente en el contexto proporcionado."
)
response = llm_fn(prompt)
# 6. Guardar en caché
doc_ids = list({m.get("doc_id", "") for m in metadatas if m.get("doc_id")})
cache.put(query, response, doc_ids)
return response
|