Goutham204's picture
Upload 8 files
b29dd97 verified
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
def get_model(num_classes):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model.to(device)
model.eval()
return model, device