Spaces:
Build error
Build error
File size: 631 Bytes
76d828d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 | import torchvision
def create_fasterrcnn_model(num_classes):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Redefine the ROI head:
in_features = model.roi_heads.box_predictor.cls_score.in_features
# model.roi_heads.box_predictor = model.roi_heads.box_predictor(in_features, out_features =len(categories))
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) # FastRCNNPredictor(in_features, num_classes=len(categories))
return model |