import gradio as gr
import numpy as np
from PIL import Image
import os
import cv2
import math
import json
import time
import subprocess
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import concurrent.futures
from scipy.io import wavfile
from scipy.signal import medfilt, correlate, find_peaks
from functools import partial
from passlib.hash import pbkdf2_sha256
from tqdm import tqdm
import pandas as pd
import plotly.express as px
import onnxruntime as ort
import torch
from torchvision import transforms
import torchvision.transforms.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi
from hls_download import download_clips
#plt.style.use('dark_background')
LOCAL = False
IMG_SIZE = 256
CACHE_API_CALLS = False
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
current_model = 'nextjump_speed'
onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
api = HfApi()
if torch.cuda.is_available():
print("Using CUDA")
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
sess_options = ort.SessionOptions()
#sess_options.log_severity_level = 0
ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
else:
print("Using CPU")
ort_sess = ort.InferenceSession(onnx_file)
# warmup inference
ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
def square_pad_opencv(image):
h, w = image.shape[:2]
max_wh = max(w, h)
hp = int((max_wh - w) / 2)
vp = int((max_wh - h) / 2)
return cv2.copyMakeBorder(image, vp, vp, hp, hp, cv2.BORDER_CONSTANT, value=[0, 0, 0])
def preprocess_image(img, img_size):
#img = square_pad_opencv(img)
#img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC)
img = Image.fromarray(img)
transforms_list = []
transforms_list.append(transforms.ToTensor())
preprocess = transforms.Compose(transforms_list)
return preprocess(img).unsqueeze(0)
def run_inference(batch_X):
batch_X = torch.cat(batch_X)
return ort_sess.run(None, {'video': batch_X.numpy()})
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def detect_beeps(video_path, event_length=30, beep_height=0.8):
reference_file = 'beep.WAV'
fs, beep = wavfile.read(reference_file)
beep = beep[:, 0] + beep[:, 1] # combine stereo to mono
video = cv2.VideoCapture(video_path)
try:
os.remove('temp.wav')
except FileNotFoundError:
pass
audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav'
subprocess.call(audio_convert_command, shell=True)
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(video.get(cv2.CAP_PROP_FPS))
audio = wavfile.read('temp.wav')[1]
audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono
corr = correlate(audio, beep, mode='same') / audio.size
# min max scale to -1, 1
corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1
event_start = length
while length - event_start < fps * event_length:
peaks, _ = find_peaks(corr, height=beep_height, distance=fs)
event_start = int(peaks[0] / fs * fps)
event_end = int(peaks[-1] / fs * fps)
if event_end == event_start:
event_end = event_start + fps * event_length
beep_height -= 0.1
if beep_height <= 0.1:
event_start = 0
event_end = length
break
#peaks, _ = find_peaks(corr, height=0.7, distance=fs)
#event_start = int(peaks[0] / fs * fps)
#event_end = int(peaks[-1] / fs * fps)
# plt.plot(corr)
# plt.plot(peaks, corr[peaks], "x")
# plt.savefig('beep.png')
# plt.close()
return event_start, event_end
def detect_relay_beeps(video_path, event_start, relay_length=30, n_jumpers=4, beep_height=0.8):
reference_file = 'relay_beep.WAV'
fs, beep = wavfile.read(reference_file)
beep = beep[:, 0] + beep[:, 1] # combine stereo to mono
video = cv2.VideoCapture(video_path)
try:
os.remove('temp.wav')
except FileNotFoundError:
pass
audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav'
subprocess.call(audio_convert_command, shell=True)
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(video.get(cv2.CAP_PROP_FPS))
audio = wavfile.read('temp.wav')[1]
audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono
corr = correlate(audio, beep, mode='same') / audio.size
# min max scale to -1, 1
corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1
# Calculate total event length in frames
total_event_length_frames = fps * relay_length * n_jumpers
print(event_start, total_event_length_frames)
expected_event_end = event_start + total_event_length_frames
# Find all significant peaks in the correlation
peaks, _ = find_peaks(corr, height=beep_height, distance=fs)
# Convert peaks from sample indices to frame indices
peak_frames = [int(peak / fs * fps) for peak in peaks]
# For debugging
plt.plot(corr)
plt.plot(peaks, corr[peaks], "x")
plt.savefig('beep.png')
plt.close()
starts = []
ends = []
# Add the event start for the first jumper
starts.append(event_start)
# Convert event_start back to sample index for comparison
event_start_sample = int(event_start * fs / fps)
# Find peaks that come after the event start but before the expected end
# Convert expected_event_end to sample index
expected_event_end_sample = int(expected_event_end * fs / fps)
relevant_peaks = [p for p in peaks if event_start_sample < p < expected_event_end_sample]
# If we don't have enough peaks, try lowering the threshold
if len(relevant_peaks) < n_jumpers - 1: # We need n_jumpers-1 transitions
for lower_height in [0.7, 0.6, 0.5, 0.4, 0.3]:
peaks, _ = find_peaks(corr, height=lower_height, distance=fs)
relevant_peaks = [p for p in peaks if event_start_sample < p < expected_event_end_sample]
if len(relevant_peaks) >= n_jumpers - 1:
break
# If we still don't have enough peaks, we'll need to estimate some transitions
relay_length_frames = fps * relay_length
# Process peaks to identify jumper transitions
if len(relevant_peaks) >= n_jumpers - 1:
# Ideal case: we found enough beeps for transitions
# Sort peaks by time to ensure correct order
relevant_peaks.sort()
# Use the first n_jumpers-1 peaks as transition points
transition_frames = [int(p / fs * fps) for p in relevant_peaks[:n_jumpers-1]]
# Set ends for jumpers based on transition points
for i in range(n_jumpers - 1):
ends.append(transition_frames[i])
starts.append(transition_frames[i])
# Add end for the last jumper
ends.append(expected_event_end)
else:
# Not enough peaks detected, use expected relay_length to estimate
for i in range(n_jumpers):
if i == 0:
# First jumper starts at event_start (already added to starts)
jumper_end = event_start + relay_length_frames
ends.append(jumper_end)
if i < n_jumpers - 1:
starts.append(jumper_end)
elif i < n_jumpers - 1:
jumper_end = starts[i] + relay_length_frames
ends.append(jumper_end)
starts.append(jumper_end)
else:
# Last jumper
jumper_end = starts[i] + relay_length_frames
ends.append(jumper_end)
# Validate and adjust if necessary
# Make sure all intervals are close to relay_length
for i in range(n_jumpers):
interval = ends[i] - starts[i]
# If an interval is significantly different from relay_length, adjust it
if abs(interval - relay_length_frames) > relay_length_frames * 0.2: # 20% tolerance
# Adjust the end time to match expected relay_length
ends[i] = starts[i] + relay_length_frames
# If not the last jumper, adjust the next start time
if i < n_jumpers - 1:
starts[i + 1] = ends[i]
# Final check: ensure the total length matches expected
if ends[-1] != expected_event_end:
# Adjust the last end to match the expected total event end
ends[-1] = expected_event_end
return starts, ends
def upload_video(out_text, in_video):
if out_text != '':
# generate a timestamp name for the video
upload_path = f"{int(time.time())}.mp4"
api.upload_file(
path_or_fileobj=in_video,
path_in_repo=upload_path,
repo_id="lumos-motion/single-rope-contest",
repo_type="dataset",
)
def inference(in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay,
count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=4,
miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
api_call=False,
progress=gr.Progress()):
global current_model
if model_choice != current_model:
current_model = model_choice
onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
if torch.cuda.is_available():
print("Using CUDA")
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
sess_options = ort.SessionOptions()
#sess_options.log_severity_level = 0
ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
else:
print("Using CPU")
ort_sess = ort.InferenceSession(onnx_file)
# warmup inference
ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), '00:00:00', '', use_60fps=use_60fps)
progress(0, desc="Running inference...")
has_access = False
if api_call:
has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
if not has_access:
return "Invalid API Key"
if beep_detection_on:
event_length = int(event_length)
event_start, event_end = detect_beeps(in_video, event_length)
print(event_start, event_end)
if relay_detection_on:
n_jumpers = int(int(event_length) / int(relay_length))
relay_starts, relay_ends = detect_relay_beeps(in_video, event_start, int(relay_length), n_jumpers)
print(relay_starts, relay_ends)
cap = cv2.VideoCapture(in_video)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
period_length_overlaps = np.zeros(length + seq_len)
fps = int(cap.get(cv2.CAP_PROP_FPS))
seconds = length / fps
all_frames = []
frame_i = 0
resize_amount = max((IMG_SIZE + 64) / frame_width, (IMG_SIZE + 64) / frame_height)
while cap.isOpened():
frame_i += 1
ret, frame = cap.read()
if ret is False:
frame = all_frames[-1] # padding will be with last frame
break
frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
# add square padding with opencv
#frame = square_pad_opencv(frame)
frame_center_x = frame.shape[1] // 2
frame_center_y = frame.shape[0] // 2
frame = cv2.resize(frame, (0, 0), fx=resize_amount, fy=resize_amount, interpolation=cv2.INTER_CUBIC)
frame_center_x = frame.shape[1] // 2
frame_center_y = frame.shape[0] // 2
crop_x = frame_center_x - IMG_SIZE // 2
crop_y = frame_center_y - IMG_SIZE // 2
frame = frame[crop_y:crop_y+IMG_SIZE, crop_x:crop_x+IMG_SIZE]
all_frames.append(frame)
cap.release()
length = len(all_frames)
period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
period_lengths_rope = np.zeros(len(all_frames) + seq_len + stride_length)
periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
phase_sin = np.zeros(len(all_frames) + seq_len + stride_length)
phase_cos = np.zeros(len(all_frames) + seq_len + stride_length)
period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
for _ in range(seq_len + stride_length): # pad full sequence
all_frames.append(all_frames[-1])
batch_list = []
idx_list = []
inference_futures = []
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
for i in progress.tqdm(range(0, length + stride_length - stride_pad, stride_length)):
batch = all_frames[i:i + seq_len]
Xlist = []
preprocess_tasks = [(idx, executor.submit(preprocess_image, img, IMG_SIZE)) for idx, img in enumerate(batch)]
for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]):
Xlist.append(future.result())
if len(Xlist) < seq_len:
for _ in range(seq_len - len(Xlist)):
Xlist.append(Xlist[-1])
X = torch.cat(Xlist)
X *= 255
batch_list.append(X.unsqueeze(0))
idx_list.append(i)
if len(batch_list) == batch_size:
future = executor.submit(run_inference, batch_list)
inference_futures.append((batch_list, idx_list, future))
batch_list = []
idx_list = []
# Process any remaining batches
if batch_list:
while len(batch_list) != batch_size:
batch_list.append(batch_list[-1])
idx_list.append(idx_list[-1])
future = executor.submit(run_inference, batch_list)
inference_futures.append((batch_list, idx_list, future))
progress(0, desc="Processing results...")
# Collect and process the inference results
for batch_list, idx_list, future in progress.tqdm(inference_futures):
outputs = future.result()
y1_out = outputs[0]
y2_out = outputs[1]
y3_out = outputs[2]
y4_out = outputs[3]
y5_out = outputs[4]
y6_out = outputs[5]
for y1, y2, y3, y4, y5, y6, idx in zip(y1_out, y2_out, y3_out, y4_out, y5_out, y6_out, idx_list):
periodLength = y1
periodicity = y2.squeeze()
marks = y3.squeeze()
event_type = y4.squeeze()
foot_type = y5.squeeze()
phase = y6.squeeze()
period_lengths[idx:idx+seq_len] += periodLength[:, 0]
#period_lengths_rope[idx:idx+seq_len] += periodLength[:, 1]
periodicities[idx:idx+seq_len] += periodicity
full_marks[idx:idx+seq_len] += marks
event_type_logits[idx:idx+seq_len] += event_type
phase_sin[idx:idx+seq_len] += phase[:, 1]
phase_cos[idx:idx+seq_len] += phase[:, 0]
period_length_overlaps[idx:idx+seq_len] += 1
event_type_logit_overlaps[idx:idx+seq_len] += 1
del y1_out, y2_out, y3_out, y4_out # free up memory
periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
#periodLength_rope = np.divide(period_lengths_rope, period_length_overlaps, where=period_length_overlaps!=0)[:length]
periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
phase_sin = np.divide(phase_sin, period_length_overlaps, where=period_length_overlaps!=0)[:length]
# negate sin to make the bottom of the plot the start of the jump
phase_sin = -phase_sin
phase_cos = np.divide(phase_cos, period_length_overlaps, where=period_length_overlaps!=0)[:length]
event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
# softmax of event type logits
event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
if median_pred_filter:
periodicity = medfilt(periodicity, 5)
periodLength = medfilt(periodLength, 5)
periodicity = sigmoid(periodicity)
full_marks = sigmoid(full_marks)
# if the event_start and event_end (in frames) are detected and form a valid event of event_length (in seconds)
if beep_detection_on:
if event_start > 0 and event_end > 0 and (event_end - event_start) - (event_length * fps) < 0.5:
print(f"Event detected: {event_start} - {event_end}")
periodicity[:event_start] = 0
periodicity[event_end:] = 0
if relay_detection_on:
for start, end in zip(relay_starts, relay_ends):
if start > 0 and end > 0:
print(f"Relay Event detected: {start} - {end}")
# immediately after the beep set periodicity to 0 for switch_delay seconds
periodicity[start:start + int(float(switch_delay) * fps)] = 0
pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
full_marks_mask = np.zeros(len(full_marks))
full_marks_mask[pred_marks_peaks] = 1
periodicity_mask = np.int32(periodicity > miss_threshold)
numofReps = 0
count = []
for i in range(len(periodLength)):
if periodLength[i] < 2 or periodicity_mask[i] == 0:
numofReps += 0
elif full_marks_mask[i]: # high confidence mark detected
if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection
numofReps = float(int(numofReps))
else:
numofReps = float(int(numofReps) + 1.01) # round up
else:
numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
count.append(round(float(numofReps), 2))
count_pred = count[-1]
marks_count_pred = 0
for i in range(len(full_marks) - 1):
# if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting)
if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0:
marks_count_pred += 1
if not both_feet:
count_pred = count_pred / 2
marks_count_pred = marks_count_pred / 2
count = np.array(count) / 2
try:
periodicity_mask = periodicity > miss_threshold
if np.sum(periodicity_mask) == 0:
confidence = 0
else:
confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
except ZeroDivisionError:
confidence = 0
self_err = abs(count_pred - marks_count_pred)
try:
self_pct_err = self_err / count_pred
except ZeroDivisionError:
self_pct_err = 0
total_confidence = confidence * (1 - self_pct_err)
if LOCAL:
if both_feet:
count_msg = f"## Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
else:
count_msg = f"## Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
else:
if both_feet:
count_msg = f"## Count (both feet): {count_pred:.1f}, Confidence: {total_confidence:.2f}"
else:
count_msg = f"## Count (one foot): {count_pred:.1f}, Confidence: {total_confidence:.2f}"
if api_call:
if CACHE_API_CALLS:
# write outputs as row of csv
with open('api_calls.tsv', 'a') as f:
periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
f.write(f"{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n")
if count_only_api:
return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
else:
# create a nice json object to return
results_dict = {
"periodLength": np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"periodicity": np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"full_marks": np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"cum_count": np.array2string(np.array(count), formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"count": f"{count_pred:.2f}",
"marks": f"{marks_count_pred:.1f}",
"confidence": f"{total_confidence:.2f}",
"single_rope_speed": f"{event_type_probs[0]:.3f}",
"double_dutch": f"{event_type_probs[1]:.3f}",
"double_unders": f"{event_type_probs[2]:.3f}",
"single_bounce": f"{event_type_probs[3]:.3f}"
}
if beep_detection_on:
results_dict['event_start'] = event_start
results_dict['event_end'] = event_end
if relay_detection_on:
results_dict['relay_starts'] = relay_starts
results_dict['relay_ends'] = relay_ends
return json.dumps(results_dict)
# fig, axs = plt.subplots(5, 1, figsize=(14, 10)) # Added a plot for count
# # Ensure data exists before plotting
# axs[0].plot(periodLength, label='Period Length')
# axs[0].plot(periodLength_rope, label='Period Length (Rope)')
# axs[0].set_title(f"Stream 0 - Period Length")
# axs[0].legend()
# axs[1].plot(periodicity)
# axs[1].set_title("Stream 0 - Periodicity")
# axs[1].set_ylim(0, 1)
# axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})')
# axs[2].plot(full_marks, label='Raw Marks')
# marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
# axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks')
# axs[2].set_title("Stream 0 - Marks")
# axs[2].set_ylim(0, 1)
# axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})')
# # plot phase
# axs[3].plot(phase_sin, label='Phase Sin')
# axs[3].plot(phase_cos, label='Phase Cos')
# axs[3].set_title("Stream 0 - Phase")
# axs[3].set_ylim(-1, 1)
# axs[3].axhline(0, color='r', linestyle=':', label='Zero Line')
# axs[3].legend()
# axs[4].plot(count)
# axs[4].set_title("Stream 0 - Calculated Count")
# plt.tight_layout()
# plt.savefig('plot.png')
# plt.close()
jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
jumping_speed = np.copy(jumps_per_second)
misses = periodicity < miss_threshold
jumps_per_second[misses] = 0
frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
frame_type[full_marks > marks_threshold] = 'jump'
per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6
df = pd.DataFrame.from_dict({'period length': periodLength,
'jumping speed': jumping_speed,
'jumps per second': jumps_per_second,
'periodicity': periodicity,
'phase sin': phase_sin,
'phase cos': phase_cos,
'miss': misses,
'frame_type': frame_type,
'event_type': per_frame_event_types,
'jumps': full_marks,
'jumps_size': (full_marks + 0.05) * 10,
'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8),
'seconds': np.linspace(0, seconds, num=len(periodLength))})
event_type_tick_vals = np.linspace(0, 1, num=7)
event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black']
fig = px.scatter(data_frame=df,
x='seconds',
y='jumps per second',
#symbol='frame_type',
#symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
color='periodicity',
size='jumps_size',
size_max=8,
color_continuous_scale='rainbow',
range_color=(0,1),
title="Jumping speed (jumps-per-second)",
trendline='rolling',
trendline_options=dict(window=16),
trendline_color_override="goldenrod",
trendline_scope='overall',
template="plotly_dark")
if beep_detection_on:
# add vertical lines for beep event
fig.add_vrect(x0=event_start / fps, x1=event_end / fps, fillcolor="LightSalmon", opacity=0.25, layer="below", line_width=0)
if relay_detection_on:
for start, end in zip(relay_starts, relay_ends):
start += 10 # add some padding
end -= 10
fig.add_vrect(x0=start / fps, x1=end / fps, fillcolor="LightGreen", opacity=0.25, layer="below",
line_width=0)
fig.update_layout(legend=dict(
orientation="h",
yanchor="bottom",
y=0.98,
xanchor="right",
x=1,
font=dict(
family="Courier",
size=12,
color="black"
),
bgcolor="AliceBlue",
),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)'
)
# remove white outline from marks
fig.update_traces(marker_line_width = 0)
# fig.update_layout(coloraxis_colorbar=dict(
# tickvals=event_type_tick_vals,
# ticktext=['single
rope', 'double
dutch', 'double
unders', 'single
bounces', 'double
bounces', 'triple
unders', 'other'],
# title='event type'
# ))
# -pi/2 phase offset to make the bottom of the plot the start of the jump
# phase_sin = np.sin(np.arctan2(phase_sin, phase_cos) - np.pi / 2)
# phase_cos = np.cos(np.arctan2(phase_sin, phase_cos) - np.pi / 2)
# plot phase spiral using plotly
fig_phase_spiral = px.scatter(x=phase_cos, y=phase_sin,
color=jumps_per_second,
color_continuous_scale='plasma',
title="Phase Spiral (speed)",
template="plotly_dark")
fig_phase_spiral.update_traces(marker=dict(size=4, opacity=0.5))
fig_phase_spiral.update_layout(
xaxis_title="Phase Cos",
yaxis_title="Phase Sin",
xaxis=dict(range=[-1, 1]),
yaxis=dict(range=[-1, 1]),
showlegend=False,
)
# label colorbar as time
fig_phase_spiral.update_coloraxes(colorbar=dict(
title="Jumps per second"))
# make axes equal
fig_phase_spiral.update_layout(
xaxis=dict(scaleanchor="y"),
yaxis=dict(constrain="domain"),
)
# overlay line plot of phase sin and cos
fig_phase_spiral.add_traces(px.line(x=phase_cos, y=phase_sin).data)
fig_phase_spiral.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)'))
# plot phase consistency (sin^2 + cos^2 = 1) as a line plot
# phase_consistency = phase_sin**2 + phase_cos**2
# #phase_consistency = medfilt(phase_consistency, 5)
# fig_phase = px.line(x=np.linspace(0, 1, len(phase_sin)), y=phase_consistency,
# title="Phase Consistency (sin^2 + cos^2)",
# labels={'x': 'Frame', 'y': 'Phase Consistency'},
# template="plotly_dark")
# plot phase spiral colored by mark_preds
fig_phase_spiral_marks = px.scatter(x=phase_cos, y=phase_sin,
color=full_marks,
color_continuous_scale='Jet',
title="Phase Spiral (marks)",
template="plotly_dark")
fig_phase_spiral_marks.update_traces(marker=dict(size=4, opacity=0.5))
fig_phase_spiral_marks.update_layout(
xaxis_title="Phase Cos",
yaxis_title="Phase Sin",
xaxis=dict(range=[-1, 1]),
yaxis=dict(range=[-1, 1]),
showlegend=False,
)
# label colorbar as time
fig_phase_spiral_marks.update_coloraxes(colorbar=dict(
title="Marks"))
# make axes equal
fig_phase_spiral_marks.update_layout(
xaxis=dict(scaleanchor="y"),
yaxis=dict(constrain="domain"),
)
# overlay line plot of phase sin and cos
fig_phase_spiral_marks.add_traces(px.line(x=phase_cos, y=phase_sin).data)
fig_phase_spiral_marks.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)'))
hist = px.histogram(df,
x="jumps per second",
template="plotly_dark",
marginal="box",
histnorm='percent',
title="Distribution of jumping speed (jumps-per-second)")
# plot the full count and predict a count for 30s, 60s, and 180s if the video is shorter than that
count = np.array(count)
regression_plot = px.scatter(x=np.arange(len(count)), y=count,
color=periodicity,
color_continuous_scale='rainbow',
title="Count Prediction (Perfect Run)",
template="plotly_dark")
regression_plot.update_coloraxes(colorbar=dict(
title="Periodicity"))
regression_plot.update_traces(marker=dict(size=6, opacity=0.5))
regression_plot.update_layout(
xaxis_title="Frame",
yaxis_title="Count",
xaxis=dict(range=[0, len(count)]),
yaxis=dict(range=[0, max(count) * 1.2]),
showlegend=False,
)
# add 30s, 60s, and 180s predictions
pred_count_30s = int(np.median(jumps_per_second[~misses]) * 30)
pred_count_60s = int(np.median(jumps_per_second[~misses]) * 60)
pred_count_180s = int(np.median(jumps_per_second[~misses]) * 180)
# add text to the plot
regression_plot.add_annotation(
x=0.5,
y=0.95,
xref="paper",
yref="paper",
text=f"No-Miss Count (30s): {pred_count_30s}
No-Miss Count (60s): {pred_count_60s}
No-Miss Count (180s): {pred_count_180s}",
showarrow=False,
font=dict(
size=14,
color="white"
),
align="center",
bgcolor="rgba(0, 0, 0, 0.5)",
bordercolor="white",
borderwidth=2,
borderpad=4,
opacity=0.8
)
try:
os.remove('temp.wav')
except FileNotFoundError:
pass
return count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist, regression_plot
#css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}'
with gr.Blocks() as demo:
gr.Markdown(
"""
# NextJump🦘Tournament Judge
### Jump rope competition scoring based on the [NextJump](https://nextjump.app) AI model
Developed by [Dylan Plummer](https://dylan-plummer.github.io/). Examples can be found at the bottom of the page. Please contact us for usage at your event: nextjumpapp@gmail.com
"""
)
with gr.Row():
in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
width=400, height=400, interactive=True, container=True,
max_length=300)
with gr.Row():
with gr.Column():
use_60fps = gr.Checkbox(label="Use 60 FPS", elem_id='use-60fps', visible=True)
model_choice = gr.Dropdown(
["nextjump_speed", "nextjump_all"], label="Model Choice", info="For now just speed-only or general model",
)
with gr.Column():
beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
relay_detection_on = gr.Checkbox(label="Relay Event", elem_id='relay-beeps', visible=True)
relay_length = gr.Textbox(label="Relay Length (s)", elem_id='relay-length', visible=True, value='30')
switch_delay = gr.Textbox(label="Expected Switch Delay (s)", elem_id='event-length', visible=True, value='0.2')
with gr.Row():
run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
count_only = gr.Checkbox(label="Count Only", visible=False)
api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)
with gr.Column(elem_id='output-video-container'):
with gr.Row():
with gr.Column():
out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
with gr.Row():
out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
with gr.Row():
with gr.Column():
out_phase_spiral = gr.Plot(label="Phase Spiral", elem_id='phase-spiral')
with gr.Column():
out_phase = gr.Plot(label="Phase Sin/Cos", elem_id='phase-spiral-marks')
with gr.Row():
with gr.Column():
out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
with gr.Column():
out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist')
demo_inference = partial(inference, count_only_api=False, api_key=None)
run_button.click(demo_inference, [in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist]).then(upload_video, inputs=[out_text, in_video])
api_inference = partial(inference, api_call=True)
api_dummy_button.click(api_inference, [in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token],
outputs=[period_length], api_name='inference')
examples = [
#['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60],
['files/wc2023.mp4', True, 'nextjump_speed', True, 30, False, '30', '0.2'],
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60]
#['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120]
]
gr.Examples(examples,
inputs=[in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay],
outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist],
fn=demo_inference, cache_examples=False)
if __name__ == "__main__":
if LOCAL:
demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
server_port=7860,
debug=False,
ssl_verify=False,
share=False)
else:
demo.queue(api_open=True, max_size=15).launch(share=False)