Ahmed-El-Sharkawy commited on
Commit
7d3d12a
·
verified ·
1 Parent(s): 07cd71e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -8,6 +8,20 @@ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
8
 
9
  # Load Models
10
  def load_model(model_path, backbone_name, num_classes):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
  if backbone_name == "resnet50":
13
  model = torch.load(model_path, map_location=device)
 
8
 
9
  # Load Models
10
  def load_model(model_path, backbone_name, num_classes):
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ if backbone_name == "resnet50":
13
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
14
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
15
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
16
+ model.load_state_dict(torch.load(model_path, map_location=device))
17
+ elif backbone_name == "mobilenet":
18
+ model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False)
19
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
20
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
21
+ model.load_state_dict(torch.load(model_path, map_location=device))
22
+ model.to(device)
23
+ model.eval()
24
+ return model
25
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
26
  if backbone_name == "resnet50":
27
  model = torch.load(model_path, map_location=device)