flozi00's picture
refactor
a86cdfa
"""
Caching system for generated audio.
Supports local and Hugging Face Hub storage.
"""
import hashlib
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from loguru import logger
@dataclass
class CacheConfig:
"""Configuration for audio caching."""
enabled: bool = True
local_cache_dir: Optional[str] = None # Local cache directory
hf_repo_id: Optional[str] = None # Hugging Face Hub repo for remote cache
max_duration_seconds: float = 30.0 # Only cache audio shorter than this
class AudioCache:
"""
Cache for generated TTS audio.
Supports both local filesystem and Hugging Face Hub storage.
"""
def __init__(self, config: Optional[CacheConfig] = None):
self.config = config or CacheConfig()
self._hf_fs = None
def _get_cache_key(self, text: str, voice_id: str, backend: str) -> str:
"""Generate a unique cache key for the given parameters."""
content = f"{backend}:{voice_id}:{text}"
return hashlib.md5(content.encode()).hexdigest()
def _get_hf_fs(self):
"""Get HuggingFace filesystem (lazy initialization)."""
if self._hf_fs is None and self.config.hf_repo_id:
try:
from huggingface_hub import HfFileSystem
self._hf_fs = HfFileSystem(token=os.environ.get("HF_TOKEN"))
except Exception as e:
logger.warning(f"Could not initialize HF filesystem: {e}")
return self._hf_fs
def get(self, text: str, voice_id: str, backend: str) -> Optional[bytes]:
"""
Retrieve cached audio if it exists.
Args:
text: Original text that was synthesized
voice_id: Voice identifier used
backend: Backend name used for synthesis
Returns:
Cached audio bytes or None if not found
"""
if not self.config.enabled:
return None
cache_key = self._get_cache_key(text, voice_id, backend)
# Try local cache first
if self.config.local_cache_dir:
local_path = Path(self.config.local_cache_dir) / f"{cache_key}.mp3"
if local_path.exists():
logger.debug(f"Cache hit (local): {cache_key}")
return local_path.read_bytes()
# Try HF Hub cache
if self.config.hf_repo_id:
fs = self._get_hf_fs()
if fs:
hf_path = f"{self.config.hf_repo_id}/{voice_id}/{cache_key}.mp3"
try:
if fs.exists(hf_path):
with fs.open(hf_path, "rb") as f:
logger.debug(f"Cache hit (HF Hub): {cache_key}")
return f.read()
except Exception as e:
logger.debug(f"HF cache lookup failed: {e}")
return None
def set(
self,
text: str,
voice_id: str,
backend: str,
audio_data: bytes,
duration_seconds: Optional[float] = None,
) -> bool:
"""
Store audio in cache.
Args:
text: Original text that was synthesized
voice_id: Voice identifier used
backend: Backend name used for synthesis
audio_data: Audio bytes to cache
duration_seconds: Duration of the audio (for max duration check)
Returns:
True if cached successfully, False otherwise
"""
if not self.config.enabled:
return False
# Check duration limit
if duration_seconds and duration_seconds > self.config.max_duration_seconds:
logger.debug(
f"Audio too long to cache: {duration_seconds}s > {self.config.max_duration_seconds}s"
)
return False
cache_key = self._get_cache_key(text, voice_id, backend)
success = False
# Save to local cache
if self.config.local_cache_dir:
try:
cache_dir = Path(self.config.local_cache_dir)
cache_dir.mkdir(parents=True, exist_ok=True)
local_path = cache_dir / f"{cache_key}.mp3"
local_path.write_bytes(audio_data)
logger.debug(f"Cached locally: {cache_key}")
success = True
except Exception as e:
logger.warning(f"Failed to cache locally: {e}")
# Save to HF Hub
if self.config.hf_repo_id:
fs = self._get_hf_fs()
if fs:
try:
voice_dir = f"{self.config.hf_repo_id}/{voice_id}"
if not fs.exists(voice_dir):
fs.makedirs(voice_dir, exist_ok=True)
hf_path = f"{voice_dir}/{cache_key}.mp3"
with fs.open(hf_path, "wb") as f:
f.write(audio_data)
logger.debug(f"Cached to HF Hub: {cache_key}")
success = True
except Exception as e:
logger.warning(f"Failed to cache to HF Hub: {e}")
return success
def clear_local(self) -> int:
"""Clear local cache. Returns number of files deleted."""
if not self.config.local_cache_dir:
return 0
cache_dir = Path(self.config.local_cache_dir)
if not cache_dir.exists():
return 0
count = 0
for file in cache_dir.glob("*.mp3"):
file.unlink()
count += 1
logger.info(f"Cleared {count} files from local cache")
return count