Update app.py
Browse files
app.py
CHANGED
|
@@ -28,20 +28,26 @@ from hls_download import download_clips
|
|
| 28 |
|
| 29 |
plt.style.use('dark_background')
|
| 30 |
|
|
|
|
| 31 |
LOCAL = False
|
| 32 |
IMG_SIZE = 256
|
| 33 |
CACHE_API_CALLS = True
|
| 34 |
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
|
|
|
|
|
|
|
| 35 |
|
| 36 |
onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 37 |
if torch.cuda.is_available():
|
|
|
|
| 38 |
print("Using CUDA")
|
| 39 |
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
|
| 40 |
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
|
| 41 |
sess_options = ort.SessionOptions()
|
| 42 |
#sess_options.log_severity_level = 0
|
|
|
|
| 43 |
ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
|
| 44 |
else:
|
|
|
|
| 45 |
print("Using CPU")
|
| 46 |
ort_sess = ort.InferenceSession(onnx_file)
|
| 47 |
|
|
@@ -117,12 +123,132 @@ def detect_beeps(video_path, event_length=30, beep_height=0.8):
|
|
| 117 |
return event_start, event_end
|
| 118 |
|
| 119 |
|
| 120 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
img_size=256, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
|
| 122 |
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
|
| 123 |
api_call=False,
|
| 124 |
progress=gr.Progress()):
|
| 125 |
progress(0, desc="Downloading clip...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
if in_video is None:
|
| 127 |
in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
|
| 128 |
else: # local uploaded video (still resize with ffmpeg)
|
|
@@ -138,6 +264,10 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 138 |
event_length = int(event_length)
|
| 139 |
event_start, event_end = detect_beeps(in_video, event_length)
|
| 140 |
print(event_start, event_end)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
cap = cv2.VideoCapture(in_video)
|
| 143 |
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
@@ -184,7 +314,7 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 184 |
batch_list = []
|
| 185 |
idx_list = []
|
| 186 |
inference_futures = []
|
| 187 |
-
with concurrent.futures.ThreadPoolExecutor(max_workers=
|
| 188 |
for i in range(0, length + stride_length - stride_pad, stride_length):
|
| 189 |
batch = all_frames[i:i + seq_len]
|
| 190 |
Xlist = []
|
|
@@ -233,6 +363,7 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 233 |
period_length_overlaps[idx:idx+seq_len] += 1
|
| 234 |
event_type_logit_overlaps[idx:idx+seq_len] += 1
|
| 235 |
del y1_out, y2_out, y3_out, y4_out # free up memory
|
|
|
|
| 236 |
|
| 237 |
periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 238 |
periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
|
@@ -254,6 +385,12 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 254 |
print(f"Event detected: {event_start} - {event_end}")
|
| 255 |
periodicity[:event_start] = 0
|
| 256 |
periodicity[event_end:] = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
|
| 258 |
full_marks_mask = np.zeros(len(full_marks))
|
| 259 |
full_marks_mask[pred_marks_peaks] = 1
|
|
@@ -287,6 +424,11 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 287 |
confidence = 0
|
| 288 |
else:
|
| 289 |
confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
except ZeroDivisionError:
|
| 291 |
confidence = 0
|
| 292 |
self_err = abs(count_pred - marks_count_pred)
|
|
@@ -300,8 +442,16 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 300 |
count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
|
| 301 |
else:
|
| 302 |
count_msg = f"## Reps Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
|
|
|
|
| 303 |
|
| 304 |
if api_call:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
if CACHE_API_CALLS:
|
| 306 |
# write outputs as row of csv
|
| 307 |
with open('api_calls.tsv', 'a') as f:
|
|
@@ -359,6 +509,12 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 359 |
if beep_detection_on:
|
| 360 |
# add vertical lines for beep event
|
| 361 |
fig.add_vrect(x0=event_start / fps, x1=event_end / fps, fillcolor="LightSalmon", opacity=0.25, layer="below", line_width=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
|
| 364 |
fig.update_layout(legend=dict(
|
|
@@ -405,10 +561,19 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 405 |
except FileNotFoundError:
|
| 406 |
pass
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
return in_video, count_msg, fig, hist, bar
|
| 409 |
|
| 410 |
|
| 411 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 412 |
with gr.Row():
|
| 413 |
in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
|
| 414 |
width=400, height=400, interactive=True, container=True,
|
|
@@ -417,11 +582,16 @@ with gr.Blocks() as demo:
|
|
| 417 |
with gr.Column():
|
| 418 |
in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
|
| 419 |
|
|
|
|
|
|
|
| 420 |
in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True, value='00:00:00')
|
| 421 |
in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
|
| 422 |
with gr.Column():
|
| 423 |
beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
|
| 424 |
event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
|
|
|
|
|
|
|
|
|
|
| 425 |
with gr.Column(min_width=480):
|
| 426 |
out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
|
| 427 |
|
|
@@ -448,24 +618,33 @@ with gr.Blocks() as demo:
|
|
| 448 |
|
| 449 |
demo_inference = partial(inference, count_only_api=False, api_key=None)
|
| 450 |
|
| 451 |
-
run_button.click(demo_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length],
|
| 452 |
outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
|
| 453 |
api_inference = partial(inference, api_call=True)
|
| 454 |
-
api_dummy_button.click(api_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, count_only, api_token],
|
| 455 |
outputs=[period_length], api_name='inference')
|
| 456 |
examples = [
|
| 457 |
#['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
|
| 458 |
[None, 'https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_UGEhqlMh/vod', '00:00:18', '00:00:55', True, 30],
|
|
|
|
| 459 |
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
|
| 460 |
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
|
| 461 |
]
|
| 462 |
gr.Examples(examples,
|
| 463 |
-
inputs=[in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length],
|
| 464 |
outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist],
|
| 465 |
fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
|
| 466 |
|
| 467 |
|
| 468 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 469 |
if LOCAL:
|
| 470 |
demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
|
| 471 |
server_port=7860,
|
|
|
|
| 28 |
|
| 29 |
plt.style.use('dark_background')
|
| 30 |
|
| 31 |
+
LOCAL = False
|
| 32 |
LOCAL = False
|
| 33 |
IMG_SIZE = 256
|
| 34 |
CACHE_API_CALLS = True
|
| 35 |
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
|
| 36 |
+
CACHE_API_CALLS = True
|
| 37 |
+
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
|
| 38 |
|
| 39 |
onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 40 |
if torch.cuda.is_available():
|
| 41 |
+
print("Using CUDA")
|
| 42 |
print("Using CUDA")
|
| 43 |
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
|
| 44 |
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
|
| 45 |
sess_options = ort.SessionOptions()
|
| 46 |
#sess_options.log_severity_level = 0
|
| 47 |
+
#sess_options.log_severity_level = 0
|
| 48 |
ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
|
| 49 |
else:
|
| 50 |
+
print("Using CPU")
|
| 51 |
print("Using CPU")
|
| 52 |
ort_sess = ort.InferenceSession(onnx_file)
|
| 53 |
|
|
|
|
| 123 |
return event_start, event_end
|
| 124 |
|
| 125 |
|
| 126 |
+
def detect_relay_beeps(video_path, event_start, relay_length=30, n_jumpers=4, beep_height=0.8):
|
| 127 |
+
reference_file = 'relay_beep.WAV'
|
| 128 |
+
fs, beep = wavfile.read(reference_file)
|
| 129 |
+
beep = beep[:, 0] + beep[:, 1] # combine stereo to mono
|
| 130 |
+
video = cv2.VideoCapture(video_path)
|
| 131 |
+
try:
|
| 132 |
+
os.remove('temp.wav')
|
| 133 |
+
except FileNotFoundError:
|
| 134 |
+
pass
|
| 135 |
+
audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav'
|
| 136 |
+
subprocess.call(audio_convert_command, shell=True)
|
| 137 |
+
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 138 |
+
fps = int(video.get(cv2.CAP_PROP_FPS))
|
| 139 |
+
audio = wavfile.read('temp.wav')[1]
|
| 140 |
+
audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono
|
| 141 |
+
corr = correlate(audio, beep, mode='same') / audio.size
|
| 142 |
+
# min max scale to -1, 1
|
| 143 |
+
corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1
|
| 144 |
+
|
| 145 |
+
# Calculate total event length in frames
|
| 146 |
+
total_event_length_frames = fps * relay_length * n_jumpers
|
| 147 |
+
print(event_start, total_event_length_frames)
|
| 148 |
+
expected_event_end = event_start + total_event_length_frames
|
| 149 |
+
|
| 150 |
+
# Find all significant peaks in the correlation
|
| 151 |
+
peaks, _ = find_peaks(corr, height=beep_height, distance=fs)
|
| 152 |
+
|
| 153 |
+
# Convert peaks from sample indices to frame indices
|
| 154 |
+
peak_frames = [int(peak / fs * fps) for peak in peaks]
|
| 155 |
+
|
| 156 |
+
# For debugging
|
| 157 |
+
plt.plot(corr)
|
| 158 |
+
plt.plot(peaks, corr[peaks], "x")
|
| 159 |
+
plt.savefig('beep.png')
|
| 160 |
+
plt.close()
|
| 161 |
+
|
| 162 |
+
starts = []
|
| 163 |
+
ends = []
|
| 164 |
+
|
| 165 |
+
# Add the event start for the first jumper
|
| 166 |
+
starts.append(event_start)
|
| 167 |
+
|
| 168 |
+
# Convert event_start back to sample index for comparison
|
| 169 |
+
event_start_sample = int(event_start * fs / fps)
|
| 170 |
+
|
| 171 |
+
# Find peaks that come after the event start but before the expected end
|
| 172 |
+
# Convert expected_event_end to sample index
|
| 173 |
+
expected_event_end_sample = int(expected_event_end * fs / fps)
|
| 174 |
+
relevant_peaks = [p for p in peaks if event_start_sample < p < expected_event_end_sample]
|
| 175 |
+
|
| 176 |
+
# If we don't have enough peaks, try lowering the threshold
|
| 177 |
+
if len(relevant_peaks) < n_jumpers - 1: # We need n_jumpers-1 transitions
|
| 178 |
+
for lower_height in [0.7, 0.6, 0.5, 0.4, 0.3]:
|
| 179 |
+
peaks, _ = find_peaks(corr, height=lower_height, distance=fs)
|
| 180 |
+
relevant_peaks = [p for p in peaks if event_start_sample < p < expected_event_end_sample]
|
| 181 |
+
if len(relevant_peaks) >= n_jumpers - 1:
|
| 182 |
+
break
|
| 183 |
+
|
| 184 |
+
# If we still don't have enough peaks, we'll need to estimate some transitions
|
| 185 |
+
relay_length_frames = fps * relay_length
|
| 186 |
+
|
| 187 |
+
# Process peaks to identify jumper transitions
|
| 188 |
+
if len(relevant_peaks) >= n_jumpers - 1:
|
| 189 |
+
# Ideal case: we found enough beeps for transitions
|
| 190 |
+
# Sort peaks by time to ensure correct order
|
| 191 |
+
relevant_peaks.sort()
|
| 192 |
+
|
| 193 |
+
# Use the first n_jumpers-1 peaks as transition points
|
| 194 |
+
transition_frames = [int(p / fs * fps) for p in relevant_peaks[:n_jumpers-1]]
|
| 195 |
+
|
| 196 |
+
# Set ends for jumpers based on transition points
|
| 197 |
+
for i in range(n_jumpers - 1):
|
| 198 |
+
ends.append(transition_frames[i])
|
| 199 |
+
starts.append(transition_frames[i])
|
| 200 |
+
|
| 201 |
+
# Add end for the last jumper
|
| 202 |
+
ends.append(expected_event_end)
|
| 203 |
+
else:
|
| 204 |
+
# Not enough peaks detected, use expected relay_length to estimate
|
| 205 |
+
for i in range(n_jumpers):
|
| 206 |
+
if i == 0:
|
| 207 |
+
# First jumper starts at event_start (already added to starts)
|
| 208 |
+
jumper_end = event_start + relay_length_frames
|
| 209 |
+
ends.append(jumper_end)
|
| 210 |
+
if i < n_jumpers - 1:
|
| 211 |
+
starts.append(jumper_end)
|
| 212 |
+
elif i < n_jumpers - 1:
|
| 213 |
+
jumper_end = starts[i] + relay_length_frames
|
| 214 |
+
ends.append(jumper_end)
|
| 215 |
+
starts.append(jumper_end)
|
| 216 |
+
else:
|
| 217 |
+
# Last jumper
|
| 218 |
+
jumper_end = starts[i] + relay_length_frames
|
| 219 |
+
ends.append(jumper_end)
|
| 220 |
+
|
| 221 |
+
# Validate and adjust if necessary
|
| 222 |
+
# Make sure all intervals are close to relay_length
|
| 223 |
+
for i in range(n_jumpers):
|
| 224 |
+
interval = ends[i] - starts[i]
|
| 225 |
+
# If an interval is significantly different from relay_length, adjust it
|
| 226 |
+
if abs(interval - relay_length_frames) > relay_length_frames * 0.2: # 20% tolerance
|
| 227 |
+
# Adjust the end time to match expected relay_length
|
| 228 |
+
ends[i] = starts[i] + relay_length_frames
|
| 229 |
+
# If not the last jumper, adjust the next start time
|
| 230 |
+
if i < n_jumpers - 1:
|
| 231 |
+
starts[i + 1] = ends[i]
|
| 232 |
+
|
| 233 |
+
# Final check: ensure the total length matches expected
|
| 234 |
+
if ends[-1] != expected_event_end:
|
| 235 |
+
# Adjust the last end to match the expected total event end
|
| 236 |
+
ends[-1] = expected_event_end
|
| 237 |
+
|
| 238 |
+
return starts, ends
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def inference(in_video, stream_url, start_time, end_time, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
|
| 242 |
+
count_only_api, api_key,
|
| 243 |
img_size=256, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
|
| 244 |
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
|
| 245 |
api_call=False,
|
| 246 |
progress=gr.Progress()):
|
| 247 |
progress(0, desc="Downloading clip...")
|
| 248 |
+
if in_video is None:
|
| 249 |
+
in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
|
| 250 |
+
else: # local uploaded video (still resize with ffmpeg)
|
| 251 |
+
in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
|
| 252 |
if in_video is None:
|
| 253 |
in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
|
| 254 |
else: # local uploaded video (still resize with ffmpeg)
|
|
|
|
| 264 |
event_length = int(event_length)
|
| 265 |
event_start, event_end = detect_beeps(in_video, event_length)
|
| 266 |
print(event_start, event_end)
|
| 267 |
+
if relay_detection_on:
|
| 268 |
+
n_jumpers = int(int(event_length) / int(relay_length))
|
| 269 |
+
relay_starts, relay_ends = detect_relay_beeps(in_video, event_start, int(relay_length), n_jumpers)
|
| 270 |
+
print(relay_starts, relay_ends)
|
| 271 |
|
| 272 |
cap = cv2.VideoCapture(in_video)
|
| 273 |
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
| 314 |
batch_list = []
|
| 315 |
idx_list = []
|
| 316 |
inference_futures = []
|
| 317 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
| 318 |
for i in range(0, length + stride_length - stride_pad, stride_length):
|
| 319 |
batch = all_frames[i:i + seq_len]
|
| 320 |
Xlist = []
|
|
|
|
| 363 |
period_length_overlaps[idx:idx+seq_len] += 1
|
| 364 |
event_type_logit_overlaps[idx:idx+seq_len] += 1
|
| 365 |
del y1_out, y2_out, y3_out, y4_out # free up memory
|
| 366 |
+
del y1_out, y2_out, y3_out, y4_out # free up memory
|
| 367 |
|
| 368 |
periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
| 369 |
periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
|
|
|
|
| 385 |
print(f"Event detected: {event_start} - {event_end}")
|
| 386 |
periodicity[:event_start] = 0
|
| 387 |
periodicity[event_end:] = 0
|
| 388 |
+
if relay_detection_on:
|
| 389 |
+
for start, end in zip(relay_starts, relay_ends):
|
| 390 |
+
if start > 0 and end > 0:
|
| 391 |
+
print(f"Relay Event detected: {start} - {end}")
|
| 392 |
+
# immediately after the beep set periodicity to 0 for switch_delay seconds
|
| 393 |
+
periodicity[start:start + int(float(switch_delay) * fps)] = 0
|
| 394 |
pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
|
| 395 |
full_marks_mask = np.zeros(len(full_marks))
|
| 396 |
full_marks_mask[pred_marks_peaks] = 1
|
|
|
|
| 424 |
confidence = 0
|
| 425 |
else:
|
| 426 |
confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
|
| 427 |
+
periodicity_mask = periodicity > miss_threshold
|
| 428 |
+
if np.sum(periodicity_mask) == 0:
|
| 429 |
+
confidence = 0
|
| 430 |
+
else:
|
| 431 |
+
confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
|
| 432 |
except ZeroDivisionError:
|
| 433 |
confidence = 0
|
| 434 |
self_err = abs(count_pred - marks_count_pred)
|
|
|
|
| 442 |
count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
|
| 443 |
else:
|
| 444 |
count_msg = f"## Reps Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
|
| 445 |
+
count_msg = f"## Reps Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
|
| 446 |
|
| 447 |
if api_call:
|
| 448 |
+
if CACHE_API_CALLS:
|
| 449 |
+
# write outputs as row of csv
|
| 450 |
+
with open('api_calls.tsv', 'a') as f:
|
| 451 |
+
periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
|
| 452 |
+
periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
|
| 453 |
+
full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
|
| 454 |
+
f.write(f"{stream_url}\t{start_time}\t{end_time}\t{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n")
|
| 455 |
if CACHE_API_CALLS:
|
| 456 |
# write outputs as row of csv
|
| 457 |
with open('api_calls.tsv', 'a') as f:
|
|
|
|
| 509 |
if beep_detection_on:
|
| 510 |
# add vertical lines for beep event
|
| 511 |
fig.add_vrect(x0=event_start / fps, x1=event_end / fps, fillcolor="LightSalmon", opacity=0.25, layer="below", line_width=0)
|
| 512 |
+
if relay_detection_on:
|
| 513 |
+
for start, end in zip(relay_starts, relay_ends):
|
| 514 |
+
start += 10 # add some padding
|
| 515 |
+
end -= 10
|
| 516 |
+
fig.add_vrect(x0=start / fps, x1=end / fps, fillcolor="LightGreen", opacity=0.25, layer="below",
|
| 517 |
+
line_width=0)
|
| 518 |
|
| 519 |
|
| 520 |
fig.update_layout(legend=dict(
|
|
|
|
| 561 |
except FileNotFoundError:
|
| 562 |
pass
|
| 563 |
|
| 564 |
+
try:
|
| 565 |
+
os.remove('temp.wav')
|
| 566 |
+
except FileNotFoundError:
|
| 567 |
+
pass
|
| 568 |
+
|
| 569 |
return in_video, count_msg, fig, hist, bar
|
| 570 |
|
| 571 |
|
| 572 |
with gr.Blocks() as demo:
|
| 573 |
+
with gr.Row():
|
| 574 |
+
in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
|
| 575 |
+
width=400, height=400, interactive=True, container=True,
|
| 576 |
+
max_length=300)
|
| 577 |
with gr.Row():
|
| 578 |
in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
|
| 579 |
width=400, height=400, interactive=True, container=True,
|
|
|
|
| 582 |
with gr.Column():
|
| 583 |
in_stream_url = gr.Textbox(label="Stream URL", elem_id='stream-url', visible=True)
|
| 584 |
|
| 585 |
+
in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True, value='00:00:00')
|
| 586 |
+
|
| 587 |
in_stream_start = gr.Textbox(label="Start Time", elem_id='stream-start', visible=True, value='00:00:00')
|
| 588 |
in_stream_end = gr.Textbox(label="End Time", elem_id='stream-end', visible=True)
|
| 589 |
with gr.Column():
|
| 590 |
beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
|
| 591 |
event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
|
| 592 |
+
relay_detection_on = gr.Checkbox(label="Relay Event", elem_id='relay-beeps', visible=True)
|
| 593 |
+
relay_length = gr.Textbox(label="Relay Length (s)", elem_id='relay-length', visible=True, value='30')
|
| 594 |
+
switch_delay = gr.Textbox(label="Expected Switch Delay (s)", elem_id='event-length', visible=True, value='0.2')
|
| 595 |
with gr.Column(min_width=480):
|
| 596 |
out_video = gr.PlayableVideo(label="Video Clip", elem_id='output-video', format='mp4', width=400, height=400)
|
| 597 |
|
|
|
|
| 618 |
|
| 619 |
demo_inference = partial(inference, count_only_api=False, api_key=None)
|
| 620 |
|
| 621 |
+
run_button.click(demo_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
|
| 622 |
outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist])
|
| 623 |
api_inference = partial(inference, api_call=True)
|
| 624 |
+
api_dummy_button.click(api_inference, [in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
|
| 625 |
outputs=[period_length], api_name='inference')
|
| 626 |
examples = [
|
| 627 |
#['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
|
| 628 |
[None, 'https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_UGEhqlMh/vod', '00:00:18', '00:00:55', True, 30],
|
| 629 |
+
[None, 'https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_UGEhqlMh/vod', '00:00:18', '00:00:55', True, 30],
|
| 630 |
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
|
| 631 |
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
|
| 632 |
]
|
| 633 |
gr.Examples(examples,
|
| 634 |
+
inputs=[in_video, in_stream_url, in_stream_start, in_stream_end, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
|
| 635 |
outputs=[out_video, out_text, out_plot, out_hist, out_event_type_dist],
|
| 636 |
fn=demo_inference, cache_examples=os.getenv('SYSTEM') == 'spaces')
|
| 637 |
|
| 638 |
|
| 639 |
if __name__ == "__main__":
|
| 640 |
+
if LOCAL:
|
| 641 |
+
demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
|
| 642 |
+
server_port=7860,
|
| 643 |
+
debug=False,
|
| 644 |
+
ssl_verify=False,
|
| 645 |
+
share=False)
|
| 646 |
+
else:
|
| 647 |
+
demo.queue(api_open=True, max_size=15).launch(share=False)
|
| 648 |
if LOCAL:
|
| 649 |
demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
|
| 650 |
server_port=7860,
|