rag_template / src /cache.py
Guilherme Favaron
Major update: Add hybrid search, reranking, multiple LLMs, and UI improvements
1b447de
"""
Sistema de cache para embeddings e resultados
"""
import hashlib
import pickle
import time
from typing import Optional, Any, Dict
from pathlib import Path
import numpy as np
class EmbeddingCache:
"""Cache em memória para embeddings"""
def __init__(self, max_size: int = 1000, ttl_seconds: int = 3600):
"""
Inicializa cache de embeddings
Args:
max_size: Número máximo de itens no cache
ttl_seconds: Tempo de vida dos itens em segundos (0 = sem expiração)
"""
self.cache: Dict[str, Dict[str, Any]] = {}
self.max_size = max_size
self.ttl_seconds = ttl_seconds
self.hits = 0
self.misses = 0
def _generate_key(self, text: str, model_id: str) -> str:
"""
Gera chave de cache a partir do texto e modelo
Args:
text: Texto para gerar embedding
model_id: ID do modelo de embedding
Returns:
Hash único para o par (text, model_id)
"""
combined = f"{model_id}:{text}"
return hashlib.sha256(combined.encode()).hexdigest()
def get(self, text: str, model_id: str) -> Optional[np.ndarray]:
"""
Recupera embedding do cache
Args:
text: Texto do embedding
model_id: ID do modelo
Returns:
Embedding ou None se não encontrado/expirado
"""
key = self._generate_key(text, model_id)
if key not in self.cache:
self.misses += 1
return None
item = self.cache[key]
# Verifica TTL
if self.ttl_seconds > 0:
age = time.time() - item["timestamp"]
if age > self.ttl_seconds:
del self.cache[key]
self.misses += 1
return None
self.hits += 1
return item["embedding"]
def set(self, text: str, model_id: str, embedding: np.ndarray) -> None:
"""
Armazena embedding no cache
Args:
text: Texto do embedding
model_id: ID do modelo
embedding: Vetor de embedding
"""
# Se cache está cheio, remove item mais antigo (FIFO)
if len(self.cache) >= self.max_size:
oldest_key = next(iter(self.cache))
del self.cache[oldest_key]
key = self._generate_key(text, model_id)
self.cache[key] = {
"embedding": embedding,
"timestamp": time.time(),
"text_length": len(text)
}
def get_stats(self) -> Dict[str, Any]:
"""
Retorna estatísticas do cache
Returns:
Dicionário com métricas
"""
total_requests = self.hits + self.misses
hit_rate = (self.hits / total_requests * 100) if total_requests > 0 else 0
return {
"total_items": len(self.cache),
"max_size": self.max_size,
"hits": self.hits,
"misses": self.misses,
"hit_rate": hit_rate,
"ttl_seconds": self.ttl_seconds
}
def clear(self) -> None:
"""Limpa todo o cache"""
self.cache.clear()
self.hits = 0
self.misses = 0
def remove_expired(self) -> int:
"""
Remove itens expirados do cache
Returns:
Número de itens removidos
"""
if self.ttl_seconds == 0:
return 0
now = time.time()
expired_keys = [
key for key, item in self.cache.items()
if now - item["timestamp"] > self.ttl_seconds
]
for key in expired_keys:
del self.cache[key]
return len(expired_keys)
class DiskCache:
"""Cache persistente em disco para embeddings"""
def __init__(self, cache_dir: str = ".cache/embeddings"):
"""
Inicializa cache em disco
Args:
cache_dir: Diretório para armazenar cache
"""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _get_cache_path(self, text: str, model_id: str) -> Path:
"""
Gera caminho do arquivo de cache
Args:
text: Texto para gerar embedding
model_id: ID do modelo
Returns:
Caminho do arquivo
"""
combined = f"{model_id}:{text}"
hash_key = hashlib.sha256(combined.encode()).hexdigest()
return self.cache_dir / f"{hash_key}.pkl"
def get(self, text: str, model_id: str) -> Optional[np.ndarray]:
"""
Recupera embedding do disco
Args:
text: Texto do embedding
model_id: ID do modelo
Returns:
Embedding ou None se não encontrado
"""
cache_path = self._get_cache_path(text, model_id)
if not cache_path.exists():
return None
try:
with open(cache_path, 'rb') as f:
data = pickle.load(f)
return data["embedding"]
except Exception:
return None
def set(self, text: str, model_id: str, embedding: np.ndarray) -> None:
"""
Armazena embedding no disco
Args:
text: Texto do embedding
model_id: ID do modelo
embedding: Vetor de embedding
"""
cache_path = self._get_cache_path(text, model_id)
data = {
"embedding": embedding,
"timestamp": time.time(),
"model_id": model_id,
"text_length": len(text)
}
try:
with open(cache_path, 'wb') as f:
pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
except Exception:
pass # Falha silenciosa
def clear(self) -> int:
"""
Limpa todo o cache em disco
Returns:
Número de arquivos removidos
"""
count = 0
for cache_file in self.cache_dir.glob("*.pkl"):
try:
cache_file.unlink()
count += 1
except Exception:
pass
return count
def get_size(self) -> int:
"""
Retorna tamanho do cache em bytes
Returns:
Tamanho total em bytes
"""
total_size = 0
for cache_file in self.cache_dir.glob("*.pkl"):
try:
total_size += cache_file.stat().st_size
except Exception:
pass
return total_size
def get_stats(self) -> Dict[str, Any]:
"""
Retorna estatísticas do cache em disco
Returns:
Dicionário com métricas
"""
cache_files = list(self.cache_dir.glob("*.pkl"))
total_size = self.get_size()
return {
"total_files": len(cache_files),
"total_size_bytes": total_size,
"total_size_mb": total_size / (1024 * 1024),
"cache_dir": str(self.cache_dir)
}