Spaces:
Runtime error
Runtime error
Commit ·
319f52e
1
Parent(s): 57925f4
add conf and theme
Browse files- app.py +15 -10
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -11,7 +11,6 @@ import matplotlib
|
|
| 11 |
matplotlib.use('Agg')
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
from scipy.signal import medfilt
|
| 14 |
-
from skimage.measure import block_reduce
|
| 15 |
from functools import partial
|
| 16 |
from passlib.hash import pbkdf2_sha256
|
| 17 |
from tqdm import tqdm
|
|
@@ -183,13 +182,11 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
|
|
| 183 |
|
| 184 |
if median_pred_filter:
|
| 185 |
periodicity = medfilt(periodicity, 5)
|
| 186 |
-
|
| 187 |
periodicity = sigmoid(periodicity)
|
| 188 |
full_marks = sigmoid(full_marks)
|
| 189 |
full_marks_mask = np.int32(full_marks > marks_threshold)
|
| 190 |
-
#full_marks_reduced = block_reduce(full_marks > marks_threshold, (3,), np.max)
|
| 191 |
periodicity_mask = np.int32(periodicity > miss_threshold)
|
| 192 |
-
#periodicity_mask_reduced = block_reduce(periodicity_mask, (3,), np.max)
|
| 193 |
numofReps = 0
|
| 194 |
count = []
|
| 195 |
for i in range(len(periodLength)):
|
|
@@ -209,10 +206,15 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
|
|
| 209 |
marks_count_pred = marks_count_pred / 2
|
| 210 |
count = np.array(count) / 2
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
if both_feet:
|
| 213 |
-
count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}"
|
| 214 |
else:
|
| 215 |
-
count_msg = f"## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}"
|
| 216 |
|
| 217 |
if api_call:
|
| 218 |
if count_only_api:
|
|
@@ -229,20 +231,23 @@ def inference(x, count_only_api, api_key, img_size=192, seq_len=64, stride_lengt
|
|
| 229 |
jumping_speed = np.copy(jumps_per_second)
|
| 230 |
misses = periodicity < miss_threshold
|
| 231 |
jumps_per_second[misses] = 0
|
|
|
|
|
|
|
| 232 |
df = pd.DataFrame.from_dict({'period length': periodLength,
|
| 233 |
'jumping speed': jumping_speed,
|
| 234 |
'jumps per second': jumps_per_second,
|
| 235 |
'periodicity': periodicity,
|
| 236 |
'miss': misses,
|
|
|
|
| 237 |
'jumps': full_marks,
|
| 238 |
-
'jumps_size': (full_marks + 0.
|
| 239 |
'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
|
| 240 |
'seconds': np.linspace(0, seconds, num=len(periodLength))})
|
| 241 |
fig = px.scatter(data_frame=df,
|
| 242 |
x='seconds',
|
| 243 |
y='jumps per second',
|
| 244 |
-
symbol='
|
| 245 |
-
symbol_map={
|
| 246 |
color='periodicity',
|
| 247 |
size='jumps_size',
|
| 248 |
size_max=10,
|
|
@@ -294,7 +299,7 @@ DESCRIPTION += '\n## AI Counting for Competitive Jump Rope'
|
|
| 294 |
DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
|
| 295 |
|
| 296 |
|
| 297 |
-
with gr.Blocks() as demo:
|
| 298 |
gr.Markdown(DESCRIPTION)
|
| 299 |
with gr.Column():
|
| 300 |
with gr.Row():
|
|
|
|
| 11 |
matplotlib.use('Agg')
|
| 12 |
import matplotlib.pyplot as plt
|
| 13 |
from scipy.signal import medfilt
|
|
|
|
| 14 |
from functools import partial
|
| 15 |
from passlib.hash import pbkdf2_sha256
|
| 16 |
from tqdm import tqdm
|
|
|
|
| 182 |
|
| 183 |
if median_pred_filter:
|
| 184 |
periodicity = medfilt(periodicity, 5)
|
| 185 |
+
periodLength = medfilt(periodLength, 5)
|
| 186 |
periodicity = sigmoid(periodicity)
|
| 187 |
full_marks = sigmoid(full_marks)
|
| 188 |
full_marks_mask = np.int32(full_marks > marks_threshold)
|
|
|
|
| 189 |
periodicity_mask = np.int32(periodicity > miss_threshold)
|
|
|
|
| 190 |
numofReps = 0
|
| 191 |
count = []
|
| 192 |
for i in range(len(periodLength)):
|
|
|
|
| 206 |
marks_count_pred = marks_count_pred / 2
|
| 207 |
count = np.array(count) / 2
|
| 208 |
|
| 209 |
+
confidence = (np.mean(periodicity[periodicity > miss_threshold]) - miss_threshold) / (1 - miss_threshold)
|
| 210 |
+
self_err = abs(count_pred - marks_count_pred)
|
| 211 |
+
self_pct_err = self_err / count_pred
|
| 212 |
+
total_confidence = confidence * (1 - self_pct_err)
|
| 213 |
+
|
| 214 |
if both_feet:
|
| 215 |
+
count_msg = f"## Reps Count (both feet): {count_pred:.1f}, Marks Count (both feet): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
|
| 216 |
else:
|
| 217 |
+
count_msg = f"## Predicted Count (one foot): {count_pred:.1f}, Marks Count (one foot): {marks_count_pred:.1f}, Confidence: {total_confidence:.2f}"
|
| 218 |
|
| 219 |
if api_call:
|
| 220 |
if count_only_api:
|
|
|
|
| 231 |
jumping_speed = np.copy(jumps_per_second)
|
| 232 |
misses = periodicity < miss_threshold
|
| 233 |
jumps_per_second[misses] = 0
|
| 234 |
+
frame_type = np.array(['miss' if miss else 'frame' for miss in misses])
|
| 235 |
+
frame_type[full_marks > marks_threshold] = 'jump'
|
| 236 |
df = pd.DataFrame.from_dict({'period length': periodLength,
|
| 237 |
'jumping speed': jumping_speed,
|
| 238 |
'jumps per second': jumps_per_second,
|
| 239 |
'periodicity': periodicity,
|
| 240 |
'miss': misses,
|
| 241 |
+
'frame_type': frame_type,
|
| 242 |
'jumps': full_marks,
|
| 243 |
+
'jumps_size': (full_marks + 0.1) * 8,
|
| 244 |
'miss_size': np.clip((1 - periodicity) * 0.9 + 0.1, 1, 10),
|
| 245 |
'seconds': np.linspace(0, seconds, num=len(periodLength))})
|
| 246 |
fig = px.scatter(data_frame=df,
|
| 247 |
x='seconds',
|
| 248 |
y='jumps per second',
|
| 249 |
+
symbol='frame_type',
|
| 250 |
+
symbol_map={'frame': 'circle', 'miss': 'circle-open', 'jump': 'triangle-down'},
|
| 251 |
color='periodicity',
|
| 252 |
size='jumps_size',
|
| 253 |
size_max=10,
|
|
|
|
| 299 |
DESCRIPTION += '\nDemo created by [Dylan Plummer](https://dylan-plummer.github.io/). Check out the [NextJump iOS app](https://apps.apple.com/us/app/nextjump-jump-rope-counter/id6451026115).'
|
| 300 |
|
| 301 |
|
| 302 |
+
with gr.Blocks(theme='WeixuanYuan/Soft_dark') as demo:
|
| 303 |
gr.Markdown(DESCRIPTION)
|
| 304 |
with gr.Column():
|
| 305 |
with gr.Row():
|
requirements.txt
CHANGED
|
@@ -4,7 +4,6 @@ matplotlib
|
|
| 4 |
plotly
|
| 5 |
passlib
|
| 6 |
scipy
|
| 7 |
-
scikit-image
|
| 8 |
--find-links https://download.pytorch.org/whl/torch_stable.html
|
| 9 |
opencv-python-headless==4.7.0.68
|
| 10 |
openvino-dev==2022.3.0
|
|
|
|
| 4 |
plotly
|
| 5 |
passlib
|
| 6 |
scipy
|
|
|
|
| 7 |
--find-links https://download.pytorch.org/whl/torch_stable.html
|
| 8 |
opencv-python-headless==4.7.0.68
|
| 9 |
openvino-dev==2022.3.0
|