Spaces:
Runtime error
Runtime error
| """ | |
| Avatar Cache System for DittoTalkingHead | |
| Implements image pre-upload and embedding caching | |
| """ | |
| import os | |
| import pickle | |
| import hashlib | |
| import time | |
| from typing import Optional, Dict, Any, Tuple | |
| from datetime import datetime, timedelta | |
| import json | |
| from pathlib import Path | |
| class AvatarCache: | |
| """ | |
| Avatar embedding cache system | |
| Stores pre-computed image embeddings for faster video generation | |
| """ | |
| def __init__(self, cache_dir: str = "/tmp/avatar_cache", ttl_days: int = 14): | |
| """ | |
| Initialize avatar cache | |
| Args: | |
| cache_dir: Directory to store cache files | |
| ttl_days: Time to live for cache entries in days | |
| """ | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self.ttl_seconds = ttl_days * 24 * 60 * 60 | |
| self.metadata_file = self.cache_dir / "metadata.json" | |
| # Load existing metadata | |
| self.metadata = self._load_metadata() | |
| # Clean expired entries on initialization | |
| self._cleanup_expired() | |
| def _load_metadata(self) -> Dict[str, Any]: | |
| """Load cache metadata""" | |
| if self.metadata_file.exists(): | |
| try: | |
| with open(self.metadata_file, 'r') as f: | |
| return json.load(f) | |
| except: | |
| return {} | |
| return {} | |
| def _save_metadata(self): | |
| """Save cache metadata""" | |
| with open(self.metadata_file, 'w') as f: | |
| json.dump(self.metadata, f, indent=2) | |
| def _cleanup_expired(self): | |
| """Remove expired cache entries""" | |
| current_time = time.time() | |
| expired_tokens = [] | |
| for token, info in self.metadata.items(): | |
| if current_time > info['expires_at']: | |
| expired_tokens.append(token) | |
| cache_file = self.cache_dir / f"{token}.pkl" | |
| if cache_file.exists(): | |
| cache_file.unlink() | |
| for token in expired_tokens: | |
| del self.metadata[token] | |
| if expired_tokens: | |
| self._save_metadata() | |
| print(f"Cleaned up {len(expired_tokens)} expired cache entries") | |
| def generate_token(self, img_bytes: bytes) -> str: | |
| """ | |
| Generate unique token for image | |
| Args: | |
| img_bytes: Image data as bytes | |
| Returns: | |
| SHA-1 hash token | |
| """ | |
| return hashlib.sha1(img_bytes).hexdigest() | |
| def store_embedding( | |
| self, | |
| img_bytes: bytes, | |
| embedding: Any, | |
| additional_info: Optional[Dict[str, Any]] = None | |
| ) -> Tuple[str, datetime]: | |
| """ | |
| Store image embedding in cache | |
| Args: | |
| img_bytes: Image data as bytes | |
| embedding: Pre-computed embedding (latent vector) | |
| additional_info: Additional metadata to store | |
| Returns: | |
| Tuple of (token, expiration_date) | |
| """ | |
| token = self.generate_token(img_bytes) | |
| cache_file = self.cache_dir / f"{token}.pkl" | |
| # Calculate expiration | |
| expires_at = time.time() + self.ttl_seconds | |
| expiration_date = datetime.fromtimestamp(expires_at) | |
| # Save embedding | |
| cache_data = { | |
| 'embedding': embedding, | |
| 'created_at': time.time(), | |
| 'expires_at': expires_at, | |
| 'additional_info': additional_info or {} | |
| } | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(cache_data, f) | |
| # Update metadata | |
| self.metadata[token] = { | |
| 'expires_at': expires_at, | |
| 'created_at': time.time(), | |
| 'file_size': os.path.getsize(cache_file) | |
| } | |
| self._save_metadata() | |
| return token, expiration_date | |
| def load_embedding(self, token: str) -> Optional[Any]: | |
| """ | |
| Load embedding from cache | |
| Args: | |
| token: Avatar token | |
| Returns: | |
| Embedding if found and valid, None otherwise | |
| """ | |
| # Check if token exists and not expired | |
| if token not in self.metadata: | |
| return None | |
| if time.time() > self.metadata[token]['expires_at']: | |
| # Token expired | |
| self._cleanup_expired() | |
| return None | |
| # Load from file | |
| cache_file = self.cache_dir / f"{token}.pkl" | |
| if not cache_file.exists(): | |
| # File missing, clean up metadata | |
| del self.metadata[token] | |
| self._save_metadata() | |
| return None | |
| try: | |
| with open(cache_file, 'rb') as f: | |
| cache_data = pickle.load(f) | |
| return cache_data['embedding'] | |
| except Exception as e: | |
| print(f"Error loading cache for token {token}: {e}") | |
| return None | |
| def get_cache_info(self) -> Dict[str, Any]: | |
| """ | |
| Get cache statistics | |
| Returns: | |
| Cache information | |
| """ | |
| total_size = 0 | |
| active_entries = 0 | |
| for token, info in self.metadata.items(): | |
| if time.time() <= info['expires_at']: | |
| active_entries += 1 | |
| total_size += info.get('file_size', 0) | |
| return { | |
| 'cache_dir': str(self.cache_dir), | |
| 'active_entries': active_entries, | |
| 'total_entries': len(self.metadata), | |
| 'total_size_mb': total_size / (1024 * 1024), | |
| 'ttl_days': self.ttl_seconds / (24 * 60 * 60) | |
| } | |
| def clear_cache(self): | |
| """Clear all cache entries""" | |
| for file in self.cache_dir.glob("*.pkl"): | |
| file.unlink() | |
| self.metadata = {} | |
| self._save_metadata() | |
| print("Avatar cache cleared") | |
| class AvatarTokenManager: | |
| """ | |
| Manages avatar tokens and their lifecycle | |
| """ | |
| def __init__(self, cache: AvatarCache): | |
| """ | |
| Initialize token manager | |
| Args: | |
| cache: Avatar cache instance | |
| """ | |
| self.cache = cache | |
| def prepare_avatar( | |
| self, | |
| image_data: bytes, | |
| appearance_encoder_func: callable, | |
| **encoder_kwargs | |
| ) -> Dict[str, Any]: | |
| """ | |
| Prepare avatar by pre-computing embedding | |
| Args: | |
| image_data: Image data as bytes | |
| appearance_encoder_func: Function to encode appearance | |
| **encoder_kwargs: Additional arguments for encoder | |
| Returns: | |
| Response with avatar token and expiration | |
| """ | |
| # Check if already cached | |
| token = self.cache.generate_token(image_data) | |
| existing_embedding = self.cache.load_embedding(token) | |
| if existing_embedding is not None: | |
| # Already cached, return existing token | |
| metadata = self.cache.metadata.get(token, {}) | |
| expires_at = datetime.fromtimestamp(metadata.get('expires_at', 0)) | |
| return { | |
| 'avatar_token': token, | |
| 'expires': expires_at.isoformat(), | |
| 'cached': True | |
| } | |
| # Compute new embedding | |
| try: | |
| embedding = appearance_encoder_func(image_data, **encoder_kwargs) | |
| # Store in cache | |
| token, expiration = self.cache.store_embedding( | |
| image_data, | |
| embedding, | |
| additional_info={'encoder_kwargs': encoder_kwargs} | |
| ) | |
| return { | |
| 'avatar_token': token, | |
| 'expires': expiration.isoformat(), | |
| 'cached': False | |
| } | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to prepare avatar: {str(e)}") | |
| def validate_token(self, token: str) -> bool: | |
| """ | |
| Validate if token is valid and not expired | |
| Args: | |
| token: Avatar token to validate | |
| Returns: | |
| True if valid, False otherwise | |
| """ | |
| return self.cache.load_embedding(token) is not None | |
| def get_token_info(self, token: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| Get information about a token | |
| Args: | |
| token: Avatar token | |
| Returns: | |
| Token information if found, None otherwise | |
| """ | |
| if token not in self.cache.metadata: | |
| return None | |
| info = self.cache.metadata[token] | |
| current_time = time.time() | |
| return { | |
| 'token': token, | |
| 'valid': current_time <= info['expires_at'], | |
| 'created_at': datetime.fromtimestamp(info['created_at']).isoformat(), | |
| 'expires_at': datetime.fromtimestamp(info['expires_at']).isoformat(), | |
| 'file_size_kb': info.get('file_size', 0) / 1024 | |
| } |