Mirrowel
refactor(cache): πŸ”¨ decouple memory cleanup from disk persistence lifecycle
9bfc01f
# src/rotator_library/providers/provider_cache.py
"""
Shared cache utility for providers.
A modular, async-capable cache system supporting:
- Dual-TTL: short-lived memory cache, longer-lived disk persistence
- Background persistence with batched writes
- Automatic cleanup of expired entries
- Generic key-value storage for any provider-specific needs
Usage examples:
- Gemini 3: thoughtSignatures (tool_call_id β†’ encrypted signature)
- Claude: Thinking content (composite_key β†’ thinking text + signature)
- General: Any transient data that benefits from persistence across requests
"""
from __future__ import annotations
import asyncio
import json
import logging
import os
import time
from pathlib import Path
from typing import Any, Dict, Optional, Tuple
from ..utils.resilient_io import safe_write_json
lib_logger = logging.getLogger("rotator_library")
# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================
def _env_bool(key: str, default: bool = False) -> bool:
"""Get boolean from environment variable."""
return os.getenv(key, str(default).lower()).lower() in ("true", "1", "yes")
def _env_int(key: str, default: int) -> int:
"""Get integer from environment variable."""
return int(os.getenv(key, str(default)))
# =============================================================================
# PROVIDER CACHE CLASS
# =============================================================================
class ProviderCache:
"""
Server-side cache for provider conversation state preservation.
A generic, modular cache supporting any key-value data that providers need
to persist across requests. Features:
- Dual-TTL system: entries live in memory for memory_ttl, but persist on
disk for the longer disk_ttl. Memory cleanup does NOT affect disk entries.
- Merge-on-save: disk writes merge current memory with existing disk entries,
preserving disk-only entries until they exceed disk_ttl
- Async disk persistence with batched writes
- Background cleanup task for memory-expired entries (disk untouched)
- Statistics tracking (hits, misses, writes, disk preservation)
Args:
cache_file: Path to disk cache file
memory_ttl_seconds: In-memory entry lifetime (default: 1 hour)
disk_ttl_seconds: Disk entry lifetime (default: 48 hours)
enable_disk: Whether to enable disk persistence (default: from env or True)
write_interval: Seconds between background disk writes (default: 60)
cleanup_interval: Seconds between expired entry cleanup (default: 30 min)
env_prefix: Environment variable prefix for configuration overrides
Environment Variables (with default prefix "PROVIDER_CACHE"):
{PREFIX}_ENABLE: Enable/disable disk persistence
{PREFIX}_WRITE_INTERVAL: Background write interval in seconds
{PREFIX}_CLEANUP_INTERVAL: Cleanup interval in seconds
"""
def __init__(
self,
cache_file: Path,
memory_ttl_seconds: int = 3600,
disk_ttl_seconds: int = 172800, # 48 hours
enable_disk: Optional[bool] = None,
write_interval: Optional[int] = None,
cleanup_interval: Optional[int] = None,
env_prefix: str = "PROVIDER_CACHE",
):
# In-memory cache: {cache_key: (data, timestamp)}
self._cache: Dict[str, Tuple[str, float]] = {}
self._memory_ttl = memory_ttl_seconds
self._disk_ttl = disk_ttl_seconds
self._lock = asyncio.Lock()
self._disk_lock = asyncio.Lock()
# Disk persistence configuration
self._cache_file = cache_file
self._enable_disk = (
enable_disk
if enable_disk is not None
else _env_bool(f"{env_prefix}_ENABLE", True)
)
self._dirty = False
self._write_interval = write_interval or _env_int(
f"{env_prefix}_WRITE_INTERVAL", 60
)
self._cleanup_interval = cleanup_interval or _env_int(
f"{env_prefix}_CLEANUP_INTERVAL", 1800
)
# Background tasks
self._writer_task: Optional[asyncio.Task] = None
self._cleanup_task: Optional[asyncio.Task] = None
self._running = False
# Statistics
self._stats = {
"memory_hits": 0,
"disk_hits": 0,
"misses": 0,
"writes": 0,
"disk_errors": 0,
}
# Track disk health for monitoring
self._disk_available = True
# Metadata about this cache instance
self._cache_name = cache_file.stem if cache_file else "unnamed"
if self._enable_disk:
lib_logger.debug(
f"ProviderCache[{self._cache_name}]: Disk enabled "
f"(memory_ttl={memory_ttl_seconds}s, disk_ttl={disk_ttl_seconds}s)"
)
asyncio.create_task(self._async_init())
else:
lib_logger.debug(f"ProviderCache[{self._cache_name}]: Memory-only mode")
# =========================================================================
# INITIALIZATION
# =========================================================================
async def _async_init(self) -> None:
"""Async initialization: load from disk and start background tasks."""
try:
await self._load_from_disk()
await self._start_background_tasks()
except Exception as e:
lib_logger.error(
f"ProviderCache[{self._cache_name}] async init failed: {e}"
)
async def _load_from_disk(self) -> None:
"""Load cache from disk file with TTL validation."""
if not self._enable_disk or not self._cache_file.exists():
return
try:
async with self._disk_lock:
with open(self._cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
if data.get("version") != "1.0":
lib_logger.warning(
f"ProviderCache[{self._cache_name}]: Version mismatch, starting fresh"
)
return
now = time.time()
entries = data.get("entries", {})
loaded = expired = 0
for cache_key, entry in entries.items():
age = now - entry.get("timestamp", 0)
if age <= self._disk_ttl:
value = entry.get(
"value", entry.get("signature", "")
) # Support both formats
if value:
self._cache[cache_key] = (value, entry["timestamp"])
loaded += 1
else:
expired += 1
lib_logger.debug(
f"ProviderCache[{self._cache_name}]: Loaded {loaded} entries ({expired} expired)"
)
except json.JSONDecodeError as e:
lib_logger.warning(
f"ProviderCache[{self._cache_name}]: File corrupted: {e}"
)
except Exception as e:
lib_logger.error(f"ProviderCache[{self._cache_name}]: Load failed: {e}")
# =========================================================================
# DISK PERSISTENCE
# =========================================================================
async def _save_to_disk(self) -> bool:
"""Persist cache to disk using atomic write with health tracking.
Implements dual-TTL preservation: merges current memory state with
existing disk entries that haven't exceeded disk_ttl. This ensures
entries persist on disk for the full disk_ttl even after they expire
from memory (which uses the shorter memory_ttl).
Returns:
True if write succeeded, False otherwise.
"""
if not self._enable_disk:
return True # Not an error if disk is disabled
async with self._disk_lock:
now = time.time()
# Step 1: Load existing disk entries (if any)
existing_entries: Dict[str, Dict[str, Any]] = {}
if self._cache_file.exists():
try:
with open(self._cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
existing_entries = data.get("entries", {})
except (json.JSONDecodeError, IOError, OSError):
pass # Start fresh if corrupted or unreadable
# Step 2: Filter existing disk entries by disk_ttl (not memory_ttl)
# This preserves entries that expired from memory but are still valid on disk
valid_disk_entries = {
k: v
for k, v in existing_entries.items()
if now - v.get("timestamp", 0) <= self._disk_ttl
}
# Step 3: Merge - memory entries take precedence (fresher timestamps)
merged_entries = valid_disk_entries.copy()
for key, (val, ts) in self._cache.items():
merged_entries[key] = {"value": val, "timestamp": ts}
# Count entries that were preserved from disk (not in memory)
memory_keys = set(self._cache.keys())
preserved_from_disk = len(
[k for k in valid_disk_entries if k not in memory_keys]
)
# Step 4: Build and save merged cache data
cache_data = {
"version": "1.0",
"memory_ttl_seconds": self._memory_ttl,
"disk_ttl_seconds": self._disk_ttl,
"entries": merged_entries,
"statistics": {
"total_entries": len(merged_entries),
"memory_entries": len(self._cache),
"disk_preserved": preserved_from_disk,
"last_write": now,
**self._stats,
},
}
if safe_write_json(
self._cache_file, cache_data, lib_logger, secure_permissions=True
):
self._stats["writes"] += 1
self._disk_available = True
# Log merge info only when we preserved disk-only entries (infrequent)
if preserved_from_disk > 0:
lib_logger.debug(
f"ProviderCache[{self._cache_name}]: Saved {len(merged_entries)} entries "
f"(memory={len(self._cache)}, preserved_from_disk={preserved_from_disk})"
)
return True
else:
self._stats["disk_errors"] += 1
self._disk_available = False
return False
# =========================================================================
# BACKGROUND TASKS
# =========================================================================
async def _start_background_tasks(self) -> None:
"""Start background writer and cleanup tasks."""
if not self._enable_disk or self._running:
return
self._running = True
self._writer_task = asyncio.create_task(self._writer_loop())
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
lib_logger.debug(f"ProviderCache[{self._cache_name}]: Started background tasks")
async def _writer_loop(self) -> None:
"""Background task: periodically flush dirty cache to disk."""
try:
while self._running:
await asyncio.sleep(self._write_interval)
if self._dirty:
try:
success = await self._save_to_disk()
if success:
self._dirty = False
# If save failed, _dirty remains True so we retry next interval
except Exception as e:
lib_logger.error(
f"ProviderCache[{self._cache_name}]: Writer error: {e}"
)
except asyncio.CancelledError:
pass
async def _cleanup_loop(self) -> None:
"""Background task: periodically clean up expired entries."""
try:
while self._running:
await asyncio.sleep(self._cleanup_interval)
await self._cleanup_expired()
except asyncio.CancelledError:
pass
async def _cleanup_expired(self) -> None:
"""Remove expired entries from memory cache.
Only cleans memory - disk entries are preserved and cleaned during
_save_to_disk() based on their own disk_ttl.
"""
async with self._lock:
now = time.time()
expired = [
k for k, (_, ts) in self._cache.items() if now - ts > self._memory_ttl
]
for k in expired:
del self._cache[k]
# Don't set dirty flag: memory cleanup shouldn't trigger disk write
# Disk entries are cleaned separately in _save_to_disk() by disk_ttl
if expired:
lib_logger.debug(
f"ProviderCache[{self._cache_name}]: Cleaned {len(expired)} expired entries from memory"
)
# =========================================================================
# CORE OPERATIONS
# =========================================================================
def store(self, key: str, value: str) -> None:
"""
Store a value synchronously (schedules async storage).
Args:
key: Cache key
value: Value to store (typically JSON-serialized data)
"""
asyncio.create_task(self._async_store(key, value))
async def _async_store(self, key: str, value: str) -> None:
"""Async implementation of store."""
async with self._lock:
self._cache[key] = (value, time.time())
self._dirty = True
async def store_async(self, key: str, value: str) -> None:
"""
Store a value asynchronously (awaitable).
Use this when you need to ensure the value is stored before continuing.
"""
await self._async_store(key, value)
def retrieve(self, key: str) -> Optional[str]:
"""
Retrieve a value by key (synchronous, with optional async disk fallback).
Args:
key: Cache key
Returns:
Cached value if found and not expired, None otherwise
"""
if key in self._cache:
value, timestamp = self._cache[key]
if time.time() - timestamp <= self._memory_ttl:
self._stats["memory_hits"] += 1
return value
else:
# Entry expired from memory - remove from memory only
# Don't set dirty flag: disk copy should persist until disk_ttl
del self._cache[key]
self._stats["misses"] += 1
if self._enable_disk:
# Schedule async disk lookup for next time
asyncio.create_task(self._check_disk_fallback(key))
return None
async def retrieve_async(self, key: str) -> Optional[str]:
"""
Retrieve a value asynchronously (checks disk if not in memory).
Use this when you can await and need guaranteed disk fallback.
"""
# Check memory first
if key in self._cache:
value, timestamp = self._cache[key]
if time.time() - timestamp <= self._memory_ttl:
self._stats["memory_hits"] += 1
return value
else:
# Entry expired from memory - remove from memory only
# Don't set dirty flag: disk copy should persist until disk_ttl
async with self._lock:
if key in self._cache:
del self._cache[key]
# Check disk
if self._enable_disk:
return await self._disk_retrieve(key)
self._stats["misses"] += 1
return None
async def _check_disk_fallback(self, key: str) -> None:
"""Check disk for key and load into memory if found (background)."""
try:
if not self._cache_file.exists():
return
async with self._disk_lock:
with open(self._cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
entries = data.get("entries", {})
if key in entries:
entry = entries[key]
ts = entry.get("timestamp", 0)
if time.time() - ts <= self._disk_ttl:
value = entry.get("value", entry.get("signature", ""))
if value:
async with self._lock:
self._cache[key] = (value, ts)
self._stats["disk_hits"] += 1
lib_logger.debug(
f"ProviderCache[{self._cache_name}]: Loaded {key} from disk"
)
except Exception as e:
lib_logger.debug(
f"ProviderCache[{self._cache_name}]: Disk fallback failed: {e}"
)
async def _disk_retrieve(self, key: str) -> Optional[str]:
"""Direct disk retrieval with loading into memory."""
try:
if not self._cache_file.exists():
self._stats["misses"] += 1
return None
async with self._disk_lock:
with open(self._cache_file, "r", encoding="utf-8") as f:
data = json.load(f)
entries = data.get("entries", {})
if key in entries:
entry = entries[key]
ts = entry.get("timestamp", 0)
if time.time() - ts <= self._disk_ttl:
value = entry.get("value", entry.get("signature", ""))
if value:
async with self._lock:
self._cache[key] = (value, ts)
self._stats["disk_hits"] += 1
return value
self._stats["misses"] += 1
return None
except Exception as e:
lib_logger.debug(
f"ProviderCache[{self._cache_name}]: Disk retrieve failed: {e}"
)
self._stats["misses"] += 1
return None
# =========================================================================
# UTILITY METHODS
# =========================================================================
def contains(self, key: str) -> bool:
"""Check if key exists in memory cache (without updating stats)."""
if key in self._cache:
_, timestamp = self._cache[key]
return time.time() - timestamp <= self._memory_ttl
return False
def get_stats(self) -> Dict[str, Any]:
"""Get cache statistics including disk health."""
return {
**self._stats,
"memory_entries": len(self._cache),
"dirty": self._dirty,
"disk_enabled": self._enable_disk,
"disk_available": self._disk_available,
}
async def clear(self) -> None:
"""Clear all cached data."""
async with self._lock:
self._cache.clear()
self._dirty = True
if self._enable_disk:
await self._save_to_disk()
async def shutdown(self) -> None:
"""Graceful shutdown: flush pending writes and stop background tasks."""
lib_logger.info(f"ProviderCache[{self._cache_name}]: Shutting down...")
self._running = False
# Cancel background tasks
for task in (self._writer_task, self._cleanup_task):
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
# Final save
if self._dirty and self._enable_disk:
await self._save_to_disk()
lib_logger.info(
f"ProviderCache[{self._cache_name}]: Shutdown complete "
f"(stats: mem_hits={self._stats['memory_hits']}, "
f"disk_hits={self._stats['disk_hits']}, misses={self._stats['misses']})"
)
# =============================================================================
# CONVENIENCE FACTORY
# =============================================================================
def create_provider_cache(
name: str,
cache_dir: Optional[Path] = None,
memory_ttl_seconds: int = 3600,
disk_ttl_seconds: int = 172800, # 48 hours
env_prefix: Optional[str] = None,
) -> ProviderCache:
"""
Factory function to create a provider cache with sensible defaults.
Args:
name: Cache name (used as filename and for logging)
cache_dir: Directory for cache file (default: project_root/cache/provider_name)
memory_ttl_seconds: In-memory TTL
disk_ttl_seconds: Disk TTL
env_prefix: Environment variable prefix (default: derived from name)
Returns:
Configured ProviderCache instance
"""
if cache_dir is None:
cache_dir = Path(__file__).resolve().parent.parent.parent.parent / "cache"
cache_file = cache_dir / f"{name}.json"
if env_prefix is None:
# Convert name to env prefix: "gemini3_signatures" -> "GEMINI3_SIGNATURES_CACHE"
env_prefix = f"{name.upper().replace('-', '_')}_CACHE"
return ProviderCache(
cache_file=cache_file,
memory_ttl_seconds=memory_ttl_seconds,
disk_ttl_seconds=disk_ttl_seconds,
env_prefix=env_prefix,
)