import gradio as gr import numpy as np from PIL import Image import os import cv2 import math import time import json import subprocess import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import concurrent.futures from scipy.io import wavfile from scipy.signal import medfilt, correlate, find_peaks from functools import partial from passlib.hash import pbkdf2_sha256 from tqdm import tqdm import pandas as pd import plotly.express as px import onnxruntime as ort import torch from torchvision import transforms import torchvision.transforms.functional as F from huggingface_hub import hf_hub_download from huggingface_hub import HfApi from hls_download import download_clips #plt.style.use('dark_background') LOCAL = False IMG_SIZE = 256 CACHE_API_CALLS = False os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True) current_model = 'nextjump_speed' onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET']) #onnx_file = f'{current_model}.onnx' api = HfApi() if torch.cuda.is_available(): print("Using CUDA") providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(), "user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})] sess_options = ort.SessionOptions() #sess_options.log_severity_level = 0 ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers) use_cuda = True else: print("Using CPU") ort_sess = ort.InferenceSession(onnx_file) use_cuda = False # warmup inference ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)}) def square_pad_opencv(image): h, w = image.shape[:2] max_wh = max(w, h) hp = int((max_wh - w) / 2) vp = int((max_wh - h) / 2) return cv2.copyMakeBorder(image, vp, vp, hp, hp, cv2.BORDER_CONSTANT, value=[0, 0, 0]) def preprocess_image(img, img_size): #img = square_pad_opencv(img) #img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC) img = Image.fromarray(img) transforms_list = [] transforms_list.append(transforms.ToTensor()) preprocess = transforms.Compose(transforms_list) return preprocess(img).unsqueeze(0) def run_inference(batch_X): global ort_sess batch_X = np.concatenate(batch_X, axis=0) return ort_sess.run(None, {'video': batch_X}) def sigmoid(x): return 1 / (1 + np.exp(-x)) def detect_beeps(video_path, target_event_length=30, beep_height=0.8): """ Detects beep sounds in a video file and returns frame indices for start and end points. Finds the pair of peaks that are closest to the target event length. Args: video_path: Path to the video file target_event_length: Target duration of the event in seconds beep_height: Initial threshold for peak detection Returns: event_start: Frame index for the start of the event event_end: Frame index for the end of the event """ # Read reference beep reference_file = 'beep.WAV' fs, beep = wavfile.read(reference_file) beep = beep[:, 0] + beep[:, 1] # combine stereo to mono # Open video file video = cv2.VideoCapture(video_path) length = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) fps = int(video.get(cv2.CAP_PROP_FPS)) # Clean up any previous temporary files try: os.remove('temp.wav') except FileNotFoundError: pass # Extract audio from video audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav' print(audio_convert_command) subprocess.call(audio_convert_command, shell=True) # Read the extracted audio _, audio = wavfile.read('temp.wav') audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono # Cross-correlate with the reference beep corr = correlate(audio, beep, mode='same') / audio.size # Min-max scale correlation to [-1, 1] corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1 # Target number of frames for the event target_frames = fps * target_event_length # Strategy: Try different height thresholds to find peaks, # then select the pair closest to the target length best_pair = None best_diff = float('inf') min_height = 0.3 # Minimum threshold to consider height_step = 0.05 # Decrease step # Try different height thresholds current_height = beep_height while current_height >= min_height: peaks, _ = find_peaks(corr, height=current_height, distance=fs//2) if len(peaks) >= 2: # Check all possible pairs of peaks for i in range(len(peaks)): for j in range(i+1, len(peaks)): start_frame = int(peaks[i] / fs * fps) end_frame = int(peaks[j] / fs * fps) duration = end_frame - start_frame # Calculate how close this pair is to the target length diff = abs(duration - target_frames) # Update if this is the best match so far if diff < best_diff: best_diff = diff best_pair = (start_frame, end_frame) if best_diff < 15: # If we found a good pair, break early break # Reduce height threshold and try again current_height -= height_step # If we found a good pair, use it if best_pair: event_start, event_end = best_pair else: # Fallback: use the whole video event_start = 0 event_end = length # Optional visualization (commented out) plt.plot(corr) plt.plot(peaks, corr[peaks], "x") plt.savefig('beep.png') plt.close() return event_start, event_end def upload_video(out_text, in_video): if out_text != '': # generate a timestamp name for the video upload_path = f"{int(time.time())}.mp4" api.upload_file( path_or_fileobj=in_video, path_in_repo=upload_path, repo_id="lumos-motion/single-rope-contest", repo_type="dataset", ) def count_phases(phase_sin, phase_cos, threshold=0.5): """ Count the number of phase transitions in the sine and cosine phases. Args: phase_sin: Numpy array of sine phase values phase_cos: Numpy array of cosine phase values threshold: Threshold to consider a transition Returns: count: Number of phase transitions phase_indices: Indices where transitions occur """ sin_crosses = (phase_sin[:-1] < threshold) != (phase_sin[1:] < threshold) cos_crosses = (phase_cos[:-1] < threshold) != (phase_cos[1:] < threshold) both_cross = sin_crosses & cos_crosses phase_indices = (np.where(both_cross)[0] + 1).tolist() count = len(phase_indices) return count, phase_indices def inference(in_video, use_60fps, beep_detection_on, event_length, miss_threshold, marks_threshold, count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=2, median_pred_filter=True, both_feet=True, api_call=False, progress=gr.Progress()): print(in_video) if in_video is None: return "No video input provided." in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), '00:00:00', '', use_60fps=use_60fps, use_cuda=use_cuda) progress(0, desc="Running inference...") has_access = False if api_call: has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key) if not has_access: return "Invalid API Key" if beep_detection_on: event_length = int(event_length) event_start, event_end = detect_beeps(in_video, event_length) print(event_start, event_end) cap = cv2.VideoCapture(in_video) length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) period_length_overlaps = np.zeros(length + seq_len) fps = int(cap.get(cv2.CAP_PROP_FPS)) seconds = length / fps all_frames = [] frame_i = 0 resize_amount = max((IMG_SIZE + 64) / frame_width, (IMG_SIZE + 64) / frame_height) while cap.isOpened(): frame_i += 1 ret, frame = cap.read() if ret is False: frame = all_frames[-1] # padding will be with last frame break frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # add square padding with opencv #frame = square_pad_opencv(frame) # frame_center_x = frame.shape[1] // 2 # frame_center_y = frame.shape[0] // 2 # frame = cv2.resize(frame, (0, 0), fx=resize_amount, fy=resize_amount, interpolation=cv2.INTER_CUBIC) # frame_center_x = frame.shape[1] // 2 # frame_center_y = frame.shape[0] // 2 # crop_x = frame_center_x - IMG_SIZE // 2 # crop_y = frame_center_y - IMG_SIZE // 2 # frame = frame[crop_y:crop_y+IMG_SIZE, crop_x:crop_x+IMG_SIZE] frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_LINEAR) all_frames.append(frame) cap.release() length = len(all_frames) period_lengths = np.zeros(len(all_frames) + seq_len + stride_length) period_lengths_rope = np.zeros(len(all_frames) + seq_len + stride_length) periodicities = np.zeros(len(all_frames) + seq_len + stride_length) full_marks = np.zeros(len(all_frames) + seq_len + stride_length) event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7)) phase_sin = np.zeros(len(all_frames) + seq_len + stride_length) phase_cos = np.zeros(len(all_frames) + seq_len + stride_length) period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length) event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7)) for _ in range(seq_len + stride_length): # pad full sequence all_frames.append(all_frames[-1]) batch_list = [] idx_list = [] inference_futures = [] with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for i in progress.tqdm(range(0, length + stride_length - stride_pad, stride_length)): batch = all_frames[i:i + seq_len] if len(batch) < seq_len: batch = batch + [batch[-1]] * (seq_len - len(batch)) # Vectorized preprocessing: stack, transpose HWC->CHW, convert to float32 # (replaces per-frame PIL conversion + torchvision ToTensor + X*=255 undo) X = np.ascontiguousarray( np.stack(batch).transpose(0, 3, 1, 2), dtype=np.float32 ) batch_list.append(X[np.newaxis]) # add batch dim: (1, seq_len, 3, H, W) idx_list.append(i) if len(batch_list) == batch_size: future = executor.submit(run_inference, batch_list) inference_futures.append((batch_list, idx_list, future)) batch_list = [] idx_list = [] # Process any remaining batches if batch_list: while len(batch_list) != batch_size: batch_list.append(batch_list[-1]) idx_list.append(idx_list[-1]) future = executor.submit(run_inference, batch_list) inference_futures.append((batch_list, idx_list, future)) progress(0, desc="Processing results...") # Collect and process the inference results for batch_list, idx_list, future in progress.tqdm(tqdm(inference_futures)): outputs = future.result() y1_out = outputs[0] y2_out = outputs[1] y3_out = outputs[2] y4_out = outputs[3] y5_out = outputs[4] try: y6_out = outputs[5] except IndexError: y6_out = np.zeros((len(batch_list), seq_len, 2)) for y1, y2, y3, y4, y5, y6, idx in zip(y1_out, y2_out, y3_out, y4_out, y5_out, y6_out, idx_list): periodLength = y1 periodicity = y2.squeeze() marks = y3.squeeze() event_type = y4.squeeze() foot_type = y5.squeeze() phase = y6.squeeze() period_lengths[idx:idx+seq_len] += periodLength[:, 0] try: period_lengths_rope[idx:idx+seq_len] += periodLength[:, 1] except IndexError: period_lengths_rope[idx:idx+seq_len] += periodLength[:, 0] periodicities[idx:idx+seq_len] += periodicity full_marks[idx:idx+seq_len] += marks event_type_logits[idx:idx+seq_len] += event_type phase_sin[idx:idx+seq_len] += phase[:, 1] phase_cos[idx:idx+seq_len] += phase[:, 0] period_length_overlaps[idx:idx+seq_len] += 1 event_type_logit_overlaps[idx:idx+seq_len] += 1 del y1_out, y2_out, y3_out, y4_out # free up memory periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length] periodLength_rope = np.divide(period_lengths_rope, period_length_overlaps, where=period_length_overlaps!=0)[:length] periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length] full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length] per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length] phase_sin = np.divide(phase_sin, period_length_overlaps, where=period_length_overlaps!=0)[:length] # negate sin to make the bottom of the plot the start of the jump phase_sin = -phase_sin phase_cos = np.divide(phase_cos, period_length_overlaps, where=period_length_overlaps!=0)[:length] event_type_logits = np.mean(per_frame_event_type_logits, axis=0) # softmax of event type logits event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits)) per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1) if median_pred_filter: periodicity = medfilt(periodicity, 5) periodLength = medfilt(periodLength, 5) periodicity = sigmoid(periodicity) full_marks = sigmoid(full_marks) # if the event_start and event_end (in frames) are detected and form a valid event of event_length (in seconds) if beep_detection_on: if event_start > 0 and event_end > 0 and (event_end - event_start) - (event_length * fps) < 0.5: print(f"Event detected: {event_start} - {event_end}") periodicity[:event_start] = 0 periodicity[event_end:] = 0 pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold) full_marks_mask = np.zeros(len(full_marks)) full_marks_mask[pred_marks_peaks] = 1 periodicity_mask = np.int32(periodicity > miss_threshold) phase_count, phase_indices = count_phases(phase_sin, phase_cos, threshold=-0.5) numofReps = 0 count = [] miss_detected = True num_misses = -1 # end of event is not counted as a miss miss_frames = [] for i in range(len(periodLength)): if periodLength[i] < 2 or periodicity_mask[i] == 0: numofReps += 0 if not miss_detected: miss_detected = True num_misses += 1 miss_frames.append(i) #numofReps -= 2 elif full_marks_mask[i]: # high confidence mark detected if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection numofReps = float(int(numofReps)) else: numofReps = float(int(numofReps) + 1.01) # round up miss_detected = False else: numofReps += max(0, periodicity_mask[i]/(periodLength[i])) miss_detected = False count.append(round(float(numofReps), 2)) count_pred = count[-1] marks_count_pred = 0 for i in range(len(full_marks) - 1): # if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting) if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0: marks_count_pred += 1 if not both_feet: count_pred = count_pred / 2 marks_count_pred = marks_count_pred / 2 count = np.array(count) / 2 try: periodicity_mask = periodicity > miss_threshold if np.sum(periodicity_mask) == 0: confidence = 0 else: confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold) except ZeroDivisionError: confidence = 0 self_err = abs(count_pred - marks_count_pred) try: self_pct_err = self_err / count_pred except ZeroDivisionError: self_pct_err = 0 total_confidence = confidence * (1 - self_pct_err) # find the fastest second (30 frames if 30fp and 60 frames if 60fps) based on the period_length scan_window = 60 if use_60fps else 30 fastest_frames_start = 0 fastest_period = float('inf') for i in range(0, len(periodLength) - scan_window, scan_window // 2): #if np.sum(periodicity_mask[i:i + scan_window]) > 0: avg_period = np.mean(periodLength[i:i + scan_window]) if avg_period < fastest_period: fastest_period = avg_period fastest_frames_start = i fastest_frames_end = fastest_frames_start + scan_window fastest_jumps_per_second = np.clip(1 / ((fastest_period / fps) + 0.0001), 0, 10) print(f"Fastest jumps per second: {fastest_jumps_per_second:.2f} (from frames {fastest_frames_start} to {fastest_frames_end})") # measure the reaction time to the beep (if beep detection is on) as the time to reach average speed time_to_speed = 0 if beep_detection_on: avg_speed = np.mean(periodLength[periodicity_mask]) reaction_frame = np.argmax((periodLength < avg_speed) & (periodicity_mask)) print(f"Reaction frame: {reaction_frame}, Avg Speed: {avg_speed}") time_to_speed = (reaction_frame - event_start) / fps # get peak speed and lowest speed peak_speed = np.quantile(periodLength[periodicity_mask], 0.01) if np.any(periodicity_mask) else 0 lowest_speed = np.quantile(periodLength[periodicity_mask], 0.99) if np.any(periodicity_mask) else 0 peak_jps = np.clip(1 / ((peak_speed / fps) + 0.0001), 0, 10) lowest_jps = np.clip(1 / ((lowest_speed / fps) + 0.0001), 0, 10) slowdown = (lowest_jps - peak_jps) slowdown_percent = (slowdown / peak_jps) * 100 if peak_jps > 0 else 0 print('slowdown', slowdown) print('percent', slowdown_percent) # estimate the score assuming no misses and fill in the gaps estimated_score = 0 filled_periodLength = np.zeros(len(periodLength)) started = False for i in range(len(periodLength)): if beep_detection_on and i < event_start: filled_periodLength[i] = 0 elif beep_detection_on and i >= event_end: filled_periodLength[i] = 0 elif periodicity_mask[i] > 0: started = True filled_periodLength[i] = periodLength[i] elif not started: filled_periodLength[i] = 0 else: # fill in the gaps with the previous value filled_periodLength[i] = filled_periodLength[i - 1] estimated_score = 0 for i in range(len(filled_periodLength)): if filled_periodLength[i] < 2: estimated_score += 0 else: estimated_score += max(0, periodicity_mask[i] / (filled_periodLength[i])) print(f"Estimated score: {estimated_score:.2f}") # find the recovery times after each miss recovery_times = [] if len(miss_frames) > 0: avg_speed = np.mean(periodLength[periodicity_mask]) for miss_frame in miss_frames: # find the next frame where the speed is above avg_speed recovery_frame = np.argmax((periodLength[miss_frame:] > avg_speed) & (periodicity_mask[miss_frame:])) + miss_frame if recovery_frame > miss_frame: recovery_time = (recovery_frame - miss_frame) / fps recovery_times.append(recovery_time) else: # end of event pass print(f"Recovery times: {recovery_times}") jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10) jumping_speed = np.copy(jumps_per_second) foot_label = "both feet" if both_feet else "one foot" count_msg = f"## 🏅 Results ({foot_label}: {count_pred:.1f})\n\n" count_msg += f"| Metric | Value |\n|---|---|\n" count_msg += f"| **Count** | {count_pred:.1f} |\n" count_msg += f"| **Confidence** | {total_confidence:.2f} |\n" count_msg += f"| **Fastest Speed** | {fastest_jumps_per_second:.2f} jumps/sec |\n" count_msg += f"| **Average Speed** | {np.mean(jumps_per_second[periodicity_mask]):.2f} jumps/sec |\n" count_msg += f"| **Slowest Speed** | {lowest_jps:.2f} jumps/sec |\n" count_msg += f"| **Slowdown** | {abs(slowdown):.2f} jumps/sec ({abs(slowdown_percent):.1f}%) |\n" if num_misses > 0: count_msg += f"| **Misses** | {num_misses} |\n" if recovery_times: avg_recovery = sum(recovery_times) / len(recovery_times) count_msg += f"| **Avg Recovery Time** | {avg_recovery:.2f}s |\n" if beep_detection_on and time_to_speed > 0: count_msg += f"| **Time to Speed** | {time_to_speed:.2f}s |\n" if api_call: if CACHE_API_CALLS: # write outputs as row of csv with open('api_calls.tsv', 'a') as f: periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '') periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '') full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '') f.write(f"{in_video}\t{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n") if count_only_api: return f"{count_pred:.2f} (conf: {total_confidence:.2f})" else: # create a nice json object to return results_dict = { "periodLength": np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "periodicity": np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "full_marks": np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "cum_count": np.array2string(np.array(count), formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "count": f"{count_pred:.2f}", "marks": f"{marks_count_pred:.1f}", "phase_count": f"{phase_count:.1f}", "confidence": f"{total_confidence:.2f}", "fastest_frames_start": fastest_frames_start, "fastest_frames_end": fastest_frames_end, "fastest_jumps_per_second": f"{fastest_jumps_per_second:.2f}", "lowest_jumps_per_second": f"{lowest_jps:.2f}", "fastest_period_length": f"{fastest_period:.2f}", "lowest_period_length": f"{lowest_speed:.2f}", "time_to_speed": f"{time_to_speed:.2f}" if beep_detection_on else 0, "slowdown": f"{slowdown:.2f}", "slowdown_percent": f"{slowdown_percent:.2f}", "num_misses": num_misses, "miss_frames": np.array2string(np.array(miss_frames[:num_misses]), formatter={'int':lambda x: str(x)}, threshold=np.inf).replace('\n', ''), "recovery_times": np.array2string(np.array(recovery_times), formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''), "no_miss_score": f"{estimated_score:.2f}" if num_misses > 0 else f"{count_pred:.2f}", "single_rope_speed": f"{event_type_probs[0]:.3f}", "double_dutch": f"{event_type_probs[1]:.3f}", "double_unders": f"{event_type_probs[2]:.3f}", "single_bounce": f"{event_type_probs[3]:.3f}" } if beep_detection_on: results_dict['event_start'] = event_start results_dict['event_end'] = event_end return json.dumps(results_dict) # fig, axs = plt.subplots(5, 1, figsize=(14, 10)) # Added a plot for count # # Ensure data exists before plotting # axs[0].plot(periodLength, label='Period Length') # axs[0].plot(periodLength_rope, label='Period Length (Rope)') # axs[0].set_title(f"Stream 0 - Period Length") # axs[0].legend() # axs[1].plot(periodicity) # axs[1].set_title("Stream 0 - Periodicity") # axs[1].set_ylim(0, 1) # axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})') # axs[2].plot(full_marks, label='Raw Marks') # marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold) # axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks') # axs[2].set_title("Stream 0 - Marks") # axs[2].set_ylim(0, 1) # axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})') # # plot phase # axs[3].plot(phase_sin, label='Phase Sin') # axs[3].plot(phase_cos, label='Phase Cos') # axs[3].set_title("Stream 0 - Phase") # axs[3].set_ylim(-1, 1) # axs[3].axhline(0, color='r', linestyle=':', label='Zero Line') # axs[3].legend() # axs[4].plot(count) # axs[4].set_title("Stream 0 - Calculated Count") # plt.tight_layout() # plt.savefig('plot.png') # plt.close() misses = periodicity < miss_threshold jumps_per_second[misses] = 0 frame_type = np.array(['miss' if miss else 'frame' for miss in misses]) frame_type[full_marks > marks_threshold] = 'jump' per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6 df = pd.DataFrame.from_dict({'period length': periodLength, 'jumping speed': jumping_speed, 'jumps per second': jumps_per_second, 'periodicity': periodicity, 'phase sin': phase_sin, 'phase cos': phase_cos, 'miss': misses, 'frame_type': frame_type, 'event_type': per_frame_event_types, 'jumps': full_marks, 'jumps_size': (full_marks + 0.05) * 10, 'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8), 'seconds': np.linspace(0, seconds, num=len(periodLength))}) event_type_tick_vals = np.linspace(0, 1, num=7) event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black'] fig = px.scatter(data_frame=df, x='seconds', y='jumps per second', color='jumping speed', size='jumps_size', size_max=8, color_continuous_scale='Turbo', range_color=(0, 10), title="Jumping speed (jumps-per-second)", trendline='rolling', trendline_options=dict(window=16), trendline_color_override="goldenrod", trendline_scope='overall', template="plotly_dark") if beep_detection_on: # add vertical lines for beep event fig.add_vrect(x0=event_start / fps, x1=event_end / fps, fillcolor="LightSalmon", opacity=0.25, layer="below", line_width=0) fig.update_layout(legend=dict( orientation="h", yanchor="bottom", y=0.98, xanchor="right", x=1, font=dict( family="Courier", size=12, color="black" ), bgcolor="AliceBlue", ), paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)' ) # remove white outline from marks fig.update_traces(marker_line_width = 0) fig.update_layout(coloraxis_colorbar=dict( title='jumps/sec' )) # -pi/2 phase offset to make the bottom of the plot the start of the jump # phase_sin = np.sin(np.arctan2(phase_sin, phase_cos) - np.pi / 2) # phase_cos = np.cos(np.arctan2(phase_sin, phase_cos) - np.pi / 2) # plot phase spiral using plotly phase_jumps = np.zeros(len(phase_sin)) phase_jumps[phase_indices] = 1 fig_phase_spiral = px.scatter(x=phase_cos, y=phase_sin, color=phase_jumps, color_continuous_scale='plasma', title="Phase Spiral (speed)", template="plotly_dark") fig_phase_spiral.update_traces(marker=dict(size=4, opacity=0.5)) fig_phase_spiral.update_layout( xaxis_title="Phase Cos", yaxis_title="Phase Sin", xaxis=dict(range=[-1, 1]), yaxis=dict(range=[-1, 1]), showlegend=False, ) # label colorbar as time fig_phase_spiral.update_coloraxes(colorbar=dict( title="Phase Jumps",)) # make axes equal fig_phase_spiral.update_layout( xaxis=dict(scaleanchor="y"), yaxis=dict(constrain="domain"), ) # overlay line plot of phase sin and cos fig_phase_spiral.add_traces(px.line(x=phase_cos, y=phase_sin).data) fig_phase_spiral.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)')) # plot phase consistency (sin^2 + cos^2 = 1) as a line plot # phase_consistency = phase_sin**2 + phase_cos**2 # #phase_consistency = medfilt(phase_consistency, 5) # fig_phase = px.line(x=np.linspace(0, 1, len(phase_sin)), y=phase_consistency, # title="Phase Consistency (sin^2 + cos^2)", # labels={'x': 'Frame', 'y': 'Phase Consistency'}, # template="plotly_dark") # plot phase spiral colored by mark_preds fig_phase_spiral_marks = px.scatter(x=phase_cos, y=phase_sin, color=full_marks, color_continuous_scale='Jet', title="Phase Spiral (marks)", template="plotly_dark") fig_phase_spiral_marks.update_traces(marker=dict(size=4, opacity=0.5)) fig_phase_spiral_marks.update_layout( xaxis_title="Phase Cos", yaxis_title="Phase Sin", xaxis=dict(range=[-1, 1]), yaxis=dict(range=[-1, 1]), showlegend=False, ) # label colorbar as time fig_phase_spiral_marks.update_coloraxes(colorbar=dict( title="Marks")) # make axes equal fig_phase_spiral_marks.update_layout( xaxis=dict(scaleanchor="y"), yaxis=dict(constrain="domain"), ) # overlay line plot of phase sin and cos fig_phase_spiral_marks.add_traces(px.line(x=phase_cos, y=phase_sin).data) fig_phase_spiral_marks.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)')) hist = px.histogram(df, x="jumps per second", template="plotly_dark", marginal="box", histnorm='percent', title="Distribution of jumping speed (jumps-per-second)") try: os.remove('temp.wav') except FileNotFoundError: pass return count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist #css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}' with gr.Blocks() as demo: with gr.Row(): in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4', width=400, height=400, interactive=True, container=True, max_length=300) with gr.Row(): with gr.Column(): gr.Markdown( """ ### Inference Options Select the framerate and thresholds for inference. Default values should work well for most videos. """, elem_id='inference-options', ) use_60fps = gr.Checkbox(label="Use 60 FPS", elem_id='use-60fps', visible=True) miss_threshold = gr.Slider(label="Periodicity Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.5, elem_id='miss-threshold') marks_threshold = gr.Slider(label="Marks Threshold", minimum=0.0, maximum=1.0, step=0.05, value=0.5, elem_id='marks-threshold') with gr.Column(): gr.Markdown( """ ### Beep Detection Options Must be using official Single Rope Contest timing tracks. """, elem_id='beep-detection-options', ) beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True) event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True, value="30") with gr.Row(): run_button = gr.Button(value="Run", elem_id='run-button', scale=1) api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2) count_only = gr.Checkbox(label="Count Only", visible=False) api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False) with gr.Column(elem_id='output-video-container'): with gr.Row(): with gr.Column(): out_text = gr.Markdown(label="Predicted Count", elem_id='output-text') period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False) periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False) with gr.Row(): out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot') with gr.Row(): with gr.Column(): out_phase_spiral = gr.Plot(label="Phase Spiral", elem_id='phase-spiral') with gr.Column(): out_phase = gr.Plot(label="Phase Sin/Cos", elem_id='phase-spiral-marks') with gr.Row(): with gr.Column(): out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist') demo_inference = partial(inference, count_only_api=False, api_key=None) run_button.click(demo_inference, [in_video, use_60fps, beep_detection_on, event_length, miss_threshold, marks_threshold], outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist]) api_inference = partial(inference, api_call=True) api_dummy_button.click(api_inference, [in_video, use_60fps, beep_detection_on, event_length, count_only, api_token], outputs=[period_length], api_name='inference') examples = [ ['files/wc2023.mp4', False, True, 30, 0.5, 0.5], ] gr.Examples(examples, inputs=[in_video, use_60fps, beep_detection_on, event_length, miss_threshold, marks_threshold], outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist], fn=demo_inference, cache_examples=False) if __name__ == "__main__": if LOCAL: demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0", server_port=7860, debug=False, ssl_verify=False, share=True) else: demo.queue(api_open=True, max_size=15).launch(share=False)