File size: 1,631 Bytes
aac9e56
 
 
e2b2661
aac9e56
e2b2661
aac9e56
e2b2661
aac9e56
 
e2b2661
 
 
 
 
 
 
 
aac9e56
 
 
e2b2661
aac9e56
 
 
 
e2b2661
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aac9e56
 
 
e2b2661
aac9e56
e2b2661
 
 
 
 
aac9e56
e2b2661
 
 
aac9e56
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# src/should_retrain.py
import json
import os
import pandas as pd

# ---------------- CONFIG ----------------
DRIFT_FILE = "drift_reports/drift_summary.json"
NEW_DATA_FILE = "data/processed/new_sentiment.csv"
DECISION_FILE = "drift_reports/retrain_flag.json"

MIN_NEW_ROWS = 50   # threshold for retraining based on new data
# ----------------------------------------


def check_drift():
    if not os.path.exists(DRIFT_FILE):
        return False

    with open(DRIFT_FILE) as f:
        drift = json.load(f)

    return any(
        v.get("drift_flag", False) for v in drift.values()
    )


def check_new_data_volume():
    if not os.path.exists(NEW_DATA_FILE):
        return False

    df = pd.read_csv(NEW_DATA_FILE)
    return len(df) >= MIN_NEW_ROWS


def main():
    drift_trigger = check_drift()
    data_trigger = check_new_data_volume()

    retrain = drift_trigger or data_trigger

    decision = {
        "retrain": retrain,
        "reason": {
            "drift_detected": drift_trigger,
            "new_data_threshold_met": data_trigger
        }
    }

    os.makedirs("drift_reports", exist_ok=True)
    with open(DECISION_FILE, "w") as f:
        json.dump(decision, f, indent=4)

    # ---- Console output (important for viva/demo) ----
    if retrain:
        print("Retraining required")
        if drift_trigger:
            print("→ Reason: feature drift detected")
        if data_trigger:
            print("→ Reason: sufficient new tweet/news data")
    else:
        print("No retraining required")
        print("→ No drift and insufficient new data")


if __name__ == "__main__":
    main()