Ahmed-El-Sharkawy commited on
Commit
56be1f0
·
verified ·
1 Parent(s): 72c5129

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -8
app.py CHANGED
@@ -8,18 +8,18 @@ from torchvision.models.detection import FasterRCNN
8
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
9
 
10
  # Load Models
11
- def load_model( backbone_name, num_classes):
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  if backbone_name == "resnet50":
14
  model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
15
  in_features = model.roi_heads.box_predictor.cls_score.in_features
16
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
17
- model.load_state_dict(torch.load("fasterrcnnResnet.pth", map_location=device))
18
  elif backbone_name == "mobilenet":
19
  model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False)
20
  in_features = model.roi_heads.box_predictor.cls_score.in_features
21
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
22
- model.load_state_dict(torch.load("fasterrcnnMobilenet", map_location=device))
23
  model.to(device)
24
  model.eval()
25
  return model
@@ -60,7 +60,15 @@ def predict_video(video_path, model):
60
  cv2.putText(frame, f"{class_names[label]}: {score:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
61
  frames.append(frame)
62
  cap.release()
63
- return frames[0] if frames else None
 
 
 
 
 
 
 
 
64
 
65
  # Gradio Interface for Image and Video Inference
66
 
@@ -70,14 +78,14 @@ inputs_image = [gr.Image(type="filepath", label="Upload Image"), model_selection
70
  outputs_image = gr.Image(type="numpy", label="Detection Output")
71
 
72
  inputs_video = [gr.Video(label="Upload Video"), model_selection]
73
- outputs_video = gr.Image(type="numpy", label="Detection Output")
74
 
75
 
76
 
77
  with gr.Blocks() as demo:
78
  with gr.TabItem("Image"):
79
  gr.Interface(
80
- fn=lambda img, model_name: predict_image(img, load_model( model_name.lower(), num_classes=6)),
81
  inputs=inputs_image,
82
  outputs=outputs_image,
83
  title="Image Inference"
@@ -85,11 +93,10 @@ with gr.Blocks() as demo:
85
 
86
  with gr.TabItem("Video"):
87
  gr.Interface(
88
- fn=lambda vid, model_name: predict_video(vid, load_model(model_name.lower(), num_classes=6)),
89
  inputs=inputs_video,
90
  outputs=outputs_video,
91
  title="Video Inference"
92
  )
93
 
94
  demo.launch()
95
-
 
8
  from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
9
 
10
  # Load Models
11
+ def load_model(model_path, backbone_name, num_classes):
12
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  if backbone_name == "resnet50":
14
  model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=False)
15
  in_features = model.roi_heads.box_predictor.cls_score.in_features
16
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
17
+ model.load_state_dict(torch.load(model_path, map_location=device))
18
  elif backbone_name == "mobilenet":
19
  model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False)
20
  in_features = model.roi_heads.box_predictor.cls_score.in_features
21
  model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
22
+ model.load_state_dict(torch.load(model_path, map_location=device))
23
  model.to(device)
24
  model.eval()
25
  return model
 
60
  cv2.putText(frame, f"{class_names[label]}: {score:.2f}", (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
61
  frames.append(frame)
62
  cap.release()
63
+ output_path = 'output_video.mp4'
64
+ height, width, _ = frames[0].shape
65
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), 20, (width, height))
66
+
67
+ for frame in frames:
68
+ out.write(frame)
69
+
70
+ out.release()
71
+ return output_path
72
 
73
  # Gradio Interface for Image and Video Inference
74
 
 
78
  outputs_image = gr.Image(type="numpy", label="Detection Output")
79
 
80
  inputs_video = [gr.Video(label="Upload Video"), model_selection]
81
+ outputs_video = gr.Video(label="Detection Output")
82
 
83
 
84
 
85
  with gr.Blocks() as demo:
86
  with gr.TabItem("Image"):
87
  gr.Interface(
88
+ fn=lambda img, model_name: predict_image(img, load_model(f'fasterrcnn{model_name}.pth', model_name.lower(), num_classes=6)),
89
  inputs=inputs_image,
90
  outputs=outputs_image,
91
  title="Image Inference"
 
93
 
94
  with gr.TabItem("Video"):
95
  gr.Interface(
96
+ fn=lambda vid, model_name: predict_video(vid, load_model(f'fasterrcnn{model_name}.pth', model_name.lower(), num_classes=6)),
97
  inputs=inputs_video,
98
  outputs=outputs_video,
99
  title="Video Inference"
100
  )
101
 
102
  demo.launch()