import gradio as gr import os import torch import cv2 import numpy as np import random from PIL import Image from utils import points_to_tensor from utils import visualize_tracking import mediapy as media # ── Colormap (matches your viz_utils.get_colors logic) ─────────────────────── def get_colors(n): """Generate n random but unique colors in RGB 0-255.""" random.seed(42) # remove this line if you want different colors each run # Spread hues evenly across 0-179 (HSV in OpenCV), then shuffle hues = list(range(0, 180, max(1, 180 // n)))[:n] random.shuffle(hues) colors = [] for hue in hues: # Randomize saturation and value slightly for more visual variety sat = random.randint(180, 255) val = random.randint(180, 255) hsv = np.uint8([[[hue, sat, val]]]) rgb = cv2.cvtColor(hsv, cv2.COLOR_HSV2RGB)[0][0] colors.append(tuple(int(c) for c in rgb)) return colors N_POINTS = 100 COLORMAP = get_colors(N_POINTS) select_points = [] # will hold np.array([x, y]) entries # ── Video helpers ───────────────────────────────────────────────────────────── def get_frame(video_path: str, frame_idx: int) -> np.ndarray: """Extract a single frame from video by index.""" cap = cv2.VideoCapture(video_path) cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) ret, frame = cap.read() cap.release() if not ret: raise ValueError(f"Could not read frame {frame_idx}") return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) def get_total_frames(video_path: str) -> int: cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() return total # ── Draw points on frame ────────────────────────────────────────────────────── def draw_points(frame: np.ndarray, points: list) -> np.ndarray: """Draw colored circle markers on frame for each selected point.""" out = frame.copy() for i, pt in enumerate(points): color = COLORMAP[i % N_POINTS] # RGB tuple bgr = (color[2], color[1], color[0]) # cv2 uses BGR cv2.circle(out, (pt[0], pt[1]), radius=6, color=bgr, thickness=-1) cv2.circle(out, (pt[0], pt[1]), radius=6, color=(255, 255, 255), thickness=2) # white border cv2.putText(out, str(i + 1), (pt[0] + 10, pt[1] - 6), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) return out _SAMPLES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "example_samples") # JS injected into gr.Blocks — controls download availability on video players _DOWNLOAD_CTRL_JS = """ (function () { const EXAMPLE_IDS = ['video_upload_player', 'out_video_player']; const USER_IDS = ['out_video_player']; function applyNoDownload(ids) { ids.forEach(function (id) { var el = document.getElementById(id); if (!el) return; el.querySelectorAll('video').forEach(function (v) { v.setAttribute('controlsList', 'nodownload'); v.oncontextmenu = function (e) { e.preventDefault(); }; }); el.querySelectorAll('a').forEach(function (a) { a.style.cssText = 'display:none!important;pointer-events:none!important'; }); el.querySelectorAll('button').forEach(function (btn) { var lbl = (btn.getAttribute('aria-label') || btn.getAttribute('title') || '').toLowerCase(); if (lbl.includes('download') || lbl.includes('save')) { btn.style.cssText = 'display:none!important;pointer-events:none!important'; } }); }); } function clearNoDownload(ids) { ids.forEach(function (id) { var el = document.getElementById(id); if (!el) return; el.querySelectorAll('video').forEach(function (v) { v.removeAttribute('controlsList'); v.oncontextmenu = null; }); el.querySelectorAll('a').forEach(function (a) { a.style.cssText = ''; }); el.querySelectorAll('button').forEach(function (btn) { btn.style.cssText = ''; }); }); } window._isExampleMode = false; function applyCurrentMode() { if (window._isExampleMode) applyNoDownload(EXAMPLE_IDS); else clearNoDownload(USER_IDS); } /* Watch both containers for DOM changes (e.g. when video src updates) */ EXAMPLE_IDS.concat(['out_video_player']).forEach(function (id) { (function tryObserve() { var el = document.getElementById(id); if (!el) { setTimeout(tryObserve, 400); return; } new MutationObserver(applyCurrentMode) .observe(el, { childList: true, subtree: true }); })(); }); /* Intercept value setter on hidden textbox to receive mode signal from Python */ function hookTrigger() { var container = document.querySelector('#download_ctrl textarea'); if (!container) { setTimeout(hookTrigger, 300); return; } var desc = Object.getOwnPropertyDescriptor(HTMLTextAreaElement.prototype, 'value'); Object.defineProperty(container, 'value', { get: function () { return desc.get.call(this); }, set: function (v) { desc.set.call(this, v); window._isExampleMode = (v === '1'); applyCurrentMode(); }, configurable: true, }); } setTimeout(hookTrigger, 500); })(); """ # label → (path, is_ood) EXAMPLE_VIDEOS = { "A4C": (os.path.join(_SAMPLES_DIR, "input1.mp4"), False), "A4C (OOD)": (os.path.join(_SAMPLES_DIR, "input2.mp4"), True), "RV (OOD)": (os.path.join(_SAMPLES_DIR, "input3_RV.mp4"), True), "PSAX (OOD)": (os.path.join(_SAMPLES_DIR, "psax_video_crop.mp4"), True), } def _get_thumbnail(video_path: str) -> np.ndarray | None: """Extract a single frame near the middle of the video for use as a thumbnail.""" try: cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, int(total * 0.4))) ret, frame = cap.read() cap.release() if ret: return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) except Exception: pass return None THUMBNAILS = {label: _get_thumbnail(path) for label, (path, _) in EXAMPLE_VIDEOS.items()} # ── Gradio event handlers ───────────────────────────────────────────────────── def on_video_upload(video_path): """Called when video is uploaded — jump to 72% frame.""" if video_path is None: # return None, gr.update(value=0, maximum=0, interactive=False), "No video loaded.", [] return None total = get_total_frames(video_path) idx_72 = int(total * 0.72) frame = get_frame(video_path, idx_72) #drawn = draw_points(frame, select_points) frame_display_update = gr.update( value=frame, interactive=True, # enables click events via gr.SelectData ) slider_update = gr.update( value=idx_72, minimum=0, maximum=total - 1, step=1, interactive=True, label=f"Frame selector (total: {total} frames)" ) select_points.clear() # clear any existing points when new video is loaded status = f"📹 Loaded — {total} frames | 🎞️ Showing frame {idx_72} (72%)" # last value resets the download-control style (user upload → downloads allowed) return frame_display_update, slider_update, status, video_path, "" def load_example(video_path): """Load an example video, reset all output/selection fields, and hide downloads.""" frame_upd, slider_upd, status, state, _ = on_video_upload(video_path) return ( gr.update(value=video_path), # video_upload frame_upd, # frame_display slider_upd, # frame_slider status, # status_text state, # video_state gr.update(value=None), # out_video — clear previous result gr.update(value="No points selected yet."), # points_display "1", # download_ctrl — disable downloads ) def on_slider_release(frame_idx, video_path, points_display): """Called when slider is released — show new frame, keep existing points.""" if video_path is None: return None, "No video loaded.", points_display frame = get_frame(video_path, int(frame_idx)) select_points.clear() # clear any existing points when new video is loaded #print(f"Selected point: {select_points}") points_display = gr.update( value="No points selected yet.", label="📋 Selected Points", lines=5, interactive=False, ) #drawn = draw_points(frame, select_points) status = f"🎞️ Showing Frame {int(frame_idx)} ({int(frame_idx) / get_total_frames(video_path) * 100:.1f}%) | {len(select_points)} point(s) selected" return frame, status, points_display def on_point_select(frame_idx, video_path, evt: gr.SelectData): """Called when user clicks on the image — add point, redraw.""" if video_path is None: return None, "Upload a video first.", format_points() if len(select_points) >= N_POINTS: status = f"⚠️ Max {N_POINTS} points reached." frame = get_frame(video_path, int(frame_idx)) return draw_points(frame, select_points), status, format_points() x, y = int(evt.index[0]), int(evt.index[1]) select_points.append(np.array([x, y])) #print(f"Selected point: {select_points}") frame = get_frame(video_path, int(frame_idx)) drawn = draw_points(frame, select_points) status = f"✅ Point {len(select_points)} added at ({x}, {y}) | Frame {int(frame_idx)}" return drawn, status, format_points() def on_clear_points(frame_idx, video_path): """Clear all selected points.""" select_points.clear() if video_path is None: return None, "Points cleared.", format_points() frame = get_frame(video_path, int(frame_idx)) return draw_points(frame, select_points), "🗑️ All points cleared.", format_points() def on_undo_point(frame_idx, video_path): """Remove last selected point.""" if select_points: removed = select_points.pop() msg = f"↩️ Removed point at ({removed[0]}, {removed[1]})" else: msg = "No points to undo." if video_path is None: return None, msg, format_points() frame = get_frame(video_path, int(frame_idx)) return draw_points(frame, select_points), msg, format_points() def format_points(): """Format select_points for display in the textbox.""" if not select_points: return "No points selected yet." lines = [f" [{i+1}] x={p[0]}, y={p[1]}" for i, p in enumerate(select_points)] return "select_points:\n" + "\n".join(lines) def track(video_path, frame_idx, out_video, target_size=(256, 256)): """Placeholder for tracking function — replace with your actual tracking logic.""" if video_path is None: status = f"⚠️ No video loaded. Cannot run the tracker." return status if len(select_points) < 1: status = f"⚠️ No points selected. Please select at least one point to track." return status tracker, device = load_model("echotracker_cvamd_ts.pt") cap = cv2.VideoCapture(video_path) W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) frames = [] paint_frames = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break paint_frames.append(frame) frame = cv2.resize(frame, target_size) frames.append(Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY))) cap.release() paint_frames = np.array(paint_frames) frames = torch.from_numpy(np.array(frames)).unsqueeze(0).unsqueeze(2).float().to(device) # shape: [B, T, H, W] q_points = points_to_tensor(select_points, frame_idx, H, W, 256).to(device) # shape: [1, N, 3] #print(f"✅ Loaded video frames: {frames.shape} {paint_frames.shape}") # print(f"Selected points: {q_points.shape}") with torch.no_grad(): output = tracker(frames, q_points) trajs_e = output[-1].cpu().permute(0, 2, 1, 3) q_points[...,1] /= 256 - 1 q_points[...,2] /= 256 - 1 trajs_e[...,0] /= 256 - 1 trajs_e[...,1] /= 256 - 1 #print(f"Tracker output trajectories: {trajs_e.shape}") paint_frames = visualize_tracking( frames=paint_frames, points=trajs_e.squeeze().cpu().numpy(), vis_color='random', thickness=5, track_length=30, ) # Save or display paint_frames as needed (e.g., save as video or show in Gradio) out_vid = "outputs/output.mp4" os.makedirs("outputs", exist_ok=True) media.write_video(out_vid, paint_frames, fps=25) status = f"✅ Tracking completed! The output is visualized below." out_video = gr.update(value=out_vid, autoplay=True, loop=True) return out_video, status def load_model(model_path: str, device: str = "cuda" if torch.cuda.is_available() else "cpu"): """Load a torchscript model Args: model_path (str): path to the torchscript weights device (str, optional): Defaults to "cuda" if torch.cuda.is_available() else "cpu". Returns: model: the loaded torchscript model """ model = torch.jit.load(model_path, map_location=device).eval() #print(f"✅ TorchScript model loaded on {device}") return model, device # ── Gradio UI ───────────────────────────────────────────────────────────────── HEADER = """

