mastefan commited on
Commit
9c22091
·
verified ·
1 Parent(s): acfa2c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +457 -212
app.py CHANGED
@@ -13,130 +13,91 @@
13
  # 5) Gradio UI: video in → gallery of clips + status text out.
14
  #
15
  # Fencing Scoreboard Clips - YOLO x AutoGluon (Gradio)
16
- # -*- coding: utf-8 -*-
17
- # Fencing Scoreboard Detector — YOLO × AutoGluon × Gradio (Stable Hugging Face Build)
18
 
19
- import os, sys, zipfile, shutil, tempfile, subprocess, threading, time, types, warnings
20
  from typing import List, Tuple
21
- import pathlib
 
22
  import numpy as np
23
  import pandas as pd
24
  import cv2
25
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  from ultralytics import YOLO
27
- from autogluon.tabular import TabularPredictor
28
- from huggingface_hub import hf_hub_download
29
-
30
- # Silence noisy warnings
31
- warnings.filterwarnings("ignore", category=UserWarning)
32
- warnings.filterwarnings("ignore", category=FutureWarning)
33
- warnings.filterwarnings("ignore", category=DeprecationWarning)
34
-
35
- # ---------------- Configuration ----------------
36
- YOLO_REPO_ID = os.getenv("YOLO_REPO_ID", "mastefan/fencing-scoreboard-yolov8")
37
- YOLO_FILENAME = os.getenv("YOLO_FILENAME", "best.pt")
38
- AG_REPO_ID = os.getenv("AG_REPO_ID", "emkessle/2024-24679-fencing-touch-predictor")
39
- AG_ZIP_NAME = os.getenv("AG_ZIP_NAME", "autogluon_predictor_dir.zip")
40
-
41
- FRAME_SKIP = 2
42
- KEEP_CONF = 0.70
43
- YOLO_CONF = 0.20
44
- YOLO_IOU = 0.50
45
- GROUP_GAP_S = 1.5
46
- CLIP_PAD_S = 2.0
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  CACHE_DIR = pathlib.Path("hf_assets")
49
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
50
 
51
- _YOLO = None
52
- _AG_PRED = None
53
- from threading import Lock
54
- _load_lock = Lock()
55
-
56
- # ---------------- Model loaders ----------------
57
  def load_yolo_from_hub() -> YOLO:
58
  w = hf_hub_download(repo_id=YOLO_REPO_ID, filename=YOLO_FILENAME, cache_dir=CACHE_DIR)
59
  return YOLO(w)
60
 
61
  def load_autogluon_tabular_from_hub() -> TabularPredictor:
62
- """Download + extract predictor ZIP, stub FastAI classes, and load predictor."""
63
- z = hf_hub_download(repo_id=AG_REPO_ID, filename=AG_ZIP_NAME,
64
- cache_dir=CACHE_DIR, force_download=True)
65
  extract_dir = CACHE_DIR / "ag_predictor_native"
66
  if extract_dir.exists():
67
  shutil.rmtree(extract_dir)
68
  with zipfile.ZipFile(z, "r") as zip_ref:
69
  zip_ref.extractall(extract_dir)
 
70
 
71
- # ---- Patch FastAI / fasttransform stubs (final ultra-safe version) ----
72
- try:
73
- import sys, types
74
-
75
- # ---- fasttransform stubs ----
76
- if "fasttransform" not in sys.modules:
77
- fasttransform_stub = types.ModuleType("fasttransform")
78
- sys.modules["fasttransform"] = fasttransform_stub
79
-
80
- if "fasttransform.transform" not in sys.modules:
81
- ft_trans_stub = types.ModuleType("fasttransform.transform")
82
-
83
- class Pipeline:
84
- def __init__(self, *a, **kw): pass
85
- def fit(self, *a, **kw): return self
86
- def transform(self, X): return X
87
- def fit_transform(self, X): return X
88
- def __call__(self, *a, **kw): return self
89
- def __getattr__(self, name): return self
90
-
91
- ft_trans_stub.Pipeline = Pipeline
92
- sys.modules["fasttransform.transform"] = ft_trans_stub
93
-
94
- # ---- TabWeightedDL stub ----
95
- import fastai.tabular.core as ftc
96
-
97
- class Dummy:
98
- """Generic self-replicating dummy for chained calls like .new().to().cpu()"""
99
- def __init__(self, *a, **kw): pass
100
- def __call__(self, *a, **kw): return self
101
- def __getattr__(self, name): return self
102
- def __iter__(self): return iter([])
103
- def __next__(self): raise StopIteration
104
- def __len__(self): return 0
105
-
106
- ftc.TabWeightedDL = Dummy
107
-
108
- print("[INFO] Ultra-safe FastAI/fasttransform stubs patched.")
109
- except Exception as e:
110
- print("[WARN] Could not patch FastAI stubs:", e)
111
-
112
-
113
- # ---- Load predictor safely ----
114
- from autogluon.tabular import TabularPredictor
115
- pred = TabularPredictor.load(
116
- str(extract_dir),
117
- require_py_version_match=False,
118
- require_version_match=False
119
- )
120
-
121
- # ------------- SAFE LOAD PATCH -------------
122
- try:
123
- model_names = []
124
- # Compatible with all AutoGluon builds
125
- if hasattr(pred, "get_model_names"):
126
- model_names = pred.get_model_names()
127
- elif hasattr(pred, "_trainer") and hasattr(pred._trainer, "get_model_names"):
128
- model_names = pred._trainer.get_model_names()
129
- elif hasattr(pred, "_learner") and hasattr(pred._learner, "trainer") \
130
- and hasattr(pred._learner.trainer, "get_model_names"):
131
- model_names = pred._learner.trainer.get_model_names()
132
-
133
- bad_models = [m for m in model_names if "NN" in m or "fastai" in m]
134
- if bad_models:
135
- print("[INFO] Removing unusable FastAI models:", bad_models)
136
- pred.delete_models(models_to_delete=bad_models, dry_run=False)
137
- except Exception as e:
138
- print("[WARN] Could not prune predictor:", e)
139
- # -------------------------------------------
140
 
