File size: 631 Bytes
76d828d
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torchvision

def create_fasterrcnn_model(num_classes):
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)

    for param in model.parameters():
        param.requires_grad = False

    # Redefine the ROI head:
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    # model.roi_heads.box_predictor = model.roi_heads.box_predictor(in_features, out_features =len(categories))
    model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes)  # FastRCNNPredictor(in_features, num_classes=len(categories))

    return model