dylanplummer commited on
Commit
8963f04
·
1 Parent(s): 1431cde

add beep detection

Browse files
Files changed (2) hide show
  1. app.py +57 -11
  2. beep.WAV +0 -0
app.py CHANGED
@@ -4,11 +4,13 @@ from PIL import Image
4
  import os
5
  import cv2
6
  import math
 
7
  import matplotlib
8
  matplotlib.use('Agg')
9
  import matplotlib.pyplot as plt
10
  import concurrent.futures
11
- from scipy.signal import medfilt, find_peaks
 
12
  from functools import partial
13
  from passlib.hash import pbkdf2_sha256
14
  from tqdm import tqdm
@@ -78,7 +80,37 @@ def sigmoid(x):
78
  return 1 / (1 + np.exp(-x))
79
 
80
 
81
- def inference(stream_url, start_time, end_time, count_only_api, api_key,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  img_size=256, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
83
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
84
  api_call=False,
@@ -91,7 +123,10 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
91
  has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
92
  if not has_access:
93
  return "Invalid API Key"
94
-
 
 
 
95
  cap = cv2.VideoCapture(in_video)
96
  length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
97
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
@@ -133,7 +168,7 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
133
  idx_list = []
134
  inference_futures = []
135
  with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
136
- for i in tqdm(range(0, length + stride_length - stride_pad, stride_length)):
137
  batch = all_frames[i:i + seq_len]
138
  Xlist = []
139
  preprocess_tasks = [(idx, executor.submit(preprocess_image, img, img_size)) for idx, img in enumerate(batch)]
@@ -163,7 +198,7 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
163
  inference_futures.append((batch_list, idx_list, future))
164
 
165
  # Collect and process the inference results
166
- for batch_list, idx_list, future in inference_futures:
167
  outputs = future.result()
168
  y1_out = outputs[0]
169
  y2_out = outputs[1]
@@ -195,6 +230,12 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
195
  periodLength = medfilt(periodLength, 5)
196
  periodicity = sigmoid(periodicity)
197
  full_marks = sigmoid(full_marks)
 
 
 
 
 
 
198
  pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
199
  full_marks_mask = np.zeros(len(full_marks))
200
  full_marks_mask[pred_marks_peaks] = 1
@@ -325,7 +366,7 @@ def inference(stream_url, start_time, end_time, count_only_api, api_key,
325
  title="Event Type Distribution",
326
  labels={'x': 'event type', 'y': 'probability'},
327
  range_y=[0, 1])
328
-
329
  return in_video, count_msg, fig, hist, bar
330
 
331
 
@@ -333,9 +374,11 @@ with gr.Blocks() as demo:
333
  with gr.Row():
334
  with gr.Column():
335
  in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
336
- with gr.Column():
337
  in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True)
338
  in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
 
 
 
339
  with gr.Column(min_width=480):
340
  out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
341
 
@@ -362,14 +405,17 @@ with gr.Blocks() as demo:
362
 
363
  demo_inference = partial(inference, count_only_api=False, api_key=None)
364
 
365
- run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end], outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
 
366
  api_inference = partial(inference, api_call=True)
367
- api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, count_only, api_token], outputs=[period_length], api_name='inference')
 