141
  def yolo() -> YOLO:
142
  global _YOLO
@@ -146,61 +107,130 @@ def yolo() -> YOLO:
146
 
147
  def ag_predictor() -> TabularPredictor:
148
  global _AG_PRED
149
- with _load_lock:
150
- if _AG_PRED is None:
151
- print("[INFO] Loading AutoGluon predictor…")
152
- _AG_PRED = load_autogluon_tabular_from_hub()
153
  return _AG_PRED
154
 
155
- # ---------------- Vision helpers ----------------
 
 
 
 
 
156
  def isolate_scoreboard_color(frame_bgr: np.ndarray,
157
  conf: float = YOLO_CONF,
158
  iou: float = YOLO_IOU,
159
- keep_conf: float = KEEP_CONF) -> np.ndarray:
160
- """Grayscale everything except the largest scoreboard bounding box."""
 
 
 
 
 
 
 
 
 
161
  H, W = frame_bgr.shape[:2]
 
 
162
  gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
163
  gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
 
 
 
 
 
164
  chosen_box = None
165
  res = yolo().predict(frame_bgr, conf=conf, iou=iou, verbose=False)
166
  if len(res):
167
  r = res[0]
168
  if getattr(r, "boxes", None) is not None and len(r.boxes) > 0:
169
- boxes = r.boxes.xyxy.cpu().numpy()
170
  scores = r.boxes.conf.cpu().numpy()
171
- candidates = [(b, s) for b, s in zip(boxes, scores) if s >= 0.78]
172
- if candidates:
173
- chosen_box, _ = max(
174
- candidates,
175
- key=lambda bs: (bs[0][2]-bs[0][0])*(bs[0][3]-bs[0][1])
176
- )
 
 
 
 
 
 
 
177
  x1, y1, x2, y2 = [int(round(v)) for v in chosen_box]
178
- gray[y1:y2, x1:x2] = frame_bgr[y1:y2, x1:x2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  return gray
180
 
181
- def color_pixel_ratio(rgb: np.ndarray, ch: int) -> float:
 
 
 
 
182
  R, G, B = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
183
  if ch == 0:
184
- mask = (R > 150) & (R > 1.2*G) & (R > 1.2*B)
 
 
185
  else:
186
- mask = (G > 100) & (G > 1.05*R) & (G > 1.05*B)
187
- return np.sum(mask) / (rgb.shape[0]*rgb.shape[1] + 1e-9)
 
 
 
188
 
189
  def rolling_z(series: pd.Series, win: int = 45) -> pd.Series:
190
  med = series.rolling(win, min_periods=5).median()
191
  mad = series.rolling(win, min_periods=5).apply(
192
- lambda x: np.median(np.abs(x - np.median(x))), raw=True)
 
193
  mad = mad.replace(0, mad[mad > 0].min() if (mad > 0).any() else 1.0)
194
- return (series - med) / mad
195
-
196
- # ---------------- Frame & feature extraction ----------------
197
- def extract_feature_timeseries(video_path: str, frame_skip: int = FRAME_SKIP) -> Tuple[pd.DataFrame, float]:
 
 
 
 
 
198
  cap = cv2.VideoCapture(video_path)
199
  if not cap.isOpened():
 
200
  return pd.DataFrame(), 0.0
 
201
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
202
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
203
  records, frame_idx = [], 0
 
204
 
205
  while True:
206
  ret, frame = cap.read()
@@ -208,7 +238,7 @@ def extract_feature_timeseries(video_path: str, frame_skip: int = FRAME_SKIP) ->
208
  break
209
  if frame_idx % frame_skip == 0:
210
  ts = frame_idx / fps
211
- masked = isolate_scoreboard_color(frame)
212
  rgb = cv2.cvtColor(masked, cv2.COLOR_BGR2RGB)
213
  red_ratio = color_pixel_ratio(rgb, 0)
214
  green_ratio = color_pixel_ratio(rgb, 1)
@@ -216,112 +246,288 @@ def extract_feature_timeseries(video_path: str, frame_skip: int = FRAME_SKIP) ->
216
  "frame_id": frame_idx,
217
  "timestamp": ts,
218
  "red_ratio": red_ratio,
219
- "green_ratio": green_ratio
220
  })
221
  frame_idx += 1
222
- cap.release()
223
 
 
224
  df = pd.DataFrame(records)
225
- if df.empty: return df, fps
226
- df["red_diff"] = df["red_ratio"].diff().fillna(0)
 
 
 
 
227
  df["green_diff"] = df["green_ratio"].diff().fillna(0)
228
- df["z_red"] = rolling_z(df["red_ratio"])
229
- df["z_green"] = rolling_z(df["green_ratio"])
 
 
 
 
 
 
230
  return df, fps
231
 
232
- # ---------------- Predictor & detection ----------------
 
 
233
  def predict_scores(df: pd.DataFrame) -> pd.Series:
234
- feat_cols = ["red_ratio","green_ratio","red_diff","green_diff","z_red","z_green"]
235
  X = df[feat_cols].copy()
236
- pred_model = ag_predictor()
237
- if pred_model is None:
238
- return pd.Series(np.zeros(len(df)))
239
  try:
240
- proba = pred_model.predict_proba(X)
241
  if isinstance(proba, pd.DataFrame) and (1 in proba.columns):
242
  return proba[1]
243
  except Exception:
244
  pass
245
- s = pd.Series(pred_model.predict(X)).astype(float)
 
 
246
  rng = (s.quantile(0.95) - s.quantile(0.05)) or 1.0
247
- return ((s - s.quantile(0.05)) / rng).clip(0,1)
248
 
