import gradio as gr import numpy as np from PIL import Image import os import cv2 import math import json import time import subprocess import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import concurrent.futures from scipy.io import wavfile from scipy.signal import medfilt, correlate, find_peaks from functools import partial from passlib.hash import pbkdf2_sha256 from tqdm import tqdm import pandas as pd import plotly.express as px import onnxruntime as ort import torch from torchvision import transforms import torchvision.transforms.functional as F from huggingface_hub import hf_hub_download from huggingface_hub import HfApi from hls_download import download_clips #plt.style.use('dark_background') LOCAL = False IMG_SIZE = 256 CACHE_API_CALLS = False os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True) current_model = 'nextjump_speed' onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET']) api = HfApi() if torch.cuda.is_available(): print("Using CUDA") providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(), "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})] sess_options = ort.SessionOptions() #sess_options.log_severity_level = 0 ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers) else: print("Using CPU") ort_sess = ort.InferenceSession(onnx_file) # warmup inference ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)}) def square_pad_opencv(image): h, w = image.shape[:2] max_wh = max(w, h) hp = int((max_wh - w) / 2) vp = int((max_wh - h) / 2) return cv2.copyMakeBorder(image, vp, vp, hp, hp, cv2.BORDER_CONSTANT, value=[0, 0, 0]) def preprocess_image(img, img_size): #img = square_pad_opencv(img) #img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) img = Image.fromarray(img) transforms_list = [] transforms_list.append(transforms.ToTensor()) preprocess = transforms.Compose(transforms_list) return preprocess(img).unsqueeze(0) def run_inference(batch_X): batch_X = torch.cat(batch_X) return ort_sess.run(None, {'video': batch_X.numpy()}) def sigmoid(x): return 1 / (1 + np.exp(-x)) def detect_beeps(video_path, event_length=30, beep_height=0.8): reference_file = 'beep.WAV' fs, beep = wavfile.read(reference_file) beep = beep[:, 0] + beep[:, 1] # combine stereo to mono video = cv2.VideoCapture(video_path) try: os.remove('temp.wav') except FileNotFoundError: pass audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav' subprocess.call(audio_convert_command, shell=True) length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(video.get(cv2.CAP_PROP_FPS)) audio = wavfile.read('temp.wav')[1] audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono corr = correlate(audio, beep, mode='same') / audio.size # min max scale to -1, 1 corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1 event_start = length while length - event_start < fps * event_length: peaks, _ = find_peaks(corr, height=beep_height, distance=fs) event_start = int(peaks[0] / fs * fps) event_end = int(peaks[-1] / fs * fps) if event_end == event_start: event_end = event_start + fps * event_length beep_height -= 0.1 if beep_height <= 0.1: event_start = 0 event_end = length break #peaks, _ = find_peaks(corr, height=0.7, distance=fs) #event_start = int(peaks[0] / fs * fps) #event_end = int(peaks[-1] / fs * fps) # plt.plot(corr) # plt.plot(peaks, corr[peaks], "x") # plt.savefig('beep.png') # plt.close() return event_start, event_end def detect_relay_beeps(video_path, event_start, relay_length=30, n_jumpers=4, beep_height=0.8): reference_file = 'relay_beep.WAV' fs, beep = wavfile.read(reference_file) beep = beep[:, 0] + beep[:, 1] # combine stereo to mono video = cv2.VideoCapture(video_path) try: os.remove('temp.wav') except FileNotFoundError: pass audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav' subprocess.call(audio_convert_command, shell=True) length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(video.get(cv2.CAP_PROP_FPS)) audio = wavfile.read('temp.wav')[1] audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono corr = correlate(audio, beep, mode='same') / audio.size # min max scale to -1, 1 corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1 # Calculate total event length in frames total_event_length_frames = fps * relay_length * n_jumpers print(event_start, total_event_length_frames) expected_event_end = event_start + total_event_length_frames # Find all significant peaks in the correlation peaks, _ = find_peaks(corr, height=beep_height, distance=fs) # Convert peaks from sample indices to frame indices peak_frames = [int(peak / fs * fps) for peak in peaks] # For debugging plt.plot(corr) plt.plot(peaks, corr[peaks], "x") plt.savefig('beep.png') plt.close() starts = [] ends = [] # Add the event start for the first jumper starts.append(event_start) # Convert event_start back to sample index for comparison event_start_sample = int(event_start * fs / fps) # Find peaks that come after the event start but before the expected end # Convert expected_event_end to sample index expected_event_end_sample = int(expected_event_end * fs / fps) relevant_peaks = [p for p in peaks if event_start_sample < p < expected_event_end_sample] # If we don't have enough peaks, try lowering the threshold if len(relevant_peaks) < n_jumpers - 1: # We need n_jumpers-1 transitions for lower_height in [0.7, 0.6, 0.5, 0.4, 0.3]: peaks, _ = find_peaks(corr, height=lower_height, distance=fs) relevant_peaks = [p for p in peaks if event_start_sample < p < expected_event_end_sample] if len(relevant_peaks) >= n_jumpers - 1: break # If we still don't have enough peaks, we'll need to estimate some transitions relay_length_frames = fps * relay_length # Process peaks to identify jumper transitions if len(relevant_peaks) >= n_jumpers - 1: # Ideal case: we found enough beeps for transitions # Sort peaks by time to ensure correct order relevant_peaks.sort() # Use the first n_jumpers-1 peaks as transition points transition_frames = [int(p / fs * fps) for p in relevant_peaks[:n_jumpers-1]] # Set ends for jumpers based on transition points for i in range(n_jumpers - 1): ends.append(transition_frames[i]) starts.append(transition_frames[i]) # Add end for the last jumper ends.append(expected_event_end) else: # Not enough peaks detected, use expected relay_length to estimate for i in range(n_jumpers): if i == 0: # First jumper starts at event_start (already added to starts) jumper_end = event_start + relay_length_frames ends.append(jumper_end) if i < n_jumpers - 1: starts.append(jumper_end) elif i < n_jumpers - 1: jumper_end = starts[i] + relay_length_frames ends.append(jumper_end) starts.append(jumper_end) else: # Last jumper jumper_end = starts[i] + relay_length_frames ends.append(jumper_end) # Validate and adjust if necessary # Make sure all intervals are close to relay_length for i in range(n_jumpers): interval = ends[i] - starts[i] # If an interval is significantly different from relay_length, adjust it if abs(interval - relay_length_frames) > relay_length_frames * 0.2: # 20% tolerance # Adjust the end time to match expected relay_length ends[i] = starts[i] + relay_length_frames # If not the last jumper, adjust the next start time if i < n_jumpers - 1: starts[i + 1] = ends[i] # Final check: ensure the total length matches expected if ends[-1] != expected_event_end: # Adjust the last end to match the expected total event end ends[-1] = expected_event_end return starts, ends def upload_video(out_text, in_video): if out_text != '': # generate a timestamp name for the video upload_path = f"{int(time.time())}.mp4" api.upload_file( path_or_fileobj=in_video, path_in_repo=upload_path, repo_id="lumos-motion/single-rope-contest", repo_type="dataset", ) def inference(in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=4, miss_threshold=0.8, marks_threshold=0.5, median_pred_filter=True, both_feet=True, api_call=False, progress=gr.Progress()): global current_model if model_choice != current_model: current_model = model_choice onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET']) if torch.cuda.is_available(): print("Using CUDA") providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(), "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})] sess_options = ort.SessionOptions() #sess_options.log_severity_level = 0 ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers) else: print("Using CPU") ort_sess = ort.InferenceSession(onnx_file) # warmup inference ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)}) in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), '00:00:00', '', use_60fps=use_60fps) progress(0, desc="Running inference...") has_access = False if api_call: has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key) if not has_access: return "Invalid API Key" if beep_detection_on: event_length = int(event_length) event_start, event_end = detect_beeps(in_video, event_length) print(event_start, event_end) if relay_detection_on: n_jumpers = int(int(event_length) / int(relay_length)) relay_starts, relay_ends = detect_relay_beeps(in_video, event_start, int(relay_length), n_jumpers) print(relay_starts, relay_ends) cap = cv2.VideoCapture(in_video) length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) period_length_overlaps = np.zeros(length + seq_len) fps = int(cap.get(cv2.CAP_PROP_FPS)) seconds = length / fps all_frames = [] frame_i = 0 resize_amount = max((IMG_SIZE + 64) / frame_width, (IMG_SIZE + 64) / frame_height) while cap.isOpened(): frame_i += 1 ret, frame = cap.read() if ret is False: frame = all_frames[-1] # padding will be with last frame break frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB) # add square padding with opencv #frame = square_pad_opencv(frame) frame_center_x = frame.shape[1] // 2 frame_center_y = frame.shape[0] // 2 frame = cv2.resize(frame, (0, 0), fx=resize_amount, fy=resize_amount, interpolation=cv2.INTER_CUBIC) frame_center_x = frame.shape[1] // 2 frame_center_y = frame.shape[0] // 2 crop_x = frame_center_x - IMG_SIZE // 2 crop_y = frame_center_y - IMG_SIZE // 2 frame = frame[crop_y:crop_y+IMG_SIZE, crop_x:crop_x+IMG_SIZE] all_frames.append(frame) cap.release() length = len(all_frames) period_lengths = np.zeros(len(all_frames) + seq_len + stride_length) period_lengths_rope = np.zeros(len(all_frames) + seq_len + stride_length) periodicities = np.zeros(len(all_frames) + seq_len + stride_length) full_marks = np.zeros(len(all_frames) + seq_len + stride_length) event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7)) phase_sin = np.zeros(len(all_frames) + seq_len + stride_length) phase_cos = np.zeros(len(all_frames) + seq_len + stride_length) period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length) event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7)) for _ in range(seq_len + stride_length): # pad full sequence all_frames.append(all_frames[-1]) batch_list = [] idx_list = [] inference_futures = [] with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor: for i in progress.tqdm(range(0, length + stride_length - stride_pad, stride_length)): batch = all_frames[i:i + seq_len] Xlist = [] preprocess_tasks = [(idx, executor.submit(preprocess_image, img, IMG_SIZE)) for idx, img in enumerate(batch)] for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]): Xlist.append(future.result()) if len(Xlist) < seq_len: for _ in range(seq_len - len(Xlist)): Xlist.append(Xlist[-1]) X = torch.cat(Xlist) X *= 255 batch_list.append(X.unsqueeze(0)) idx_list.append(i) if len(batch_list) == batch_size: future = executor.submit(run_inference, batch_list) inference_futures.append((batch_list, idx_list, future)) batch_list = [] idx_list = [] # Process any remaining batches if batch_list: while len(batch_list) != batch_size: batch_list.append(batch_list[-1]) idx_list.append(idx_list[-1]) future = executor.submit(run_inference, batch_list) inference_futures.append((batch_list, idx_list, future)) progress(0, desc="Processing results...") # Collect and process the inference results for batch_list, idx_list, future in progress.tqdm(inference_futures): outputs = future.result() y1_out = outputs[0] y2_out = outputs[1] y3_out = outputs[2] y4_out = outputs[3] y5_out = outputs[4] y6_out = outputs[5] for y1, y2, y3, y4, y5, y6, idx in zip(y1_out, y2_out, y3_out, y4_out, y5_out, y6_out, idx_list): periodLength = y1 periodicity = y2.squeeze() marks = y3.squeeze() event_type = y4.squeeze() foot_type = y5.squeeze() phase = y6.squeeze() period_lengths[idx:idx+seq_len] += periodLength[:, 0] #period_lengths_rope[idx:idx+seq_len] += periodLength[:, 1] periodicities[idx:idx+seq_len] += periodicity full_marks[idx:idx+seq_len] += marks event_type_logits[idx:idx+seq_len] += event_type phase_sin[idx:idx+seq_len] += phase[:, 1] phase_cos[idx:idx+seq_len] += phase[:, 0] period_length_overlaps[idx:idx+seq_len] += 1 event_type_logit_overlaps[idx:idx+seq_len] += 1 del y1_out, y2_out, y3_out, y4_out # free up memory periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length] #periodLength_rope = np.divide(period_lengths_rope, period_length_overlaps, where=period_length_overlaps!=0)[:length] periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length] full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length] per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length] phase_sin = np.divide(phase_sin, period_length_overlaps, where=period_length_overlaps!=0)[:length] # negate sin to make the bottom of the plot the start of the jump phase_sin = -phase_sin phase_cos = np.divide(phase_cos, period_length_overlaps, where=period_length_overlaps!=0)[:length] event_type_logits = np.mean(per_frame_event_type_logits, axis=0) # softmax of event type logits event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits)) per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1) if median_pred_filter: periodicity = medfilt(periodicity, 5) periodLength = medfilt(periodLength, 5) periodicity = sigmoid(periodicity) full_marks = sigmoid(full_marks) # if the event_start and event_end (in frames) are detected and form a valid event of event_length (in seconds) if beep_detection_on: if event_start > 0 and event_end > 0 and (event_end - event_start) - (event_length * fps) < 0.5: print(f"Event detected: {event_start} - {event_end}") periodicity[:event_start] = 0 periodicity[event_end:] = 0 if relay_detection_on: for start, end in zip(relay_starts, relay_ends): if start > 0 and end > 0: print(f"Relay Event detected: {start} - {end}") # immediately after the beep set periodicity to 0 for switch_delay seconds periodicity[start:start + int(float(switch_delay) * fps)] = 0 pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold) full_marks_mask = np.zeros(len(full_marks)) full_marks_mask[pred_marks_peaks] = 1 periodicity_mask = np.int32(periodicity > miss_threshold) numofReps = 0 count = [] for i in range(len(periodLength)): if periodLength[i] < 2 or periodicity_mask[i] == 0: numofReps += 0 elif full_marks_mask[i]: # high confidence mark detected if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection numofReps = float(int(numofReps)) else: numofReps = float(int(numofReps) + 1.01) # round up else: numofReps += max(0, periodicity_mask[i]/(periodLength[i])) count.append(round(float(numofReps), 2)) count_pred = count[-1] marks_count_pred = 0 for i in range(len(full_marks) - 1): # if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting) if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0: marks_count_pred += 1 if not both_feet: count_pred = count_pred / 2 marks_count_pred = marks_count_pred / 2 count = np.array(count) / 2 try: periodicity_mask = periodicity > miss_threshold if np.sum(periodicity_mask) == 0: confidence = 0 else: confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold) except ZeroDivisionError: confidence = 0 self_err = abs(count_pred - marks_count_pred) try: self_pct_err = self_err / count_pred except ZeroDivisionError: self_pct_err = 0 total_confidence = confidence * (1 - self_pct_err) if LOCAL: if both_feet: count_msg = f"## Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}" else: count_msg = f"## Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}" else: if both_feet: count_msg = f"## Count (both feet): {count_pred:.1f}, Confidence: {total_confidence:.2f}" else: count_msg = f"## Count (one foot): {count_pred:.1f}, Confidence: {total_confidence:.2f}" if api_call: if CACHE_API_CALLS: # write outputs as row of csv with open('api_calls.tsv', 'a') as f: periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '') periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '') full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '') f.write(f"{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n") if count_only_api: return f"{count_pred:.2f} (conf: {total_confidence:.2f})" else: # create a nice json object to return results_dict = { "periodLength": np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "periodicity": np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "full_marks": np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "cum_count": np.array2string(np.array(count), formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "count": f"{count_pred:.2f}", "marks": f"{marks_count_pred:.1f}", "confidence": f"{total_confidence:.2f}", "single_rope_speed": f"{event_type_probs[0]:.3f}", "double_dutch": f"{event_type_probs[1]:.3f}", "double_unders": f"{event_type_probs[2]:.3f}", "single_bounce": f"{event_type_probs[3]:.3f}" } if beep_detection_on: results_dict['event_start'] = event_start results_dict['event_end'] = event_end if relay_detection_on: results_dict['relay_starts'] = relay_starts results_dict['relay_ends'] = relay_ends return json.dumps(results_dict) # fig, axs = plt.subplots(5, 1, figsize=(14, 10)) # Added a plot for count # # Ensure data exists before plotting # axs[0].plot(periodLength, label='Period Length') # axs[0].plot(periodLength_rope, label='Period Length (Rope)') # axs[0].set_title(f"Stream 0 - Period Length") # axs[0].legend() # axs[1].plot(periodicity) # axs[1].set_title("Stream 0 - Periodicity") # axs[1].set_ylim(0, 1) # axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})') # axs[2].plot(full_marks, label='Raw Marks') # marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold) # axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks') # axs[2].set_title("Stream 0 - Marks") # axs[2].set_ylim(0, 1) # axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})') # # plot phase # axs[3].plot(phase_sin, label='Phase Sin') # axs[3].plot(phase_cos, label='Phase Cos') # axs[3].set_title("Stream 0 - Phase") # axs[3].set_ylim(-1, 1) # axs[3].axhline(0, color='r', linestyle=':', label='Zero Line') # axs[3].legend() # axs[4].plot(count) # axs[4].set_title("Stream 0 - Calculated Count") # plt.tight_layout() # plt.savefig('plot.png') # plt.close() jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10) jumping_speed = np.copy(jumps_per_second) misses = periodicity < miss_threshold jumps_per_second[misses] = 0 frame_type = np.array(['miss' if miss else 'frame' for miss in misses]) frame_type[full_marks > marks_threshold] = 'jump' per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6 df = pd.DataFrame.from_dict({'period length': periodLength, 'jumping speed': jumping_speed, 'jumps per second': jumps_per_second, 'periodicity': periodicity, 'phase sin': phase_sin, 'phase cos': phase_cos, 'miss': misses, 'frame_type': frame_type, 'event_type': per_frame_event_types, 'jumps': full_marks, 'jumps_size': (full_marks + 0.05) * 10, 'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8), 'seconds': np.linspace(0, seconds, num=len(periodLength))}) event_type_tick_vals = np.linspace(0, 1, num=7) event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black'] fig = px.scatter(data_frame=df, x='seconds', y='jumps per second', #symbol='frame_type', #symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'}, color='periodicity', size='jumps_size', size_max=8, color_continuous_scale='rainbow', range_color=(0,1), title="Jumping speed (jumps-per-second)", trendline='rolling', trendline_options=dict(window=16), trendline_color_override="goldenrod", trendline_scope='overall', template="plotly_dark") if beep_detection_on: # add vertical lines for beep event fig.add_vrect(x0=event_start / fps, x1=event_end / fps, fillcolor="LightSalmon", opacity=0.25, layer="below", line_width=0) if relay_detection_on: for start, end in zip(relay_starts, relay_ends): start += 10 # add some padding end -= 10 fig.add_vrect(x0=start / fps, x1=end / fps, fillcolor="LightGreen", opacity=0.25, layer="below", line_width=0) fig.update_layout(legend=dict( orientation="h", yanchor="bottom", y=0.98, xanchor="right", x=1, font=dict( family="Courier", size=12, color="black" ), bgcolor="AliceBlue", ), paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)' ) # remove white outline from marks fig.update_traces(marker_line_width = 0) # fig.update_layout(coloraxis_colorbar=dict( # tickvals=event_type_tick_vals, # ticktext=['single
rope', 'double
dutch', 'double
unders', 'single
bounces', 'double
bounces', 'triple
unders', 'other'], # title='event type' # )) # -pi/2 phase offset to make the bottom of the plot the start of the jump # phase_sin = np.sin(np.arctan2(phase_sin, phase_cos) - np.pi / 2) # phase_cos = np.cos(np.arctan2(phase_sin, phase_cos) - np.pi / 2) # plot phase spiral using plotly fig_phase_spiral = px.scatter(x=phase_cos, y=phase_sin, color=jumps_per_second, color_continuous_scale='plasma', title="Phase Spiral (speed)", template="plotly_dark") fig_phase_spiral.update_traces(marker=dict(size=4, opacity=0.5)) fig_phase_spiral.update_layout( xaxis_title="Phase Cos", yaxis_title="Phase Sin", xaxis=dict(range=[-1, 1]), yaxis=dict(range=[-1, 1]), showlegend=False, ) # label colorbar as time fig_phase_spiral.update_coloraxes(colorbar=dict( title="Jumps per second")) # make axes equal fig_phase_spiral.update_layout( xaxis=dict(scaleanchor="y"), yaxis=dict(constrain="domain"), ) # overlay line plot of phase sin and cos fig_phase_spiral.add_traces(px.line(x=phase_cos, y=phase_sin).data) fig_phase_spiral.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)')) # plot phase consistency (sin^2 + cos^2 = 1) as a line plot # phase_consistency = phase_sin**2 + phase_cos**2 # #phase_consistency = medfilt(phase_consistency, 5) # fig_phase = px.line(x=np.linspace(0, 1, len(phase_sin)), y=phase_consistency, # title="Phase Consistency (sin^2 + cos^2)", # labels={'x': 'Frame', 'y': 'Phase Consistency'}, # template="plotly_dark") # plot phase spiral colored by mark_preds fig_phase_spiral_marks = px.scatter(x=phase_cos, y=phase_sin, color=full_marks, color_continuous_scale='Jet', title="Phase Spiral (marks)", template="plotly_dark") fig_phase_spiral_marks.update_traces(marker=dict(size=4, opacity=0.5)) fig_phase_spiral_marks.update_layout( xaxis_title="Phase Cos", yaxis_title="Phase Sin", xaxis=dict(range=[-1, 1]), yaxis=dict(range=[-1, 1]), showlegend=False, ) # label colorbar as time fig_phase_spiral_marks.update_coloraxes(colorbar=dict( title="Marks")) # make axes equal fig_phase_spiral_marks.update_layout( xaxis=dict(scaleanchor="y"), yaxis=dict(constrain="domain"), ) # overlay line plot of phase sin and cos fig_phase_spiral_marks.add_traces(px.line(x=phase_cos, y=phase_sin).data) fig_phase_spiral_marks.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)')) hist = px.histogram(df, x="jumps per second", template="plotly_dark", marginal="box", histnorm='percent', title="Distribution of jumping speed (jumps-per-second)") # plot the full count and predict a count for 30s, 60s, and 180s if the video is shorter than that count = np.array(count) regression_plot = px.scatter(x=np.arange(len(count)), y=count, color=periodicity, color_continuous_scale='rainbow', title="Count Prediction (Perfect Run)", template="plotly_dark") regression_plot.update_coloraxes(colorbar=dict( title="Periodicity")) regression_plot.update_traces(marker=dict(size=6, opacity=0.5)) regression_plot.update_layout( xaxis_title="Frame", yaxis_title="Count", xaxis=dict(range=[0, len(count)]), yaxis=dict(range=[0, max(count) * 1.2]), showlegend=False, ) # add 30s, 60s, and 180s predictions pred_count_30s = int(np.median(jumps_per_second[~misses]) * 30) pred_count_60s = int(np.median(jumps_per_second[~misses]) * 60) pred_count_180s = int(np.median(jumps_per_second[~misses]) * 180) # add text to the plot regression_plot.add_annotation( x=0.5, y=0.95, xref="paper", yref="paper", text=f"No-Miss Count (30s): {pred_count_30s}
No-Miss Count (60s): {pred_count_60s}
No-Miss Count (180s): {pred_count_180s}", showarrow=False, font=dict( size=14, color="white" ), align="center", bgcolor="rgba(0, 0, 0, 0.5)", bordercolor="white", borderwidth=2, borderpad=4, opacity=0.8 ) try: os.remove('temp.wav') except FileNotFoundError: pass return count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist, regression_plot #css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}' with gr.Blocks() as demo: gr.Markdown( """ # NextJump🦘Tournament Judge ### Jump rope competition scoring based on the [NextJump](https://nextjump.app) AI model Developed by [Dylan Plummer](https://dylan-plummer.github.io/). Examples can be found at the bottom of the page. Please contact us for usage at your event: nextjumpapp@gmail.com """ ) with gr.Row(): in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4', width=400, height=400, interactive=True, container=True, max_length=300) with gr.Row(): with gr.Column(): use_60fps = gr.Checkbox(label="Use 60 FPS", elem_id='use-60fps', visible=True) model_choice = gr.Dropdown( ["nextjump_speed", "nextjump_all"], label="Model Choice", info="For now just speed-only or general model", ) with gr.Column(): beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True) event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True) relay_detection_on = gr.Checkbox(label="Relay Event", elem_id='relay-beeps', visible=True) relay_length = gr.Textbox(label="Relay Length (s)", elem_id='relay-length', visible=True, value='30') switch_delay = gr.Textbox(label="Expected Switch Delay (s)", elem_id='event-length', visible=True, value='0.2') with gr.Row(): run_button = gr.Button(value="Run", elem_id='run-button', scale=1) api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2) count_only = gr.Checkbox(label="Count Only", visible=False) api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False) with gr.Column(elem_id='output-video-container'): with gr.Row(): with gr.Column(): out_text = gr.Markdown(label="Predicted Count", elem_id='output-text') period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False) periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False) with gr.Row(): out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot') with gr.Row(): with gr.Column(): out_phase_spiral = gr.Plot(label="Phase Spiral", elem_id='phase-spiral') with gr.Column(): out_phase = gr.Plot(label="Phase Sin/Cos", elem_id='phase-spiral-marks') with gr.Row(): with gr.Column(): out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist') with gr.Column(): out_event_type_dist = gr.Plot(label="Event Type Distribution", elem_id='output-event-type-dist') demo_inference = partial(inference, count_only_api=False, api_key=None) run_button.click(demo_inference, [in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay], outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist]).then(upload_video, inputs=[out_text, in_video]) api_inference = partial(inference, api_call=True) api_dummy_button.click(api_inference, [in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay, count_only, api_token], outputs=[period_length], api_name='inference') examples = [ #['https://hiemdall-dev2.azurewebsites.net/api/clip/clp_vrpWTyjM/mp4', '00:00:00', '00:01:10', True, 60], ['files/wc2023.mp4', True, 'nextjump_speed', True, 30, False, '30', '0.2'], #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_rd2FAyUo/vod', '01:24:22', '01:25:35', True, 60] #['https://hiemdall-dev2.azurewebsites.net/api/playlist/rec_PY5Ukaua/vod, '00:52:53', '00:55:00', True, 120] ] gr.Examples(examples, inputs=[in_video, use_60fps, model_choice, beep_detection_on, event_length, relay_detection_on, relay_length, switch_delay], outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist, out_event_type_dist], fn=demo_inference, cache_examples=False) if __name__ == "__main__": if LOCAL: demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0", server_port=7860, debug=False, ssl_verify=False, share=False) else: demo.queue(api_open=True, max_size=15).launch(share=False)