368
  examples = [
369
- ['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '00:43:10', '00:43:40'],
 
370
  ]
371
  gr.Examples(examples,
372
- inputs=[in_stream_url, in_stream_start, in_stream_end],
373
  outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist],
374
  fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
375
 
 
4
  import os
5
  import cv2
6
  import math
7
+ import subprocess
8
  import matplotlib
9
  matplotlib.use('Agg')
10
  import matplotlib.pyplot as plt
11
  import concurrent.futures
12
+ from scipy.io import wavfile
13
+ from scipy.signal import medfilt, correlate, find_peaks
14
  from functools import partial
15
  from passlib.hash import pbkdf2_sha256
16
  from tqdm import tqdm
 
80
  return 1 / (1 + np.exp(-x))
81
 
82
 
83
+ def detect_beeps(video_path, event_length=30):
84
+ reference_file = 'beep.WAV'
85
+ fs, beep = wavfile.read(reference_file)
86
+ beep = beep[:, 0] + beep[:, 1] # combine stereo to mono
87
+ video = cv2.VideoCapture(video_path)
88
+ audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav'
89
+ subprocess.call(audio_convert_command, shell=True)
90
+ length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
91
+ fps = int(video.get(cv2.CAP_PROP_FPS))
92
+ audio = wavfile.read('temp.wav')[1]
93
+ audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono
94
+ corr = correlate(audio, beep, mode='same') / audio.size
95
+ # min max scale to -1, 1
96
+ corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1
97
+
98
+ # top_q = np.max(corr) - 0.1
99
+ # mean = np.mean(corr)
100
+ # print(top_q, mean)
101
+
102
+ peaks, _ = find_peaks(corr, height=0.7, distance=fs)
103
+ event_start = int(peaks[0] / fs * fps)
104
+ event_end = int(peaks[-1] / fs * fps)
105
+ # plt.plot(corr)
106
+ # plt.plot(peaks, corr[peaks], "x")
107
+ # plt.savefig('beep.png')
108
+ # plt.close()
109
+
110
+ return event_start, event_end
111
+
112
+
113
+ def inference(stream_url, start_time, end_time, beep_detection_on, event_length, count_only_api, api_key,
114
  img_size=256, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
115
  miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, center_crop=True, both_feet=True,
116
  api_call=False,
 
123
  has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
124
  if not has_access:
125
  return "Invalid API Key"
126
+ if beep_detection_on:
127
+ event_start, event_end = detect_beeps(in_video, event_length)
128
+ print(event_start, event_end)
129
+ event_length = int(event_length)
130
  cap = cv2.VideoCapture(in_video)
131
  length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
132
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
168
  idx_list = []
169
  inference_futures = []
170
  with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
171
+ for i in range(0, length + stride_length - stride_pad, stride_length):
172
  batch = all_frames[i:i + seq_len]
173
  Xlist = []
174
  preprocess_tasks = [(idx, executor.submit(preprocess_image, img, img_size)) for idx, img in enumerate(batch)]
 
198
  inference_futures.append((batch_list, idx_list, future))
199
 
200
  # Collect and process the inference results
201
+ for batch_list, idx_list, future in tqdm(inference_futures):
202
  outputs = future.result()
203
  y1_out = outputs[0]
204
  y2_out = outputs[1]
 
230
  periodLength = medfilt(periodLength, 5)
231
  periodicity = sigmoid(periodicity)
232
  full_marks = sigmoid(full_marks)
233
+ # if the event_start and event_end (in frames) are detected and form a valid event of event_length (in seconds)
234
+ if beep_detection_on:
235
+ if event_start > 0 and event_end > 0 and (event_end - event_start) - (event_length * fps) < 0.5:
236
+ print(f"Event detected: {event_start} - {event_end}")
237
+ periodicity[:event_start] = 0
238
+ periodicity[event_end:] = 0
239
  pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
240
  full_marks_mask = np.zeros(len(full_marks))
241
  full_marks_mask[pred_marks_peaks] = 1
 
366
  title="Event Type Distribution",
367
  labels={'x': 'event type', 'y': 'probability'},
368
  range_y=[0, 1])
369
+ os.remove('temp.wav')
370
  return in_video, count_msg, fig, hist, bar
371
 
372
 
 
374
  with gr.Row():
375
  with gr.Column():
376
  in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
 
377
  in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True)
378
  in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
379
+ with gr.Column():
380
+ beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
381
+ event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
382
  with gr.Column(min_width=480):
383
  out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
384
 
 
405
 
406
  demo_inference = partial(inference, count_only_api=False, api_key=None)
407
 
408
+ run_button.click(demo_inference, [in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length],
409
+ outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
410
  api_inference = partial(inference, api_call=True)
411
+ api_dummy_button.click(api_inference, [in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, count_only, api_token],
412
+ outputs=[period_length], api_name='inference')
413
  examples = [
414
+ ['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '00:43:10', '00:43:45', True, 30],
415
+ ['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_UGEhqlMh/vod', '00:00:18', '00:00:55', True, 30]
416
  ]
417
  gr.Examples(examples,
418
+ inputs=[in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length],
419
  outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist],
420
  fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
421
 
beep.WAV ADDED
Binary file (70.7 kB). View file