"""Train RF-DETR model on car_data for aerial car detection.""" import argparse from pathlib import Path import rfdetr MODEL_CLASSES: dict[str, type] = { "nano": rfdetr.RFDETRNano, "small": rfdetr.RFDETRSmall, "base": rfdetr.RFDETRBase, "medium": rfdetr.RFDETRMedium, "large": rfdetr.RFDETRLarge, } def run_training( dataset_dir: str | Path, epochs: int = 50, batch_size: int = 4, lr: float = 1e-4, resolution: int = 640, output_dir: str = "output", model_size: str = "base", grad_accum_steps: int = 1, num_classes: int = 1, ) -> None: """Run RF-DETR training. Args: dataset_dir: Path to dataset (YOLO or COCO format, auto-detected). epochs: Number of training epochs. batch_size: Batch size. lr: Learning rate. resolution: Input resolution. output_dir: Checkpoint output directory. model_size: Model variant (nano/small/base/medium/large). grad_accum_steps: Gradient accumulation steps. num_classes: Number of object classes. """ model_cls = MODEL_CLASSES.get(model_size) if model_cls is None: raise ValueError( f"Unknown model_size {model_size!r}, " f"choose from: {', '.join(MODEL_CLASSES)}" ) model = model_cls() model.train( dataset_dir=str(dataset_dir), epochs=epochs, batch_size=batch_size, lr=lr, resolution=resolution, output_dir=output_dir, num_classes=num_classes, grad_accum_steps=grad_accum_steps, run_test=False, ) def main() -> None: parser = argparse.ArgumentParser(description="Train RF-DETR on car_data") parser.add_argument("--epochs", type=int, default=50) parser.add_argument("--batch-size", type=int, default=4) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--resolution", type=int, default=640) parser.add_argument("--output-dir", type=str, default="output") args = parser.parse_args() training_dir = Path(__file__).resolve().parent dataset_dir = training_dir / "car_data" / "mydata" / "mydata" run_training( dataset_dir=dataset_dir, epochs=args.epochs, batch_size=args.batch_size, lr=args.lr, resolution=args.resolution, output_dir=args.output_dir, ) if __name__ == "__main__": main()