🫀 EchoTracker

Advancing Myocardial Point Tracking in Echocardiography

MICCAI 2024  ·  Azad, Chernyshov, Nyberg, Tveten, Lovstakken, Dalen, Grenne, Østvik

Model weights from: Taming Modern Point Tracking for Speckle Tracking Echocardiography via Impartial Motion  ·  ICCV 2025 Workshop  ·  Azad, Nyberg, Dalen, Grenne, Lovstakken, Østvik

📄 Paper (MICCAI 2024) 📄 Paper (ICCV 2025 Workshop) 📝 ArXiv (EchoTracker) 📝 ArXiv (Taming) 💻 GitHub 🌐 Project Page
""" CITATION_MD = """ If you use EchoTracker or the model weights in this demo, please cite both papers: ```bibtex @InProceedings{azad2024echotracker, author = {Azad, Md Abulkalam and Chernyshov, Artem and Nyberg, John and Tveten, Ingrid and Lovstakken, Lasse and Dalen, H{\\aa}vard and Grenne, Bj{\\o}rnar and {\\O}stvik, Andreas}, title = {EchoTracker: Advancing Myocardial Point Tracking in Echocardiography}, booktitle = {Medical Image Computing and Computer Assisted Intervention -- MICCAI 2024}, year = {2024}, publisher = {Springer Nature Switzerland}, doi = {10.1007/978-3-031-72083-3_60} } @InProceedings{Azad_2025_ICCV, author = {Azad, Md Abulkalam and Nyberg, John and Dalen, H{\\aa}vard and Grenne, Bj{\\o}rnar and Lovstakken, Lasse and {\\O}stvik, Andreas}, title = {Taming Modern Point Tracking for Speckle Tracking Echocardiography via Impartial Motion}, booktitle = {Proceedings of the IEEE/CVF International Conference on Computer Vision (ICCV) Workshops}, month = {October}, year = {2025}, pages = {1115--1124} } ``` """ with gr.Blocks(title="EchoTracker", theme=gr.themes.Soft(), css=""" .gr-button { font-weight: 600; } :root { --echo-muted: #444; --echo-subtle: #666; } .dark { --echo-muted: #c0c0c0; --echo-subtle: #a8a8a8; } """, js=_DOWNLOAD_CTRL_JS) as demo: gr.HTML(HEADER) gr.Markdown("---") # ── Instructions ────────────────────────────────────────────────────────── with gr.Accordion("ℹ️ How to use", open=False): gr.Markdown(""" 1. **Load a video** — upload your own echocardiography clip, or click one of the provided example videos below the panel. 2. **Navigate** to the desired query frame using the frame slider. 3. **Click** on the frame image to place tracking points on cardiac tissue surfaces (e.g. LV/RV walls, myocardium). 4. Use **Undo** or **Clear All** to adjust your selection. 5. Press **▶ Run EchoTracker** to generate tracked trajectories for all selected points. > **Tip:** Select points at the *end-diastolic* frame for best results. Up to 100 points are supported. > Example clips cover apical 4-chamber (A4C), right-ventricle (RV), and parasternal short-axis (PSAX) views. > Clips marked **OOD** (🔶) are out-of-distribution — different scanner or view not seen during training, showcasing EchoTracker's generalisation ability. """) # hidden state video_state = gr.State(value=None) # injects/removes CSS that hides download buttons on example videos download_ctrl = gr.Textbox(value="0", visible=False, elem_id="download_ctrl") gr.Markdown("### Step 1 — Upload & Select Query Points") gr.Markdown( "Upload your own echocardiography video, or click one of the **example clips** below to get started." ) with gr.Row(equal_height=False): # ── Left column: input + points ─────────────────────────────────────── with gr.Column(scale=1, min_width=300): video_upload = gr.Video( label="Echocardiography Video — upload yours or use an example below", sources="upload", include_audio=False, autoplay=True, loop=True, elem_id="video_upload_player", ) points_display = gr.Textbox( value="No points selected yet.", label="📋 Selected Query Points", lines=5, max_lines=5, interactive=False, ) gr.Markdown( "Coordinates are stored as " "np.array([x, y]) and passed to the tracker." ) # ── Right column: frame viewer + controls ───────────────────────────── with gr.Column(scale=2, min_width=400): frame_display = gr.Image( label="Query Frame — click to place tracking points", interactive=True, type="numpy", sources=[], ) frame_slider = gr.Slider( minimum=0, maximum=100, value=0, step=1, label="Frame", interactive=False, ) status_text = gr.Textbox( label="Status", lines=1, interactive=False, show_label=False, placeholder="Status messages will appear here…", ) with gr.Row(): undo_btn = gr.Button("↩ Undo", scale=1) clear_btn = gr.Button("🗑 Clear All", variant="stop", scale=1) gr.Markdown("---") gr.Markdown("### Step 2 — Run Tracker & View Output") with gr.Row(): with gr.Column(scale=1): run_btn = gr.Button("▶ Run EchoTracker", variant="primary", size="lg") with gr.Column(scale=2): out_video = gr.Video( label="Tracking Output", sources=[], include_audio=False, interactive=False, autoplay=True, loop=True, elem_id="out_video_player", ) gr.Markdown("---") gr.Markdown( "**Or try an example clip** " "— OOD = out-of-distribution (different scanner / view not seen during training)" ) gr.Markdown( "> ⚠️ **Example videos are provided for demonstration purposes only. " "They should not be downloaded, reproduced, or used for any purpose outside this demo.**" ) ex_btns = [] with gr.Row(equal_height=True): for label, (path, is_ood) in EXAMPLE_VIDEOS.items(): with gr.Column(min_width=120): gr.Image( value=THUMBNAILS[label], show_label=False, interactive=False, height=110, container=False, ) btn_label = f"{label} 🔶" if is_ood else label ex_btns.append(gr.Button(btn_label, size="sm")) # ── Like request ────────────────────────────────────────────────────────── gr.Markdown( "
" "If you find this demo useful, please click the ❤️ Like button at the top of this Space — " "it helps others discover this work and supports open research in cardiac image analysis." "
" ) # ── Citation ────────────────────────────────────────────────────────────── with gr.Accordion("📝 Citation", open=False): gr.Markdown(CITATION_MD) # ── Wire events ─────────────────────────────────────────────────────────── video_upload.upload( fn=on_video_upload, inputs=[video_upload], outputs=[frame_display, frame_slider, status_text, video_state, download_ctrl] ) frame_slider.release( fn=on_slider_release, inputs=[frame_slider, video_state, points_display], outputs=[frame_display, status_text, points_display] ) frame_display.select( fn=on_point_select, inputs=[frame_slider, video_state], outputs=[frame_display, status_text, points_display] ) undo_btn.click( fn=on_undo_point, inputs=[frame_slider, video_state], outputs=[frame_display, status_text, points_display] ) clear_btn.click( fn=on_clear_points, inputs=[frame_slider, video_state], outputs=[frame_display, status_text, points_display] ) for btn, (path, _) in zip(ex_btns, EXAMPLE_VIDEOS.values()): btn.click( fn=load_example, inputs=gr.State(path), outputs=[video_upload, frame_display, frame_slider, status_text, video_state, out_video, points_display, download_ctrl] ) run_btn.click( fn=track, inputs=[video_state, frame_slider, out_video], outputs=[out_video, status_text] ) demo.launch(share=False)