car-detection / training /export_onnx.py
socks22's picture
first
f3f6f5d
"""Export a trained RF-DETR checkpoint to ONNX format."""
import argparse
from pathlib import Path
import rfdetr
_DIR = Path(__file__).resolve().parent
MODEL_CLASSES: dict[str, type] = {
"nano": rfdetr.RFDETRNano,
"small": rfdetr.RFDETRSmall,
"base": rfdetr.RFDETRBase,
"medium": rfdetr.RFDETRMedium,
"large": rfdetr.RFDETRLarge,
}
def export_model(
checkpoint: str | Path,
model_size: str = "medium",
output_dir: str | Path = "./exported_models",
simplify: bool = True,
resolution: int = 640,
) -> None:
"""Export RF-DETR checkpoint to ONNX.
Args:
checkpoint: Path to trained checkpoint (.pth file).
model_size: Model variant — must match the size used during training.
output_dir: Directory to write exported ONNX files.
simplify: Whether to simplify the ONNX graph.
resolution: Input resolution (height and width).
"""
model_cls = MODEL_CLASSES.get(model_size)
if model_cls is None:
raise ValueError(
f"Unknown model_size {model_size!r}, "
f"choose from: {', '.join(MODEL_CLASSES)}"
)
model = model_cls(pretrain_weights=str(checkpoint))
model.export(
output_dir=str(output_dir),
simplify=simplify,
shape=(resolution, resolution),
)
print(f"Exported ONNX model to {output_dir}/")
def main() -> None:
parser = argparse.ArgumentParser(
description="Export trained RF-DETR checkpoint to ONNX"
)
parser.add_argument(
"--checkpoint",
type=str,
default=str(_DIR / "output" / "checkpoint_best_regular.pth"),
help="Path to trained checkpoint",
)
parser.add_argument(
"--model-size",
type=str,
default="medium",
choices=list(MODEL_CLASSES),
help="Model variant — must match training size (default: medium)",
)
parser.add_argument(
"--output-dir",
type=str,
default=str(_DIR / "exported_models"),
help="Directory to save exported ONNX files",
)
parser.add_argument(
"--simplify",
action=argparse.BooleanOptionalAction,
default=True,
help="Simplify ONNX graph (default: enabled, use --no-simplify to disable)",
)
parser.add_argument(
"--resolution",
type=int,
default=640,
help="Input resolution H=W (default: 640)",
)
args = parser.parse_args()
export_model(
checkpoint=args.checkpoint,
model_size=args.model_size,
output_dir=args.output_dir,
simplify=args.simplify,
resolution=args.resolution,
)
if __name__ == "__main__":
main()