File size: 3,882 Bytes
1aa566a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""Intelligent retraining trigger with a dual-gate decision.

Retraining fires only when ALL conditions hold:
  1. Feature drift detected (PSI >= threshold or KS p-value < threshold)
  2. Performance degradation detected (RMSE increased by >= N%)
  3. Minimum samples collected since last retrain
  4. Cooldown period has elapsed
"""
from __future__ import annotations

import json
import time
from pathlib import Path
from typing import Optional

from src.utils.config import settings, resolve
from src.utils.logging_config import get_logger

log = get_logger(__name__)


class RetrainingTrigger:
    """Evaluate whether all retraining conditions are met."""

    def __init__(self) -> None:
        self._last_retrain_time: Optional[float] = None
        self._retrain_log_path = resolve(settings.retraining.retrain_log_path)
        self._cooldown_seconds = settings.retraining.cooldown_hours * 3600

    def should_retrain(
        self,
        feature_drift_report: dict,
        performance_report: dict,
        samples_since_last_retrain: int,
    ) -> dict:
        """Return a decision dict with should_retrain, reasons, and blocking_reasons."""
        reasons: list[str] = []
        blocking: list[str] = []

        feature_drifted = feature_drift_report.get("drift_detected", False)
        perf_drifted = performance_report.get("drift_detected", False)

        if feature_drifted:
            drifted_features = feature_drift_report.get("drifted_features", [])
            reasons.append(f"Feature drift detected in: {drifted_features}")
        else:
            blocking.append("No significant feature drift.")

        if perf_drifted:
            pct = performance_report.get("pct_change", 0.0)
            reasons.append(f"Performance degraded by {pct:.1f}% vs baseline.")
        else:
            blocking.append("Performance within acceptable range.")

        min_samples = settings.retraining.min_samples_since_last_retrain
        if samples_since_last_retrain < min_samples:
            blocking.append(
                f"Only {samples_since_last_retrain} new samples (need {min_samples})."
            )

        if self._last_retrain_time is not None:
            elapsed = time.time() - self._last_retrain_time
            if elapsed < self._cooldown_seconds:
                remaining = int(self._cooldown_seconds - elapsed)
                blocking.append(
                    f"Cooldown active: {remaining}s remaining."
                )

        should_retrain = (
            feature_drifted
            and perf_drifted
            and samples_since_last_retrain >= min_samples
            and (
                self._last_retrain_time is None
                or (time.time() - self._last_retrain_time) >= self._cooldown_seconds
            )
        )

        decision = {
            "should_retrain": should_retrain,
            "reasons": reasons,
            "blocking_reasons": blocking,
            "feature_drift": feature_drifted,
            "performance_drift": perf_drifted,
            "samples_since_last_retrain": samples_since_last_retrain,
            "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
        }

        if should_retrain:
            log.warning("RETRAINING TRIGGERED! Reasons: %s", reasons)
        else:
            log.info("Retraining NOT triggered. Blocking: %s", blocking or ["none"])

        self._log_decision(decision)
        return decision

    def record_retrain_completed(self) -> None:
        """Reset cooldown after a successful retraining run."""
        self._last_retrain_time = time.time()
        log.info("Retraining cooldown reset at %s", time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()))

    def _log_decision(self, decision: dict) -> None:
        with open(self._retrain_log_path, "a", encoding="utf-8") as fh:
            fh.write(json.dumps(decision) + "\n")