Spaces:
No application file
No application file
| """ | |
| Robust cache key generation for financial ML data structures. | |
| Handles numpy arrays, pandas DataFrames, and time-series data properly. | |
| """ | |
| import hashlib | |
| import pickle | |
| from pathlib import Path | |
| from typing import Any, Optional, Tuple | |
| import numpy as np | |
| import pandas as pd | |
| from loguru import logger | |
| class CacheKeyGenerator: | |
| """Generate robust, collision-resistant cache keys for ML data structures.""" | |
| def generate_key(func, args: tuple, kwargs: dict) -> str: | |
| """ | |
| Generate a robust cache key for a function call. | |
| Args: | |
| func: The function being cached | |
| args: Positional arguments | |
| kwargs: Keyword arguments | |
| Returns: | |
| MD5 hash string representing the unique call signature | |
| """ | |
| key_parts = [ | |
| func.__module__, | |
| func.__qualname__, | |
| ] | |
| # Process positional arguments | |
| for i, arg in enumerate(args): | |
| try: | |
| key_part = CacheKeyGenerator._hash_argument(arg, f"arg_{i}") | |
| key_parts.append(key_part) | |
| except Exception as e: | |
| logger.warning(f"Failed to hash argument {i} of type {type(arg)}: {e}") | |
| # Fallback to string representation | |
| key_parts.append(f"arg_{i}_{str(hash(str(arg)))}") | |
| # Process keyword arguments (sorted for consistency) | |
| for key, value in sorted(kwargs.items()): | |
| try: | |
| key_part = CacheKeyGenerator._hash_argument(value, key) | |
| key_parts.append(f"{key}={key_part}") | |
| except Exception as e: | |
| logger.warning( | |
| f"Failed to hash kwarg '{key}' of type {type(value)}: {e}" | |
| ) | |
| # Fallback | |
| key_parts.append(f"{key}={str(hash(str(value)))}") | |
| # Combine all parts and hash | |
| combined = "_".join(key_parts) | |
| return hashlib.md5(combined.encode()).hexdigest() | |
| def _hash_argument(arg: Any, name: str) -> str: | |
| """Hash a single argument based on its type.""" | |
| try: | |
| from sklearn.base import BaseEstimator | |
| if isinstance(arg, BaseEstimator): | |
| return CacheKeyGenerator._hash_sklearn_estimator(arg, name) | |
| except ImportError: | |
| pass # sklearn not available, continue with other types | |
| if isinstance(arg, np.ndarray): | |
| return CacheKeyGenerator._hash_numpy_array(arg, name) | |
| elif isinstance(arg, pd.DataFrame): | |
| return CacheKeyGenerator._hash_dataframe(arg, name) | |
| elif isinstance(arg, pd.Series): | |
| return CacheKeyGenerator._hash_series(arg, name) | |
| elif isinstance(arg, (list, tuple)): | |
| return CacheKeyGenerator._hash_sequence(arg, name) | |
| elif isinstance(arg, dict): | |
| return CacheKeyGenerator._hash_dict(arg, name) | |
| elif isinstance(arg, (int, float, str, bool, type(None))): | |
| return CacheKeyGenerator._hash_primitive(arg, name) | |
| else: | |
| # Fallback for unknown types | |
| return CacheKeyGenerator._hash_generic(arg, name) | |
| def _hash_numpy_array(arr: np.ndarray, name: str) -> str: | |
| """Hash numpy array including shape, dtype, and content.""" | |
| # For large arrays, sample for performance | |
| if arr.size > 10000: | |
| # Hash shape, dtype, and a sample | |
| sample = arr.flat[:: max(1, arr.size // 1000)] # Sample ~1000 points | |
| content_hash = hashlib.md5(sample.tobytes()).hexdigest()[:8] | |
| else: | |
| # Hash full content for small arrays | |
| content_hash = hashlib.md5(arr.tobytes()).hexdigest()[:8] | |
| return f"{name}_arr_{arr.shape}_{arr.dtype}_{content_hash}" | |
| def _hash_dataframe(df: pd.DataFrame, name: str) -> str: | |
| """Hash pandas DataFrame including index, columns, dtypes, and content.""" | |
| parts = [ | |
| f"shape_{df.shape}", | |
| f"cols_{hashlib.md5(str(tuple(df.columns)).encode()).hexdigest()[:8]}", | |
| f"dtypes_{hashlib.md5(str(tuple(df.dtypes)).encode()).hexdigest()[:8]}", | |
| ] | |
| # Hash index | |
| if isinstance(df.index, pd.DatetimeIndex): | |
| # For datetime index, hash start, end, and frequency | |
| parts.append(f"idx_dt_{df.index[0]}_{df.index[-1]}_{len(df.index)}") | |
| else: | |
| idx_hash = hashlib.md5(str(tuple(df.index)).encode()).hexdigest()[:8] | |
| parts.append(f"idx_{idx_hash}") | |
| # Hash content (sample for large DataFrames) | |
| if df.size > 10000: | |
| # Sample rows for hashing | |
| sample_rows = df.iloc[:: max(1, len(df) // 100)] # ~100 rows | |
| content_hash = hashlib.md5(sample_rows.values.tobytes()).hexdigest()[:8] | |
| else: | |
| content_hash = hashlib.md5(df.values.tobytes()).hexdigest()[:8] | |
| parts.append(f"data_{content_hash}") | |
| return f"{name}_df_{'_'.join(parts)}" | |
| def _hash_series(series: pd.Series, name: str) -> str: | |
| """Hash pandas Series.""" | |
| parts = [ | |
| f"len_{len(series)}", | |
| f"dtype_{series.dtype}", | |
| ] | |
| # Hash index | |
| if isinstance(series.index, pd.DatetimeIndex): | |
| parts.append(f"idx_dt_{series.index[0]}_{series.index[-1]}") | |
| else: | |
| idx_hash = hashlib.md5(str(tuple(series.index)).encode()).hexdigest()[:8] | |
| parts.append(f"idx_{idx_hash}") | |
| # Hash values | |
| if len(series) > 1000: | |
| sample = series.iloc[:: max(1, len(series) // 100)] | |
| content_hash = hashlib.md5(sample.values.tobytes()).hexdigest()[:8] | |
| else: | |
| content_hash = hashlib.md5(series.values.tobytes()).hexdigest()[:8] | |
| parts.append(f"data_{content_hash}") | |
| return f"{name}_series_{'_'.join(parts)}" | |
| def _hash_sequence(seq: Tuple[Any, ...] | list, name: str) -> str: | |
| """Hash list or tuple recursively.""" | |
| if len(seq) == 0: | |
| return f"{name}_empty_seq" | |
| # Hash each element | |
| element_hashes = [] | |
| for i, item in enumerate(seq): | |
| elem_hash = CacheKeyGenerator._hash_argument(item, f"{name}_{i}") | |
| element_hashes.append(elem_hash) | |
| combined = "_".join(element_hashes) | |
| return hashlib.md5(combined.encode()).hexdigest()[:8] | |
| def _hash_dict(d: dict, name: str) -> str: | |
| """Hash dictionary recursively.""" | |
| if len(d) == 0: | |
| return f"{name}_empty_dict" | |
| # Sort keys for consistency | |
| items_hash = [] | |
| for key, value in sorted(d.items()): | |
| val_hash = CacheKeyGenerator._hash_argument(value, f"{name}_{key}") | |
| items_hash.append(f"{key}={val_hash}") | |
| combined = "_".join(items_hash) | |
| return hashlib.md5(combined.encode()).hexdigest()[:8] | |
| def _hash_primitive(value: Any, name: str) -> str: | |
| """Hash primitive types.""" | |
| return f"{name}_{type(value).__name__}_{hash(value)}" | |
| def _hash_generic(obj: Any, name: str) -> str: | |
| """Fallback hashing for unknown types.""" | |
| try: | |
| # Try to use object's __repr__ | |
| return f"{name}_{type(obj).__name__}_{hash(repr(obj))}" | |
| except Exception: | |
| # Last resort: use id | |
| return f"{name}_{type(obj).__name__}_{id(obj)}" | |
| def _hash_sklearn_estimator(estimator: Any, name: str) -> str: | |
| """Hash sklearn estimator including nested estimators.""" | |
| try: | |
| from sklearn.base import BaseEstimator | |
| if not isinstance(estimator, BaseEstimator): | |
| return CacheKeyGenerator._hash_generic(estimator, name) | |
| # Use the enhanced estimator hashing from cv_cache | |
| from .cv_cache import _hash_classifier | |
| estimator_hash = _hash_classifier(estimator) | |
| return f"{name}_estimator_{estimator_hash}" | |
| except ImportError: | |
| # Fallback if sklearn not available | |
| return CacheKeyGenerator._hash_generic(estimator, name) | |
| class TimeSeriesCacheKey(CacheKeyGenerator): | |
| """ | |
| Extended cache key generator with time-series awareness. | |
| Useful for financial data where lookback periods matter. | |
| """ | |
| def generate_key_with_time_range( | |
| func, | |
| args: tuple, | |
| kwargs: dict, | |
| time_range: Tuple[pd.Timestamp, pd.Timestamp] = None, | |
| ) -> str: | |
| """ | |
| Generate cache key that includes time range information. | |
| Args: | |
| func: Function being cached | |
| args: Positional arguments | |
| kwargs: Keyword arguments | |
| time_range: Optional (start, end) timestamp tuple | |
| Returns: | |
| Cache key string | |
| """ | |
| base_key = CacheKeyGenerator.generate_key(func, args, kwargs) | |
| if time_range is None: | |
| # Try to extract time range from data | |
| time_range = TimeSeriesCacheKey._extract_time_range(args, kwargs) | |
| if time_range: | |
| start, end = time_range | |
| time_hash = f"time_{start}_{end}" | |
| return f"{base_key}_{time_hash}" | |
| return base_key | |
| def _extract_time_range( | |
| args: tuple, kwargs: dict | |
| ) -> Tuple[pd.Timestamp, pd.Timestamp] | None: | |
| """ | |
| Attempt to extract time range from function arguments. | |
| Looks for DataFrames with DatetimeIndex or explicit start/end parameters. | |
| """ | |
| # Check kwargs for explicit time parameters | |
| if "start_date" in kwargs and "end_date" in kwargs: | |
| return ( | |
| pd.Timestamp(kwargs["start_date"]), | |
| pd.Timestamp(kwargs["end_date"]), | |
| ) | |
| # Check for DataFrames with DatetimeIndex in args | |
| for arg in args: | |
| if isinstance(arg, pd.DataFrame) and isinstance( | |
| arg.index, pd.DatetimeIndex | |
| ): | |
| if len(arg.index) > 0: | |
| return (arg.index[0], arg.index[-1]) | |
| elif isinstance(arg, pd.Series) and isinstance(arg.index, pd.DatetimeIndex): | |
| if len(arg.index) > 0: | |
| return (arg.index[0], arg.index[-1]) | |
| return None | |
| # ============================================================================= | |
| # Integration with existing cacheable decorator | |
| # ============================================================================= | |
| def create_robust_cacheable( | |
| track_data_access: bool = False, | |
| dataset_name: Optional[str] = None, | |
| purpose: Optional[str] = None, | |
| use_time_awareness: bool = False, | |
| ): | |
| """ | |
| Factory function to create robust cacheable decorators with data tracking. | |
| Args: | |
| track_data_access: Whether to track DataFrame accesses | |
| dataset_name: Name of the dataset for tracking | |
| purpose: One of 'train', 'test', 'validate', 'optimize', 'analyze' | |
| use_time_awareness: Whether to use time-series aware cache keys | |
| Returns: | |
| Decorator function | |
| """ | |
| import time | |
| from functools import wraps | |
| from . import cache_stats, memory | |
| from .cache_monitoring import get_cache_monitor | |
| def decorator(func): | |
| func_name = f"{func.__module__}.{func.__qualname__}" | |
| cached_func = memory.cache(func) | |
| seen_signatures = set() | |
| monitor = get_cache_monitor() | |
| def wrapper(*args, **kwargs): | |
| nonlocal seen_signatures | |
| # Track access time (ALWAYS do this first) | |
| monitor.track_access(func_name) | |
| # Generate cache key | |
| cache_key = None | |
| is_hit = False | |
| computation_start = None | |
| try: | |
| if use_time_awareness: | |
| cache_key = TimeSeriesCacheKey.generate_key_with_time_range( | |
| func, args, kwargs | |
| ) | |
| else: | |
| cache_key = CacheKeyGenerator.generate_key(func, args, kwargs) | |
| # Track hit/miss | |
| try: | |
| cached_func.check_call_in_cache(*args, **kwargs) | |
| is_hit = True | |
| cache_stats.record_hit(func_name) | |
| logger.debug(f"Cache HIT for {func_name}") | |
| except: | |
| cache_stats.record_miss(func_name) | |
| is_hit = False | |
| computation_start = time.time() # Start timing for misses | |
| logger.debug(f"Cache MISS for {func_name}") | |
| # Add to seen_signatures for this session | |
| seen_signatures.add(cache_key) | |
| except Exception as e: | |
| logger.warning(f"Cache key generation failed for {func_name}: {e}") | |
| cache_stats.record_miss(func_name) | |
| cache_key = None | |
| is_hit = False | |
| computation_start = time.time() # Start timing for error case | |
| # Track data access if requested | |
| if track_data_access: | |
| try: | |
| from .data_access_tracker import get_data_tracker | |
| _track_dataframe_access( | |
| get_data_tracker(), args, kwargs, dataset_name, purpose | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Data tracking failed for {func_name}: {e}") | |
| # Execute function | |
| try: | |
| if is_hit: | |
| # For cache hits, just return cached result (no timing needed) | |
| result = cached_func(*args, **kwargs) | |
| else: | |
| # For cache misses, time the computation | |
| result = cached_func(*args, **kwargs) | |
| if computation_start: | |
| computation_time = time.time() - computation_start | |
| monitor.track_computation_time(func_name, computation_time) | |
| logger.debug( | |
| f"Computation time for {func_name}: {computation_time:.3f}s" | |
| ) | |
| return result | |
| except (EOFError, pickle.PickleError, OSError) as e: | |
| # Handle cache corruption | |
| logger.warning( | |
| f"Cache corruption for {func_name}: {type(e).__name__} - recomputing" | |
| ) | |
| # Clear corrupted cache if possible | |
| if cache_key is not None: | |
| _clear_corrupted_cache(cached_func, args, kwargs, func_name) | |
| # Execute function directly and track time | |
| direct_start = time.time() | |
| result = func(*args, **kwargs) | |
| if computation_start: # Track time if it was originally a miss | |
| computation_time = time.time() - direct_start | |
| monitor.track_computation_time(func_name, computation_time) | |
| logger.debug( | |
| f"Direct computation time for {func_name}: {computation_time:.3f}s" | |
| ) | |
| return result | |
| except Exception as e: | |
| # Other unexpected errors | |
| logger.error(f"Unexpected cache error for {func_name}: {e}") | |
| raise | |
| # Add cache info method for debugging | |
| def cache_info(): | |
| return { | |
| "function_name": func_name, | |
| "seen_signatures": len(seen_signatures), | |
| "hits": cache_stats._stats.get(func_name, {}).get("hits", 0), | |
| "misses": cache_stats._stats.get(func_name, {}).get("misses", 0), | |
| } | |
| wrapper.cache_info = cache_info | |
| wrapper._afml_cacheable = True | |
| return wrapper | |
| return decorator | |
| def _clear_corrupted_cache(cached_func, args, kwargs, func_name): | |
| """Helper to clear corrupted cache entries.""" | |
| try: | |
| if hasattr(cached_func, "_get_cache_id"): | |
| joblib_cache_key = cached_func._get_cache_id(*args, **kwargs) | |
| cache_dir = Path(cached_func.store_backend.location) | |
| # Remove files matching this cache key | |
| removed_count = 0 | |
| for cache_file in cache_dir.rglob("*"): | |
| if cache_file.is_file() and str(joblib_cache_key) in str(cache_file): | |
| cache_file.unlink() | |
| removed_count += 1 | |
| logger.debug(f"Removed corrupted file: {cache_file.name}") | |
| if removed_count > 0: | |
| logger.info( | |
| f"Cleared {removed_count} corrupted cache files for {func_name}" | |
| ) | |
| except Exception as clear_exc: | |
| logger.warning(f"Failed to clear corrupted cache for {func_name}: {clear_exc}") | |
| def _track_dataframe_access(tracker, args, kwargs, dataset_name, purpose): | |
| """Track DataFrame accesses for data hygiene monitoring.""" | |
| # Check all arguments for DataFrames with DatetimeIndex | |
| for i, arg in enumerate(args): | |
| if _is_trackable_dataframe(arg): | |
| _log_dataframe_access(tracker, arg, dataset_name or f"arg_{i}", purpose) | |
| for key, value in kwargs.items(): | |
| if _is_trackable_dataframe(value): | |
| _log_dataframe_access(tracker, value, dataset_name or key, purpose) | |
| def _is_trackable_dataframe(obj): | |
| """Check if object is a DataFrame with temporal index.""" | |
| return ( | |
| isinstance(obj, pd.DataFrame) | |
| and isinstance(obj.index, pd.DatetimeIndex) | |
| and len(obj) > 0 | |
| ) | |
| def _log_dataframe_access(tracker, df, name, purpose): | |
| """Log DataFrame access to tracker.""" | |
| tracker.log_access( | |
| dataset_name=name, | |
| start_date=df.index[0], | |
| end_date=df.index[-1], | |
| purpose=purpose or "unknown", | |
| data_shape=df.shape, | |
| ) | |
| # ============================================================================= | |
| # Final convenience exports | |
| # ============================================================================= | |
| # Standard decorators (backward compatible) | |
| robust_cacheable = create_robust_cacheable(use_time_awareness=False) | |
| time_aware_cacheable = create_robust_cacheable(use_time_awareness=True) | |
| # Data tracking decorators (new functionality) | |
| data_tracking_cacheable = lambda dataset_name, purpose: create_robust_cacheable( | |
| track_data_access=True, | |
| dataset_name=dataset_name, | |
| purpose=purpose, | |
| use_time_awareness=False, | |
| ) | |
| time_aware_data_tracking_cacheable = ( | |
| lambda dataset_name, purpose: create_robust_cacheable( | |
| track_data_access=True, | |
| dataset_name=dataset_name, | |
| purpose=purpose, | |
| use_time_awareness=True, | |
| ) | |
| ) | |
| __all__ = [ | |
| "CacheKeyGenerator", | |
| "TimeSeriesCacheKey", | |
| "data_tracking_cacheable", # NEW | |
| "robust_cacheable", # Backward compatible | |
| "time_aware_cacheable", # Backward compatible | |
| "time_aware_data_tracking_cacheable", # NEW | |
| ] | |