249
  def pick_events(df: pd.DataFrame, score: pd.Series, fps: float) -> List[float]:
 
 
 
 
 
 
 
 
250
  max_score = score.max()
251
  raw_cutoff = 0.7 * max_score if max_score > 0 else 0.4
 
252
  z = rolling_z(score, win=45)
253
- z_cutoff = max(2.0, 0.6 * z.max())
254
- out_times, min_dist_frames = [], int(fps)
255
- y, last_kept = score.values, -int(fps)
 
 
 
 
 
 
 
 
256
  for i in range(1, len(y)-1):
257
  ts = float(df.iloc[i]["timestamp"])
258
  local_peak = y[i] > y[i-1] and y[i] > y[i+1]
259
- if ts >= 1.0 and ((z.iloc[i] > z_cutoff) or (y[i] > raw_cutoff)) \
260
- and local_peak and (i - last_kept) >= min_dist_frames:
261
  out_times.append(ts)
262
  last_kept = i
263
- if not out_times and len(y)>0:
 
264
  best_idx = int(np.argmax(y))
265
  ts = float(df.iloc[best_idx]["timestamp"])
266
- if ts >= 1.0: out_times=[ts]
267
- grouped=[]
 
 
 
 
 
 
 
268
  for t in out_times:
269
- if not grouped or (t - grouped[-1]) > GROUP_GAP_S:
270
  grouped.append(t)
 
 
271
  return grouped
272
 
273
- # ---------------- Video clipping ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  def _probe_duration(video_path: str) -> float:
275
  try:
276
- import ffmpeg
 
277
  meta = ffmpeg.probe(video_path)
278
  return float(meta["format"]["duration"])
279
  except Exception:
280
  return 0.0
281
 
282
  def cut_clip(video_path: str, start: float, end: float, out_path: str) -> str:
 
283
  try:
284
- cmd = ["ffmpeg", "-y","-ss",str(max(0,start)),"-to",str(max(start,end)),"-i",video_path,"-c","copy",out_path]
285
- sp = subprocess.run(cmd,stdout=subprocess.PIPE,stderr=subprocess.PIPE)
286
- if sp.returncode==0 and os.path.exists(out_path): return out_path
287
- except: pass
 
 
 
 
 
288
  from moviepy.editor import VideoFileClip
289
- clip=VideoFileClip(video_path).subclip(max(0,start),max(start,end))
290
- clip.write_videofile(out_path,codec="libx264",audio_codec="aac",verbose=False,logger=None)
291
  return out_path
292
 
293
- def extract_score_clips(video_path: str) -> Tuple[List[Tuple[str,str]], str]:
294
- df,fps=extract_feature_timeseries(video_path)
295
- if df.empty: return [],"No frames processed."
296
- score=predict_scores(df)
297
- if score.max()<=1e-6: return [],"⚠️ No scoreboard detected or illumination scores flat."
298
- events=pick_events(df,score,fps)
299
- if not events: return [],"⚠️ No touches confidently detected."
300
- duration=_probe_duration(video_path) or df["timestamp"].max()+CLIP_PAD_S+0.5
301
- clips=[]
302
- base=os.path.splitext(os.path.basename(video_path))[0]
303
- for i,t in enumerate(events):
304
- s=max(0.0,t-CLIP_PAD_S)
305
- e=min(duration,t+CLIP_PAD_S)
306
- out=os.path.join(tempfile.gettempdir(),f"{base}_score_{i+1:02d}.mp4")
307
- cut_clip(video_path,s,e,out)
308
- clips.append((out,f"Touch {i+1} @ {t:.2f}s"))
309
- return clips,f" Detected {len(clips)} event(s)."
310
-
311
- # ---------------- Gradio GUI ----------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  CSS = """
313
  .gradio-container {max-width: 900px; margin: auto;}
314
  .header {text-align: center; margin-bottom: 20px;}
315
  .full-width {width: 100% !important;}
