import torch import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor def get_detection_model(num_classes): """ Original helper to instantiate PyTorch Faster R-CNN with ResNet50 backbone. """ # Load a model pre-trained on COCO model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) # Get the number of input features for the classifier in_features = model.roi_heads.box_predictor.cls_score.in_features # Replace the pre-trained head with a new one model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) return model