Spaces:
Build error
Build error
| import torchvision | |
| def create_fasterrcnn_model(num_classes): | |
| model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| # Redefine the ROI head: | |
| in_features = model.roi_heads.box_predictor.cls_score.in_features | |
| # model.roi_heads.box_predictor = model.roi_heads.box_predictor(in_features, out_features =len(categories)) | |
| model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) # FastRCNNPredictor(in_features, num_classes=len(categories)) | |
| return model |