316
- .progress-bar {width:100%;height:30px;background:#e0e0e0;border-radius:15px;margin:15px 0;position:relative;}
317
- .progress-fill {height:100%;background:#4CAF50;border-radius:15px;text-align:center;line-height:30px;color:white;font-weight:bold;transition:width 0.3s;}
318
- .fencer {position:absolute;top:-5px;font-size:24px;transition:left 0.3s;transform:scaleX(-1);}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
  """
320
 
321
- def _make_fencer_strip(percent: int):
 
322
  return f"""
323
  <div class="progress-bar">
324
- <div id="progress-fill" class="progress-fill" style="width:{percent}%"></div>
325
  <div id="fencer" class="fencer" style="left:{percent}%">🤺</div>
326
  </div>
327
  """
@@ -331,39 +537,78 @@ def run_with_progress(video_file):
331
  yield [], "Please upload a video file.", gr.update(visible=False)
332
  return
333
 
334
- result = {}
335
- def run_pipeline():
336
- try:
337
- clips, msg = extract_score_clips(video_file)
338
- except Exception as e:
339
- clips, msg = [], f"❌ Error: {e}"
340
- result["clips"], result["msg"] = clips, msg
341
-
342
- t = threading.Thread(target=run_pipeline)
343
- t.start()
344
- pos = 0
345
- while t.is_alive():
346
- pos = (pos + 5) % 100
347
- yield gr.update(value=[], visible=False), "Processing...", gr.update(value=_make_fencer_strip(pos), visible=True)
348
- time.sleep(0.1)
349
- clips = result.get("clips", [])
350
- msg = result.get("msg", "⚠️ Unknown error during processing.")
351
- yield gr.update(value=clips, visible=True), msg, gr.update(value=_make_fencer_strip(100), visible=True)
352
 
353
  with gr.Blocks(css=CSS, title="Fencing Scoreboard Detector") as demo:
354
  with gr.Row(elem_classes="header"):
355
- gr.Markdown("## 🤺 Fencing Score Detector\nUpload a fencing bout video. The system detects scoreboard lights and returns highlight clips around each scoring event.")
 
 
 
 
 
356
  in_video = gr.Video(label="Upload Bout Video", elem_classes="full-width", height=400)
357
  run_btn = gr.Button("⚡ Detect Touches", elem_classes="full-width")
358
- progress_html = gr.HTML(value="", visible=False)
359
- status = gr.Markdown("Ready.")
360
- gallery = gr.Gallery(label="Detected Clips", columns=1, height=400,
361
- preview=True, allow_preview=True,
362
- show_download_button=True, visible=False)
363
- run_btn.click(fn=run_with_progress,
364
- inputs=in_video,
365
- outputs=[gallery, status, progress_html])
366
 
367
- if __name__ == "__main__":
368
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
369
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # 5) Gradio UI: video in → gallery of clips + status text out.
14
  #
15
  # Fencing Scoreboard Clips - YOLO x AutoGluon (Gradio)
 
 
16
 
17
+ import os, sys, zipfile, shutil, subprocess, tempfile, pathlib
18
  from typing import List, Tuple
19
+ import uuid
20
+
21
  import numpy as np
22
  import pandas as pd
23
  import cv2
24
  import gradio as gr
25
+
26
+ # ---- Robust imports/installs for Colab/Spaces ----
27
+ def _pip(pkgs: List[str]):
28
+ import subprocess, sys
29
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "--quiet", *pkgs])
30
+
31
+ try:
32
+ import ultralytics
33
+ except:
34
+ _pip(["ultralytics"])
35
+ import ultralytics
36
+
37
+ try:
38
+ import ffmpeg # optional helper for duration probe
39
+ except:
40
+ try:
41
+ _pip(["ffmpeg-python"])
42
+ import ffmpeg
43
+ except Exception:
44
+ ffmpeg = None
45
+
46
+ try:
47
+ from autogluon.tabular import TabularPredictor
48
+ except:
49
+ _pip(["autogluon.tabular"])
50
+ from autogluon.tabular import TabularPredictor
51
+
52
+ try:
53
+ from huggingface_hub import hf_hub_download
54
+ except:
55
+ _pip(["huggingface_hub"])
56
+ from huggingface_hub import hf_hub_download
57
+
58
  from ultralytics import YOLO
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # ----------------------------
61
+ # Config — HF Hub repositories
62
+ # ----------------------------
63
+ # YOLO scoreboard detector weights (pushed by your training file)
64
+ YOLO_REPO_ID = os.getenv("YOLO_REPO_ID", "mastefan/fencing-scoreboard-yolov8")
65
+ YOLO_FILENAME = os.getenv("YOLO_FILENAME", "best.pt")
66
+
67
+ # AutoGluon Tabular detector (your color/timeseries model zip)
68
+ AG_REPO_ID = os.getenv("AG_REPO_ID", "emkessle/2024-24679-fencing-touch-predictor")
69
+ AG_ZIP_NAME = os.getenv("AG_ZIP_NAME", "autogluon_predictor_dir.zip")
70
+
71
+ # Processing parameters
72
+ FRAME_SKIP = int(os.getenv("FRAME_SKIP", "2")) # process every Nth frame
73
+ KEEP_CONF = float(os.getenv("KEEP_CONF", "0.85"))# YOLO conf to keep color inside bbox
74
+ YOLO_CONF = float(os.getenv("YOLO_CONF", "0.25"))
75
+ YOLO_IOU = float(os.getenv("YOLO_IOU", "0.50"))
76
+ MIN_SEP_S = float(os.getenv("MIN_SEP_S", "1.2")) # min gap between events (s)
77
+ CLIP_PAD_S = float(os.getenv("CLIP_PAD_S","2.0")) # before/after padding each hit
78
+ GROUP_GAP_S = float(os.getenv("GROUP_GAP_S","1.5"))# cluster close frames to single event
79
+
80
+ # ----------------
81
+ # Model loaders
82
+ # ----------------
83
  CACHE_DIR = pathlib.Path("hf_assets")
84
  CACHE_DIR.mkdir(parents=True, exist_ok=True)
85
 
 
 
 
 
 
 
86
  def load_yolo_from_hub() -> YOLO:
87
  w = hf_hub_download(repo_id=YOLO_REPO_ID, filename=YOLO_FILENAME, cache_dir=CACHE_DIR)
88
  return YOLO(w)
89
 
90
  def load_autogluon_tabular_from_hub() -> TabularPredictor:
91
+ z = hf_hub_download(repo_id=AG_REPO_ID, filename=AG_ZIP_NAME, cache_dir=CACHE_DIR)
 
 
92
  extract_dir = CACHE_DIR / "ag_predictor_native"
93
  if extract_dir.exists():
94
  shutil.rmtree(extract_dir)
95
  with zipfile.ZipFile(z, "r") as zip_ref:
96
  zip_ref.extractall(extract_dir)
97
+ return TabularPredictor.load(str(extract_dir))
98
 
99
+ _YOLO = None
100
+ _AG_PRED = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
  def yolo() -> YOLO:
103
  global _YOLO
 
107
 
108
  def ag_predictor() -> TabularPredictor:
109
  global _AG_PRED
110
+ if _AG_PRED is None:
111
+ _AG_PRED = load_autogluon_tabular_from_hub()
 
 
112
  return _AG_PRED
113
 
114
+ # ----------------------------
115
+ # Vision helpers
116
+ # ----------------------------
117
+ DEBUG_DIR = pathlib.Path("debug_frames")
118
+ DEBUG_DIR.mkdir(exist_ok=True)
119
+
120
  def isolate_scoreboard_color(frame_bgr: np.ndarray,
121
  conf: float = YOLO_CONF,
122
  iou: float = YOLO_IOU,
123
+ keep_conf: float = KEEP_CONF,
124
+ debug: bool = False,
125
+ frame_id: int = None) -> np.ndarray:
126
+ """
127
+ Reverted version:
128
+ - Choose the largest bbox among candidates meeting confidence.
129
+ - Primary threshold: >= max(0.80, keep_conf)
130
+ - Fallback threshold: >= (primary - 0.02) (i.e., ~0.78 by default)
131
+ - Entire chosen bbox is restored to color; everything else is grayscale.
132
+ - Single safeguard: reject very low-saturation ROIs (likely flat/neutral areas).
133
+ """
134
  H, W = frame_bgr.shape[:2]
135
+
136
+ # start fully grayscale
137
  gray = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2GRAY)
138
  gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
139
+
140
+ primary_thr = max(0.80, keep_conf) # accept ≥0.80 as "good"
141
+ fallback_thr = max(0.7, primary_thr - 0.05) # accept ≥0.75 as fallback
142
+
143
+
144
  chosen_box = None
145
  res = yolo().predict(frame_bgr, conf=conf, iou=iou, verbose=False)
146
  if len(res):
147
  r = res[0]
148
  if getattr(r, "boxes", None) is not None and len(r.boxes) > 0:
149
+ boxes = r.boxes.xyxy.cpu().numpy()
150
  scores = r.boxes.conf.cpu().numpy()
151
+ candidates = list(zip(boxes, scores))
152
+
153
+ # Prefer largest box that meets primary threshold
154
+ strong = [(b, s) for (b, s) in candidates if float(s) >= primary_thr]
155
+ if strong:
156
+ chosen_box, _ = max(strong, key=lambda bs: (bs[0][2]-bs[0][0]) * (bs[0][3]-bs[0][1]))
157
+ else:
158
+ # Fallback: largest box meeting fallback threshold
159
+ medium = [(b, s) for (b, s) in candidates if float(s) >= fallback_thr]
160
+ if medium:
161
+ chosen_box, _ = max(medium, key=lambda bs: (bs[0][2]-bs[0][0]) * (bs[0][3]-bs[0][1]))
162
+
163
+ if chosen_box is not None:
164
  x1, y1, x2, y2 = [int(round(v)) for v in chosen_box]
165
+ x1, y1 = max(0, x1), max(0, y1)
166
+ x2, y2 = min(W-1, x2), min(H-1, y2)
167
+
168
+ if x2 > x1 and y2 > y1:
169
+ # Single safeguard: reject very low-saturation ROIs
170
+ roi_color = frame_bgr[y1:y2, x1:x2]
171
+ if roi_color.size > 0:
172
+ hsv = cv2.cvtColor(roi_color, cv2.COLOR_BGR2HSV)
173
+ sat_mean = hsv[:, :, 1].mean()
174
+ if sat_mean < 25: # flat/neutral area → reject
175
+ print(f"[WARN] Rejected bbox due to low saturation (mean={sat_mean:.1f})")
176
+ chosen_box = None
177
+
178
+ # If accepted, restore whole bbox to color
179
+ if chosen_box is not None:
180
+ gray[y1:y2, x1:x2] = frame_bgr[y1:y2, x1:x2]
181
+
182
+ # Optional debug save
183
+ if debug and frame_id is not None:
184
+ dbg = gray.copy()
185
+ if chosen_box is not None:
186
+ x1, y1, x2, y2 = [int(round(v)) for v in chosen_box]
187
+ cv2.rectangle(dbg, (x1, y1), (x2, y2), (0, 255, 0), 2)
188
+ out_path = DEBUG_DIR / f"frame_{frame_id:06d}.jpg"
189
+ cv2.imwrite(str(out_path), dbg)
190
+ print(f"[DEBUG] Saved debug frame → {out_path}")
191
+
192
  return gray
193
 
194
+
195
+ # Color features
196
+ def _count_color_pixels(rgb: np.ndarray, ch: int,
197
+ red_thresh=150, green_thresh=100,
198
+ red_dom=1.2, green_dom=1.05) -> int:
199
  R, G, B = rgb[:, :, 0], rgb[:, :, 1], rgb[:, :, 2]
200
  if ch == 0:
201
+ mask = (R > red_thresh) & (R > red_dom*G) & (R > red_dom*B)
202
+ elif ch == 1:
203
+ mask = (G > green_thresh) & (G > green_dom*R) & (G > green_dom*B)
204
  else:
205
+ raise ValueError("ch must be 0 (red) or 1 (green)")
206
+ return int(np.sum(mask))
207
+
208
+ def color_pixel_ratio(rgb: np.ndarray, ch: int) -> float:
209
+ return _count_color_pixels(rgb, ch) / float(rgb.shape[0]*rgb.shape[1] + 1e-9)
210
 
211
  def rolling_z(series: pd.Series, win: int = 45) -> pd.Series:
212
  med = series.rolling(win, min_periods=5).median()
213
  mad = series.rolling(win, min_periods=5).apply(
214
+ lambda x: np.median(np.abs(x - np.median(x))), raw=True
215
+ )
216
  mad = mad.replace(0, mad[mad > 0].min() if (mad > 0).any() else 1.0)
217
+ return (series - med) / mad
218
+
219
+ # ----------------------------
220
+ # Video feature table
221
+ # ----------------------------
222
+ def extract_feature_timeseries(video_path: str,
223
+ frame_skip: int = FRAME_SKIP,
224
+ debug: bool = False) -> Tuple[pd.DataFrame, float]:
225
+ print("[INFO] Starting frame extraction...")
226
  cap = cv2.VideoCapture(video_path)
227
  if not cap.isOpened():
228
+ print("[ERROR] Could not open video.")
229
  return pd.DataFrame(), 0.0
230
+
231
  fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
 
232
  records, frame_idx = [], 0
233
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
234
 
235
  while True:
236
  ret, frame = cap.read()
 
238
  break
239
  if frame_idx % frame_skip == 0:
240
  ts = frame_idx / fps
241
+ masked = isolate_scoreboard_color(frame, debug=debug, frame_id=frame_idx)
242
  rgb = cv2.cvtColor(masked, cv2.COLOR_BGR2RGB)
243
  red_ratio = color_pixel_ratio(rgb, 0)
244
  green_ratio = color_pixel_ratio(rgb, 1)
 
246
  "frame_id": frame_idx,
247
  "timestamp": ts,
248
  "red_ratio": red_ratio,
249
+ "green_ratio": green_ratio,
250
  })
251
  frame_idx += 1
 
252
 
253
+ cap.release()
254
  df = pd.DataFrame(records)
255
+ print(f"[INFO] Processed {len(df)} frames out of {total_frames} (fps={fps:.2f})")
256
+
257
+ if df.empty:
258
+ return df, fps
259
+
260
+ df["red_diff"] = df["red_ratio"].diff().fillna(0)
261
  df["green_diff"] = df["green_ratio"].diff().fillna(0)
262
+ df["z_red"] = rolling_z(df["red_ratio"])
263
+ df["z_green"] = rolling_z(df["green_ratio"])
264
+
265
+ if debug:
266
+ out_csv = DEBUG_DIR / f"features_{uuid.uuid4().hex}.csv"
267
+ df.to_csv(out_csv, index=False)
268
+ print(f"[DEBUG] Saved feature CSV → {out_csv}")
269
+
270
  return df, fps
271
 
272
+ # ----------------------------
273
+ # AutoGluon inference + event picking
274
+ # ----------------------------
275
  def predict_scores(df: pd.DataFrame) -> pd.Series:
276
+ feat_cols = ["red_ratio", "green_ratio", "red_diff", "green_diff", "z_red", "z_green"]
277
  X = df[feat_cols].copy()
278
+ pred = ag_predictor().predict(X)
279
+
280
+ # Prefer classification proba if available
281
  try:
282
+ proba = ag_predictor().predict_proba(X)
283
  if isinstance(proba, pd.DataFrame) and (1 in proba.columns):
284
  return proba[1]
285
  except Exception:
286
  pass
287
+
288
+ # Fallback: normalize regression-like output to 0..1 robustly
289
+ s = pd.Series(pred).astype(float)
290
  rng = (s.quantile(0.95) - s.quantile(0.05)) or 1.0
291
+ return ((s - s.quantile(0.05)) / rng).clip(0, 1)
292
 
293
  def pick_events(df: pd.DataFrame, score: pd.Series, fps: float) -> List[float]:
294
+ """
295
+ Adaptive hybrid event detection:
296
+ - Adaptive raw threshold = 0.7 × max score
297
+ - Adaptive z-threshold = max(2.0, 0.6 × max z-score)
298
+ - Must be a local peak
299
+ - Enforce min spacing (1.0s) and group gap (1.5s)
300
+ - Ignore any detections before 1.0s
301
+ """
302
  max_score = score.max()
303
  raw_cutoff = 0.7 * max_score if max_score > 0 else 0.4
304
+
305
  z = rolling_z(score, win=45)
306
+ max_z = z.max()
307
+ z_cutoff = max(2.0, 0.6 * max_z)
308
+
309
+ print(f"[DEBUG] Predictor score stats: min={score.min():.3f}, max={max_score:.3f}, mean={score.mean():.3f}")
310
+ print(f"[DEBUG] Adaptive thresholds: raw>{raw_cutoff:.3f}, z>{z_cutoff:.2f}")
311
+
312
+ out_times = []
313
+ min_dist_frames = max(1, int(1.0 * max(1.0, fps))) # 1.0s spacing
314
+ y = score.values
315
+ last_kept = -min_dist_frames
316
+
317
  for i in range(1, len(y)-1):
318
  ts = float(df.iloc[i]["timestamp"])
319
  local_peak = y[i] > y[i-1] and y[i] > y[i+1]
320
+ if ts >= 1.0 and ((z.iloc[i] > z_cutoff) or (y[i] > raw_cutoff)) and local_peak and (i - last_kept) >= min_dist_frames:
 
321
  out_times.append(ts)
322
  last_kept = i
323
+
324
+ if not out_times and len(y) > 0:
325
  best_idx = int(np.argmax(y))
326
  ts = float(df.iloc[best_idx]["timestamp"])
327
+ if ts >= 1.0:
328
+ out_times = [ts]
329
+ print(f"[DEBUG] Fallback → using global max at {ts:.2f}s")
330
+ else:
331
+ print(f"[DEBUG] Ignored fallback at {ts:.2f}s (within first second)")
332
+
333
+ out_times.sort()
334
+
335
+ grouped = []
336
  for t in out_times:
337
+ if (not grouped) or (t - grouped[-1]) > GROUP_GAP_S:
338
  grouped.append(t)
339
+
340
+ print(f"[DEBUG] Final detected events: {grouped}")
341
  return grouped
342
 
343
+ def save_event_snapshot(video_path: str, timestamp: float, out_path: str, fps: float):
344
+ """Save a snapshot frame at timestamp with YOLO bbox drawn."""
345
+ cap = cv2.VideoCapture(video_path)
346
+ if not cap.isOpened():
347
+ print("[ERROR] Could not open video for snapshot.")
348
+ return None
349
+
350
+ frame_idx = int(timestamp * fps)
351
+ cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
352
+ ret, frame = cap.read()
353
+ cap.release()
354
+
355
+ if not ret or frame is None:
356
+ print(f"[WARN] Could not grab frame at {timestamp:.2f}s")
357
+ return None
358
+
359
+ masked = isolate_scoreboard_color(frame, debug=False)
360
+ res = yolo().predict(frame, conf=YOLO_CONF, iou=YOLO_IOU, verbose=False)
361
+
362
+ if len(res) and getattr(res[0], "boxes", None) is not None and len(res[0].boxes) > 0:
363
+ boxes = res[0].boxes.xyxy.cpu().numpy()
364
+ scores = res[0].boxes.conf.cpu().numpy()
365
+ valid = [(box, score) for box, score in zip(boxes, scores) if float(score) >= KEEP_CONF]
366
+ if valid:
367
+ largest, _ = max(valid, key=lambda bs: (bs[0][2]-bs[0][0])*(bs[0][3]-bs[0][1]))
368
+ x1, y1, x2, y2 = [int(round(v)) for v in largest]
369
+ cv2.rectangle(masked, (x1, y1), (x2, y2), (0, 255, 0), 3)
370
+
371
+ cv2.imwrite(out_path, masked)
372
+ print(f"[DEBUG] Saved snapshot → {out_path}")
373
+ return out_path
374
+
375
+ import matplotlib.pyplot as plt
376
+ def save_debug_plot(df: pd.DataFrame, score: pd.Series, events: List[float], base_name="debug_plot"):
377
+ plt.figure(figsize=(12, 5))
378
+ plt.plot(df["timestamp"], score, label="Predicted Score")
379
+ plt.axhline(y=0.5, color="gray", linestyle="--", alpha=0.5)
380
+ first = True
381
+ for ev in events:
382
+ plt.axvline(x=ev, color="red", linestyle="--", label="Detected Event" if first else None)
383
+ first = False
384
+ plt.xlabel("Time (s)")
385
+ plt.ylabel("Score")
386
+ plt.title("AutoGluon Score vs Time")
387
+ plt.legend()
388
+ out_path = DEBUG_DIR / f"{base_name}.png"
389
+ plt.savefig(out_path)
390
+ plt.close()
391
+ print(f"[DEBUG] Saved debug score plot → {out_path}")
392
+
393
+
394
+ # ----------------------------
395
+ # Clip cutting (ffmpeg w/ moviepy fallback)
396
+ # ----------------------------
397
  def _probe_duration(video_path: str) -> float:
398
  try:
399
+ if ffmpeg is None:
400
+ raise RuntimeError("ffmpeg-python not available")
401
  meta = ffmpeg.probe(video_path)
402
  return float(meta["format"]["duration"])
403
  except Exception:
404
  return 0.0
405
 
406
  def cut_clip(video_path: str, start: float, end: float, out_path: str) -> str:
407
+ # Fast path (copy) if ffmpeg available
408
  try:
409
+ cmd = ["ffmpeg", "-y", "-ss", str(max(0, start)), "-to", str(max(start, end)),
410
+ "-i", video_path, "-c", "copy", out_path]
411
+ sp = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
412
+ if sp.returncode == 0 and os.path.exists(out_path):
413
+ return out_path
414
+ except Exception:
415
+ pass
416
+
417
+ # Fallback: moviepy re-encode
418
  from moviepy.editor import VideoFileClip
419
+ clip = VideoFileClip(video_path).subclip(max(0, start), max(start, end))
420
+ clip.write_videofile(out_path, codec="libx264", audio_codec="aac", verbose=False, logger=None)
421
  return out_path
422
 
423
+ # ----------------------------
424
+ # Orchestrator: detect + cut + debug
425
+ # ----------------------------
426
+ def extract_score_clips(video_path: str, debug: bool = False) -> Tuple[List[Tuple[str, str]], str]:
427
+ print("[INFO] Running full detection pipeline...")
428
+ df, fps = extract_feature_timeseries(video_path, frame_skip=FRAME_SKIP, debug=debug)
429
+ if df.empty:
430
+ print("[WARN] Empty dataframe — no frames processed.")
431
+ return [], "No frames processed."
432
+
433
+ print("[INFO] Feature extraction done. Running predictor...")
434
+ score = predict_scores(df)
435
+
436
+ # Bail early if the model produced no signal at all
437
+ if score.max() <= 1e-6:
438
+ print("[WARN] Flat scores from predictor (possible YOLO miss or feature mismatch).")
439
+ return [], "⚠️ No scoreboard detected or illumination scores flat. Please check video or model."
440
+
441
+ print("[INFO] Picking events from predictor scores...")
442
+ events = pick_events(df, score, fps)
443
+ print(f"[INFO] Picked {len(events)} event(s): {events}")
444
+
445
+ if not events:
446
+ topk = np.argsort(score.values)[-5:][::-1]
447
+ dbg = [(float(df.iloc[i]['timestamp']), float(score.iloc[i])) for i in topk]
448
+ print(f"[DEBUG] Top-5 peaks (ts,score): {dbg}")
449
+ return [], "⚠️ No touches confidently detected in this video."
450
+
451
+ duration = _probe_duration(video_path)
452
+ if duration <= 0:
453
+ duration = float(df["timestamp"].max() + CLIP_PAD_S + 0.5)
454
+
455
+ clips = []
456
+ snapshots = []
457
+ base = os.path.splitext(os.path.basename(video_path))[0]
458
+ for i, t in enumerate(events):
459
+ s = max(0.0, t - CLIP_PAD_S)
460
+ e = min(duration, t + CLIP_PAD_S)
461
+ clip_path = os.path.join(tempfile.gettempdir(), f"{base}_score_{i+1:02d}.mp4")
462
+ img_path = os.path.join(tempfile.gettempdir(), f"{base}_score_{i+1:02d}.jpg")
463
+ cut_clip(video_path, s, e, clip_path)
464
+ save_event_snapshot(video_path, t, img_path, fps)
465
+ label = f"Touch {i+1} @ {t:.2f}s"
466
+ clips.append((clip_path, label))
467
+ snapshots.append(img_path)
468
+
469
+ if debug:
470
+ debug_csv = DEBUG_DIR / f"scores_{base}.csv"
471
+ pd.DataFrame({"timestamp": df["timestamp"], "score": score}).to_csv(debug_csv, index=False)
472
+ print(f"[DEBUG] Saved score debug CSV → {debug_csv}")
473
+ save_debug_plot(df, score, events, base_name=base)
474
+ print(f"[DEBUG] Saved debug frames in {DEBUG_DIR}/")
475
+
476
+ return clips, f"✅ Detected {len(clips)} event(s). Snapshots saved to temp."
477
+
478
+ import time
479
+
480
+ def looping_progress():
481
+ """
482
+ Infinite generator that loops the fencer animation from 0 → 100%.
483
+ Yields progress bar HTML until stopped by the pipeline finishing.
484
+ """
485
+ while True:
486
+ for i in range(101):
487
+ bar = _make_progress_bar(i)
488
+ yield gr.update(value=bar, visible=True)
489
+ time.sleep(0.05) # controls speed of march (~5s per loop)
490
+
491
+ # ----------------------------
492
+ # Gradio UI
493
+ # ----------------------------
494
  CSS = """
495
  .gradio-container {max-width: 900px; margin: auto;}
496
  .header {text-align: center; margin-bottom: 20px;}
497
  .full-width {width: 100% !important;}
498
+ .progress-bar {
499
+ width: 100%;
500
+ height: 30px;
501
+ background-color: #e0e0e0;
502
+ border-radius: 15px;
503
+ margin: 15px 0;
504
+ position: relative;
505
+ overflow: hidden;
506
+ }
507
+ .progress-fill {
508
+ height: 100%;
509
+ background-color: #4CAF50;
510
+ border-radius: 15px;
511
+ text-align: center;
512
+ line-height: 30px;
513
+ color: white;
514
+ font-weight: bold;
515
+ transition: width 0.3s;
516
+ }
517
+ .fencer {
518
+ position: absolute;
519
+ top: -5px;
520
+ font-size: 24px;
521
+ transition: left 0.3s;
522
+ transform: scaleX(-1); /* flip to face right */
523
+ }
524
  """
525
 
526
+ def _make_progress_bar(percent: int, final_text: str = None):
527
+ text = f"{percent}%" if not final_text else final_text
528
  return f"""
529
  <div class="progress-bar">
530
+ <div id="progress-fill" class="progress-fill" style="width:{percent}%">{text}</div>
531
  <div id="fencer" class="fencer" style="left:{percent}%">🤺</div>
532
  </div>
533
  """
 
537
  yield [], "Please upload a video file.", gr.update(visible=False)
538
  return
539
 
540
+ # Step 1: Extract frames + features
541
+ yield [], "🔄 Extracting frames...", _make_progress_bar(20)
542
+ df, fps = extract_feature_timeseries(video_file, frame_skip=FRAME_SKIP, debug=False)
543
+ if df.empty:
544
+ yield [], "❌ No frames processed!", _make_progress_bar(100, "No Frames ❌")
545
+ return
546
+
547
+ # Step 2–4: Predict & pick events via the single orchestrator
548
+ yield [], "🔄 Scoring & detecting touches...", _make_progress_bar(80)
549
+ clips, status_msg = extract_score_clips(video_file, debug=True)
550
+
551
+ # Step 5: Done (and cutting already handled in orchestrator)
552
+ final_bar = _make_progress_bar(100, f"Detected {len(clips)} Touches ⚡" if clips else "No Touches")
553
+ yield clips, status_msg, final_bar
 
 
 
 
554
 
555
  with gr.Blocks(css=CSS, title="Fencing Scoreboard Detector") as demo:
556
  with gr.Row(elem_classes="header"):
557
+ gr.Markdown(
558
+ "## 🤺 Fencing Score Detector\n"
559
+ "Upload a fencing bout video. We’ll detect scoreboard lights (YOLO + AutoGluon), "
560
+ "and return 4-second highlight clips around each scoring event."
561
+ )
562
+
563
  in_video = gr.Video(label="Upload Bout Video", elem_classes="full-width", height=400)
564
  run_btn = gr.Button("⚡ Detect Touches", elem_classes="full-width")
 
 
 
 
 
 
 
 
565
 
566
+ progress_html = gr.HTML(value="", label="Processing Progress", visible=False)
567
+ status = gr.Markdown("Ready.")
568
+ gallery = gr.Gallery(
569
+ label="Detected Clips",
570
+ columns=1,
571
+ height=400,
572
+ preview=True,
573
+ allow_preview=True,
574
+ show_download_button=True,
575
+ visible=False
576
+ )
577
 
578
+ def wrapped_run(video_file):
579
+ if not video_file:
580
+ yield gr.update(value=[], visible=False), "Please upload a video file.", gr.update(value=_make_progress_bar(0), visible=False)
581
+ return
582
+
583
+ # Start looping animation
584
+ progress_iter = looping_progress()
585
+
586
+ # Run pipeline in background, but yield progress until it finishes
587
+ import threading
588
+
589
+ result = {}
590
+
591
+ def run_pipeline():
592
+ clips, status_msg = extract_score_clips(video_file, debug=False)
593
+ result["clips"] = clips
594
+ result["status"] = status_msg
595
+
596
+ t = threading.Thread(target=run_pipeline)
597
+ t.start()
598
+
599
+ while t.is_alive():
600
+ yield gr.update(value=[], visible=False), "Processing...", next(progress_iter)
601
+
602
+ # When pipeline is done → final bar at 100% + output
603
+ clips, status_msg = result["clips"], result["status"]
604
+ final_bar = _make_progress_bar(100, "✅ Done")
605
+ yield gr.update(value=clips, visible=True), status_msg, final_bar
606
+
607
+ run_btn.click(
608
+ fn=wrapped_run,
609
+ inputs=in_video,
610
+ outputs=[gallery, status, progress_html],
611
+ )
612
+
613
+ if __name__ == "__main__":
614
+ demo.launch(debug=True)