Spaces:
Sleeping
Sleeping
| """Train RF-DETR model on car_data for aerial car detection.""" | |
| import argparse | |
| from pathlib import Path | |
| import rfdetr | |
| MODEL_CLASSES: dict[str, type] = { | |
| "nano": rfdetr.RFDETRNano, | |
| "small": rfdetr.RFDETRSmall, | |
| "base": rfdetr.RFDETRBase, | |
| "medium": rfdetr.RFDETRMedium, | |
| "large": rfdetr.RFDETRLarge, | |
| } | |
| def run_training( | |
| dataset_dir: str | Path, | |
| epochs: int = 50, | |
| batch_size: int = 4, | |
| lr: float = 1e-4, | |
| resolution: int = 640, | |
| output_dir: str = "output", | |
| model_size: str = "base", | |
| grad_accum_steps: int = 1, | |
| num_classes: int = 1, | |
| ) -> None: | |
| """Run RF-DETR training. | |
| Args: | |
| dataset_dir: Path to dataset (YOLO or COCO format, auto-detected). | |
| epochs: Number of training epochs. | |
| batch_size: Batch size. | |
| lr: Learning rate. | |
| resolution: Input resolution. | |
| output_dir: Checkpoint output directory. | |
| model_size: Model variant (nano/small/base/medium/large). | |
| grad_accum_steps: Gradient accumulation steps. | |
| num_classes: Number of object classes. | |
| """ | |
| model_cls = MODEL_CLASSES.get(model_size) | |
| if model_cls is None: | |
| raise ValueError( | |
| f"Unknown model_size {model_size!r}, " | |
| f"choose from: {', '.join(MODEL_CLASSES)}" | |
| ) | |
| model = model_cls() | |
| model.train( | |
| dataset_dir=str(dataset_dir), | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| lr=lr, | |
| resolution=resolution, | |
| output_dir=output_dir, | |
| num_classes=num_classes, | |
| grad_accum_steps=grad_accum_steps, | |
| run_test=False, | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Train RF-DETR on car_data") | |
| parser.add_argument("--epochs", type=int, default=50) | |
| parser.add_argument("--batch-size", type=int, default=4) | |
| parser.add_argument("--lr", type=float, default=1e-4) | |
| parser.add_argument("--resolution", type=int, default=640) | |
| parser.add_argument("--output-dir", type=str, default="output") | |
| args = parser.parse_args() | |
| training_dir = Path(__file__).resolve().parent | |
| dataset_dir = training_dir / "car_data" / "mydata" / "mydata" | |
| run_training( | |
| dataset_dir=dataset_dir, | |
| epochs=args.epochs, | |
| batch_size=args.batch_size, | |
| lr=args.lr, | |
| resolution=args.resolution, | |
| output_dir=args.output_dir, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |