import gradio as gr
import numpy as np
from PIL import Image
import os
import cv2
import math
import time
import json
import subprocess
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import concurrent.futures
from scipy.io import wavfile
from scipy.signal import medfilt, correlate, find_peaks
from functools import partial
from passlib.hash import pbkdf2_sha256
from tqdm import tqdm
import pandas as pd
import plotly.express as px
import onnxruntime as ort
import torch
from torchvision import transforms
import torchvision.transforms.functional as F
from huggingface_hub import hf_hub_download
from huggingface_hub import HfApi
from hls_download import download_clips
#plt.style.use('dark_background')
LOCAL = False
IMG_SIZE = 256
CACHE_API_CALLS = False
os.makedirs(os.path.join(os.getcwd(), 'clips'), exist_ok=True)
current_model = 'nextjump_speed'
onnx_file = hf_hub_download(repo_id="lumos-motion/nextjump", filename=f"{current_model}.onnx", repo_type="model", token=os.environ['DATASET_SECRET'])
#onnx_file = f'{current_model}.onnx'
api = HfApi()
if torch.cuda.is_available():
print("Using CUDA")
providers = [("CUDAExecutionProvider", {"device_id": torch.cuda.current_device(),
"user_compute_stream": str(torch.cuda.current_stream().cuda_stream)})]
sess_options = ort.SessionOptions()
#sess_options.log_severity_level = 0
ort_sess = ort.InferenceSession(onnx_file, sess_options=sess_options, providers=providers)
use_cuda = True
else:
print("Using CPU")
ort_sess = ort.InferenceSession(onnx_file)
use_cuda = False
# warmup inference
ort_sess.run(None, {'video': np.zeros((4, 64, 3, IMG_SIZE, IMG_SIZE), dtype=np.float32)})
def square_pad_opencv(image):
h, w = image.shape[:2]
max_wh = max(w, h)
hp = int((max_wh - w) / 2)
vp = int((max_wh - h) / 2)
return cv2.copyMakeBorder(image, vp, vp, hp, hp, cv2.BORDER_CONSTANT, value=[0, 0, 0])
def preprocess_image(img, img_size):
#img = square_pad_opencv(img)
#img = cv2.resize(img, (img_size, img_size), interpolation=cv2.INTER_CUBIC)
img = Image.fromarray(img)
transforms_list = []
transforms_list.append(transforms.ToTensor())
preprocess = transforms.Compose(transforms_list)
return preprocess(img).unsqueeze(0)
def run_inference(batch_X):
global ort_sess
batch_X = torch.cat(batch_X)
return ort_sess.run(None, {'video': batch_X.numpy()})
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def detect_beeps(video_path, target_event_length=30, beep_height=0.8):
"""
Detects beep sounds in a video file and returns frame indices for start and end points.
Finds the pair of peaks that are closest to the target event length.
Args:
video_path: Path to the video file
target_event_length: Target duration of the event in seconds
beep_height: Initial threshold for peak detection
Returns:
event_start: Frame index for the start of the event
event_end: Frame index for the end of the event
"""
# Read reference beep
reference_file = 'beep.WAV'
fs, beep = wavfile.read(reference_file)
beep = beep[:, 0] + beep[:, 1] # combine stereo to mono
# Open video file
video = cv2.VideoCapture(video_path)
length = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
fps = int(video.get(cv2.CAP_PROP_FPS))
# Clean up any previous temporary files
try:
os.remove('temp.wav')
except FileNotFoundError:
pass
# Extract audio from video
audio_convert_command = f'ffmpeg -i {video_path} -vn -acodec pcm_s16le -ar {fs} -ac 2 temp.wav'
print(audio_convert_command)
subprocess.call(audio_convert_command, shell=True)
# Read the extracted audio
_, audio = wavfile.read('temp.wav')
audio = (audio[:, 0] + audio[:, 1]) / 2 # combine stereo to mono
# Cross-correlate with the reference beep
corr = correlate(audio, beep, mode='same') / audio.size
# Min-max scale correlation to [-1, 1]
corr = 2 * (corr - np.min(corr)) / (np.max(corr) - np.min(corr)) - 1
# Target number of frames for the event
target_frames = fps * target_event_length
# Strategy: Try different height thresholds to find peaks,
# then select the pair closest to the target length
best_pair = None
best_diff = float('inf')
min_height = 0.3 # Minimum threshold to consider
height_step = 0.05 # Decrease step
# Try different height thresholds
current_height = beep_height
while current_height >= min_height:
peaks, _ = find_peaks(corr, height=current_height, distance=fs//2)
if len(peaks) >= 2:
# Check all possible pairs of peaks
for i in range(len(peaks)):
for j in range(i+1, len(peaks)):
start_frame = int(peaks[i] / fs * fps)
end_frame = int(peaks[j] / fs * fps)
duration = end_frame - start_frame
# Calculate how close this pair is to the target length
diff = abs(duration - target_frames)
# Update if this is the best match so far
if diff < best_diff:
best_diff = diff
best_pair = (start_frame, end_frame)
if best_diff < 15: # If we found a good pair, break early
break
# Reduce height threshold and try again
current_height -= height_step
# If we found a good pair, use it
if best_pair:
event_start, event_end = best_pair
else:
# Fallback: use the whole video
event_start = 0
event_end = length
# Optional visualization (commented out)
plt.plot(corr)
plt.plot(peaks, corr[peaks], "x")
plt.savefig('beep.png')
plt.close()
return event_start, event_end
def upload_video(out_text, in_video):
if out_text != '':
# generate a timestamp name for the video
upload_path = f"{int(time.time())}.mp4"
api.upload_file(
path_or_fileobj=in_video,
path_in_repo=upload_path,
repo_id="lumos-motion/single-rope-contest",
repo_type="dataset",
)
def count_phases(phase_sin, phase_cos, threshold=0.5):
"""
Count the number of phase transitions in the sine and cosine phases.
Args:
phase_sin: Numpy array of sine phase values
phase_cos: Numpy array of cosine phase values
threshold: Threshold to consider a transition
Returns:
count: Number of phase transitions
phase_indices: Indices where transitions occur
"""
phase_indices = []
count = 0
for i in range(1, len(phase_sin)):
# Check if the sine and cosine phases cross each other
if (phase_sin[i-1] < threshold and phase_sin[i] >= threshold) or \
(phase_sin[i-1] >= threshold and phase_sin[i] < threshold):
# Check if the cosine phase crosses the threshold
if (phase_cos[i-1] < threshold and phase_cos[i] >= threshold) or \
(phase_cos[i-1] >= threshold and phase_cos[i] < threshold):
phase_indices.append(i)
count += 1
return count, phase_indices
def inference(in_video, use_60fps,
beep_detection_on, event_length,
count_only_api, api_key, seq_len=64, stride_length=32, stride_pad=3, batch_size=2,
miss_threshold=0.5, marks_threshold=0.5, median_pred_filter=True, both_feet=True,
api_call=False,
progress=gr.Progress()):
print(in_video)
if in_video is None:
return "No video input provided."
in_video = download_clips(in_video, os.path.join(os.getcwd(), 'clips'), '00:00:00', '', use_60fps=use_60fps, use_cuda=use_cuda)
progress(0, desc="Running inference...")
has_access = False
if api_call:
has_access = pbkdf2_sha256.verify(os.environ['DEV_API_TOKEN'], api_key)
if not has_access:
return "Invalid API Key"
if beep_detection_on:
event_length = int(event_length)
event_start, event_end = detect_beeps(in_video, event_length)
print(event_start, event_end)
cap = cv2.VideoCapture(in_video)
length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
period_length_overlaps = np.zeros(length + seq_len)
fps = int(cap.get(cv2.CAP_PROP_FPS))
seconds = length / fps
all_frames = []
frame_i = 0
resize_amount = max((IMG_SIZE + 64) / frame_width, (IMG_SIZE + 64) / frame_height)
while cap.isOpened():
frame_i += 1
ret, frame = cap.read()
if ret is False:
frame = all_frames[-1] # padding will be with last frame
break
frame = cv2.cvtColor(np.uint8(frame), cv2.COLOR_BGR2RGB)
# add square padding with opencv
#frame = square_pad_opencv(frame)
# frame_center_x = frame.shape[1] // 2
# frame_center_y = frame.shape[0] // 2
# frame = cv2.resize(frame, (0, 0), fx=resize_amount, fy=resize_amount, interpolation=cv2.INTER_CUBIC)
# frame_center_x = frame.shape[1] // 2
# frame_center_y = frame.shape[0] // 2
# crop_x = frame_center_x - IMG_SIZE // 2
# crop_y = frame_center_y - IMG_SIZE // 2
# frame = frame[crop_y:crop_y+IMG_SIZE, crop_x:crop_x+IMG_SIZE]
frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_CUBIC)
all_frames.append(frame)
cap.release()
length = len(all_frames)
period_lengths = np.zeros(len(all_frames) + seq_len + stride_length)
period_lengths_rope = np.zeros(len(all_frames) + seq_len + stride_length)
periodicities = np.zeros(len(all_frames) + seq_len + stride_length)
full_marks = np.zeros(len(all_frames) + seq_len + stride_length)
event_type_logits = np.zeros((len(all_frames) + seq_len + stride_length, 7))
phase_sin = np.zeros(len(all_frames) + seq_len + stride_length)
phase_cos = np.zeros(len(all_frames) + seq_len + stride_length)
period_length_overlaps = np.zeros(len(all_frames) + seq_len + stride_length)
event_type_logit_overlaps = np.zeros((len(all_frames) + seq_len + stride_length, 7))
for _ in range(seq_len + stride_length): # pad full sequence
all_frames.append(all_frames[-1])
batch_list = []
idx_list = []
inference_futures = []
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
for i in progress.tqdm(range(0, length + stride_length - stride_pad, stride_length)):
batch = all_frames[i:i + seq_len]
Xlist = []
preprocess_tasks = [(idx, executor.submit(preprocess_image, img, IMG_SIZE)) for idx, img in enumerate(batch)]
for idx, future in sorted(preprocess_tasks, key=lambda x: x[0]):
Xlist.append(future.result())
if len(Xlist) < seq_len:
for _ in range(seq_len - len(Xlist)):
Xlist.append(Xlist[-1])
X = torch.cat(Xlist)
X *= 255
batch_list.append(X.unsqueeze(0))
idx_list.append(i)
if len(batch_list) == batch_size:
future = executor.submit(run_inference, batch_list)
inference_futures.append((batch_list, idx_list, future))
batch_list = []
idx_list = []
# Process any remaining batches
if batch_list:
while len(batch_list) != batch_size:
batch_list.append(batch_list[-1])
idx_list.append(idx_list[-1])
future = executor.submit(run_inference, batch_list)
inference_futures.append((batch_list, idx_list, future))
progress(0, desc="Processing results...")
# Collect and process the inference results
for batch_list, idx_list, future in progress.tqdm(tqdm(inference_futures)):
outputs = future.result()
y1_out = outputs[0]
y2_out = outputs[1]
y3_out = outputs[2]
y4_out = outputs[3]
y5_out = outputs[4]
try:
y6_out = outputs[5]
except IndexError:
y6_out = np.zeros((len(batch_list), seq_len, 2))
for y1, y2, y3, y4, y5, y6, idx in zip(y1_out, y2_out, y3_out, y4_out, y5_out, y6_out, idx_list):
periodLength = y1
periodicity = y2.squeeze()
marks = y3.squeeze()
event_type = y4.squeeze()
foot_type = y5.squeeze()
phase = y6.squeeze()
period_lengths[idx:idx+seq_len] += periodLength[:, 0]
try:
period_lengths_rope[idx:idx+seq_len] += periodLength[:, 1]
except IndexError:
period_lengths_rope[idx:idx+seq_len] += periodLength[:, 0]
periodicities[idx:idx+seq_len] += periodicity
full_marks[idx:idx+seq_len] += marks
event_type_logits[idx:idx+seq_len] += event_type
phase_sin[idx:idx+seq_len] += phase[:, 1]
phase_cos[idx:idx+seq_len] += phase[:, 0]
period_length_overlaps[idx:idx+seq_len] += 1
event_type_logit_overlaps[idx:idx+seq_len] += 1
del y1_out, y2_out, y3_out, y4_out # free up memory
periodLength = np.divide(period_lengths, period_length_overlaps, where=period_length_overlaps!=0)[:length]
periodLength_rope = np.divide(period_lengths_rope, period_length_overlaps, where=period_length_overlaps!=0)[:length]
periodicity = np.divide(periodicities, period_length_overlaps, where=period_length_overlaps!=0)[:length]
full_marks = np.divide(full_marks, period_length_overlaps, where=period_length_overlaps!=0)[:length]
per_frame_event_type_logits = np.divide(event_type_logits, event_type_logit_overlaps, where=event_type_logit_overlaps!=0)[:length]
phase_sin = np.divide(phase_sin, period_length_overlaps, where=period_length_overlaps!=0)[:length]
# negate sin to make the bottom of the plot the start of the jump
phase_sin = -phase_sin
phase_cos = np.divide(phase_cos, period_length_overlaps, where=period_length_overlaps!=0)[:length]
event_type_logits = np.mean(per_frame_event_type_logits, axis=0)
# softmax of event type logits
event_type_probs = np.exp(event_type_logits) / np.sum(np.exp(event_type_logits))
per_frame_event_types = np.argmax(per_frame_event_type_logits, axis=1)
if median_pred_filter:
periodicity = medfilt(periodicity, 5)
periodLength = medfilt(periodLength, 5)
periodicity = sigmoid(periodicity)
full_marks = sigmoid(full_marks)
# if the event_start and event_end (in frames) are detected and form a valid event of event_length (in seconds)
if beep_detection_on:
if event_start > 0 and event_end > 0 and (event_end - event_start) - (event_length * fps) < 0.5:
print(f"Event detected: {event_start} - {event_end}")
periodicity[:event_start] = 0
periodicity[event_end:] = 0
pred_marks_peaks, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
full_marks_mask = np.zeros(len(full_marks))
full_marks_mask[pred_marks_peaks] = 1
periodicity_mask = np.int32(periodicity > miss_threshold)
phase_count, phase_indices = count_phases(phase_sin, phase_cos, threshold=-0.5)
numofReps = 0
count = []
miss_detected = True
num_misses = -1 # end of event is not counted as a miss
miss_frames = []
for i in range(len(periodLength)):
if periodLength[i] < 2 or periodicity_mask[i] == 0:
numofReps += 0
if not miss_detected:
miss_detected = True
num_misses += 1
miss_frames.append(i)
#numofReps -= 2
elif full_marks_mask[i]: # high confidence mark detected
if math.modf(numofReps)[0] < 0.2: # probably false positive/late detection
numofReps = float(int(numofReps))
else:
numofReps = float(int(numofReps) + 1.01) # round up
miss_detected = False
else:
numofReps += max(0, periodicity_mask[i]/(periodLength[i]))
miss_detected = False
count.append(round(float(numofReps), 2))
count_pred = count[-1]
marks_count_pred = 0
for i in range(len(full_marks) - 1):
# if a jump was counted, and periodicity is high, and the next frame was not counted (to avoid double counting)
if full_marks_mask[i] > 0 and periodicity_mask[i] > 0 and full_marks_mask[i + 1] == 0:
marks_count_pred += 1
if not both_feet:
count_pred = count_pred / 2
marks_count_pred = marks_count_pred / 2
count = np.array(count) / 2
try:
periodicity_mask = periodicity > miss_threshold
if np.sum(periodicity_mask) == 0:
confidence = 0
else:
confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
except ZeroDivisionError:
confidence = 0
self_err = abs(count_pred - marks_count_pred)
try:
self_pct_err = self_err / count_pred
except ZeroDivisionError:
self_pct_err = 0
total_confidence = confidence * (1 - self_pct_err)
# find the fastest second (30 frames if 30fp and 60 frames if 60fps) based on the period_length
scan_window = 60 if use_60fps else 30
fastest_frames_start = 0
fastest_period = float('inf')
for i in range(0, len(periodLength) - scan_window, scan_window // 2):
#if np.sum(periodicity_mask[i:i + scan_window]) > 0:
avg_period = np.mean(periodLength[i:i + scan_window])
if avg_period < fastest_period:
fastest_period = avg_period
fastest_frames_start = i
fastest_frames_end = fastest_frames_start + scan_window
fastest_jumps_per_second = np.clip(1 / ((fastest_period / fps) + 0.0001), 0, 10)
print(f"Fastest jumps per second: {fastest_jumps_per_second:.2f} (from frames {fastest_frames_start} to {fastest_frames_end})")
# measure the reaction time to the beep (if beep detection is on) as the time to reach average speed
time_to_speed = 0
if beep_detection_on:
avg_speed = np.mean(periodLength[periodicity_mask])
reaction_frame = np.argmax((periodLength < avg_speed) & (periodicity_mask))
print(f"Reaction frame: {reaction_frame}, Avg Speed: {avg_speed}")
time_to_speed = (reaction_frame - event_start) / fps
# get peak speed and lowest speed
peak_speed = np.quantile(periodLength[periodicity_mask], 0.01) if np.any(periodicity_mask) else 0
lowest_speed = np.quantile(periodLength[periodicity_mask], 0.99) if np.any(periodicity_mask) else 0
peak_jps = np.clip(1 / ((peak_speed / fps) + 0.0001), 0, 10)
lowest_jps = np.clip(1 / ((lowest_speed / fps) + 0.0001), 0, 10)
slowdown = (lowest_jps - peak_jps)
slowdown_percent = (slowdown / peak_jps) * 100 if peak_jps > 0 else 0
print('slowdown', slowdown)
print('percent', slowdown_percent)
# estimate the score assuming no misses and fill in the gaps
estimated_score = 0
filled_periodLength = np.zeros(len(periodLength))
started = False
for i in range(len(periodLength)):
if beep_detection_on and i < event_start:
filled_periodLength[i] = 0
elif beep_detection_on and i >= event_end:
filled_periodLength[i] = 0
elif periodicity_mask[i] > 0:
started = True
filled_periodLength[i] = periodLength[i]
elif not started:
filled_periodLength[i] = 0
else:
# fill in the gaps with the previous value
filled_periodLength[i] = filled_periodLength[i - 1]
estimated_score = 0
for i in range(len(filled_periodLength)):
if filled_periodLength[i] < 2:
estimated_score += 0
else:
estimated_score += max(0, periodicity_mask[i] / (filled_periodLength[i]))
print(f"Estimated score: {estimated_score:.2f}")
# find the recovery times after each miss
recovery_times = []
if len(miss_frames) > 0:
avg_speed = np.mean(periodLength[periodicity_mask])
for miss_frame in miss_frames:
# find the next frame where the speed is above avg_speed
recovery_frame = np.argmax((periodLength[miss_frame:] > avg_speed) & (periodicity_mask[miss_frame:])) + miss_frame
if recovery_frame > miss_frame:
recovery_time = (recovery_frame - miss_frame) / fps
recovery_times.append(recovery_time)
else: # end of event
pass
print(f"Recovery times: {recovery_times}")
if LOCAL:
if both_feet:
count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks: {marks_count_pred:.1f}, Phase: {phase_count:.1f}, Confidence: {total_confidence:.2f}, Time to Speed: {time_to_speed:.2f} seconds"
else:
count_msg = f"## Reps Count (one foot): {count_pred:.1f}, Marks: {marks_count_pred:.1f}, Phase: {phase_count:.1f}, Confidence: {total_confidence:.2f}, Time to Speed: {time_to_speed:.2f} seconds"
else:
if both_feet:
count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Confidence: {total_confidence:.2f}"
else:
count_msg = f"## Reps Count (one foot): {count_pred:.1f}, Confidence: {total_confidence:.2f}"
if api_call:
if CACHE_API_CALLS:
# write outputs as row of csv
with open('api_calls.tsv', 'a') as f:
periodicity_str = np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
periodLength_str = np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
full_marks_str = np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', '')
f.write(f"{in_video}\t{beep_detection_on}\t{event_length}\t{periodicity_str}\t{periodLength_str}\t{full_marks_str}\t{count_pred}\t{total_confidence}\n")
if count_only_api:
return f"{count_pred:.2f} (conf: {total_confidence:.2f})"
else:
# create a nice json object to return
results_dict = {
"periodLength": np.array2string(periodLength, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"periodicity": np.array2string(periodicity, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"full_marks": np.array2string(full_marks, formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"cum_count": np.array2string(np.array(count), formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"count": f"{count_pred:.2f}",
"marks": f"{marks_count_pred:.1f}",
"phase_count": f"{phase_count:.1f}",
"confidence": f"{total_confidence:.2f}",
"fastest_frames_start": fastest_frames_start,
"fastest_frames_end": fastest_frames_end,
"fastest_jumps_per_second": f"{fastest_jumps_per_second:.2f}",
"lowest_jumps_per_second": f"{lowest_jps:.2f}",
"fastest_period_length": f"{fastest_period:.2f}",
"lowest_period_length": f"{lowest_speed:.2f}",
"time_to_speed": f"{time_to_speed:.2f}" if beep_detection_on else 0,
"slowdown": f"{slowdown:.2f}",
"slowdown_percent": f"{slowdown_percent:.2f}",
"num_misses": num_misses,
"miss_frames": np.array2string(np.array(miss_frames[:num_misses]), formatter={'int':lambda x: str(x)}, threshold=np.inf).replace('\n', ''),
"recovery_times": np.array2string(np.array(recovery_times), formatter={'float_kind':lambda x: "%.2f" % x}, threshold=np.inf).replace('\n', ''),
"no_miss_score": f"{estimated_score:.2f}" if num_misses > 0 else f"{count_pred:.2f}",
"single_rope_speed": f"{event_type_probs[0]:.3f}",
"double_dutch": f"{event_type_probs[1]:.3f}",
"double_unders": f"{event_type_probs[2]:.3f}",
"single_bounce": f"{event_type_probs[3]:.3f}"
}
if beep_detection_on:
results_dict['event_start'] = event_start
results_dict['event_end'] = event_end
return json.dumps(results_dict)
# fig, axs = plt.subplots(5, 1, figsize=(14, 10)) # Added a plot for count
# # Ensure data exists before plotting
# axs[0].plot(periodLength, label='Period Length')
# axs[0].plot(periodLength_rope, label='Period Length (Rope)')
# axs[0].set_title(f"Stream 0 - Period Length")
# axs[0].legend()
# axs[1].plot(periodicity)
# axs[1].set_title("Stream 0 - Periodicity")
# axs[1].set_ylim(0, 1)
# axs[1].axhline(miss_threshold, color='r', linestyle=':', label=f'Miss Thresh ({miss_threshold})')
# axs[2].plot(full_marks, label='Raw Marks')
# marks_peaks_vis, _ = find_peaks(full_marks, distance=3, height=marks_threshold)
# axs[2].plot(marks_peaks_vis, np.array(full_marks)[marks_peaks_vis], "x", label='Detected Peaks')
# axs[2].set_title("Stream 0 - Marks")
# axs[2].set_ylim(0, 1)
# axs[2].axhline(marks_threshold, color='r', linestyle=':', label=f'Mark Thresh ({marks_threshold})')
# # plot phase
# axs[3].plot(phase_sin, label='Phase Sin')
# axs[3].plot(phase_cos, label='Phase Cos')
# axs[3].set_title("Stream 0 - Phase")
# axs[3].set_ylim(-1, 1)
# axs[3].axhline(0, color='r', linestyle=':', label='Zero Line')
# axs[3].legend()
# axs[4].plot(count)
# axs[4].set_title("Stream 0 - Calculated Count")
# plt.tight_layout()
# plt.savefig('plot.png')
# plt.close()
jumps_per_second = np.clip(1 / ((periodLength / fps) + 0.0001), 0, 10)
jumping_speed = np.copy(jumps_per_second)
misses = periodicity < miss_threshold
jumps_per_second[misses] = 0
frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
frame_type[full_marks > marks_threshold] = 'jump'
per_frame_event_types = np.clip(per_frame_event_types, 0, 6) / 6
df = pd.DataFrame.from_dict({'period length': periodLength,
'jumping speed': jumping_speed,
'jumps per second': jumps_per_second,
'periodicity': periodicity,
'phase sin': phase_sin,
'phase cos': phase_cos,
'miss': misses,
'frame_type': frame_type,
'event_type': per_frame_event_types,
'jumps': full_marks,
'jumps_size': (full_marks + 0.05) * 10,
'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 8),
'seconds': np.linspace(0, seconds, num=len(periodLength))})
event_type_tick_vals = np.linspace(0, 1, num=7)
event_type_colors = ['red', 'orange', 'green', 'blue', 'purple', 'pink', 'black']
fig = px.scatter(data_frame=df,
x='seconds',
y='jumps per second',
#symbol='frame_type',
#symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
color='event_type',
size='jumps_size',
size_max=8,
color_continuous_scale=[(t, c) for t, c in zip(event_type_tick_vals, event_type_colors)],
range_color=(0,1),
title="Jumping speed (jumps-per-second)",
trendline='rolling',
trendline_options=dict(window=16),
trendline_color_override="goldenrod",
trendline_scope='overall',
template="plotly_dark")
if beep_detection_on:
# add vertical lines for beep event
fig.add_vrect(x0=event_start / fps, x1=event_end / fps, fillcolor="LightSalmon", opacity=0.25, layer="below", line_width=0)
fig.update_layout(legend=dict(
orientation="h",
yanchor="bottom",
y=0.98,
xanchor="right",
x=1,
font=dict(
family="Courier",
size=12,
color="black"
),
bgcolor="AliceBlue",
),
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)'
)
# remove white outline from marks
fig.update_traces(marker_line_width = 0)
fig.update_layout(coloraxis_colorbar=dict(
tickvals=event_type_tick_vals,
ticktext=['single
rope', 'double
dutch', 'double
unders', 'single
bounces', 'double
bounces', 'triple
unders', 'other'],
title='event type'
))
# -pi/2 phase offset to make the bottom of the plot the start of the jump
# phase_sin = np.sin(np.arctan2(phase_sin, phase_cos) - np.pi / 2)
# phase_cos = np.cos(np.arctan2(phase_sin, phase_cos) - np.pi / 2)
# plot phase spiral using plotly
phase_jumps = np.zeros(len(phase_sin))
phase_jumps[phase_indices] = 1
fig_phase_spiral = px.scatter(x=phase_cos, y=phase_sin,
color=phase_jumps,
color_continuous_scale='plasma',
title="Phase Spiral (speed)",
template="plotly_dark")
fig_phase_spiral.update_traces(marker=dict(size=4, opacity=0.5))
fig_phase_spiral.update_layout(
xaxis_title="Phase Cos",
yaxis_title="Phase Sin",
xaxis=dict(range=[-1, 1]),
yaxis=dict(range=[-1, 1]),
showlegend=False,
)
# label colorbar as time
fig_phase_spiral.update_coloraxes(colorbar=dict(
title="Phase Jumps",))
# make axes equal
fig_phase_spiral.update_layout(
xaxis=dict(scaleanchor="y"),
yaxis=dict(constrain="domain"),
)
# overlay line plot of phase sin and cos
fig_phase_spiral.add_traces(px.line(x=phase_cos, y=phase_sin).data)
fig_phase_spiral.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)'))
# plot phase consistency (sin^2 + cos^2 = 1) as a line plot
# phase_consistency = phase_sin**2 + phase_cos**2
# #phase_consistency = medfilt(phase_consistency, 5)
# fig_phase = px.line(x=np.linspace(0, 1, len(phase_sin)), y=phase_consistency,
# title="Phase Consistency (sin^2 + cos^2)",
# labels={'x': 'Frame', 'y': 'Phase Consistency'},
# template="plotly_dark")
# plot phase spiral colored by mark_preds
fig_phase_spiral_marks = px.scatter(x=phase_cos, y=phase_sin,
color=full_marks,
color_continuous_scale='Jet',
title="Phase Spiral (marks)",
template="plotly_dark")
fig_phase_spiral_marks.update_traces(marker=dict(size=4, opacity=0.5))
fig_phase_spiral_marks.update_layout(
xaxis_title="Phase Cos",
yaxis_title="Phase Sin",
xaxis=dict(range=[-1, 1]),
yaxis=dict(range=[-1, 1]),
showlegend=False,
)
# label colorbar as time
fig_phase_spiral_marks.update_coloraxes(colorbar=dict(
title="Marks"))
# make axes equal
fig_phase_spiral_marks.update_layout(
xaxis=dict(scaleanchor="y"),
yaxis=dict(constrain="domain"),
)
# overlay line plot of phase sin and cos
fig_phase_spiral_marks.add_traces(px.line(x=phase_cos, y=phase_sin).data)
fig_phase_spiral_marks.update_traces(line=dict(width=0.5, color='rgba(255, 255, 255, 0.25)'))
hist = px.histogram(df,
x="jumps per second",
template="plotly_dark",
marginal="box",
histnorm='percent',
title="Distribution of jumping speed (jumps-per-second)")
try:
os.remove('temp.wav')
except FileNotFoundError:
pass
return count_msg, fig, fig_phase_spiral, fig_phase_spiral_marks, hist
#css = '#phase-spiral {transform: rotate(0.25turn);}\n#phase-spiral-marks {transform: rotate(0.25turn);}'
with gr.Blocks() as demo:
with gr.Row():
in_video = gr.PlayableVideo(label="Input Video", elem_id='input-video', format='mp4',
width=400, height=400, interactive=True, container=True,
max_length=300)
with gr.Row():
with gr.Column():
gr.Markdown(
"""
### Inference Options
Select the framerate for inference.
""",
elem_id='inference-options',
)
use_60fps = gr.Checkbox(label="Use 60 FPS", elem_id='use-60fps', visible=True)
with gr.Column():
gr.Markdown(
"""
### Beep Detection Options
Must be using official IJRU timing tracks.
""",
elem_id='beep-detection-options',
)
beep_detection_on = gr.Checkbox(label="Detect Beeps", elem_id='detect-beeps', visible=True)
event_length = gr.Textbox(label="Expected Event Length (s)", elem_id='event-length', visible=True)
with gr.Row():
run_button = gr.Button(value="Run", elem_id='run-button', scale=1)
api_dummy_button = gr.Button(value="Run (No Viz)", elem_id='count-only', visible=False, scale=2)
count_only = gr.Checkbox(label="Count Only", visible=False)
api_token = gr.Textbox(label="API Key", elem_id='api-token', visible=False)
with gr.Column(elem_id='output-video-container'):
with gr.Row():
with gr.Column():
out_text = gr.Markdown(label="Predicted Count", elem_id='output-text')
period_length = gr.Textbox(label="Period Length", elem_id='period-length', visible=False)
periodicity = gr.Textbox(label="Periodicity", elem_id='periodicity', visible=False)
with gr.Row():
out_plot = gr.Plot(label="Jumping Speed", elem_id='output-plot')
with gr.Row():
with gr.Column():
out_phase_spiral = gr.Plot(label="Phase Spiral", elem_id='phase-spiral')
with gr.Column():
out_phase = gr.Plot(label="Phase Sin/Cos", elem_id='phase-spiral-marks')
with gr.Row():
with gr.Column():
out_hist = gr.Plot(label="Speed Histogram", elem_id='output-hist')
demo_inference = partial(inference, count_only_api=False, api_key=None)
run_button.click(demo_inference, [in_video, use_60fps, beep_detection_on, event_length],
outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist])
api_inference = partial(inference, api_call=True)
api_dummy_button.click(api_inference, [in_video, use_60fps, beep_detection_on, event_length, count_only, api_token],
outputs=[period_length], api_name='inference')
examples = [
['files/wc2023.mp4', True, True, 30],
]
gr.Examples(examples,
inputs=[in_video, use_60fps, beep_detection_on, event_length],
outputs=[out_text, out_plot, out_phase_spiral, out_phase, out_hist],
fn=demo_inference, cache_examples=False)
if __name__ == "__main__":
if LOCAL:
demo.queue(api_open=True, max_size=15).launch(server_name="0.0.0.0",
server_port=7860,
debug=False,
ssl_verify=False,
share=True)
else:
demo.queue(api_open=True, max_size=15).launch(share=False)