import gradio as gr import os import torch import cv2 import numpy as np import random from PIL import Image from utils import points_to_tensor from utils import visualize_tracking import mediapy as media # ── Colormap (matches your viz_utils.get_colors logic) ─────────────────────── def get_colors(n): """Generate n random but unique colors in RGB 0-255.""" random.seed(42) # remove this line if you want different colors each run # Spread hues evenly across 0-179 (HSV in OpenCV), then shuffle hues = list(range(0, 180, max(1, 180 // n)))[:n] random.shuffle(hues) colors = [] for hue in hues: # Randomize saturation and value slightly for more visual variety sat = random.randint(180, 255) val = random.randint(180, 255) hsv = np.uint8([[[hue, sat, val]]]) rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)[0][0] colors.append(tuple(int(c) for c in rgb)) return colors N_POINTS = 100 COLORMAP = get_colors(N_POINTS) select_points = [] # will hold np.array([x, y]) entries # ── Video helpers ───────────────────────────────────────────────────────────── def get_frame(video_path: str, frame_idx: int) -> np.ndarray: """Extract a single frame from video by index.""" cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() cap.release() if not ret: raise ValueError(f"Could not read frame {frame_idx}") return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) def get_total_frames(video_path: str) -> int: cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return total # ── Draw points on frame ────────────────────────────────────────────────────── def draw_points(frame: np.ndarray, points: list) -> np.ndarray: """Draw colored circle markers on frame for each selected point.""" out = frame.copy() for i, pt in enumerate(points): color = COLORMAP[i % N_POINTS] # RGB tuple bgr = (color[2], color[1], color[0]) # cv2 uses BGR cv2.circle(out, (pt[0], pt[1]), radius=6, color=bgr, thickness=-1) cv2.circle(out, (pt[0], pt[1]), radius=6, color=(255, 255, 255), thickness=2) # white border cv2.putText(out, str(i + 1), (pt[0] + 10, pt[1] - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) return out _SAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_samples") # JS injected into gr.Blocks — controls download availability on video players _DOWNLOAD_CTRL_JS = """ (function () { const EXAMPLE_IDS = ['video_upload_player', 'out_video_player']; const USER_IDS = ['out_video_player']; function applyNoDownload(ids) { ids.forEach(function (id) { var el = document.getElementById(id); if (!el) return; el.querySelectorAll('video').forEach(function (v) { v.setAttribute('controlsList', 'nodownload'); v.oncontextmenu = function (e) { e.preventDefault(); }; }); el.querySelectorAll('a').forEach(function (a) { a.style.cssText = 'display:none!important;pointer-events:none!important'; }); el.querySelectorAll('button').forEach(function (btn) { var lbl = (btn.getAttribute('aria-label') || btn.getAttribute('title') || '').toLowerCase(); if (lbl.includes('download') || lbl.includes('save')) { btn.style.cssText = 'display:none!important;pointer-events:none!important'; } }); }); } function clearNoDownload(ids) { ids.forEach(function (id) { var el = document.getElementById(id); if (!el) return; el.querySelectorAll('video').forEach(function (v) { v.removeAttribute('controlsList'); v.oncontextmenu = null; }); el.querySelectorAll('a').forEach(function (a) { a.style.cssText = ''; }); el.querySelectorAll('button').forEach(function (btn) { btn.style.cssText = ''; }); }); } window._isExampleMode = false; function applyCurrentMode() { if (window._isExampleMode) applyNoDownload(EXAMPLE_IDS); else clearNoDownload(USER_IDS); } /* Watch both containers for DOM changes (e.g. when video src updates) */ EXAMPLE_IDS.concat(['out_video_player']).forEach(function (id) { (function tryObserve() { var el = document.getElementById(id); if (!el) { setTimeout(tryObserve, 400); return; } new MutationObserver(applyCurrentMode) .observe(el, { childList: true, subtree: true }); })(); }); /* Intercept value setter on hidden textbox to receive mode signal from Python */ function hookTrigger() { var container = document.querySelector('#download_ctrl textarea'); if (!container) { setTimeout(hookTrigger, 300); return; } var desc = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value'); Object.defineProperty(container, 'value', { get: function () { return desc.get.call(this); }, set: function (v) { desc.set.call(this, v); window._isExampleMode = (v === '1'); applyCurrentMode(); }, configurable: true, }); } setTimeout(hookTrigger, 500); })(); """ # label → (path, is_ood) EXAMPLE_VIDEOS = { "A4C": (os.path.join(_SAMPLES_DIR, "input1.mp4"), False), "A4C (OOD)": (os.path.join(_SAMPLES_DIR, "input2.mp4"), True), "RV (OOD)": (os.path.join(_SAMPLES_DIR, "input3_RV.mp4"), True), "PSAX (OOD)": (os.path.join(_SAMPLES_DIR, "psax_video_crop.mp4"), True), } def _get_thumbnail(video_path: str) -> np.ndarray | None: """Extract a single frame near the middle of the video for use as a thumbnail.""" try: cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, int(total * 0.4))) ret, frame = cap.read() cap.release() if ret: return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) except Exception: pass return None THUMBNAILS = {label: _get_thumbnail(path) for label, (path, _) in EXAMPLE_VIDEOS.items()} # ── Gradio event handlers ───────────────────────────────────────────────────── def on_video_upload(video_path): """Called when video is uploaded — jump to 72% frame.""" if video_path is None: # return None, gr.update(value=0, maximum=0, interactive=False), "No video loaded.", [] return None total = get_total_frames(video_path) idx_72 = int(total * 0.72) frame = get_frame(video_path, idx_72) #drawn = draw_points(frame, select_points) frame_display_update = gr.update( value=frame, interactive=True, # enables click events via gr.SelectData ) slider_update = gr.update( value=idx_72, minimum=0, maximum=total - 1, step=1, interactive=True, label=f"Frame selector (total: {total} frames)" ) select_points.clear() # clear any existing points when new video is loaded status = f"📹 Loaded — {total} frames | 🎞️ Showing frame {idx_72} (72%)" # last value resets the download-control style (user upload → downloads allowed) return frame_display_update, slider_update, status, video_path, "" def load_example(video_path): """Load an example video, reset all output/selection fields, and hide downloads.""" frame_upd, slider_upd, status, state, _ = on_video_upload(video_path) return ( gr.update(value=video_path), # video_upload frame_upd, # frame_display slider_upd, # frame_slider status, # status_text state, # video_state gr.update(value=None), # out_video — clear previous result gr.update(value="No points selected yet."), # points_display "1", # download_ctrl — disable downloads ) def on_slider_release(frame_idx, video_path, points_display): """Called when slider is released — show new frame, keep existing points.""" if video_path is None: return None, "No video loaded.", points_display frame = get_frame(video_path, int(frame_idx)) select_points.clear() # clear any existing points when new video is loaded #print(f"Selected point: {select_points}") points_display = gr.update( value="No points selected yet.", label="📋 Selected Points", lines=5, interactive=False, ) #drawn = draw_points(frame, select_points) status = f"🎞️ Showing Frame {int(frame_idx)} ({int(frame_idx) / get_total_frames(video_path) * 100:.1f}%) | {len(select_points)} point(s) selected" return frame, status, points_display def on_point_select(frame_idx, video_path, evt: gr.SelectData): """Called when user clicks on the image — add point, redraw.""" if video_path is None: return None, "Upload a video first.", format_points() if len(select_points) >= N_POINTS: status = f"⚠️ Max {N_POINTS} points reached." frame = get_frame(video_path, int(frame_idx)) return draw_points(frame, select_points), status, format_points() x, y = int(evt.index[0]), int(evt.index[1]) select_points.append(np.array([x, y])) #print(f"Selected point: {select_points}") frame = get_frame(video_path, int(frame_idx)) drawn = draw_points(frame, select_points) status = f"✅ Point {len(select_points)} added at ({x}, {y}) | Frame {int(frame_idx)}" return drawn, status, format_points() def on_clear_points(frame_idx, video_path): """Clear all selected points.""" select_points.clear() if video_path is None: return None, "Points cleared.", format_points() frame = get_frame(video_path, int(frame_idx)) return draw_points(frame, select_points), "🗑️ All points cleared.", format_points() def on_undo_point(frame_idx, video_path): """Remove last selected point.""" if select_points: removed = select_points.pop() msg = f"↩️ Removed point at ({removed[0]}, {removed[1]})" else: msg = "No points to undo." if video_path is None: return None, msg, format_points() frame = get_frame(video_path, int(frame_idx)) return draw_points(frame, select_points), msg, format_points() def format_points(): """Format select_points for display in the textbox.""" if not select_points: return "No points selected yet." lines = [f" [{i+1}] x={p[0]}, y={p[1]}" for i, p in enumerate(select_points)] return "select_points:\n" + "\n".join(lines) def track(video_path, frame_idx, out_video, target_size=(256, 256)): """Placeholder for tracking function — replace with your actual tracking logic.""" if video_path is None: status = f"⚠️ No video loaded. Cannot run the tracker." return status if len(select_points) < 1: status = f"⚠️ No points selected. Please select at least one point to track." return status tracker, device = load_model("echotracker_cvamd_ts.pt") cap = cv2.VideoCapture(video_path) W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frames = [] paint_frames = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break paint_frames.append(frame) frame = cv2.resize(frame, target_size) frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY))) cap.release() paint_frames = np.array(paint_frames) frames = torch.from_numpy(np.array(frames)).unsqueeze(0).unsqueeze(2).float().to(device) # shape: [B, T, H, W] q_points = points_to_tensor(select_points, frame_idx, H, W, 256).to(device) # shape: [1, N, 3] #print(f"✅ Loaded video frames: {frames.shape} {paint_frames.shape}") # print(f"Selected points: {q_points.shape}") with torch.no_grad(): output = tracker(frames, q_points) trajs_e = output[-1].cpu().permute(0, 2, 1, 3) q_points[...,1] /= 256 - 1 q_points[...,2] /= 256 - 1 trajs_e[...,0] /= 256 - 1 trajs_e[...,1] /= 256 - 1 #print(f"Tracker output trajectories: {trajs_e.shape}") paint_frames = visualize_tracking( frames=paint_frames, points=trajs_e.squeeze().cpu().numpy(), vis_color='random', thickness=5, track_length=30, ) # Save or display paint_frames as needed (e.g., save as video or show in Gradio) out_vid = "outputs/output.mp4" os.makedirs("outputs", exist_ok=True) media.write_video(out_vid, paint_frames, fps=25) status = f"✅ Tracking completed! The output is visualized below." out_video = gr.update(value=out_vid, autoplay=True, loop=True) return out_video, status def load_model(model_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"): """Load a torchscript model Args: model_path (str): path to the torchscript weights device (str, optional): Defaults to "cuda" if torch.cuda.is_available() else "cpu". Returns: model: the loaded torchscript model """ model = torch.jit.load(model_path, map_location=device).eval() #print(f"✅ TorchScript model loaded on {device}") return model, device # ── Gradio UI ───────────────────────────────────────────────────────────────── HEADER = """
Advancing Myocardial Point Tracking in Echocardiography
MICCAI 2024 · Azad, Chernyshov, Nyberg, Tveten, Lovstakken, Dalen, Grenne, Østvik
Model weights from: Taming Modern Point Tracking for Speckle Tracking Echocardiography via Impartial Motion · ICCV 2025 Workshop · Azad, Nyberg, Dalen, Grenne, Lovstakken, Østvik
np.array([x, y]) and passed to the tracker."
)
# ── Right column: frame viewer + controls ─────────────────────────────
with gr.Column(scale=2, min_width=400):
frame_display = gr.Image(
label="Query Frame — click to place tracking points",
interactive=True,
type="numpy",
sources=[],
)
frame_slider = gr.Slider(
minimum=0, maximum=100, value=0, step=1,
label="Frame",
interactive=False,
)
status_text = gr.Textbox(
label="Status", lines=1, interactive=False, show_label=False,
placeholder="Status messages will appear here…",
)
with gr.Row():
undo_btn = gr.Button("↩ Undo", scale=1)
clear_btn = gr.Button("🗑 Clear All", variant="stop", scale=1)
gr.Markdown("---")
gr.Markdown("### Step 2 — Run Tracker & View Output")
with gr.Row():
with gr.Column(scale=1):
run_btn = gr.Button("▶ Run EchoTracker", variant="primary", size="lg")
with gr.Column(scale=2):
out_video = gr.Video(
label="Tracking Output",
sources=[],
include_audio=False,
interactive=False,
autoplay=True,
loop=True,
elem_id="out_video_player",
)
gr.Markdown("---")
gr.Markdown(
"**Or try an example clip** "
"— OOD = out-of-distribution (different scanner / view not seen during training)"
)
gr.Markdown(
"> ⚠️ **Example videos are provided for demonstration purposes only. "
"They should not be downloaded, reproduced, or used for any purpose outside this demo.**"
)
ex_btns = []
with gr.Row(equal_height=True):
for label, (path, is_ood) in EXAMPLE_VIDEOS.items():
with gr.Column(min_width=120):
gr.Image(
value=THUMBNAILS[label],
show_label=False,
interactive=False,
height=110,
container=False,
)
btn_label = f"{label} 🔶" if is_ood else label
ex_btns.append(gr.Button(btn_label, size="sm"))
# ── Like request ──────────────────────────────────────────────────────────
gr.Markdown(
"