usiddiquee
hi
e1832f4
#!/usr/bin/env python3
import argparse
import time
from pathlib import Path
import torch
from boxmot.appearance.exporters.base_exporter import BaseExporter
from boxmot.appearance.exporters.onnx_exporter import ONNXExporter
from boxmot.appearance.exporters.openvino_exporter import OpenVINOExporter
from boxmot.appearance.exporters.tflite_exporter import TFLiteExporter
from boxmot.appearance.exporters.torchscript_exporter import TorchScriptExporter
from boxmot.appearance.exporters.tensorrt_exporter import EngineExporter
from boxmot.appearance.reid import export_formats
from boxmot.appearance.reid.auto_backend import ReidAutoBackend
from boxmot.appearance.reid.registry import ReIDModelRegistry
from boxmot.utils import WEIGHTS, logger as LOGGER
from boxmot.utils.torch_utils import select_device
def parse_args():
"""
Parse command-line arguments for the ReID export script.
"""
parser = argparse.ArgumentParser(description="ReID Export Script")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for export")
parser.add_argument("--imgsz", "--img", "--img-size",
nargs="+", type=int, default=[256, 128],
help="Image size in the format: height width")
parser.add_argument("--device", default="cpu",
help="CUDA device (e.g., '0', '0,1,2,3', or 'cpu')")
parser.add_argument("--optimize", action="store_true",
help="Optimize TorchScript for mobile (CPU export only)")
parser.add_argument("--dynamic", action="store_true",
help="Enable dynamic axes for ONNX/TF/TensorRT export")
parser.add_argument("--simplify", action="store_true",
help="Simplify ONNX model")
parser.add_argument("--opset", type=int, default=12,
help="ONNX opset version")
parser.add_argument("--workspace", type=int, default=4,
help="TensorRT workspace size (GB)")
parser.add_argument("--verbose", action="store_true",
help="Enable verbose logging for TensorRT")
parser.add_argument("--weights", type=Path,
default=WEIGHTS / "osnet_x0_25_msmt17.pt",
help="Path to the model weights (.pt file)")
parser.add_argument("--half", action="store_true",
help="Enable FP16 half-precision export (GPU only)")
parser.add_argument("--include", nargs="+",
default=["torchscript"],
help=("Export formats to include. Options: torchscript, onnx, "
"openvino, engine, tflite"))
return parser.parse_args()
def validate_export_formats(include):
"""
Validate the provided export formats and return corresponding flags.
Args:
include (list): List of export formats provided via the command line.
Returns:
tuple: Boolean flags for each export format in the order:
(torchscript, onnx, openvino, engine, tflite)
"""
available_formats = tuple(export_formats()["Argument"][1:])
include_lower = [fmt.lower() for fmt in include]
flags = [fmt in include_lower for fmt in available_formats]
if sum(flags) != len(include_lower):
raise AssertionError(
f"ERROR: Invalid --include {include}, valid arguments are {available_formats}"
)
return tuple(flags)
def setup_model(args):
"""
Initialize and prepare the ReID model for export.
Args:
args: Parsed command-line arguments.
Returns:
tuple: (model (torch.nn.Module), dummy_input (torch.Tensor))
"""
# Select the correct device
args.device = select_device(args.device)
if args.half and args.device.type == "cpu":
raise AssertionError("--half only compatible with GPU export, use --device 0 for GPU")
# Initialize backend model using the auto backend
auto_backend = ReidAutoBackend(weights=args.weights, device=args.device, half=args.half)
_ = auto_backend.get_backend() # Backend model is managed internally
# Build and load the ReID model from the registry
model_name = ReIDModelRegistry.get_model_name(args.weights)
nr_classes = ReIDModelRegistry.get_nr_classes(args.weights)
pretrained = not (args.weights and args.weights.is_file() and args.weights.suffix == ".pt")
model = ReIDModelRegistry.build_model(
model_name,
num_classes=nr_classes,
pretrained=pretrained,
use_gpu=args.device,
).to(args.device)
ReIDModelRegistry.load_pretrained_weights(model, args.weights)
model.eval()
# Ensure --optimize is only used with CPU exports
if args.optimize and args.device.type != "cpu":
raise AssertionError("--optimize not compatible with CUDA devices, use --device cpu")
# Adjust image size if a specific weight type is detected
if "lmbn" in str(args.weights):
args.imgsz = [384, 128]
# Create dummy input tensor for warming up the model
dummy_input = torch.empty(args.batch_size, 3, args.imgsz[0], args.imgsz[1]).to(args.device)
for _ in range(2):
_ = model(dummy_input)
# Convert to half precision if required
if args.half:
dummy_input = dummy_input.half()
model = model.half()
return model, dummy_input
def create_export_tasks(args, model, dummy_input):
"""
Create a mapping of export tasks with associated flags, exporter classes, and parameters.
Args:
args: Parsed command-line arguments.
model: Prepared ReID model.
dummy_input: Dummy input tensor.
Returns:
dict: Mapping of export format to a tuple (flag, exporter_class, export_args)
"""
torchscript_flag, onnx_flag, openvino_flag, engine_flag, tflite_flag = validate_export_formats(args.include)
return {
"torchscript": (
torchscript_flag,
TorchScriptExporter,
(model, dummy_input, args.weights, args.optimize)
),
"engine": (
engine_flag,
EngineExporter,
(model, dummy_input, args.weights, args.half, args.dynamic, args.simplify, args.verbose)
),
"onnx": (
onnx_flag,
ONNXExporter,
(model, dummy_input, args.weights, args.opset, args.dynamic, args.half, args.simplify)
),
"tflite": (
tflite_flag,
TFLiteExporter,
(model, dummy_input, args.weights)
),
"openvino": (
openvino_flag,
OpenVINOExporter,
(model, dummy_input, args.weights, args.half)
)
}
def perform_exports(export_tasks):
"""
Iterate over export tasks and perform export for enabled formats.
Args:
export_tasks (dict): Mapping of export tasks.
Returns:
dict: Mapping of export format to export results.
"""
exported_files = {}
for fmt, (flag, exporter_class, exp_args) in export_tasks.items():
if flag:
exporter = exporter_class(*exp_args)
export_result = exporter.export()
exported_files[fmt] = export_result
return exported_files
def main():
"""Main function to execute the ReID export process."""
args = parse_args()
start_time = time.time()
# Ensure the weights directory exists
WEIGHTS.mkdir(parents=False, exist_ok=True)
# Setup model and create a dummy input tensor
model, dummy_input = setup_model(args)
# Log model output shape and file size
output = model(dummy_input)
output_tensor = output[0] if isinstance(output, tuple) else output
output_shape = tuple(output_tensor.shape)
LOGGER.info(
f"\nStarting from {args.weights} with output shape {output_shape} "
f"({BaseExporter.file_size(args.weights):.1f} MB)"
)
# Create export tasks
export_tasks = create_export_tasks(args, model, dummy_input)
# Perform exports for enabled formats
exported_files = perform_exports(export_tasks)
if exported_files:
elapsed_time = time.time() - start_time
LOGGER.info(
f"\nExport complete ({elapsed_time:.1f}s)"
f"\nResults saved to {args.weights.parent.resolve()}"
f"\nVisualize: https://netron.app"
)
if __name__ == "__main__":
main()