Spaces:
No application file
No application file
| """ | |
| Centralized caching system for AFML package. | |
| Now with robust cache keys, MLflow integration, backtest caching, and monitoring. | |
| """ | |
| if __name__ == "__main__" and not __package__: | |
| print("This file is a package initializer and is not meant to be run directly.") | |
| print("Use it via import, for example:") | |
| print(r" .\.venv\Scripts\python.exe -c ""from afml.cache import initialize_cache_system""") | |
| raise SystemExit(0) | |
| import json | |
| import os | |
| import threading | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from types import FunctionType | |
| from typing import Callable, Dict, Optional, Union | |
| from appdirs import user_cache_dir | |
| from joblib import Memory | |
| from loguru import logger | |
| # ============================================================================= | |
| # 1) CACHE DIRECTORY SETUP | |
| # ============================================================================= | |
| def _setup_cache_directories() -> Dict[str, Path]: | |
| """Setup centralized cache directories.""" | |
| # Base cache directory from environment or default | |
| cache_env = os.getenv("AFML_CACHE") | |
| base_dir = Path(cache_env) if cache_env else Path(user_cache_dir("afml")) | |
| dirs = { | |
| "base": base_dir, | |
| "joblib": base_dir / "joblib_cache", | |
| "numba": base_dir / "numba_cache", | |
| "backtest": base_dir / "backtest_cache", # Added backtest cache directory | |
| } | |
| # Create directories | |
| for cache_dir in dirs.values(): | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| return dirs | |
| CACHE_DIRS = _setup_cache_directories() | |
| # ============================================================================= | |
| # 2) NUMBA CONFIGURATION | |
| # ============================================================================= | |
| def _configure_numba(): | |
| """Configure Numba to use centralized cache.""" | |
| numba_dir = str(CACHE_DIRS["numba"]) | |
| os.environ["NUMBA_CACHE_DIR"] = numba_dir | |
| # Performance optimizations | |
| os.environ.setdefault("NUMBA_DISABLE_JIT", "0") | |
| os.environ.setdefault("NUMBA_WARNINGS", "0") | |
| logger.debug("Numba cache configured: {}", numba_dir) | |
| # ============================================================================= | |
| # 3) SIMPLE CACHE STATISTICS | |
| # ============================================================================= | |
| class CacheStats: | |
| """Lightweight cache statistics tracking.""" | |
| def __init__(self): | |
| self._lock = threading.Lock() | |
| self._stats = defaultdict(lambda: {"hits": 0, "misses": 0}) | |
| self._stats_file = CACHE_DIRS["base"] / "cache_stats.json" | |
| self._load_stats() | |
| def _load_stats(self): | |
| """Load stats from disk.""" | |
| if self._stats_file.exists(): | |
| try: | |
| with open(self._stats_file, "r") as f: | |
| data = json.load(f) | |
| self._stats.update(data) | |
| except Exception: | |
| pass # Start fresh if corrupted | |
| def _save_stats(self): | |
| """Save stats to disk.""" | |
| try: | |
| with open(self._stats_file, "w") as f: | |
| json.dump(dict(self._stats), f) | |
| except Exception: | |
| pass # Fail silently | |
| def record_hit(self, func_name: str): | |
| """Record cache hit.""" | |
| with self._lock: | |
| self._stats[func_name]["hits"] += 1 | |
| # Save every 25 hits to reduce I/O | |
| if self._stats[func_name]["hits"] % 25 == 0: | |
| self._save_stats() | |
| def record_miss(self, func_name: str): | |
| """Record cache miss.""" | |
| with self._lock: | |
| self._stats[func_name]["misses"] += 1 | |
| # Save every 25 misses | |
| if self._stats[func_name]["misses"] % 25 == 0: | |
| self._save_stats() | |
| def get_hit_rate(self, func_name: str = None) -> float: | |
| """Get hit rate for function or overall.""" | |
| with self._lock: | |
| if func_name: | |
| stats = self._stats[func_name] | |
| total = stats["hits"] + stats["misses"] | |
| return stats["hits"] / total if total > 0 else 0.0 | |
| else: | |
| total_hits = sum(s["hits"] for s in self._stats.values()) | |
| total_calls = sum(s["hits"] + s["misses"] for s in self._stats.values()) | |
| return total_hits / total_calls if total_calls > 0 else 0.0 | |
| def get_stats(self) -> Dict[str, Dict[str, int]]: | |
| """Get all statistics.""" | |
| with self._lock: | |
| return dict(self._stats) | |
| def clear(self): | |
| """Clear all statistics.""" | |
| with self._lock: | |
| self._stats.clear() | |
| if self._stats_file.exists(): | |
| self._stats_file.unlink() | |
| # Global stats instance | |
| cache_stats = CacheStats() | |
| # ============================================================================= | |
| # 4) JOBLIB MEMORY INSTANCE | |
| # ============================================================================= | |
| memory = Memory(location=str(CACHE_DIRS["joblib"]), verbose=0) | |
| # ============================================================================= | |
| # 5) UTILITY FUNCTIONS | |
| # ============================================================================= | |
| def get_cache_hit_rate(func_name: str = None) -> float: | |
| """Get cache hit rate.""" | |
| return cache_stats.get_hit_rate(func_name) | |
| def get_cache_stats() -> Dict[str, Dict[str, int]]: | |
| """Get cache statistics.""" | |
| return cache_stats.get_stats() | |
| def clear_cache_stats(): | |
| """Clear cache statistics.""" | |
| cache_stats.clear() | |
| def clear_afml_cache(warn: bool = True): | |
| """Clear all AFML caches.""" | |
| if warn: | |
| logger.warning("Clearing AFML cache...") | |
| memory.clear(warn=warn) | |
| clear_cache_stats() | |
| def get_cache_summary() -> Dict[str, Union[float, int]]: | |
| """Get simple cache performance summary.""" | |
| stats = cache_stats.get_stats() | |
| total_hits = sum(s["hits"] for s in stats.values()) | |
| total_calls = sum(s["hits"] + s["misses"] for s in stats.values()) | |
| return { | |
| "hit_rate": total_hits / total_calls if total_calls > 0 else 0.0, | |
| "total_calls": total_calls, | |
| "functions_tracked": len(stats), | |
| } | |
| # ============================================================================= | |
| # 6) CACHE ANALYSIS CONTEXT MANAGER | |
| # ============================================================================= | |
| class CacheAnalyzer: | |
| """Simple context manager for analyzing cache performance.""" | |
| def __init__(self, name: str = "analysis"): | |
| self.name = name | |
| self.start_stats = None | |
| def __enter__(self): | |
| self.start_stats = cache_stats.get_stats().copy() | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| if exc_type is None: | |
| end_stats = cache_stats.get_stats() | |
| report = self._generate_report(end_stats) | |
| if report: | |
| logger.info("Cache analysis '{}': {}", self.name, report) | |
| def _generate_report(self, end_stats) -> Optional[str]: | |
| """Generate simple performance report.""" | |
| if not self.start_stats: | |
| return None | |
| total_new_hits = 0 | |
| total_new_calls = 0 | |
| for func_name, end_data in end_stats.items(): | |
| start_data = self.start_stats.get(func_name, {"hits": 0, "misses": 0}) | |
| new_hits = end_data["hits"] - start_data["hits"] | |
| new_misses = end_data["misses"] - start_data["misses"] | |
| new_calls = new_hits + new_misses | |
| total_new_hits += new_hits | |
| total_new_calls += new_calls | |
| if total_new_calls > 0: | |
| hit_rate = total_new_hits / total_new_calls | |
| return f"{total_new_calls} calls, {hit_rate:.1%} hit rate" | |
| return "no cache activity" | |
| # ============================================================================= | |
| # 7) INITIALIZATION FUNCTION | |
| # ============================================================================= | |
| def initialize_cache_system(): | |
| """Initialize the AFML cache system.""" | |
| # Configure Numba first (before any @njit functions are defined) | |
| _configure_numba() | |
| # Log cache setup | |
| logger.info("AFML cache system initialized:") | |
| logger.info(" Joblib cache: {}", CACHE_DIRS["joblib"]) | |
| logger.info(" Numba cache: {}", CACHE_DIRS["numba"]) | |
| # Load existing stats | |
| stats = cache_stats.get_stats() | |
| if stats: | |
| hit_rate = cache_stats.get_hit_rate() | |
| logger.info( | |
| " Loaded stats: {} functions, {:.1%} hit rate", len(stats), hit_rate | |
| ) | |
| # ============================================================================= | |
| # 8) NOW SAFE TO IMPORT OTHER MODULES | |
| # ============================================================================= | |
| # Import robust cache key generation - NOW SAFE (memory and cache_stats exist) | |
| from .data_access_tracker import ( # noqa: E402 | |
| DataAccessTracker, # noqa: E402 | |
| clear_data_access_log, | |
| get_data_tracker, | |
| log_data_access, | |
| print_contamination_report, | |
| ) | |
| # Import selective cleaner functions after base components are defined | |
| from .selective_cleaner import ( # noqa: E402 | |
| analyze_cache_versions, # noqa: E402 | |
| cache_maintenance, | |
| clean_orphaned_caches, | |
| cleanup_by_age, | |
| cleanup_by_size, | |
| clear_orphaned_features_caches, | |
| clear_orphaned_labeling_caches, | |
| clear_orphaned_ml_caches, | |
| find_orphaned_caches, | |
| get_version_tracker, | |
| print_version_analysis, | |
| ) | |
| # Add to imports | |
| from .unified_cache_system import ( | |
| cacheable, # noqa: E402 | |
| create_cacheable_param_grid, | |
| cv_cacheable, | |
| data_tracking_cacheable, | |
| print_cache_report, | |
| reconstruct_param_grid, | |
| robust_cacheable, | |
| time_aware_cacheable, | |
| ) | |
| # MLflow integration (optional) | |
| try: | |
| from .mlflow_integration import ( | |
| MLFLOW_AVAILABLE, | |
| MLflowCacheIntegration, | |
| get_mlflow_cache, | |
| mlflow_cached, | |
| setup_mlflow_cache, | |
| ) | |
| MLFLOW_INTEGRATION_AVAILABLE = True | |
| except ImportError: | |
| MLFLOW_INTEGRATION_AVAILABLE = False | |
| logger.debug("MLflow integration not available (install mlflow)") | |
| # Backtest caching | |
| from .backtest_cache import ( | |
| BacktestCache, | |
| BacktestMetadata, # noqa: E402 | |
| BacktestResult, | |
| cached_backtest, | |
| get_backtest_cache, | |
| ) | |
| # Cache monitoring | |
| from .cache_monitoring import ( | |
| CacheHealthReport, | |
| CacheMonitor, # noqa: E402 | |
| FunctionCacheStats, | |
| analyze_cache_patterns, | |
| debug_function_cache, | |
| diagnose_cache_issues, | |
| get_cache_efficiency_report, | |
| get_cache_monitor, | |
| print_cache_health, | |
| ) | |
| # ============================================================================= | |
| # 9) ENHANCED CONVENIENCE FUNCTIONS | |
| # ============================================================================= | |
| def get_comprehensive_cache_status() -> dict: | |
| """ | |
| Get comprehensive cache status including all subsystems. | |
| Returns: | |
| Dict with status of all cache components | |
| """ | |
| status = { | |
| "core": get_cache_summary(), | |
| "health": None, | |
| "backtest": None, | |
| "mlflow": {"available": MLFLOW_INTEGRATION_AVAILABLE}, | |
| } | |
| # Get health report | |
| try: | |
| monitor = get_cache_monitor() | |
| report = monitor.generate_health_report() | |
| status["health"] = { | |
| "total_functions": report.total_functions, | |
| "hit_rate": report.overall_hit_rate, | |
| "total_calls": report.total_calls, | |
| "cache_size_mb": report.total_cache_size_mb, | |
| } | |
| except Exception as e: | |
| logger.debug(f"Health report failed: {e}") | |
| # Get backtest cache stats | |
| try: | |
| backtest_cache = get_backtest_cache() | |
| status["backtest"] = backtest_cache.get_cache_stats() | |
| except Exception as e: | |
| logger.debug(f"Backtest cache stats failed: {e}") | |
| return status | |
| def optimize_cache_system( | |
| clear_changed: bool = True, | |
| max_size_mb: int = 1000, | |
| max_age_days: int = 30, | |
| print_report: bool = True, | |
| ) -> dict: | |
| """ | |
| Comprehensive cache optimization and maintenance. | |
| Args: | |
| clear_changed: Clear caches for changed functions | |
| max_size_mb: Maximum total cache size in MB | |
| max_age_days: Remove caches older than this | |
| print_report: Print detailed report | |
| Returns: | |
| Dict with optimization results | |
| """ | |
| logger.info("Running comprehensive cache optimization...") | |
| results = { | |
| "maintenance": None, | |
| "health_report": None, | |
| "backtest_cleanup": None, | |
| } | |
| # Run core cache maintenance | |
| try: | |
| results["maintenance"] = cache_maintenance( | |
| clean_orphaned=clear_changed, | |
| max_cache_size_mb=max_size_mb, | |
| max_age_days=max_age_days, | |
| ) | |
| except Exception as e: | |
| logger.warning(f"Cache maintenance failed: {e}") | |
| # Get health report | |
| try: | |
| monitor = get_cache_monitor() | |
| results["health_report"] = monitor.generate_health_report() | |
| if print_report: | |
| monitor.print_health_report(detailed=False) | |
| except Exception as e: | |
| logger.warning(f"Health report failed: {e}") | |
| # Clean old backtest caches | |
| try: | |
| backtest_cache = get_backtest_cache() | |
| cleared = backtest_cache.clear_old_runs(days=max_age_days) | |
| results["backtest_cleanup"] = {"runs_cleared": cleared} | |
| logger.info(f"Cleared {cleared} old backtest runs") | |
| except Exception as e: | |
| logger.warning(f"Backtest cleanup failed: {e}") | |
| return results | |
| def setup_production_cache( | |
| enable_mlflow: bool = True, | |
| mlflow_experiment: str = "production", | |
| mlflow_uri: str = None, | |
| max_cache_size_mb: int = 2000, | |
| ) -> dict: | |
| """ | |
| Setup cache system for production use. | |
| Args: | |
| enable_mlflow: Enable MLflow integration | |
| mlflow_experiment: MLflow experiment name | |
| mlflow_uri: MLflow tracking URI | |
| max_cache_size_mb: Maximum cache size | |
| Returns: | |
| Dict with initialized components | |
| """ | |
| logger.info("Initializing production cache system...") | |
| components = { | |
| "core_cache": None, | |
| "mlflow_cache": None, | |
| "backtest_cache": None, | |
| "monitor": None, | |
| } | |
| # Initialize core cache | |
| initialize_cache_system() | |
| components["core_cache"] = True | |
| # Setup MLflow if available and requested | |
| if enable_mlflow and MLFLOW_INTEGRATION_AVAILABLE: | |
| try: | |
| components["mlflow_cache"] = setup_mlflow_cache( | |
| experiment_name=mlflow_experiment, | |
| tracking_uri=mlflow_uri, | |
| ) | |
| logger.info(f"MLflow tracking enabled: {mlflow_experiment}") | |
| except Exception as e: | |
| logger.warning(f"MLflow setup failed: {e}") | |
| # Initialize backtest cache | |
| try: | |
| components["backtest_cache"] = get_backtest_cache() | |
| except Exception as e: | |
| logger.warning(f"Backtest cache setup failed: {e}") | |
| # Initialize monitor | |
| try: | |
| components["monitor"] = get_cache_monitor() | |
| except Exception as e: | |
| logger.warning(f"Cache monitor setup failed: {e}") | |
| # Run initial maintenance | |
| try: | |
| optimize_cache_system(max_size_mb=max_cache_size_mb, print_report=False) | |
| except Exception as e: | |
| logger.warning(f"Initial optimization failed: {e}") | |
| logger.info("✅ Production cache system ready") | |
| return components | |
| # ============================================================================= | |
| # 10) ADDITIONAL UTILITY FUNCTIONS | |
| # ============================================================================= | |
| def get_cache_size_info() -> Dict[str, Union[int, float]]: | |
| """ | |
| Get detailed information about cache sizes. | |
| Returns: | |
| Dict with cache size information in bytes and MB | |
| """ | |
| size_info = {} | |
| for cache_name, cache_dir in CACHE_DIRS.items(): | |
| if cache_dir.exists(): | |
| total_size = 0 | |
| file_count = 0 | |
| for file_path in cache_dir.rglob("*"): | |
| if file_path.is_file(): | |
| total_size += file_path.stat().st_size | |
| file_count += 1 | |
| size_info[cache_name] = { | |
| "size_bytes": total_size, | |
| "size_mb": round(total_size / (1024 * 1024), 2), | |
| "file_count": file_count, | |
| } | |
| return size_info | |
| def clear_cache_by_pattern(pattern: str, cache_type: str = "joblib"): | |
| """ | |
| Clear cache entries matching a pattern. | |
| Args: | |
| pattern: String pattern to match in cache filenames | |
| cache_type: Type of cache to clear ('joblib', 'numba', 'backtest') | |
| """ | |
| if cache_type not in CACHE_DIRS: | |
| raise ValueError( | |
| f"Invalid cache type: {cache_type}. Available: {list(CACHE_DIRS.keys())}" | |
| ) | |
| cache_dir = CACHE_DIRS[cache_type] | |
| removed_count = 0 | |
| for cache_file in cache_dir.rglob("*"): | |
| if cache_file.is_file() and pattern in cache_file.name: | |
| try: | |
| cache_file.unlink() | |
| removed_count += 1 | |
| logger.debug(f"Removed cache file: {cache_file.name}") | |
| except Exception as e: | |
| logger.warning(f"Failed to remove {cache_file}: {e}") | |
| logger.info( | |
| f"Removed {removed_count} cache files matching pattern '{pattern}' from {cache_type} cache" | |
| ) | |
| def apply_decorator_to_methods(decorator: Callable, *, include_private: bool = False): | |
| """ | |
| Class decorator factory that applies `decorator` to each function attribute | |
| on the class (by default public methods only). Preserves staticmethod/classmethod. | |
| """ | |
| def class_decorator(cls): | |
| for name, attr in list(cls.__dict__.items()): | |
| if not include_private and name.startswith("_"): | |
| continue | |
| # staticmethod | |
| if isinstance(attr, staticmethod): | |
| fn = attr.__func__ | |
| wrapped = decorator(fn) | |
| setattr(cls, name, staticmethod(wrapped)) | |
| continue | |
| # classmethod | |
| if isinstance(attr, classmethod): | |
| fn = attr.__func__ | |
| wrapped = decorator(fn) | |
| setattr(cls, name, classmethod(wrapped)) | |
| continue | |
| # plain function (instance method) | |
| if isinstance(attr, FunctionType): | |
| wrapped = decorator(attr) | |
| setattr(cls, name, wrapped) | |
| return cls | |
| return class_decorator | |
| # ============================================================================= | |
| # 11) EXPORTS | |
| # ============================================================================= | |
| __all__ = [ | |
| # Core caching | |
| "memory", | |
| "cacheable", # NEW: Universal decorator | |
| "initialize_cache_system", | |
| "cache_stats", | |
| "get_cache_hit_rate", | |
| "get_cache_stats", | |
| "clear_cache_stats", | |
| "get_cache_summary", | |
| "CacheAnalyzer", | |
| "clear_afml_cache", | |
| "CACHE_DIRS", | |
| # Selective cache management (NOW FOCUSED ON CLEANUP) | |
| "cache_maintenance", | |
| "find_orphaned_caches", | |
| "clean_orphaned_caches", | |
| "cleanup_by_size", | |
| "cleanup_by_age", | |
| "get_version_tracker", | |
| "clear_orphaned_ml_caches", | |
| "clear_orphaned_labeling_caches", | |
| "clear_orphaned_features_caches", | |
| "analyze_cache_versions", | |
| "print_version_analysis", | |
| # NOTE: Removed exports: | |
| # - smart_cacheable (replaced by auto_versioning parameter) | |
| # - clear_changed_* functions (replaced by clean_orphaned_* functions) | |
| # - selective_cache_clear (replaced by clean_orphaned_caches) | |
| # Robust cache keys | |
| "CacheKeyGenerator", | |
| "clear_data_access_log", | |
| "DataAccessTracker", | |
| "get_data_tracker", | |
| "log_data_access", | |
| "print_contamination_report", | |
| "robust_cacheable", # Alias for cacheable() | |
| "time_aware_cacheable", # Alias for cacheable(time_aware=True) | |
| "data_tracking_cacheable", | |
| # MLflow integration | |
| "MLflowCacheIntegration", | |
| "setup_mlflow_cache", | |
| "get_mlflow_cache", | |
| "mlflow_cached", | |
| "MLFLOW_AVAILABLE", | |
| "MLFLOW_INTEGRATION_AVAILABLE", | |
| # Backtest caching | |
| "BacktestCache", | |
| "BacktestMetadata", | |
| "BacktestResult", | |
| "get_backtest_cache", | |
| "cached_backtest", | |
| # Cache monitoring | |
| "CacheMonitor", | |
| "FunctionCacheStats", | |
| "CacheHealthReport", | |
| "get_cache_monitor", | |
| "print_cache_report", | |
| "print_cache_health", | |
| "get_cache_efficiency_report", | |
| "analyze_cache_patterns", | |
| "debug_function_cache", | |
| "diagnose_cache_issues", | |
| # Enhanced convenience functions | |
| "get_comprehensive_cache_status", | |
| "optimize_cache_system", | |
| "setup_production_cache", | |
| # Cache cross-validation | |
| "cv_cacheable", # Alias for cacheable() | |
| # Additional utility functions | |
| "get_cache_size_info", | |
| "clear_cache_by_pattern", | |
| "apply_decorator_to_methods", | |
| # Hyper-parameter fit helpers | |
| "reconstruct_param_grid", | |
| "create_cacheable_param_grid", | |
| ] | |
| # ============================================================================= | |
| # STARTUP MESSAGE UPDATE | |
| # ============================================================================= | |
| # Add to end of file to show new features are available | |
| logger.debug("Enhanced cache features available:") | |
| logger.debug(" - Unified cacheable() decorator with auto_versioning") | |
| logger.debug(" - Robust cache keys for NumPy/Pandas") | |
| logger.debug(" - MLflow integration: {}", "✓" if MLFLOW_INTEGRATION_AVAILABLE else "✗") | |
| logger.debug(" - Backtest caching: ✓") | |
| logger.debug(" - Cache monitoring: ✓") | |
| logger.debug(" - Orphaned cache cleanup: ✓") | |
| logger.debug(" - Orphaned cache cleanup: ✓") | |
| logger.debug(" - Cache size info and selective clearing: ✓") | |