AFML / afml /cache /__init__.py
akshayboora's picture
Upload 940 files
669d6a1 verified
"""
Centralized caching system for AFML package.
Now with robust cache keys, MLflow integration, backtest caching, and monitoring.
"""
if __name__ == "__main__" and not __package__:
print("This file is a package initializer and is not meant to be run directly.")
print("Use it via import, for example:")
print(r" .\.venv\Scripts\python.exe -c ""from afml.cache import initialize_cache_system""")
raise SystemExit(0)
import json
import os
import threading
from collections import defaultdict
from pathlib import Path
from types import FunctionType
from typing import Callable, Dict, Optional, Union
from appdirs import user_cache_dir
from joblib import Memory
from loguru import logger
# =============================================================================
# 1) CACHE DIRECTORY SETUP
# =============================================================================
def _setup_cache_directories() -> Dict[str, Path]:
"""Setup centralized cache directories."""
# Base cache directory from environment or default
cache_env = os.getenv("AFML_CACHE")
base_dir = Path(cache_env) if cache_env else Path(user_cache_dir("afml"))
dirs = {
"base": base_dir,
"joblib": base_dir / "joblib_cache",
"numba": base_dir / "numba_cache",
"backtest": base_dir / "backtest_cache", # Added backtest cache directory
}
# Create directories
for cache_dir in dirs.values():
cache_dir.mkdir(parents=True, exist_ok=True)
return dirs
CACHE_DIRS = _setup_cache_directories()
# =============================================================================
# 2) NUMBA CONFIGURATION
# =============================================================================
def _configure_numba():
"""Configure Numba to use centralized cache."""
numba_dir = str(CACHE_DIRS["numba"])
os.environ["NUMBA_CACHE_DIR"] = numba_dir
# Performance optimizations
os.environ.setdefault("NUMBA_DISABLE_JIT", "0")
os.environ.setdefault("NUMBA_WARNINGS", "0")
logger.debug("Numba cache configured: {}", numba_dir)
# =============================================================================
# 3) SIMPLE CACHE STATISTICS
# =============================================================================
class CacheStats:
"""Lightweight cache statistics tracking."""
def __init__(self):
self._lock = threading.Lock()
self._stats = defaultdict(lambda: {"hits": 0, "misses": 0})
self._stats_file = CACHE_DIRS["base"] / "cache_stats.json"
self._load_stats()
def _load_stats(self):
"""Load stats from disk."""
if self._stats_file.exists():
try:
with open(self._stats_file, "r") as f:
data = json.load(f)
self._stats.update(data)
except Exception:
pass # Start fresh if corrupted
def _save_stats(self):
"""Save stats to disk."""
try:
with open(self._stats_file, "w") as f:
json.dump(dict(self._stats), f)
except Exception:
pass # Fail silently
def record_hit(self, func_name: str):
"""Record cache hit."""
with self._lock:
self._stats[func_name]["hits"] += 1
# Save every 25 hits to reduce I/O
if self._stats[func_name]["hits"] % 25 == 0:
self._save_stats()
def record_miss(self, func_name: str):
"""Record cache miss."""
with self._lock:
self._stats[func_name]["misses"] += 1
# Save every 25 misses
if self._stats[func_name]["misses"] % 25 == 0:
self._save_stats()
def get_hit_rate(self, func_name: str = None) -> float:
"""Get hit rate for function or overall."""
with self._lock:
if func_name:
stats = self._stats[func_name]
total = stats["hits"] + stats["misses"]
return stats["hits"] / total if total > 0 else 0.0
else:
total_hits = sum(s["hits"] for s in self._stats.values())
total_calls = sum(s["hits"] + s["misses"] for s in self._stats.values())
return total_hits / total_calls if total_calls > 0 else 0.0
def get_stats(self) -> Dict[str, Dict[str, int]]:
"""Get all statistics."""
with self._lock:
return dict(self._stats)
def clear(self):
"""Clear all statistics."""
with self._lock:
self._stats.clear()
if self._stats_file.exists():
self._stats_file.unlink()
# Global stats instance
cache_stats = CacheStats()
# =============================================================================
# 4) JOBLIB MEMORY INSTANCE
# =============================================================================
memory = Memory(location=str(CACHE_DIRS["joblib"]), verbose=0)
# =============================================================================
# 5) UTILITY FUNCTIONS
# =============================================================================
def get_cache_hit_rate(func_name: str = None) -> float:
"""Get cache hit rate."""
return cache_stats.get_hit_rate(func_name)
def get_cache_stats() -> Dict[str, Dict[str, int]]:
"""Get cache statistics."""
return cache_stats.get_stats()
def clear_cache_stats():
"""Clear cache statistics."""
cache_stats.clear()
def clear_afml_cache(warn: bool = True):
"""Clear all AFML caches."""
if warn:
logger.warning("Clearing AFML cache...")
memory.clear(warn=warn)
clear_cache_stats()
def get_cache_summary() -> Dict[str, Union[float, int]]:
"""Get simple cache performance summary."""
stats = cache_stats.get_stats()
total_hits = sum(s["hits"] for s in stats.values())
total_calls = sum(s["hits"] + s["misses"] for s in stats.values())
return {
"hit_rate": total_hits / total_calls if total_calls > 0 else 0.0,
"total_calls": total_calls,
"functions_tracked": len(stats),
}
# =============================================================================
# 6) CACHE ANALYSIS CONTEXT MANAGER
# =============================================================================
class CacheAnalyzer:
"""Simple context manager for analyzing cache performance."""
def __init__(self, name: str = "analysis"):
self.name = name
self.start_stats = None
def __enter__(self):
self.start_stats = cache_stats.get_stats().copy()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
end_stats = cache_stats.get_stats()
report = self._generate_report(end_stats)
if report:
logger.info("Cache analysis '{}': {}", self.name, report)
def _generate_report(self, end_stats) -> Optional[str]:
"""Generate simple performance report."""
if not self.start_stats:
return None
total_new_hits = 0
total_new_calls = 0
for func_name, end_data in end_stats.items():
start_data = self.start_stats.get(func_name, {"hits": 0, "misses": 0})
new_hits = end_data["hits"] - start_data["hits"]
new_misses = end_data["misses"] - start_data["misses"]
new_calls = new_hits + new_misses
total_new_hits += new_hits
total_new_calls += new_calls
if total_new_calls > 0:
hit_rate = total_new_hits / total_new_calls
return f"{total_new_calls} calls, {hit_rate:.1%} hit rate"
return "no cache activity"
# =============================================================================
# 7) INITIALIZATION FUNCTION
# =============================================================================
def initialize_cache_system():
"""Initialize the AFML cache system."""
# Configure Numba first (before any @njit functions are defined)
_configure_numba()
# Log cache setup
logger.info("AFML cache system initialized:")
logger.info(" Joblib cache: {}", CACHE_DIRS["joblib"])
logger.info(" Numba cache: {}", CACHE_DIRS["numba"])
# Load existing stats
stats = cache_stats.get_stats()
if stats:
hit_rate = cache_stats.get_hit_rate()
logger.info(
" Loaded stats: {} functions, {:.1%} hit rate", len(stats), hit_rate
)
# =============================================================================
# 8) NOW SAFE TO IMPORT OTHER MODULES
# =============================================================================
# Import robust cache key generation - NOW SAFE (memory and cache_stats exist)
from .data_access_tracker import ( # noqa: E402
DataAccessTracker, # noqa: E402
clear_data_access_log,
get_data_tracker,
log_data_access,
print_contamination_report,
)
# Import selective cleaner functions after base components are defined
from .selective_cleaner import ( # noqa: E402
analyze_cache_versions, # noqa: E402
cache_maintenance,
clean_orphaned_caches,
cleanup_by_age,
cleanup_by_size,
clear_orphaned_features_caches,
clear_orphaned_labeling_caches,
clear_orphaned_ml_caches,
find_orphaned_caches,
get_version_tracker,
print_version_analysis,
)
# Add to imports
from .unified_cache_system import (
cacheable, # noqa: E402
create_cacheable_param_grid,
cv_cacheable,
data_tracking_cacheable,
print_cache_report,
reconstruct_param_grid,
robust_cacheable,
time_aware_cacheable,
)
# MLflow integration (optional)
try:
from .mlflow_integration import (
MLFLOW_AVAILABLE,
MLflowCacheIntegration,
get_mlflow_cache,
mlflow_cached,
setup_mlflow_cache,
)
MLFLOW_INTEGRATION_AVAILABLE = True
except ImportError:
MLFLOW_INTEGRATION_AVAILABLE = False
logger.debug("MLflow integration not available (install mlflow)")
# Backtest caching
from .backtest_cache import (
BacktestCache,
BacktestMetadata, # noqa: E402
BacktestResult,
cached_backtest,
get_backtest_cache,
)
# Cache monitoring
from .cache_monitoring import (
CacheHealthReport,
CacheMonitor, # noqa: E402
FunctionCacheStats,
analyze_cache_patterns,
debug_function_cache,
diagnose_cache_issues,
get_cache_efficiency_report,
get_cache_monitor,
print_cache_health,
)
# =============================================================================
# 9) ENHANCED CONVENIENCE FUNCTIONS
# =============================================================================
def get_comprehensive_cache_status() -> dict:
"""
Get comprehensive cache status including all subsystems.
Returns:
Dict with status of all cache components
"""
status = {
"core": get_cache_summary(),
"health": None,
"backtest": None,
"mlflow": {"available": MLFLOW_INTEGRATION_AVAILABLE},
}
# Get health report
try:
monitor = get_cache_monitor()
report = monitor.generate_health_report()
status["health"] = {
"total_functions": report.total_functions,
"hit_rate": report.overall_hit_rate,
"total_calls": report.total_calls,
"cache_size_mb": report.total_cache_size_mb,
}
except Exception as e:
logger.debug(f"Health report failed: {e}")
# Get backtest cache stats
try:
backtest_cache = get_backtest_cache()
status["backtest"] = backtest_cache.get_cache_stats()
except Exception as e:
logger.debug(f"Backtest cache stats failed: {e}")
return status
def optimize_cache_system(
clear_changed: bool = True,
max_size_mb: int = 1000,
max_age_days: int = 30,
print_report: bool = True,
) -> dict:
"""
Comprehensive cache optimization and maintenance.
Args:
clear_changed: Clear caches for changed functions
max_size_mb: Maximum total cache size in MB
max_age_days: Remove caches older than this
print_report: Print detailed report
Returns:
Dict with optimization results
"""
logger.info("Running comprehensive cache optimization...")
results = {
"maintenance": None,
"health_report": None,
"backtest_cleanup": None,
}
# Run core cache maintenance
try:
results["maintenance"] = cache_maintenance(
clean_orphaned=clear_changed,
max_cache_size_mb=max_size_mb,
max_age_days=max_age_days,
)
except Exception as e:
logger.warning(f"Cache maintenance failed: {e}")
# Get health report
try:
monitor = get_cache_monitor()
results["health_report"] = monitor.generate_health_report()
if print_report:
monitor.print_health_report(detailed=False)
except Exception as e:
logger.warning(f"Health report failed: {e}")
# Clean old backtest caches
try:
backtest_cache = get_backtest_cache()
cleared = backtest_cache.clear_old_runs(days=max_age_days)
results["backtest_cleanup"] = {"runs_cleared": cleared}
logger.info(f"Cleared {cleared} old backtest runs")
except Exception as e:
logger.warning(f"Backtest cleanup failed: {e}")
return results
def setup_production_cache(
enable_mlflow: bool = True,
mlflow_experiment: str = "production",
mlflow_uri: str = None,
max_cache_size_mb: int = 2000,
) -> dict:
"""
Setup cache system for production use.
Args:
enable_mlflow: Enable MLflow integration
mlflow_experiment: MLflow experiment name
mlflow_uri: MLflow tracking URI
max_cache_size_mb: Maximum cache size
Returns:
Dict with initialized components
"""
logger.info("Initializing production cache system...")
components = {
"core_cache": None,
"mlflow_cache": None,
"backtest_cache": None,
"monitor": None,
}
# Initialize core cache
initialize_cache_system()
components["core_cache"] = True
# Setup MLflow if available and requested
if enable_mlflow and MLFLOW_INTEGRATION_AVAILABLE:
try:
components["mlflow_cache"] = setup_mlflow_cache(
experiment_name=mlflow_experiment,
tracking_uri=mlflow_uri,
)
logger.info(f"MLflow tracking enabled: {mlflow_experiment}")
except Exception as e:
logger.warning(f"MLflow setup failed: {e}")
# Initialize backtest cache
try:
components["backtest_cache"] = get_backtest_cache()
except Exception as e:
logger.warning(f"Backtest cache setup failed: {e}")
# Initialize monitor
try:
components["monitor"] = get_cache_monitor()
except Exception as e:
logger.warning(f"Cache monitor setup failed: {e}")
# Run initial maintenance
try:
optimize_cache_system(max_size_mb=max_cache_size_mb, print_report=False)
except Exception as e:
logger.warning(f"Initial optimization failed: {e}")
logger.info("✅ Production cache system ready")
return components
# =============================================================================
# 10) ADDITIONAL UTILITY FUNCTIONS
# =============================================================================
def get_cache_size_info() -> Dict[str, Union[int, float]]:
"""
Get detailed information about cache sizes.
Returns:
Dict with cache size information in bytes and MB
"""
size_info = {}
for cache_name, cache_dir in CACHE_DIRS.items():
if cache_dir.exists():
total_size = 0
file_count = 0
for file_path in cache_dir.rglob("*"):
if file_path.is_file():
total_size += file_path.stat().st_size
file_count += 1
size_info[cache_name] = {
"size_bytes": total_size,
"size_mb": round(total_size / (1024 * 1024), 2),
"file_count": file_count,
}
return size_info
def clear_cache_by_pattern(pattern: str, cache_type: str = "joblib"):
"""
Clear cache entries matching a pattern.
Args:
pattern: String pattern to match in cache filenames
cache_type: Type of cache to clear ('joblib', 'numba', 'backtest')
"""
if cache_type not in CACHE_DIRS:
raise ValueError(
f"Invalid cache type: {cache_type}. Available: {list(CACHE_DIRS.keys())}"
)
cache_dir = CACHE_DIRS[cache_type]
removed_count = 0
for cache_file in cache_dir.rglob("*"):
if cache_file.is_file() and pattern in cache_file.name:
try:
cache_file.unlink()
removed_count += 1
logger.debug(f"Removed cache file: {cache_file.name}")
except Exception as e:
logger.warning(f"Failed to remove {cache_file}: {e}")
logger.info(
f"Removed {removed_count} cache files matching pattern '{pattern}' from {cache_type} cache"
)
def apply_decorator_to_methods(decorator: Callable, *, include_private: bool = False):
"""
Class decorator factory that applies `decorator` to each function attribute
on the class (by default public methods only). Preserves staticmethod/classmethod.
"""
def class_decorator(cls):
for name, attr in list(cls.__dict__.items()):
if not include_private and name.startswith("_"):
continue
# staticmethod
if isinstance(attr, staticmethod):
fn = attr.__func__
wrapped = decorator(fn)
setattr(cls, name, staticmethod(wrapped))
continue
# classmethod
if isinstance(attr, classmethod):
fn = attr.__func__
wrapped = decorator(fn)
setattr(cls, name, classmethod(wrapped))
continue
# plain function (instance method)
if isinstance(attr, FunctionType):
wrapped = decorator(attr)
setattr(cls, name, wrapped)
return cls
return class_decorator
# =============================================================================
# 11) EXPORTS
# =============================================================================
__all__ = [
# Core caching
"memory",
"cacheable", # NEW: Universal decorator
"initialize_cache_system",
"cache_stats",
"get_cache_hit_rate",
"get_cache_stats",
"clear_cache_stats",
"get_cache_summary",
"CacheAnalyzer",
"clear_afml_cache",
"CACHE_DIRS",
# Selective cache management (NOW FOCUSED ON CLEANUP)
"cache_maintenance",
"find_orphaned_caches",
"clean_orphaned_caches",
"cleanup_by_size",
"cleanup_by_age",
"get_version_tracker",
"clear_orphaned_ml_caches",
"clear_orphaned_labeling_caches",
"clear_orphaned_features_caches",
"analyze_cache_versions",
"print_version_analysis",
# NOTE: Removed exports:
# - smart_cacheable (replaced by auto_versioning parameter)
# - clear_changed_* functions (replaced by clean_orphaned_* functions)
# - selective_cache_clear (replaced by clean_orphaned_caches)
# Robust cache keys
"CacheKeyGenerator",
"clear_data_access_log",
"DataAccessTracker",
"get_data_tracker",
"log_data_access",
"print_contamination_report",
"robust_cacheable", # Alias for cacheable()
"time_aware_cacheable", # Alias for cacheable(time_aware=True)
"data_tracking_cacheable",
# MLflow integration
"MLflowCacheIntegration",
"setup_mlflow_cache",
"get_mlflow_cache",
"mlflow_cached",
"MLFLOW_AVAILABLE",
"MLFLOW_INTEGRATION_AVAILABLE",
# Backtest caching
"BacktestCache",
"BacktestMetadata",
"BacktestResult",
"get_backtest_cache",
"cached_backtest",
# Cache monitoring
"CacheMonitor",
"FunctionCacheStats",
"CacheHealthReport",
"get_cache_monitor",
"print_cache_report",
"print_cache_health",
"get_cache_efficiency_report",
"analyze_cache_patterns",
"debug_function_cache",
"diagnose_cache_issues",
# Enhanced convenience functions
"get_comprehensive_cache_status",
"optimize_cache_system",
"setup_production_cache",
# Cache cross-validation
"cv_cacheable", # Alias for cacheable()
# Additional utility functions
"get_cache_size_info",
"clear_cache_by_pattern",
"apply_decorator_to_methods",
# Hyper-parameter fit helpers
"reconstruct_param_grid",
"create_cacheable_param_grid",
]
# =============================================================================
# STARTUP MESSAGE UPDATE
# =============================================================================
# Add to end of file to show new features are available
logger.debug("Enhanced cache features available:")
logger.debug(" - Unified cacheable() decorator with auto_versioning")
logger.debug(" - Robust cache keys for NumPy/Pandas")
logger.debug(" - MLflow integration: {}", "✓" if MLFLOW_INTEGRATION_AVAILABLE else "✗")
logger.debug(" - Backtest caching: ✓")
logger.debug(" - Cache monitoring: ✓")
logger.debug(" - Orphaned cache cleanup: ✓")
logger.debug(" - Orphaned cache cleanup: ✓")
logger.debug(" - Cache size info and selective clearing: ✓")