import torch import torchvision from torchvision.models.detection.faster_rcnn import FastRCNNPredictor def get_model(num_classes): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) in_features = model.roi_heads.box_predictor.cls_score.in_features model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) model.to(device) model.eval() return model, device