Update app.py
Browse files
app.py
CHANGED
|
@@ -4,6 +4,7 @@ from PIL import Image
|
|
| 4 |
import os
|
| 5 |
import cv2
|
| 6 |
import math
|
|
|
|
| 7 |
import subprocess
|
| 8 |
import matplotlib
|
| 9 |
matplotlib.use('Agg')
|
|
@@ -26,28 +27,22 @@ from huggingface_hub import HfApi
|
|
| 26 |
|
| 27 |
from hls_download import download_clips
|
| 28 |
|
| 29 |
-
plt.style.use('dark_background')
|
| 30 |
|
| 31 |
-
LOCAL = False
|
| 32 |
LOCAL = False
|
| 33 |
IMG_SIZE = 256
|
| 34 |
CACHE_API_CALLS = True
|
| 35 |
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
|
| 36 |
-
CACHE_API_CALLS = True
|
| 37 |
-
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
|
| 38 |
|
| 39 |
onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 40 |
if torch.cuda.is_available():
|
| 41 |
-
print("Using CUDA")
|
| 42 |
print("Using CUDA")
|
| 43 |
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
|
| 44 |
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
|
| 45 |
sess_options = ort.SessionOptions()
|
| 46 |
#sess_options.log_severity_level = 0
|
| 47 |
-
#sess_options.log_severity_level = 0
|
| 48 |
ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
|
| 49 |
else:
|
| 50 |
-
print("Using CPU")
|
| 51 |
print("Using CPU")
|
| 52 |
ort_sess = ort.InferenceSession(onnx_file)
|
| 53 |
|
|
@@ -245,10 +240,6 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 245 |
api_call=False,
|
| 246 |
progress=gr.Progress()):
|
| 247 |
progress(0, desc="Downloading clip...")
|
| 248 |
-
if in_video is None:
|
| 249 |
-
in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
|
| 250 |
-
else: # local uploaded video (still resize with ffmpeg)
|
| 251 |
-
in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
|
| 252 |
if in_video is None:
|
| 253 |
in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
|
| 254 |
else: # local uploaded video (still resize with ffmpeg)
|
|
@@ -448,13 +439,58 @@ def inference(in_video, stream_url, start_time, end_time, beep_detection_on, eve
|
|
| 448 |
if count_only_api:
|
| 449 |
return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
|
| 450 |
else:
|
| 451 |
-
|
| 452 |
-
|
| 453 |
-
np.array2string(
|
| 454 |
-
|
| 455 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
|
| 457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 458 |
jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
|
| 459 |
jumping_speed = np.copy(jumps_per_second)
|
| 460 |
misses = periodicity < miss_threshold
|
|
@@ -602,7 +638,7 @@ with gr.Blocks() as demo:
|
|
| 602 |
outputs=[period_length], api_name='inference')
|
| 603 |
examples = [
|
| 604 |
#['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
|
| 605 |
-
[
|
| 606 |
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
|
| 607 |
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
|
| 608 |
]
|
|
@@ -613,14 +649,6 @@ with gr.Blocks() as demo:
|
|
| 613 |
|
| 614 |
|
| 615 |
if __name__ == "__main__":
|
| 616 |
-
if LOCAL:
|
| 617 |
-
demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
|
| 618 |
-
server_port=7860,
|
| 619 |
-
debug=False,
|
| 620 |
-
ssl_verify=False,
|
| 621 |
-
share=False)
|
| 622 |
-
else:
|
| 623 |
-
demo.queue(api_open=True, max_size=15).launch(share=False)
|
| 624 |
if LOCAL:
|
| 625 |
demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
|
| 626 |
server_port=7860,
|
|
|
|
| 4 |
import os
|
| 5 |
import cv2
|
| 6 |
import math
|
| 7 |
+
import json
|
| 8 |
import subprocess
|
| 9 |
import matplotlib
|
| 10 |
matplotlib.use('Agg')
|
|
|
|
| 27 |
|
| 28 |
from hls_download import download_clips
|
| 29 |
|
| 30 |
+
#plt.style.use('dark_background')
|
| 31 |
|
|
|
|
| 32 |
LOCAL = False
|
| 33 |
IMG_SIZE = 256
|
| 34 |
CACHE_API_CALLS = True
|
| 35 |
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
|
|
|
|
|
|
|
| 36 |
|
| 37 |
onnx_file = hf_hub_download(repo_id="dylanplummer/ropenet", filename="nextjump.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
|
| 38 |
if torch.cuda.is_available():
|
|
|
|
| 39 |
print("Using CUDA")
|
| 40 |
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
|
| 41 |
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
|
| 42 |
sess_options = ort.SessionOptions()
|
| 43 |
#sess_options.log_severity_level = 0
|
|
|
|
| 44 |
ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
|
| 45 |
else:
|
|
|
|
| 46 |
print("Using CPU")
|
| 47 |
ort_sess = ort.InferenceSession(onnx_file)
|
| 48 |
|
|
|
|
| 240 |
api_call=False,
|
| 241 |
progress=gr.Progress()):
|
| 242 |
progress(0, desc="Downloading clip...")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
if in_video is None:
|
| 244 |
in_video = download_clips(stream_url, os.path.join(os.getcwd(), 'clips'), start_time, end_time)
|
| 245 |
else: # local uploaded video (still resize with ffmpeg)
|
|
|
|
| 439 |
if count_only_api:
|
| 440 |
return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
|
| 441 |
else:
|
| 442 |
+
# create a nice json object to return
|
| 443 |
+
results_dict = {
|
| 444 |
+
"periodLength": np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
|
| 445 |
+
"periodicity": np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
|
| 446 |
+
"full_marks": np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
|
| 447 |
+
"cum_count": np.array2string(np.array(count), formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
|
| 448 |
+
"count": f"{count_pred:.2f}",
|
| 449 |
+
"marks": f"{marks_count_pred:.1f}",
|
| 450 |
+
"confidence": f"{total_confidence:.2f}",
|
| 451 |
+
"single_rope_speed": f"{event_type_probs[0]:.3f}",
|
| 452 |
+
"double_dutch": f"{event_type_probs[1]:.3f}",
|
| 453 |
+
"double_unders": f"{event_type_probs[2]:.3f}",
|
| 454 |
+
"single_bounce": f"{event_type_probs[3]:.3f}"
|
| 455 |
+
}
|
| 456 |
+
if beep_detection_on:
|
| 457 |
+
results_dict['event_start'] = event_start
|
| 458 |
+
results_dict['event_end'] = event_end
|
| 459 |
+
if relay_detection_on:
|
| 460 |
+
results_dict['relay_starts'] = relay_starts
|
| 461 |
+
results_dict['relay_ends'] = relay_ends
|
| 462 |
+
return json.dumps(results_dict)
|
| 463 |
+
|
| 464 |
|
| 465 |
|
| 466 |
+
fig, axs = plt.subplots(4, 1, figsize=(12, 10)) # Added a plot for count
|
| 467 |
+
|
| 468 |
+
# Ensure data exists before plotting
|
| 469 |
+
axs[0].plot(periodLength)
|
| 470 |
+
axs[0].set_title(f"Stream 0 - Period Length")
|
| 471 |
+
|
| 472 |
+
axs[1].plot(periodicity)
|
| 473 |
+
axs[1].set_title("Stream 0 - Periodicity")
|
| 474 |
+
axs[1].set_ylim(0, 1)
|
| 475 |
+
axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})')
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
axs[2].plot(full_marks, label='Raw Marks')
|
| 479 |
+
marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
|
| 480 |
+
axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks')
|
| 481 |
+
axs[2].set_title("Stream 0 - Marks")
|
| 482 |
+
axs[2].set_ylim(0, 1)
|
| 483 |
+
axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})')
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
axs[3].plot(count)
|
| 487 |
+
axs[3].set_title("Stream 0 - Calculated Count")
|
| 488 |
+
|
| 489 |
+
plt.tight_layout()
|
| 490 |
+
|
| 491 |
+
plt.savefig('plot.png')
|
| 492 |
+
plt.close()
|
| 493 |
+
|
| 494 |
jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
|
| 495 |
jumping_speed = np.copy(jumps_per_second)
|
| 496 |
misses = periodicity < miss_threshold
|
|
|
|
| 638 |
outputs=[period_length], api_name='inference')
|
| 639 |
examples = [
|
| 640 |
#['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
|
| 641 |
+
['files/wc2023.mp4', '', '00:00:00', '', True, 30, False, '', '0.2'],
|
| 642 |
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
|
| 643 |
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
|
| 644 |
]
|
|
|
|
| 649 |
|
| 650 |
|
| 651 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 652 |
if LOCAL:
|
| 653 |
demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
|
| 654 |
server_port=7860,
|