File size: 19,065 Bytes
669d6a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
"""
Robust cache key generation for financial ML data structures.
Handles numpy arrays, pandas DataFrames, and time-series data properly.
"""

import hashlib
import pickle
from pathlib import Path
from typing import Any, Optional, Tuple

import numpy as np
import pandas as pd
from loguru import logger


class CacheKeyGenerator:
    """Generate robust, collision-resistant cache keys for ML data structures."""

    @staticmethod
    def generate_key(func, args: tuple, kwargs: dict) -> str:
        """
        Generate a robust cache key for a function call.

        Args:
            func: The function being cached
            args: Positional arguments
            kwargs: Keyword arguments

        Returns:
            MD5 hash string representing the unique call signature
        """
        key_parts = [
            func.__module__,
            func.__qualname__,
        ]

        # Process positional arguments
        for i, arg in enumerate(args):
            try:
                key_part = CacheKeyGenerator._hash_argument(arg, f"arg_{i}")
                key_parts.append(key_part)
            except Exception as e:
                logger.warning(f"Failed to hash argument {i} of type {type(arg)}: {e}")
                # Fallback to string representation
                key_parts.append(f"arg_{i}_{str(hash(str(arg)))}")

        # Process keyword arguments (sorted for consistency)
        for key, value in sorted(kwargs.items()):
            try:
                key_part = CacheKeyGenerator._hash_argument(value, key)
                key_parts.append(f"{key}={key_part}")
            except Exception as e:
                logger.warning(
                    f"Failed to hash kwarg '{key}' of type {type(value)}: {e}"
                )
                # Fallback
                key_parts.append(f"{key}={str(hash(str(value)))}")

        # Combine all parts and hash
        combined = "_".join(key_parts)
        return hashlib.md5(combined.encode()).hexdigest()

    @staticmethod
    def _hash_argument(arg: Any, name: str) -> str:
        """Hash a single argument based on its type."""
        try:
            from sklearn.base import BaseEstimator

            if isinstance(arg, BaseEstimator):
                return CacheKeyGenerator._hash_sklearn_estimator(arg, name)
        except ImportError:
            pass  # sklearn not available, continue with other types

        if isinstance(arg, np.ndarray):
            return CacheKeyGenerator._hash_numpy_array(arg, name)
        elif isinstance(arg, pd.DataFrame):
            return CacheKeyGenerator._hash_dataframe(arg, name)
        elif isinstance(arg, pd.Series):
            return CacheKeyGenerator._hash_series(arg, name)
        elif isinstance(arg, (list, tuple)):
            return CacheKeyGenerator._hash_sequence(arg, name)
        elif isinstance(arg, dict):
            return CacheKeyGenerator._hash_dict(arg, name)
        elif isinstance(arg, (int, float, str, bool, type(None))):
            return CacheKeyGenerator._hash_primitive(arg, name)
        else:
            # Fallback for unknown types
            return CacheKeyGenerator._hash_generic(arg, name)

    @staticmethod
    def _hash_numpy_array(arr: np.ndarray, name: str) -> str:
        """Hash numpy array including shape, dtype, and content."""
        # For large arrays, sample for performance
        if arr.size > 10000:
            # Hash shape, dtype, and a sample
            sample = arr.flat[:: max(1, arr.size // 1000)]  # Sample ~1000 points
            content_hash = hashlib.md5(sample.tobytes()).hexdigest()[:8]
        else:
            # Hash full content for small arrays
            content_hash = hashlib.md5(arr.tobytes()).hexdigest()[:8]

        return f"{name}_arr_{arr.shape}_{arr.dtype}_{content_hash}"

    @staticmethod
    def _hash_dataframe(df: pd.DataFrame, name: str) -> str:
        """Hash pandas DataFrame including index, columns, dtypes, and content."""
        parts = [
            f"shape_{df.shape}",
            f"cols_{hashlib.md5(str(tuple(df.columns)).encode()).hexdigest()[:8]}",
            f"dtypes_{hashlib.md5(str(tuple(df.dtypes)).encode()).hexdigest()[:8]}",
        ]

        # Hash index
        if isinstance(df.index, pd.DatetimeIndex):
            # For datetime index, hash start, end, and frequency
            parts.append(f"idx_dt_{df.index[0]}_{df.index[-1]}_{len(df.index)}")
        else:
            idx_hash = hashlib.md5(str(tuple(df.index)).encode()).hexdigest()[:8]
            parts.append(f"idx_{idx_hash}")

        # Hash content (sample for large DataFrames)
        if df.size > 10000:
            # Sample rows for hashing
            sample_rows = df.iloc[:: max(1, len(df) // 100)]  # ~100 rows
            content_hash = hashlib.md5(sample_rows.values.tobytes()).hexdigest()[:8]
        else:
            content_hash = hashlib.md5(df.values.tobytes()).hexdigest()[:8]

        parts.append(f"data_{content_hash}")

        return f"{name}_df_{'_'.join(parts)}"

    @staticmethod
    def _hash_series(series: pd.Series, name: str) -> str:
        """Hash pandas Series."""
        parts = [
            f"len_{len(series)}",
            f"dtype_{series.dtype}",
        ]

        # Hash index
        if isinstance(series.index, pd.DatetimeIndex):
            parts.append(f"idx_dt_{series.index[0]}_{series.index[-1]}")
        else:
            idx_hash = hashlib.md5(str(tuple(series.index)).encode()).hexdigest()[:8]
            parts.append(f"idx_{idx_hash}")

        # Hash values
        if len(series) > 1000:
            sample = series.iloc[:: max(1, len(series) // 100)]
            content_hash = hashlib.md5(sample.values.tobytes()).hexdigest()[:8]
        else:
            content_hash = hashlib.md5(series.values.tobytes()).hexdigest()[:8]

        parts.append(f"data_{content_hash}")

        return f"{name}_series_{'_'.join(parts)}"

    @staticmethod
    def _hash_sequence(seq: Tuple[Any, ...] | list, name: str) -> str:
        """Hash list or tuple recursively."""
        if len(seq) == 0:
            return f"{name}_empty_seq"

        # Hash each element
        element_hashes = []
        for i, item in enumerate(seq):
            elem_hash = CacheKeyGenerator._hash_argument(item, f"{name}_{i}")
            element_hashes.append(elem_hash)

        combined = "_".join(element_hashes)
        return hashlib.md5(combined.encode()).hexdigest()[:8]

    @staticmethod
    def _hash_dict(d: dict, name: str) -> str:
        """Hash dictionary recursively."""
        if len(d) == 0:
            return f"{name}_empty_dict"

        # Sort keys for consistency
        items_hash = []
        for key, value in sorted(d.items()):
            val_hash = CacheKeyGenerator._hash_argument(value, f"{name}_{key}")
            items_hash.append(f"{key}={val_hash}")

        combined = "_".join(items_hash)
        return hashlib.md5(combined.encode()).hexdigest()[:8]

    @staticmethod
    def _hash_primitive(value: Any, name: str) -> str:
        """Hash primitive types."""
        return f"{name}_{type(value).__name__}_{hash(value)}"

    @staticmethod
    def _hash_generic(obj: Any, name: str) -> str:
        """Fallback hashing for unknown types."""
        try:
            # Try to use object's __repr__
            return f"{name}_{type(obj).__name__}_{hash(repr(obj))}"
        except Exception:
            # Last resort: use id
            return f"{name}_{type(obj).__name__}_{id(obj)}"

    @staticmethod
    def _hash_sklearn_estimator(estimator: Any, name: str) -> str:
        """Hash sklearn estimator including nested estimators."""
        try:
            from sklearn.base import BaseEstimator

            if not isinstance(estimator, BaseEstimator):
                return CacheKeyGenerator._hash_generic(estimator, name)

            # Use the enhanced estimator hashing from cv_cache
            from .cv_cache import _hash_classifier

            estimator_hash = _hash_classifier(estimator)
            return f"{name}_estimator_{estimator_hash}"

        except ImportError:
            # Fallback if sklearn not available
            return CacheKeyGenerator._hash_generic(estimator, name)


class TimeSeriesCacheKey(CacheKeyGenerator):
    """
    Extended cache key generator with time-series awareness.
    Useful for financial data where lookback periods matter.
    """

    @staticmethod
    def generate_key_with_time_range(
        func,
        args: tuple,
        kwargs: dict,
        time_range: Tuple[pd.Timestamp, pd.Timestamp] = None,
    ) -> str:
        """
        Generate cache key that includes time range information.

        Args:
            func: Function being cached
            args: Positional arguments
            kwargs: Keyword arguments
            time_range: Optional (start, end) timestamp tuple

        Returns:
            Cache key string
        """
        base_key = CacheKeyGenerator.generate_key(func, args, kwargs)

        if time_range is None:
            # Try to extract time range from data
            time_range = TimeSeriesCacheKey._extract_time_range(args, kwargs)

        if time_range:
            start, end = time_range
            time_hash = f"time_{start}_{end}"
            return f"{base_key}_{time_hash}"

        return base_key

    @staticmethod
    def _extract_time_range(
        args: tuple, kwargs: dict
    ) -> Tuple[pd.Timestamp, pd.Timestamp] | None:
        """
        Attempt to extract time range from function arguments.
        Looks for DataFrames with DatetimeIndex or explicit start/end parameters.
        """
        # Check kwargs for explicit time parameters
        if "start_date" in kwargs and "end_date" in kwargs:
            return (
                pd.Timestamp(kwargs["start_date"]),
                pd.Timestamp(kwargs["end_date"]),
            )

        # Check for DataFrames with DatetimeIndex in args
        for arg in args:
            if isinstance(arg, pd.DataFrame) and isinstance(
                arg.index, pd.DatetimeIndex
            ):
                if len(arg.index) > 0:
                    return (arg.index[0], arg.index[-1])

            elif isinstance(arg, pd.Series) and isinstance(arg.index, pd.DatetimeIndex):
                if len(arg.index) > 0:
                    return (arg.index[0], arg.index[-1])

        return None


# =============================================================================
# Integration with existing cacheable decorator
# =============================================================================


def create_robust_cacheable(
    track_data_access: bool = False,
    dataset_name: Optional[str] = None,
    purpose: Optional[str] = None,
    use_time_awareness: bool = False,
):
    """
    Factory function to create robust cacheable decorators with data tracking.
    Args:
        track_data_access: Whether to track DataFrame accesses
        dataset_name: Name of the dataset for tracking
        purpose: One of 'train', 'test', 'validate', 'optimize', 'analyze'
        use_time_awareness: Whether to use time-series aware cache keys
    Returns:
        Decorator function
    """
    import time
    from functools import wraps

    from . import cache_stats, memory
    from .cache_monitoring import get_cache_monitor

    def decorator(func):
        func_name = f"{func.__module__}.{func.__qualname__}"
        cached_func = memory.cache(func)
        seen_signatures = set()
        monitor = get_cache_monitor()

        @wraps(func)
        def wrapper(*args, **kwargs):
            nonlocal seen_signatures

            # Track access time (ALWAYS do this first)
            monitor.track_access(func_name)

            # Generate cache key
            cache_key = None
            is_hit = False
            computation_start = None

            try:
                if use_time_awareness:
                    cache_key = TimeSeriesCacheKey.generate_key_with_time_range(
                        func, args, kwargs
                    )
                else:
                    cache_key = CacheKeyGenerator.generate_key(func, args, kwargs)

                # Track hit/miss
                try:
                    cached_func.check_call_in_cache(*args, **kwargs)
                    is_hit = True
                    cache_stats.record_hit(func_name)
                    logger.debug(f"Cache HIT for {func_name}")
                except:
                    cache_stats.record_miss(func_name)
                    is_hit = False
                    computation_start = time.time()  # Start timing for misses
                    logger.debug(f"Cache MISS for {func_name}")

                # Add to seen_signatures for this session
                seen_signatures.add(cache_key)

            except Exception as e:
                logger.warning(f"Cache key generation failed for {func_name}: {e}")
                cache_stats.record_miss(func_name)
                cache_key = None
                is_hit = False
                computation_start = time.time()  # Start timing for error case

            # Track data access if requested
            if track_data_access:
                try:
                    from .data_access_tracker import get_data_tracker

                    _track_dataframe_access(
                        get_data_tracker(), args, kwargs, dataset_name, purpose
                    )
                except Exception as e:
                    logger.warning(f"Data tracking failed for {func_name}: {e}")

            # Execute function
            try:
                if is_hit:
                    # For cache hits, just return cached result (no timing needed)
                    result = cached_func(*args, **kwargs)
                else:
                    # For cache misses, time the computation
                    result = cached_func(*args, **kwargs)
                    if computation_start:
                        computation_time = time.time() - computation_start
                        monitor.track_computation_time(func_name, computation_time)
                        logger.debug(
                            f"Computation time for {func_name}: {computation_time:.3f}s"
                        )

                return result

            except (EOFError, pickle.PickleError, OSError) as e:
                # Handle cache corruption
                logger.warning(
                    f"Cache corruption for {func_name}: {type(e).__name__} - recomputing"
                )

                # Clear corrupted cache if possible
                if cache_key is not None:
                    _clear_corrupted_cache(cached_func, args, kwargs, func_name)

                # Execute function directly and track time
                direct_start = time.time()
                result = func(*args, **kwargs)

                if computation_start:  # Track time if it was originally a miss
                    computation_time = time.time() - direct_start
                    monitor.track_computation_time(func_name, computation_time)
                    logger.debug(
                        f"Direct computation time for {func_name}: {computation_time:.3f}s"
                    )

                return result

            except Exception as e:
                # Other unexpected errors
                logger.error(f"Unexpected cache error for {func_name}: {e}")
                raise

        # Add cache info method for debugging
        def cache_info():
            return {
                "function_name": func_name,
                "seen_signatures": len(seen_signatures),
                "hits": cache_stats._stats.get(func_name, {}).get("hits", 0),
                "misses": cache_stats._stats.get(func_name, {}).get("misses", 0),
            }

        wrapper.cache_info = cache_info
        wrapper._afml_cacheable = True
        return wrapper

    return decorator


def _clear_corrupted_cache(cached_func, args, kwargs, func_name):
    """Helper to clear corrupted cache entries."""
    try:
        if hasattr(cached_func, "_get_cache_id"):
            joblib_cache_key = cached_func._get_cache_id(*args, **kwargs)
            cache_dir = Path(cached_func.store_backend.location)

            # Remove files matching this cache key
            removed_count = 0
            for cache_file in cache_dir.rglob("*"):
                if cache_file.is_file() and str(joblib_cache_key) in str(cache_file):
                    cache_file.unlink()
                    removed_count += 1
                    logger.debug(f"Removed corrupted file: {cache_file.name}")

            if removed_count > 0:
                logger.info(
                    f"Cleared {removed_count} corrupted cache files for {func_name}"
                )

    except Exception as clear_exc:
        logger.warning(f"Failed to clear corrupted cache for {func_name}: {clear_exc}")


def _track_dataframe_access(tracker, args, kwargs, dataset_name, purpose):
    """Track DataFrame accesses for data hygiene monitoring."""
    # Check all arguments for DataFrames with DatetimeIndex
    for i, arg in enumerate(args):
        if _is_trackable_dataframe(arg):
            _log_dataframe_access(tracker, arg, dataset_name or f"arg_{i}", purpose)

    for key, value in kwargs.items():
        if _is_trackable_dataframe(value):
            _log_dataframe_access(tracker, value, dataset_name or key, purpose)


def _is_trackable_dataframe(obj):
    """Check if object is a DataFrame with temporal index."""
    return (
        isinstance(obj, pd.DataFrame)
        and isinstance(obj.index, pd.DatetimeIndex)
        and len(obj) > 0
    )


def _log_dataframe_access(tracker, df, name, purpose):
    """Log DataFrame access to tracker."""
    tracker.log_access(
        dataset_name=name,
        start_date=df.index[0],
        end_date=df.index[-1],
        purpose=purpose or "unknown",
        data_shape=df.shape,
    )


# =============================================================================
# Final convenience exports
# =============================================================================

# Standard decorators (backward compatible)
robust_cacheable = create_robust_cacheable(use_time_awareness=False)
time_aware_cacheable = create_robust_cacheable(use_time_awareness=True)

# Data tracking decorators (new functionality)
data_tracking_cacheable = lambda dataset_name, purpose: create_robust_cacheable(
    track_data_access=True,
    dataset_name=dataset_name,
    purpose=purpose,
    use_time_awareness=False,
)

time_aware_data_tracking_cacheable = (
    lambda dataset_name, purpose: create_robust_cacheable(
        track_data_access=True,
        dataset_name=dataset_name,
        purpose=purpose,
        use_time_awareness=True,
    )
)

__all__ = [
    "CacheKeyGenerator",
    "TimeSeriesCacheKey",
    "data_tracking_cacheable",  # NEW
    "robust_cacheable",  # Backward compatible
    "time_aware_cacheable",  # Backward compatible
    "time_aware_data_tracking_cacheable",  # NEW
]