foryahasake commited on
Commit
4643a08
·
verified ·
1 Parent(s): 0a9d64b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -1
app.py CHANGED
@@ -75,6 +75,78 @@ def inference(image_url, image, min_score):
75
  return out.get_image()
76
 
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def infer_video(video_path):
79
  sv.process_video(source_path=video_path, target_path=f"result.mp4", callback=predict_frame)
80
  return f"result.mp4"
@@ -128,7 +200,7 @@ img_interface = gr.Interface(
128
  inputs=[input_url,input_image,sliderr], outputs=[output_image], api_name="find"
129
  )
130
  video_interface = gr.Interface(
131
- fn=infer_video,
132
  inputs=[input_video], outputs=[output_video], api_name="vid"
133
  )
134
  demo = gr.TabbedInterface([img_interface, video_interface], ["Image Upload", "Video Upload"])
 
75
  return out.get_image()
76
 
77
 
78
+ def process_vid(video_path):
79
+
80
+ cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.3
81
+ torch.cuda.empty_cache()
82
+ if not torch.cuda.is_available():
83
+ cfg.MODEL.DEVICE = "cpu"
84
+ else:
85
+ cfg.MODEL.DEVICE = "cuda"
86
+ predictor = DefaultPredictor(cfg)
87
+ v = VideoVisualizer(my_metadata,ColorMode.IMAGE)
88
+ cap = cv2.VideoCapture(video_path)
89
+ frame_width = int(cap.get(3))
90
+ frame_height = int(cap.get(4))
91
+ frame_size = (frame_width,frame_height)
92
+ fps = int(cap.get(5))
93
+ vid_fourcc= int(cap.get(cv2.CAP_PROP_FOURCC))
94
+ output_path = '/content/drive/MyDrive/ColabNotebooks/gradio-exp/output.mp4'
95
+ fourcc = cv2.VideoWriter_fourcc(*'MJPG')
96
+ video_writer = cv2.VideoWriter(output_path,fourcc, fps, frame_size)
97
+
98
+
99
+ def runOnVideo(video, maxFrames):
100
+
101
+
102
+ """ Runs the predictor on every frame in the video (unless maxFrames is given),
103
+ and returns the frame with the predictions drawn.
104
+ """
105
+
106
+ readFrames = 0
107
+ while True:
108
+ hasFrame, frame = video.read()
109
+ if not hasFrame:
110
+ break
111
+
112
+ # Get prediction results for this frame
113
+ outputs = predictor(frame)
114
+
115
+ # Make sure the frame is colored
116
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
117
+
118
+ # Draw a visualization of the predictions using the video visualizer
119
+ visualization = v.draw_instance_predictions(frame, outputs["instances"].to("cpu"))
120
+
121
+ # Convert Matplotlib RGB format to OpenCV BGR format
122
+ visualization = cv2.cvtColor(visualization.get_image(), cv2.COLOR_RGB2BGR)
123
+
124
+ yield visualization
125
+
126
+ readFrames += 1
127
+ if readFrames > maxFrames:
128
+ break
129
+
130
+ # Create a cut-off for debugging
131
+ num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
132
+
133
+ # Enumerate the frames of the video
134
+ for visualization in tqdm.tqdm(runOnVideo(cap, num_frames), total=num_frames):
135
+
136
+ # Write test image
137
+ cv2.imwrite('POSE detectron2.png', visualization)
138
+
139
+ # Write to video file
140
+ video_writer.write(visualization)
141
+
142
+ # Release resources
143
+ cap.release()
144
+ video_writer.release()
145
+ return output_path
146
+
147
+
148
+
149
+
150
  def infer_video(video_path):
151
  sv.process_video(source_path=video_path, target_path=f"result.mp4", callback=predict_frame)
152
  return f"result.mp4"
 
200
  inputs=[input_url,input_image,sliderr], outputs=[output_image], api_name="find"
201
  )
202
  video_interface = gr.Interface(
203
+ fn=process_vid,
204
  inputs=[input_video], outputs=[output_video], api_name="vid"
205
  )
206
  demo = gr.TabbedInterface([img_interface, video_interface], ["Image Upload", "Video Upload"])