CoolWasteAI / scripts /retrain_from_feedback.py
Celvin
Prepare deployable AI API for competition and free hosting
12d831f
"""
Run the feedback-to-model refresh pipeline in one command.
Usage:
python scripts/retrain_from_feedback.py
python scripts/retrain_from_feedback.py --skip_train
python scripts/retrain_from_feedback.py --dry_run
"""
import argparse
import subprocess
import sys
from pathlib import Path
def run_step(command: list[str], dry_run: bool) -> None:
rendered = " ".join(command)
print(f"\n>> {rendered}")
if dry_run:
return
subprocess.run(command, check=True)
def count_feedback_records(feedback_dir: Path) -> int:
if not feedback_dir.exists():
return 0
return sum(1 for path in feedback_dir.rglob("*.json") if path.is_file())
def main() -> None:
parser = argparse.ArgumentParser(description="Ingest reviewed feedback and optionally retrain the model.")
parser.add_argument("--feedback_dir", default="feedback_queue")
parser.add_argument("--feedback_output_dir", default="data/feedback_labeled")
parser.add_argument("--trashnet_dir", default="data/raw/trashnet_sample/dataset-resized")
parser.add_argument("--realwaste_dir", default="data/raw/realwaste")
parser.add_argument("--extra_dir", default="data/local_boost")
parser.add_argument("--processed_dir", default="data/processed")
parser.add_argument("--models_dir", default="models")
parser.add_argument("--phase1_epochs", type=int, default=3)
parser.add_argument("--phase2_epochs", type=int, default=12)
parser.add_argument("--balance_strategy", choices=["class_weight", "oversample"], default="class_weight")
parser.add_argument("--skip_train", action="store_true")
parser.add_argument("--skip_export", action="store_true")
parser.add_argument("--benchmark", action="store_true")
parser.add_argument("--overwrite_feedback", action="store_true")
parser.add_argument("--min_feedback", type=int, default=20)
parser.add_argument("--force_train", action="store_true")
parser.add_argument("--dry_run", action="store_true")
args = parser.parse_args()
project_root = Path(__file__).resolve().parent.parent
python_executable = Path(sys.executable)
ingest_command = [
str(python_executable),
"scripts/ingest_feedback.py",
"--feedback_dir",
args.feedback_dir,
"--output_dir",
args.feedback_output_dir,
]
if args.overwrite_feedback:
ingest_command.append("--overwrite")
prepare_command = [
str(python_executable),
"scripts/download_data.py",
"--trashnet_dir",
args.trashnet_dir,
"--realwaste_dir",
args.realwaste_dir,
"--feedback_dir",
args.feedback_output_dir,
"--extra_dir",
args.extra_dir,
"--output_dir",
args.processed_dir,
]
train_command = [
str(python_executable),
"scripts/train.py",
"--data_dir",
args.processed_dir,
"--output_dir",
args.models_dir,
"--phase1_epochs",
str(args.phase1_epochs),
"--phase2_epochs",
str(args.phase2_epochs),
"--balance_strategy",
args.balance_strategy,
]
export_command = [
str(python_executable),
"scripts/export_tflite.py",
"--saved_model",
str(Path(args.models_dir) / "waste_classifier_v1"),
"--data_dir",
args.processed_dir,
"--output",
str(Path(args.models_dir) / "model.tflite"),
]
if args.benchmark:
export_command.append("--benchmark")
print("Feedback Retraining Pipeline")
print(f"Project root : {project_root}")
print(f"Python : {python_executable}")
run_step(ingest_command, args.dry_run)
run_step(prepare_command, args.dry_run)
feedback_count = count_feedback_records(project_root / args.feedback_dir)
print(f"Feedback count: {feedback_count}")
should_train = not args.skip_train
if should_train and not args.force_train and feedback_count < args.min_feedback:
print(
f"\nSkipping training because feedback count is below --min_feedback "
f"({feedback_count} < {args.min_feedback})."
)
print("Use --force_train to retrain anyway.")
should_train = False
if should_train:
run_step(train_command, args.dry_run)
if not args.skip_export:
run_step(export_command, args.dry_run)
elif not args.skip_export:
print("\nSkipping export because training was skipped.")
print("\nPipeline complete.")
if __name__ == "__main__":
main()