mediscan-api / detection /model.py
Mittalyash's picture
Upload folder using huggingface_hub
9916246 verified
raw
history blame contribute delete
640 Bytes
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def get_detection_model(num_classes):
"""
Original helper to instantiate PyTorch Faster R-CNN with ResNet50 backbone.
"""
# Load a model pre-trained on COCO
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# Get the number of input features for the classifier
in_features = model.roi_heads.box_predictor.cls_score.in_features
# Replace the pre-trained head with a new one
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model