Spaces:
No application file
No application file
| """ | |
| Specialized caching for cross-validation functions. | |
| Handles sklearn classifiers, CV generators, and complex ML workflows. | |
| """ | |
| import hashlib | |
| import inspect | |
| import json | |
| import pickle | |
| import time | |
| from functools import wraps | |
| from typing import Callable, Optional | |
| import numpy as np | |
| import pandas as pd | |
| from loguru import logger | |
| from sklearn.base import BaseEstimator | |
| def _hash_classifier(clf: BaseEstimator) -> str: | |
| """ | |
| Generate stable hash for sklearn classifier. | |
| Uses class name + parameters (not the trained state). | |
| """ | |
| try: | |
| # Get classifier type and parameters | |
| clf_type = type(clf).__name__ | |
| params = clf.get_params(deep=True) | |
| # Filter out non-serializable params (like objects, functions) | |
| serializable_params = {} | |
| for k, v in params.items(): | |
| try: | |
| # Test if JSON serializable | |
| json.dumps(v) | |
| serializable_params[k] = v | |
| except (TypeError, ValueError): | |
| # Use type name for non-serializable params | |
| serializable_params[k] = f"<{type(v).__name__}>" | |
| # Create stable hash | |
| param_str = json.dumps(serializable_params, sort_keys=True) | |
| combined = f"{clf_type}_{param_str}" | |
| return hashlib.md5(combined.encode()).hexdigest()[:12] | |
| except Exception as e: | |
| logger.debug(f"Failed to hash classifier: {e}") | |
| return f"clf_{type(clf).__name__}_{id(clf)}" | |
| def _hash_cv_generator(cv_gen) -> str: | |
| """Generate hash for CV generator (KFold, PurgedKFold, etc.)""" | |
| try: | |
| cv_type = type(cv_gen).__name__ | |
| # Get CV parameters | |
| params = {} | |
| if hasattr(cv_gen, "n_splits"): | |
| params["n_splits"] = cv_gen.n_splits | |
| if hasattr(cv_gen, "pct_embargo"): | |
| params["pct_embargo"] = cv_gen.pct_embargo | |
| if hasattr(cv_gen, "t1"): | |
| # Hash the t1 series structure (not full content for speed) | |
| t1 = cv_gen.t1 | |
| if isinstance(t1, pd.Series): | |
| params["t1_len"] = len(t1) | |
| params["t1_start"] = str(t1.index[0]) | |
| params["t1_end"] = str(t1.index[-1]) | |
| param_str = json.dumps(params, sort_keys=True) | |
| combined = f"{cv_type}_{param_str}" | |
| return hashlib.md5(combined.encode()).hexdigest()[:12] | |
| except Exception as e: | |
| logger.debug(f"Failed to hash CV generator: {e}") | |
| return f"cv_{type(cv_gen).__name__}_{id(cv_gen)}" | |
| def _hash_dataframe_fast(df: pd.DataFrame) -> str: | |
| """ | |
| Fast DataFrame hashing for CV caching. | |
| Uses shape + columns + index range + sample of data. | |
| """ | |
| parts = [ | |
| f"shape_{df.shape}", | |
| f"cols_{hashlib.md5(str(tuple(df.columns)).encode()).hexdigest()[:8]}", | |
| ] | |
| # Hash index | |
| if isinstance(df.index, pd.DatetimeIndex): | |
| parts.append(f"idx_{df.index[0]}_{df.index[-1]}_{len(df)}") | |
| else: | |
| parts.append(f"idx_{df.index[0]}_{df.index[-1]}") | |
| # Sample data for hash (for speed) | |
| if len(df) > 100: | |
| sample = df.iloc[:: max(1, len(df) // 100)] | |
| else: | |
| sample = df | |
| data_hash = hashlib.md5(sample.values.tobytes()).hexdigest()[:8] | |
| parts.append(f"data_{data_hash}") | |
| return "_".join(parts) | |
| def _hash_series_fast(series: pd.Series) -> str: | |
| """Fast Series hashing.""" | |
| parts = [f"len_{len(series)}", f"dtype_{series.dtype}"] | |
| if isinstance(series.index, pd.DatetimeIndex): | |
| parts.append(f"idx_{series.index[0]}_{series.index[-1]}") | |
| # Sample for hash | |
| if len(series) > 100: | |
| sample = series.iloc[:: max(1, len(series) // 100)] | |
| else: | |
| sample = series | |
| data_hash = hashlib.md5(sample.values.tobytes()).hexdigest()[:8] | |
| parts.append(f"data_{data_hash}") | |
| return "_".join(parts) | |
| def _generate_cv_cache_key(func: Callable, args: tuple, kwargs: dict) -> str: | |
| """ | |
| Generate specialized cache key for CV functions. | |
| Handles classifiers, CV generators, DataFrames, and sample weights. | |
| """ | |
| key_parts = [func.__module__, func.__qualname__] | |
| # Get function signature to map args to param names | |
| sig = inspect.signature(func) | |
| bound = sig.bind(*args, **kwargs) | |
| bound.apply_defaults() | |
| for param_name, param_value in bound.arguments.items(): | |
| try: | |
| # Handle different parameter types | |
| if param_value is None: | |
| key_parts.append(f"{param_name}_None") | |
| elif isinstance(param_value, BaseEstimator): | |
| # Sklearn classifier/estimator | |
| clf_hash = _hash_classifier(param_value) | |
| key_parts.append(f"{param_name}_clf_{clf_hash}") | |
| elif hasattr(param_value, "split") and hasattr(param_value, "n_splits"): | |
| # CV generator (has split method and n_splits) | |
| cv_hash = _hash_cv_generator(param_value) | |
| key_parts.append(f"{param_name}_cv_{cv_hash}") | |
| elif isinstance(param_value, pd.DataFrame): | |
| df_hash = _hash_dataframe_fast(param_value) | |
| key_parts.append(f"{param_name}_df_{df_hash}") | |
| elif isinstance(param_value, pd.Series): | |
| series_hash = _hash_series_fast(param_value) | |
| key_parts.append(f"{param_name}_ser_{series_hash}") | |
| elif isinstance(param_value, np.ndarray): | |
| arr_hash = hashlib.md5(param_value.tobytes()).hexdigest()[:8] | |
| key_parts.append(f"{param_name}_arr_{param_value.shape}_{arr_hash}") | |
| elif isinstance(param_value, (str, int, float, bool)): | |
| key_parts.append(f"{param_name}_{param_value}") | |
| elif callable(param_value): | |
| # For scoring functions | |
| func_name = getattr(param_value, "__name__", str(type(param_value))) | |
| key_parts.append(f"{param_name}_func_{func_name}") | |
| else: | |
| # Fallback: try to hash string representation | |
| key_parts.append(f"{param_name}_{hash(str(param_value))}") | |
| except Exception as e: | |
| logger.debug(f"Failed to hash param '{param_name}': {e}") | |
| key_parts.append(f"{param_name}_unknown") | |
| # Create final hash | |
| combined = "_".join(key_parts) | |
| return hashlib.md5(combined.encode()).hexdigest() | |
| def cv_cacheable( | |
| _func=None, | |
| track_data_access: bool = False, | |
| dataset_name: Optional[str] = None, | |
| purpose: Optional[str] = None, | |
| log_metrics: bool = True, | |
| ): | |
| """ | |
| Specialized caching decorator for cross-validation functions. | |
| Handles sklearn classifiers, CV generators, and complex ML workflows. | |
| Dual-mode decorator that supports both old and new syntax. | |
| # Old syntax (backward compatible) | |
| @cv_cacheable | |
| def my_func(...) | |
| # New syntax | |
| @cv_cacheable(track_data_access=True, dataset_name='my_data', purpose='train') | |
| def my_func(...) | |
| Args: | |
| track_data_access: Track DataFrame access for contamination detection | |
| dataset_name: Name of dataset for tracking | |
| purpose: Purpose of data access (train/test/validate/optimize/analyze) | |
| log_metrics: Log results to MLflow if available | |
| """ | |
| # Validate purpose parameter | |
| if purpose and purpose not in ["train", "test", "validate", "optimize", "analyze"]: | |
| raise ValueError( | |
| f"Invalid purpose: {purpose}. Must be one of: " | |
| "train, test, validate, optimize, analyze" | |
| ) | |
| def decorator(func): | |
| # If no enhanced parameters are set, use old behavior | |
| if not track_data_access and dataset_name is None and purpose is None: | |
| return _cv_cacheable_legacy(func) | |
| else: | |
| return _cv_cacheable_enhanced( | |
| func, | |
| track_data_access=track_data_access, | |
| dataset_name=dataset_name, | |
| purpose=purpose, | |
| log_metrics=log_metrics, | |
| ) | |
| if _func is None: | |
| return decorator | |
| else: | |
| return decorator(_func) | |
| def _cv_cacheable_legacy(func): | |
| """Original cv_cacheable implementation for backward compatibility.""" | |
| from . import CACHE_DIRS, cache_stats | |
| func_name = f"{func.__module__}.{func.__qualname__}" | |
| cache_dir = CACHE_DIRS["base"] / "cv_cache" | |
| cache_dir.mkdir(exist_ok=True) | |
| def wrapper(*args, **kwargs): | |
| # Original cache key generation (unchanged) | |
| cache_key = _generate_cv_cache_key(func, args, kwargs) | |
| cache_file = cache_dir / f"{cache_key}.pkl" | |
| if cache_file.exists(): | |
| try: | |
| with open(cache_file, "rb") as f: | |
| result = pickle.load(f) | |
| cache_stats.record_hit(func_name) | |
| logger.info(f"CV cache hit for {func.__name__}") | |
| return result | |
| except Exception as e: | |
| logger.warning(f"CV cache read failed: {e}") | |
| cache_file.unlink() | |
| # Cache miss | |
| cache_stats.record_miss(func_name) | |
| result = func(*args, **kwargs) | |
| try: | |
| with open(cache_file, "wb") as f: | |
| pickle.dump(result, f) | |
| except Exception as e: | |
| logger.warning(f"Failed to cache CV result: {e}") | |
| return result | |
| wrapper._afml_cacheable = True | |
| return wrapper | |
| def _cv_cacheable_enhanced( | |
| func, track_data_access=False, dataset_name=None, purpose=None, log_metrics=True | |
| ): | |
| """Enhanced version with tracking capabilities.""" | |
| from . import CACHE_DIRS, cache_stats | |
| from .mlflow_integration import MLFLOW_AVAILABLE, get_mlflow_cache | |
| func_name = f"{func.__module__}.{func.__qualname__}" | |
| cache_dir = CACHE_DIRS["base"] / "cv_cache_enhanced" | |
| cache_dir.mkdir(exist_ok=True) | |
| def _generate_enhanced_cv_cache_key( | |
| base_key, track_data_access, dataset_name, purpose, log_metrics | |
| ): | |
| """Generate cache key that includes tracking parameters.""" | |
| tracking_params = { | |
| "track_data_access": track_data_access, | |
| "dataset_name": dataset_name, | |
| "purpose": purpose, | |
| "log_metrics": log_metrics, | |
| } | |
| params_hash = hashlib.md5( | |
| json.dumps(tracking_params, sort_keys=True).encode() | |
| ).hexdigest()[:8] | |
| return f"{base_key}_tracking_{params_hash}" | |
| def wrapper(*args, **kwargs): | |
| base_key = _generate_cv_cache_key(func, args, kwargs) | |
| cache_key = _generate_enhanced_cv_cache_key( | |
| base_key, track_data_access, dataset_name, purpose, log_metrics | |
| ) | |
| cache_file = cache_dir / f"{cache_key}.pkl" | |
| # Track data access - IMPORT HERE | |
| if track_data_access: | |
| from .data_access_tracker import get_data_tracker | |
| _track_cv_data_access( | |
| get_data_tracker(), args, kwargs, dataset_name, purpose | |
| ) | |
| # Check cache | |
| if cache_file.exists(): | |
| try: | |
| with open(cache_file, "rb") as f: | |
| result = pickle.load(f) | |
| cache_stats.record_hit(func_name) | |
| # Log cached results to MLflow | |
| if log_metrics and MLFLOW_AVAILABLE: | |
| _log_cv_metrics_to_mlflow(result, func_name, cache_key, "cached") | |
| logger.info(f"Enhanced CV cache hit for {func.__name__}") | |
| return result | |
| except Exception as e: | |
| logger.warning(f"Enhanced CV cache read failed: {e}") | |
| cache_file.unlink() | |
| # Cache miss | |
| cache_stats.record_miss(func_name) | |
| logger.info(f"Enhanced CV cache miss for {func.__name__} - computing...") | |
| start_time = time.time() | |
| result = func(*args, **kwargs) | |
| execution_time = time.time() - start_time | |
| # Save to enhanced cache | |
| try: | |
| with open(cache_file, "wb") as f: | |
| pickle.dump(result, f) | |
| logger.debug(f"Cached enhanced CV result: {cache_key}") | |
| except Exception as e: | |
| logger.warning(f"Failed to cache enhanced CV result: {e}") | |
| # Log to MLflow | |
| if log_metrics and MLFLOW_AVAILABLE: | |
| _log_cv_metrics_to_mlflow(result, func_name, cache_key, "computed") | |
| try: | |
| mlflow_cache = get_mlflow_cache() | |
| mlflow_cache._log_metrics({"execution_time_seconds": execution_time}) | |
| except Exception as e: | |
| logger.debug(f"Failed to log execution time: {e}") | |
| return result | |
| wrapper._afml_cacheable = True | |
| return wrapper | |
| def _track_cv_data_access(tracker, args, kwargs, dataset_name, purpose): | |
| """Track data access in CV functions.""" | |
| from .robust_cache_keys import _is_trackable_dataframe | |
| # Extract X, y from common CV function signatures | |
| X, y = None, None | |
| # Try to find X and y in args/kwargs | |
| for arg in args: | |
| if isinstance(arg, pd.DataFrame) and _is_trackable_dataframe(arg): | |
| X = arg | |
| elif isinstance(arg, (pd.Series, np.ndarray)) and len(arg) > 0: | |
| y = arg | |
| for key, value in kwargs.items(): | |
| if key in ["X", "x", "features"] and _is_trackable_dataframe(value): | |
| X = value | |
| elif key in ["y", "target", "labels"] and isinstance( | |
| value, (pd.Series, np.ndarray) | |
| ): | |
| y = value | |
| # Log access if we found trackable data | |
| if X is not None: | |
| tracker.log_access( | |
| dataset_name=dataset_name or "cv_dataset", | |
| start_date=X.index[0], | |
| end_date=X.index[-1], | |
| purpose=purpose or "cv", | |
| data_shape=X.shape, | |
| ) | |
| def _log_cv_metrics_to_mlflow(result, func_name, cache_key): | |
| """Log CV metrics to MLflow for experiment tracking.""" | |
| from .mlflow_integration import get_mlflow_cache | |
| try: | |
| mlflow_cache = get_mlflow_cache() | |
| with mlflow_cache.experiment_run( | |
| run_name=f"cv_{func_name}_{cache_key[:8]}", | |
| tags={"type": "cross_validation", "function": func_name}, | |
| ) as ctx: | |
| # Extract metrics from common CV result formats | |
| if isinstance(result, dict): | |
| # Direct metric dictionary | |
| for key, value in result.items(): | |
| if isinstance(value, (int, float)): | |
| ctx.log_metric(f"cv_{key}", value) | |
| elif isinstance(result, (list, np.ndarray)): | |
| # Array of scores | |
| if len(result) > 0: | |
| ctx.log_metric("cv_mean_score", np.mean(result)) | |
| ctx.log_metric("cv_std_score", np.std(result)) | |
| ctx.log_metric("cv_n_folds", len(result)) | |
| elif hasattr(result, "cv_results_"): | |
| # Sklearn CV result object | |
| for key, value in result.cv_results_.items(): | |
| if isinstance(value, (list, np.ndarray)) and len(value) > 0: | |
| ctx.log_metric(f"cv_{key}_mean", np.mean(value)) | |
| except Exception as e: | |
| logger.debug(f"Failed to log CV metrics to MLflow: {e}") | |
| def clear_cv_cache(): | |
| """Clear all CV cache files.""" | |
| from . import CACHE_DIRS | |
| cache_dir = CACHE_DIRS["base"] / "cv_cache" | |
| model_cache_dir = CACHE_DIRS["base"] / "cv_cache_models" | |
| count = 0 | |
| for cache_dir in [cache_dir, model_cache_dir]: | |
| if cache_dir.exists(): | |
| for cache_file in cache_dir.glob("*.pkl"): | |
| cache_file.unlink() | |
| count += 1 | |
| logger.info(f"Cleared {count} CV cache files") | |
| return count | |
| def cv_cache_with_classifier_state(func: Callable) -> Callable: | |
| """ | |
| Caching decorator that also caches the trained classifier state. | |
| Use this if you want to cache both CV scores AND the trained models. | |
| Returns: (original_result, cached_classifiers) | |
| where cached_classifiers is a list of trained classifiers from each fold. | |
| Usage: | |
| @cv_cache_with_classifier_state | |
| def ml_cross_val_score_with_models(classifier, X, y, cv_gen, ...): | |
| # Your CV loop that returns (scores, trained_models) | |
| ... | |
| """ | |
| from . import CACHE_DIRS, cache_stats | |
| func_name = f"{func.__module__}.{func.__qualname__}" | |
| cache_dir = CACHE_DIRS["base"] / "cv_cache_models" | |
| cache_dir.mkdir(exist_ok=True) | |
| def wrapper(*args, **kwargs): | |
| # Generate cache key | |
| try: | |
| cache_key = _generate_cv_cache_key(func, args, kwargs) | |
| except Exception as e: | |
| logger.warning(f"CV cache key generation failed: {e}") | |
| cache_stats.record_miss(func_name) | |
| return func(*args, **kwargs) | |
| # Check cache | |
| cache_file = cache_dir / f"{cache_key}.pkl" | |
| if cache_file.exists(): | |
| try: | |
| with open(cache_file, "rb") as f: | |
| result = pickle.load(f) | |
| cache_stats.record_hit(func_name) | |
| logger.info(f"CV cache hit (with models) for {func.__name__}") | |
| return result | |
| except Exception as e: | |
| logger.warning(f"CV cache read failed: {e}") | |
| cache_file.unlink() | |
| # Cache miss - compute | |
| cache_stats.record_miss(func_name) | |
| logger.info(f"CV cache miss (with models) for {func.__name__} - computing...") | |
| result = func(*args, **kwargs) | |
| # Save to cache | |
| try: | |
| with open(cache_file, "wb") as f: | |
| pickle.dump(result, f) | |
| logger.debug(f"Cached CV result with models: {cache_key}") | |
| except Exception as e: | |
| logger.warning(f"Failed to cache CV result: {e}") | |
| return result | |
| wrapper._afml_cacheable = True | |
| return wrapper | |
| __all__ = [ | |
| "cv_cacheable", | |
| "cv_cache_with_classifier_state", | |
| "clear_cv_cache", | |
| ] | |