Spaces:
Running
Running
| """ | |
| Run the feedback-to-model refresh pipeline in one command. | |
| Usage: | |
| python scripts/retrain_from_feedback.py | |
| python scripts/retrain_from_feedback.py --skip_train | |
| python scripts/retrain_from_feedback.py --dry_run | |
| """ | |
| import argparse | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| def run_step(command: list[str], dry_run: bool) -> None: | |
| rendered = " ".join(command) | |
| print(f"\n>> {rendered}") | |
| if dry_run: | |
| return | |
| subprocess.run(command, check=True) | |
| def count_feedback_records(feedback_dir: Path) -> int: | |
| if not feedback_dir.exists(): | |
| return 0 | |
| return sum(1 for path in feedback_dir.rglob("*.json") if path.is_file()) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Ingest reviewed feedback and optionally retrain the model.") | |
| parser.add_argument("--feedback_dir", default="feedback_queue") | |
| parser.add_argument("--feedback_output_dir", default="data/feedback_labeled") | |
| parser.add_argument("--trashnet_dir", default="data/raw/trashnet_sample/dataset-resized") | |
| parser.add_argument("--realwaste_dir", default="data/raw/realwaste") | |
| parser.add_argument("--extra_dir", default="data/local_boost") | |
| parser.add_argument("--processed_dir", default="data/processed") | |
| parser.add_argument("--models_dir", default="models") | |
| parser.add_argument("--phase1_epochs", type=int, default=3) | |
| parser.add_argument("--phase2_epochs", type=int, default=12) | |
| parser.add_argument("--balance_strategy", choices=["class_weight", "oversample"], default="class_weight") | |
| parser.add_argument("--skip_train", action="store_true") | |
| parser.add_argument("--skip_export", action="store_true") | |
| parser.add_argument("--benchmark", action="store_true") | |
| parser.add_argument("--overwrite_feedback", action="store_true") | |
| parser.add_argument("--min_feedback", type=int, default=20) | |
| parser.add_argument("--force_train", action="store_true") | |
| parser.add_argument("--dry_run", action="store_true") | |
| args = parser.parse_args() | |
| project_root = Path(__file__).resolve().parent.parent | |
| python_executable = Path(sys.executable) | |
| ingest_command = [ | |
| str(python_executable), | |
| "scripts/ingest_feedback.py", | |
| "--feedback_dir", | |
| args.feedback_dir, | |
| "--output_dir", | |
| args.feedback_output_dir, | |
| ] | |
| if args.overwrite_feedback: | |
| ingest_command.append("--overwrite") | |
| prepare_command = [ | |
| str(python_executable), | |
| "scripts/download_data.py", | |
| "--trashnet_dir", | |
| args.trashnet_dir, | |
| "--realwaste_dir", | |
| args.realwaste_dir, | |
| "--feedback_dir", | |
| args.feedback_output_dir, | |
| "--extra_dir", | |
| args.extra_dir, | |
| "--output_dir", | |
| args.processed_dir, | |
| ] | |
| train_command = [ | |
| str(python_executable), | |
| "scripts/train.py", | |
| "--data_dir", | |
| args.processed_dir, | |
| "--output_dir", | |
| args.models_dir, | |
| "--phase1_epochs", | |
| str(args.phase1_epochs), | |
| "--phase2_epochs", | |
| str(args.phase2_epochs), | |
| "--balance_strategy", | |
| args.balance_strategy, | |
| ] | |
| export_command = [ | |
| str(python_executable), | |
| "scripts/export_tflite.py", | |
| "--saved_model", | |
| str(Path(args.models_dir) / "waste_classifier_v1"), | |
| "--data_dir", | |
| args.processed_dir, | |
| "--output", | |
| str(Path(args.models_dir) / "model.tflite"), | |
| ] | |
| if args.benchmark: | |
| export_command.append("--benchmark") | |
| print("Feedback Retraining Pipeline") | |
| print(f"Project root : {project_root}") | |
| print(f"Python : {python_executable}") | |
| run_step(ingest_command, args.dry_run) | |
| run_step(prepare_command, args.dry_run) | |
| feedback_count = count_feedback_records(project_root / args.feedback_dir) | |
| print(f"Feedback count: {feedback_count}") | |
| should_train = not args.skip_train | |
| if should_train and not args.force_train and feedback_count < args.min_feedback: | |
| print( | |
| f"\nSkipping training because feedback count is below --min_feedback " | |
| f"({feedback_count} < {args.min_feedback})." | |
| ) | |
| print("Use --force_train to retrain anyway.") | |
| should_train = False | |
| if should_train: | |
| run_step(train_command, args.dry_run) | |
| if not args.skip_export: | |
| run_step(export_command, args.dry_run) | |
| elif not args.skip_export: | |
| print("\nSkipping export because training was skipped.") | |
| print("\nPipeline complete.") | |
| if __name__ == "__main__": | |
| main() | |