File size: 4,577 Bytes
206d8b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d831f
206d8b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12d831f
 
206d8b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Run the feedback-to-model refresh pipeline in one command.

Usage:
    python scripts/retrain_from_feedback.py
    python scripts/retrain_from_feedback.py --skip_train
    python scripts/retrain_from_feedback.py --dry_run
"""

import argparse
import subprocess
import sys
from pathlib import Path


def run_step(command: list[str], dry_run: bool) -> None:
    rendered = " ".join(command)
    print(f"\n>> {rendered}")
    if dry_run:
        return
    subprocess.run(command, check=True)


def count_feedback_records(feedback_dir: Path) -> int:
    if not feedback_dir.exists():
        return 0
    return sum(1 for path in feedback_dir.rglob("*.json") if path.is_file())


def main() -> None:
    parser = argparse.ArgumentParser(description="Ingest reviewed feedback and optionally retrain the model.")
    parser.add_argument("--feedback_dir", default="feedback_queue")
    parser.add_argument("--feedback_output_dir", default="data/feedback_labeled")
    parser.add_argument("--trashnet_dir", default="data/raw/trashnet_sample/dataset-resized")
    parser.add_argument("--realwaste_dir", default="data/raw/realwaste")
    parser.add_argument("--extra_dir", default="data/local_boost")
    parser.add_argument("--processed_dir", default="data/processed")
    parser.add_argument("--models_dir", default="models")
    parser.add_argument("--phase1_epochs", type=int, default=3)
    parser.add_argument("--phase2_epochs", type=int, default=12)
    parser.add_argument("--balance_strategy", choices=["class_weight", "oversample"], default="class_weight")
    parser.add_argument("--skip_train", action="store_true")
    parser.add_argument("--skip_export", action="store_true")
    parser.add_argument("--benchmark", action="store_true")
    parser.add_argument("--overwrite_feedback", action="store_true")
    parser.add_argument("--min_feedback", type=int, default=20)
    parser.add_argument("--force_train", action="store_true")
    parser.add_argument("--dry_run", action="store_true")
    args = parser.parse_args()

    project_root = Path(__file__).resolve().parent.parent
    python_executable = Path(sys.executable)

    ingest_command = [
        str(python_executable),
        "scripts/ingest_feedback.py",
        "--feedback_dir",
        args.feedback_dir,
        "--output_dir",
        args.feedback_output_dir,
    ]
    if args.overwrite_feedback:
        ingest_command.append("--overwrite")

    prepare_command = [
        str(python_executable),
        "scripts/download_data.py",
        "--trashnet_dir",
        args.trashnet_dir,
        "--realwaste_dir",
        args.realwaste_dir,
        "--feedback_dir",
        args.feedback_output_dir,
        "--extra_dir",
        args.extra_dir,
        "--output_dir",
        args.processed_dir,
    ]

    train_command = [
        str(python_executable),
        "scripts/train.py",
        "--data_dir",
        args.processed_dir,
        "--output_dir",
        args.models_dir,
        "--phase1_epochs",
        str(args.phase1_epochs),
        "--phase2_epochs",
        str(args.phase2_epochs),
        "--balance_strategy",
        args.balance_strategy,
    ]

    export_command = [
        str(python_executable),
        "scripts/export_tflite.py",
        "--saved_model",
        str(Path(args.models_dir) / "waste_classifier_v1"),
        "--data_dir",
        args.processed_dir,
        "--output",
        str(Path(args.models_dir) / "model.tflite"),
    ]
    if args.benchmark:
        export_command.append("--benchmark")

    print("Feedback Retraining Pipeline")
    print(f"Project root : {project_root}")
    print(f"Python       : {python_executable}")

    run_step(ingest_command, args.dry_run)
    run_step(prepare_command, args.dry_run)

    feedback_count = count_feedback_records(project_root / args.feedback_dir)
    print(f"Feedback count: {feedback_count}")

    should_train = not args.skip_train
    if should_train and not args.force_train and feedback_count < args.min_feedback:
        print(
            f"\nSkipping training because feedback count is below --min_feedback "
            f"({feedback_count} < {args.min_feedback})."
        )
        print("Use --force_train to retrain anyway.")
        should_train = False

    if should_train:
        run_step(train_command, args.dry_run)

        if not args.skip_export:
            run_step(export_command, args.dry_run)
    elif not args.skip_export:
        print("\nSkipping export because training was skipped.")

    print("\nPipeline complete.")


if __name__ == "__main__":
    main()