Ahmed-El-Sharkawy commited on
Commit
07cd71e
·
verified ·
1 Parent(s): 31c8c7a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -1
app.py CHANGED
@@ -8,6 +8,14 @@ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
8
 
9
  # Load Models
10
  def load_model(model_path, backbone_name, num_classes):
 
 
 
 
 
 
 
 
11
  if backbone_name == "resnet50":
12
  model = torch.load(model_path)
13
  elif backbone_name == "mobilenet":
@@ -79,4 +87,4 @@ video_interface = gr.Interface(
79
  title="Video Inference"
80
  )
81
 
82
- gr.TabbedInterface([image_interface, video_interface], ["Image", "Video"]).launch()
 
8
 
9
  # Load Models
10
  def load_model(model_path, backbone_name, num_classes):
11
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
+ if backbone_name == "resnet50":
13
+ model = torch.load(model_path, map_location=device)
14
+ elif backbone_name == "mobilenet":
15
+ model = torch.load(model_path, map_location=device)
16
+ model.to(device)
17
+ model.eval()
18
+ return model
19
  if backbone_name == "resnet50":
20
  model = torch.load(model_path)
21
  elif backbone_name == "mobilenet":
 
87
  title="Video Inference"
88
  )
89
 
90
+ gr.TabbedInterface([image_interface, video_interface], ["Image", "Video"]).launch()