AFML / afml /cache /cv_cache.py
akshayboora's picture
Upload 940 files
669d6a1 verified
"""
Specialized caching for cross-validation functions.
Handles sklearn classifiers, CV generators, and complex ML workflows.
"""
import hashlib
import inspect
import json
import pickle
import time
from functools import wraps
from typing import Callable, Optional
import numpy as np
import pandas as pd
from loguru import logger
from sklearn.base import BaseEstimator
def _hash_classifier(clf: BaseEstimator) -> str:
"""
Generate stable hash for sklearn classifier.
Uses class name + parameters (not the trained state).
"""
try:
# Get classifier type and parameters
clf_type = type(clf).__name__
params = clf.get_params(deep=True)
# Filter out non-serializable params (like objects, functions)
serializable_params = {}
for k, v in params.items():
try:
# Test if JSON serializable
json.dumps(v)
serializable_params[k] = v
except (TypeError, ValueError):
# Use type name for non-serializable params
serializable_params[k] = f"<{type(v).__name__}>"
# Create stable hash
param_str = json.dumps(serializable_params, sort_keys=True)
combined = f"{clf_type}_{param_str}"
return hashlib.md5(combined.encode()).hexdigest()[:12]
except Exception as e:
logger.debug(f"Failed to hash classifier: {e}")
return f"clf_{type(clf).__name__}_{id(clf)}"
def _hash_cv_generator(cv_gen) -> str:
"""Generate hash for CV generator (KFold, PurgedKFold, etc.)"""
try:
cv_type = type(cv_gen).__name__
# Get CV parameters
params = {}
if hasattr(cv_gen, "n_splits"):
params["n_splits"] = cv_gen.n_splits
if hasattr(cv_gen, "pct_embargo"):
params["pct_embargo"] = cv_gen.pct_embargo
if hasattr(cv_gen, "t1"):
# Hash the t1 series structure (not full content for speed)
t1 = cv_gen.t1
if isinstance(t1, pd.Series):
params["t1_len"] = len(t1)
params["t1_start"] = str(t1.index[0])
params["t1_end"] = str(t1.index[-1])
param_str = json.dumps(params, sort_keys=True)
combined = f"{cv_type}_{param_str}"
return hashlib.md5(combined.encode()).hexdigest()[:12]
except Exception as e:
logger.debug(f"Failed to hash CV generator: {e}")
return f"cv_{type(cv_gen).__name__}_{id(cv_gen)}"
def _hash_dataframe_fast(df: pd.DataFrame) -> str:
"""
Fast DataFrame hashing for CV caching.
Uses shape + columns + index range + sample of data.
"""
parts = [
f"shape_{df.shape}",
f"cols_{hashlib.md5(str(tuple(df.columns)).encode()).hexdigest()[:8]}",
]
# Hash index
if isinstance(df.index, pd.DatetimeIndex):
parts.append(f"idx_{df.index[0]}_{df.index[-1]}_{len(df)}")
else:
parts.append(f"idx_{df.index[0]}_{df.index[-1]}")
# Sample data for hash (for speed)
if len(df) > 100:
sample = df.iloc[:: max(1, len(df) // 100)]
else:
sample = df
data_hash = hashlib.md5(sample.values.tobytes()).hexdigest()[:8]
parts.append(f"data_{data_hash}")
return "_".join(parts)
def _hash_series_fast(series: pd.Series) -> str:
"""Fast Series hashing."""
parts = [f"len_{len(series)}", f"dtype_{series.dtype}"]
if isinstance(series.index, pd.DatetimeIndex):
parts.append(f"idx_{series.index[0]}_{series.index[-1]}")
# Sample for hash
if len(series) > 100:
sample = series.iloc[:: max(1, len(series) // 100)]
else:
sample = series
data_hash = hashlib.md5(sample.values.tobytes()).hexdigest()[:8]
parts.append(f"data_{data_hash}")
return "_".join(parts)
def _generate_cv_cache_key(func: Callable, args: tuple, kwargs: dict) -> str:
"""
Generate specialized cache key for CV functions.
Handles classifiers, CV generators, DataFrames, and sample weights.
"""
key_parts = [func.__module__, func.__qualname__]
# Get function signature to map args to param names
sig = inspect.signature(func)
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
for param_name, param_value in bound.arguments.items():
try:
# Handle different parameter types
if param_value is None:
key_parts.append(f"{param_name}_None")
elif isinstance(param_value, BaseEstimator):
# Sklearn classifier/estimator
clf_hash = _hash_classifier(param_value)
key_parts.append(f"{param_name}_clf_{clf_hash}")
elif hasattr(param_value, "split") and hasattr(param_value, "n_splits"):
# CV generator (has split method and n_splits)
cv_hash = _hash_cv_generator(param_value)
key_parts.append(f"{param_name}_cv_{cv_hash}")
elif isinstance(param_value, pd.DataFrame):
df_hash = _hash_dataframe_fast(param_value)
key_parts.append(f"{param_name}_df_{df_hash}")
elif isinstance(param_value, pd.Series):
series_hash = _hash_series_fast(param_value)
key_parts.append(f"{param_name}_ser_{series_hash}")
elif isinstance(param_value, np.ndarray):
arr_hash = hashlib.md5(param_value.tobytes()).hexdigest()[:8]
key_parts.append(f"{param_name}_arr_{param_value.shape}_{arr_hash}")
elif isinstance(param_value, (str, int, float, bool)):
key_parts.append(f"{param_name}_{param_value}")
elif callable(param_value):
# For scoring functions
func_name = getattr(param_value, "__name__", str(type(param_value)))
key_parts.append(f"{param_name}_func_{func_name}")
else:
# Fallback: try to hash string representation
key_parts.append(f"{param_name}_{hash(str(param_value))}")
except Exception as e:
logger.debug(f"Failed to hash param '{param_name}': {e}")
key_parts.append(f"{param_name}_unknown")
# Create final hash
combined = "_".join(key_parts)
return hashlib.md5(combined.encode()).hexdigest()
def cv_cacheable(
_func=None,
track_data_access: bool = False,
dataset_name: Optional[str] = None,
purpose: Optional[str] = None,
log_metrics: bool = True,
):
"""
Specialized caching decorator for cross-validation functions.
Handles sklearn classifiers, CV generators, and complex ML workflows.
Dual-mode decorator that supports both old and new syntax.
# Old syntax (backward compatible)
@cv_cacheable
def my_func(...)
# New syntax
@cv_cacheable(track_data_access=True, dataset_name='my_data', purpose='train')
def my_func(...)
Args:
track_data_access: Track DataFrame access for contamination detection
dataset_name: Name of dataset for tracking
purpose: Purpose of data access (train/test/validate/optimize/analyze)
log_metrics: Log results to MLflow if available
"""
# Validate purpose parameter
if purpose and purpose not in ["train", "test", "validate", "optimize", "analyze"]:
raise ValueError(
f"Invalid purpose: {purpose}. Must be one of: "
"train, test, validate, optimize, analyze"
)
def decorator(func):
# If no enhanced parameters are set, use old behavior
if not track_data_access and dataset_name is None and purpose is None:
return _cv_cacheable_legacy(func)
else:
return _cv_cacheable_enhanced(
func,
track_data_access=track_data_access,
dataset_name=dataset_name,
purpose=purpose,
log_metrics=log_metrics,
)
if _func is None:
return decorator
else:
return decorator(_func)
def _cv_cacheable_legacy(func):
"""Original cv_cacheable implementation for backward compatibility."""
from . import CACHE_DIRS, cache_stats
func_name = f"{func.__module__}.{func.__qualname__}"
cache_dir = CACHE_DIRS["base"] / "cv_cache"
cache_dir.mkdir(exist_ok=True)
@wraps(func)
def wrapper(*args, **kwargs):
# Original cache key generation (unchanged)
cache_key = _generate_cv_cache_key(func, args, kwargs)
cache_file = cache_dir / f"{cache_key}.pkl"
if cache_file.exists():
try:
with open(cache_file, "rb") as f:
result = pickle.load(f)
cache_stats.record_hit(func_name)
logger.info(f"CV cache hit for {func.__name__}")
return result
except Exception as e:
logger.warning(f"CV cache read failed: {e}")
cache_file.unlink()
# Cache miss
cache_stats.record_miss(func_name)
result = func(*args, **kwargs)
try:
with open(cache_file, "wb") as f:
pickle.dump(result, f)
except Exception as e:
logger.warning(f"Failed to cache CV result: {e}")
return result
wrapper._afml_cacheable = True
return wrapper
def _cv_cacheable_enhanced(
func, track_data_access=False, dataset_name=None, purpose=None, log_metrics=True
):
"""Enhanced version with tracking capabilities."""
from . import CACHE_DIRS, cache_stats
from .mlflow_integration import MLFLOW_AVAILABLE, get_mlflow_cache
func_name = f"{func.__module__}.{func.__qualname__}"
cache_dir = CACHE_DIRS["base"] / "cv_cache_enhanced"
cache_dir.mkdir(exist_ok=True)
def _generate_enhanced_cv_cache_key(
base_key, track_data_access, dataset_name, purpose, log_metrics
):
"""Generate cache key that includes tracking parameters."""
tracking_params = {
"track_data_access": track_data_access,
"dataset_name": dataset_name,
"purpose": purpose,
"log_metrics": log_metrics,
}
params_hash = hashlib.md5(
json.dumps(tracking_params, sort_keys=True).encode()
).hexdigest()[:8]
return f"{base_key}_tracking_{params_hash}"
@wraps(func)
def wrapper(*args, **kwargs):
base_key = _generate_cv_cache_key(func, args, kwargs)
cache_key = _generate_enhanced_cv_cache_key(
base_key, track_data_access, dataset_name, purpose, log_metrics
)
cache_file = cache_dir / f"{cache_key}.pkl"
# Track data access - IMPORT HERE
if track_data_access:
from .data_access_tracker import get_data_tracker
_track_cv_data_access(
get_data_tracker(), args, kwargs, dataset_name, purpose
)
# Check cache
if cache_file.exists():
try:
with open(cache_file, "rb") as f:
result = pickle.load(f)
cache_stats.record_hit(func_name)
# Log cached results to MLflow
if log_metrics and MLFLOW_AVAILABLE:
_log_cv_metrics_to_mlflow(result, func_name, cache_key, "cached")
logger.info(f"Enhanced CV cache hit for {func.__name__}")
return result
except Exception as e:
logger.warning(f"Enhanced CV cache read failed: {e}")
cache_file.unlink()
# Cache miss
cache_stats.record_miss(func_name)
logger.info(f"Enhanced CV cache miss for {func.__name__} - computing...")
start_time = time.time()
result = func(*args, **kwargs)
execution_time = time.time() - start_time
# Save to enhanced cache
try:
with open(cache_file, "wb") as f:
pickle.dump(result, f)
logger.debug(f"Cached enhanced CV result: {cache_key}")
except Exception as e:
logger.warning(f"Failed to cache enhanced CV result: {e}")
# Log to MLflow
if log_metrics and MLFLOW_AVAILABLE:
_log_cv_metrics_to_mlflow(result, func_name, cache_key, "computed")
try:
mlflow_cache = get_mlflow_cache()
mlflow_cache._log_metrics({"execution_time_seconds": execution_time})
except Exception as e:
logger.debug(f"Failed to log execution time: {e}")
return result
wrapper._afml_cacheable = True
return wrapper
def _track_cv_data_access(tracker, args, kwargs, dataset_name, purpose):
"""Track data access in CV functions."""
from .robust_cache_keys import _is_trackable_dataframe
# Extract X, y from common CV function signatures
X, y = None, None
# Try to find X and y in args/kwargs
for arg in args:
if isinstance(arg, pd.DataFrame) and _is_trackable_dataframe(arg):
X = arg
elif isinstance(arg, (pd.Series, np.ndarray)) and len(arg) > 0:
y = arg
for key, value in kwargs.items():
if key in ["X", "x", "features"] and _is_trackable_dataframe(value):
X = value
elif key in ["y", "target", "labels"] and isinstance(
value, (pd.Series, np.ndarray)
):
y = value
# Log access if we found trackable data
if X is not None:
tracker.log_access(
dataset_name=dataset_name or "cv_dataset",
start_date=X.index[0],
end_date=X.index[-1],
purpose=purpose or "cv",
data_shape=X.shape,
)
def _log_cv_metrics_to_mlflow(result, func_name, cache_key):
"""Log CV metrics to MLflow for experiment tracking."""
from .mlflow_integration import get_mlflow_cache
try:
mlflow_cache = get_mlflow_cache()
with mlflow_cache.experiment_run(
run_name=f"cv_{func_name}_{cache_key[:8]}",
tags={"type": "cross_validation", "function": func_name},
) as ctx:
# Extract metrics from common CV result formats
if isinstance(result, dict):
# Direct metric dictionary
for key, value in result.items():
if isinstance(value, (int, float)):
ctx.log_metric(f"cv_{key}", value)
elif isinstance(result, (list, np.ndarray)):
# Array of scores
if len(result) > 0:
ctx.log_metric("cv_mean_score", np.mean(result))
ctx.log_metric("cv_std_score", np.std(result))
ctx.log_metric("cv_n_folds", len(result))
elif hasattr(result, "cv_results_"):
# Sklearn CV result object
for key, value in result.cv_results_.items():
if isinstance(value, (list, np.ndarray)) and len(value) > 0:
ctx.log_metric(f"cv_{key}_mean", np.mean(value))
except Exception as e:
logger.debug(f"Failed to log CV metrics to MLflow: {e}")
def clear_cv_cache():
"""Clear all CV cache files."""
from . import CACHE_DIRS
cache_dir = CACHE_DIRS["base"] / "cv_cache"
model_cache_dir = CACHE_DIRS["base"] / "cv_cache_models"
count = 0
for cache_dir in [cache_dir, model_cache_dir]:
if cache_dir.exists():
for cache_file in cache_dir.glob("*.pkl"):
cache_file.unlink()
count += 1
logger.info(f"Cleared {count} CV cache files")
return count
def cv_cache_with_classifier_state(func: Callable) -> Callable:
"""
Caching decorator that also caches the trained classifier state.
Use this if you want to cache both CV scores AND the trained models.
Returns: (original_result, cached_classifiers)
where cached_classifiers is a list of trained classifiers from each fold.
Usage:
@cv_cache_with_classifier_state
def ml_cross_val_score_with_models(classifier, X, y, cv_gen, ...):
# Your CV loop that returns (scores, trained_models)
...
"""
from . import CACHE_DIRS, cache_stats
func_name = f"{func.__module__}.{func.__qualname__}"
cache_dir = CACHE_DIRS["base"] / "cv_cache_models"
cache_dir.mkdir(exist_ok=True)
@wraps(func)
def wrapper(*args, **kwargs):
# Generate cache key
try:
cache_key = _generate_cv_cache_key(func, args, kwargs)
except Exception as e:
logger.warning(f"CV cache key generation failed: {e}")
cache_stats.record_miss(func_name)
return func(*args, **kwargs)
# Check cache
cache_file = cache_dir / f"{cache_key}.pkl"
if cache_file.exists():
try:
with open(cache_file, "rb") as f:
result = pickle.load(f)
cache_stats.record_hit(func_name)
logger.info(f"CV cache hit (with models) for {func.__name__}")
return result
except Exception as e:
logger.warning(f"CV cache read failed: {e}")
cache_file.unlink()
# Cache miss - compute
cache_stats.record_miss(func_name)
logger.info(f"CV cache miss (with models) for {func.__name__} - computing...")
result = func(*args, **kwargs)
# Save to cache
try:
with open(cache_file, "wb") as f:
pickle.dump(result, f)
logger.debug(f"Cached CV result with models: {cache_key}")
except Exception as e:
logger.warning(f"Failed to cache CV result: {e}")
return result
wrapper._afml_cacheable = True
return wrapper
__all__ = [
"cv_cacheable",
"cv_cache_with_classifier_state",
"clear_cv_cache",
]