Spaces:
Sleeping
Sleeping
| """Export a trained RF-DETR checkpoint to ONNX format.""" | |
| import argparse | |
| from pathlib import Path | |
| import rfdetr | |
| _DIR = Path(__file__).resolve().parent | |
| MODEL_CLASSES: dict[str, type] = { | |
| "nano": rfdetr.RFDETRNano, | |
| "small": rfdetr.RFDETRSmall, | |
| "base": rfdetr.RFDETRBase, | |
| "medium": rfdetr.RFDETRMedium, | |
| "large": rfdetr.RFDETRLarge, | |
| } | |
| def export_model( | |
| checkpoint: str | Path, | |
| model_size: str = "medium", | |
| output_dir: str | Path = "./exported_models", | |
| simplify: bool = True, | |
| resolution: int = 640, | |
| ) -> None: | |
| """Export RF-DETR checkpoint to ONNX. | |
| Args: | |
| checkpoint: Path to trained checkpoint (.pth file). | |
| model_size: Model variant — must match the size used during training. | |
| output_dir: Directory to write exported ONNX files. | |
| simplify: Whether to simplify the ONNX graph. | |
| resolution: Input resolution (height and width). | |
| """ | |
| model_cls = MODEL_CLASSES.get(model_size) | |
| if model_cls is None: | |
| raise ValueError( | |
| f"Unknown model_size {model_size!r}, " | |
| f"choose from: {', '.join(MODEL_CLASSES)}" | |
| ) | |
| model = model_cls(pretrain_weights=str(checkpoint)) | |
| model.export( | |
| output_dir=str(output_dir), | |
| simplify=simplify, | |
| shape=(resolution, resolution), | |
| ) | |
| print(f"Exported ONNX model to {output_dir}/") | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Export trained RF-DETR checkpoint to ONNX" | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default=str(_DIR / "output" / "checkpoint_best_regular.pth"), | |
| help="Path to trained checkpoint", | |
| ) | |
| parser.add_argument( | |
| "--model-size", | |
| type=str, | |
| default="medium", | |
| choices=list(MODEL_CLASSES), | |
| help="Model variant — must match training size (default: medium)", | |
| ) | |
| parser.add_argument( | |
| "--output-dir", | |
| type=str, | |
| default=str(_DIR / "exported_models"), | |
| help="Directory to save exported ONNX files", | |
| ) | |
| parser.add_argument( | |
| "--simplify", | |
| action=argparse.BooleanOptionalAction, | |
| default=True, | |
| help="Simplify ONNX graph (default: enabled, use --no-simplify to disable)", | |
| ) | |
| parser.add_argument( | |
| "--resolution", | |
| type=int, | |
| default=640, | |
| help="Input resolution H=W (default: 640)", | |
| ) | |
| args = parser.parse_args() | |
| export_model( | |
| checkpoint=args.checkpoint, | |
| model_size=args.model_size, | |
| output_dir=args.output_dir, | |
| simplify=args.simplify, | |
| resolution=args.resolution, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |