Spaces:
Runtime error
Runtime error
| """ | |
| Wildlife Behavior Analyzer | |
| Powered by X3D-KABR (WACV 2024) + YOLOv8 | |
| Author: rohansingh0 | HuggingFace Spaces | |
| """ | |
| import os | |
| import re | |
| import io | |
| import cv2 | |
| import time | |
| import tempfile | |
| import zipfile | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| import matplotlib | |
| matplotlib.use("Agg") # headless backend — must be before pyplot import | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from collections import deque, Counter | |
| from huggingface_hub import hf_hub_download | |
| from pytorchvideo.models.x3d import create_x3d | |
| from ultralytics import YOLO | |
| # ───────────────────────────────────────────────────────────── | |
| # Constants | |
| # ───────────────────────────────────────────────────────────── | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| NUM_CLASSES = 8 | |
| NUM_FRAMES = 16 | |
| CROP_SIZE = 300 | |
| MEAN = [0.45, 0.45, 0.45] | |
| STD = [0.225, 0.225, 0.225] | |
| ANIMAL_IDS = set(range(14, 26)) # COCO animal class IDs | |
| LABEL_NAMES = ["Walk", "Trot", "Run", "Graze", | |
| "Browse", "Head Up", "Auto-Groom", "Occluded"] | |
| LABEL_COLORS_HEX = [ | |
| "#2196F3", "#9C27B0", "#F44336", "#4CAF50", | |
| "#009688", "#FF9800", "#E91E63", "#9E9E9E" | |
| ] | |
| LABEL_COLORS_BGR = [ | |
| (243, 150, 33), (176, 39, 156), (54, 67, 244), (80, 175, 76), | |
| (136, 150, 0), (0, 152, 255), (99, 30, 233), (158, 158, 158) | |
| ] | |
| # ───────────────────────────────────────────────────────────── | |
| # Model Loading (runs ONCE at server startup — not per session) | |
| # This is the key fix: Streamlit reloads per session; Gradio does not. | |
| # ───────────────────────────────────────────────────────────── | |
| def remap_key(k): | |
| """Remap SlowFast checkpoint keys → pytorchvideo format.""" | |
| k = re.sub(r"^s1\.pathway0_stem\.conv_xy\b", "blocks.0.conv.conv_t", k) | |
| k = re.sub(r"^s1\.pathway0_stem\.conv\b", "blocks.0.conv.conv_xy", k) | |
| k = re.sub(r"^s1\.pathway0_stem\.bn\b", "blocks.0.norm", k) | |
| def stage(m): | |
| return f"blocks.{int(m.group(1))-1}.res_blocks.{int(m.group(2))}." | |
| k = re.sub(r"^s(\d)\.pathway0_res(\d+)\.", stage, k) | |
| k = re.sub(r"\.branch1\.weight$", ".branch1_conv.weight", k) | |
| k = re.sub(r"\.branch1_bn\.", ".branch1_norm.", k) | |
| k = re.sub(r"\.branch2\.a\.weight$", ".branch2.conv_a.weight", k) | |
| k = re.sub(r"\.branch2\.a_bn\.", ".branch2.norm_a.", k) | |
| k = re.sub(r"\.branch2\.b\.weight$", ".branch2.conv_b.weight", k) | |
| k = re.sub(r"\.branch2\.b_bn\.", ".branch2.norm_b.0.", k) | |
| k = re.sub(r"\.branch2\.se\.fc1\.", ".branch2.norm_b.1.block.0.", k) | |
| k = re.sub(r"\.branch2\.se\.fc2\.", ".branch2.norm_b.1.block.2.", k) | |
| k = re.sub(r"\.branch2\.c\.weight$", ".branch2.conv_c.weight", k) | |
| k = re.sub(r"\.branch2\.c_bn\.", ".branch2.norm_c.", k) | |
| k = re.sub(r"^head\.conv_5\.", "blocks.5.pool.pre_conv.", k) | |
| k = re.sub(r"^head\.conv_5_bn\.", "blocks.5.pool.pre_norm.", k) | |
| k = re.sub(r"^head\.lin_5\.", "blocks.5.pool.post_conv.", k) | |
| k = re.sub(r"^head\.projection\.", "blocks.5.proj.", k) | |
| return k | |
| def _load_models(): | |
| print(f"[Wildlife Analyzer] Loading models on {DEVICE} ...") | |
| # ── X3D-KABR ────────────────────────────────────────────── | |
| ckpt_zip = hf_hub_download( | |
| repo_id = "imageomics/x3d-kabr-kinetics", | |
| filename = "checkpoint_epoch_00075.pyth.zip", | |
| local_dir = "/tmp" | |
| ) | |
| with zipfile.ZipFile(ckpt_zip, "r") as z: | |
| z.extractall("/tmp") | |
| ckpt_path = None | |
| for root, _, files in os.walk("/tmp"): | |
| for f in files: | |
| if f.endswith(".pyth"): | |
| ckpt_path = os.path.join(root, f) | |
| break | |
| model = create_x3d( | |
| input_clip_length=16, input_crop_size=300, | |
| model_num_class=NUM_CLASSES, depth_factor=5.0, | |
| width_factor=2.0, bottleneck_factor=2.25, | |
| dropout_rate=0.5, head_activation=None, | |
| ) | |
| checkpoint = torch.load(ckpt_path, map_location="cpu") | |
| state_dict = checkpoint["model_state"] | |
| model_sd = model.state_dict() | |
| model_keys = set(model_sd.keys()) | |
| remapped = {} | |
| for ck, cv in state_dict.items(): | |
| mk = remap_key(ck) | |
| if mk in model_keys and model_sd[mk].shape == cv.shape: | |
| remapped[mk] = cv | |
| model.load_state_dict(remapped, strict=False) | |
| model = model.to(DEVICE).eval() | |
| # ── YOLOv8 ──────────────────────────────────────────────── | |
| yolo = YOLO("yolov8n.pt") | |
| yolo.to("cpu") | |
| print(f"[Wildlife Analyzer] ✅ Models ready on {DEVICE}") | |
| return model, yolo | |
| # Module-level singleton — loaded once, shared across all requests | |
| MODEL, YOLO_MODEL = _load_models() | |
| # ───────────────────────────────────────────────────────────── | |
| # Inference Helpers (identical logic to original) | |
| # ───────────────────────────────────────────────────────────── | |
| def preprocess_clip(frames_bgr): | |
| out = [] | |
| for f in frames_bgr: | |
| f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) | |
| h, w = f.shape[:2] | |
| s = CROP_SIZE / min(h, w) | |
| f = cv2.resize(f, (int(w * s), int(h * s))) | |
| ch = (f.shape[0] - CROP_SIZE) // 2 | |
| cw = (f.shape[1] - CROP_SIZE) // 2 | |
| f = f[ch:ch + CROP_SIZE, cw:cw + CROP_SIZE] | |
| f = (f.astype(np.float32) / 255. - MEAN) / STD | |
| out.append(f) | |
| clip = np.stack(out).transpose(3, 0, 1, 2) | |
| return torch.tensor(clip, dtype=torch.float32).unsqueeze(0) | |
| def classify_behavior(crop_sequence): | |
| with torch.no_grad(): | |
| x = preprocess_clip(crop_sequence).to(DEVICE) | |
| out = MODEL(x) | |
| probs = torch.sigmoid(out).squeeze().cpu().numpy() | |
| pred = int(np.argmax(probs[:7])) | |
| return pred, LABEL_NAMES[pred], probs | |
| def detect_animals(frame_bgr, species_name): | |
| rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| H, W = frame_bgr.shape[:2] | |
| results = YOLO_MODEL(rgb, verbose=False, device="cpu", conf=0.15, iou=0.4)[0] | |
| animals = [] | |
| for box in results.boxes: | |
| if int(box.cls[0]) not in ANIMAL_IDS: | |
| continue | |
| x1, y1, x2, y2 = map(int, box.xyxy[0]) | |
| x1, y1 = max(0, x1), max(0, y1) | |
| x2, y2 = min(W, x2), min(H, y2) | |
| w_box, h_box = x2 - x1, y2 - y1 | |
| area_ratio = (w_box * h_box) / (W * H) | |
| if area_ratio < 0.005 or area_ratio > 0.80: | |
| continue | |
| if w_box < 30 or h_box < 30: | |
| continue | |
| animals.append((x1, y1, x2, y2, species_name, float(box.conf[0]))) | |
| return animals | |
| def smooth_predictions(timeline, window=5): | |
| smoothed = [] | |
| label_buffer = deque(maxlen=window) | |
| for t, label, conf in timeline: | |
| label_buffer.append(label) | |
| majority = Counter(label_buffer).most_common(1)[0][0] | |
| smoothed.append((t, majority, conf)) | |
| return smoothed | |
| # ───────────────────────────────────────────────────────────── | |
| # Visualization Helpers (identical logic to original) | |
| # ───────────────────────────────────────────────────────────── | |
| def make_time_budget_chart(counts, title): | |
| labels = [l for l in LABEL_NAMES[:7] if counts.get(l, 0) > 0] | |
| sizes = [counts[l] for l in labels] | |
| colors = [LABEL_COLORS_HEX[LABEL_NAMES.index(l)] for l in labels] | |
| fig, ax = plt.subplots(figsize=(6, 6)) | |
| ax.pie(sizes, labels=labels, colors=colors, | |
| autopct="%1.1f%%", startangle=90, | |
| textprops={"fontsize": 11}) | |
| ax.set_title(title, fontsize=13, fontweight="bold", pad=15) | |
| plt.tight_layout() | |
| return fig | |
| def make_timeline_chart(smoothed, video_name): | |
| times = [t for t, _, _ in smoothed] | |
| if not times: | |
| return None | |
| label_to_idx = {l: i for i, l in enumerate(LABEL_NAMES)} | |
| fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 5), | |
| gridspec_kw={"height_ratios": [3, 1]}) | |
| for i, lbl in enumerate(LABEL_NAMES[:7]): | |
| pts = [(t, c) for t, l, c in smoothed if l == lbl] | |
| if pts: | |
| ts, cs = zip(*pts) | |
| ax1.scatter(ts, cs, label=lbl, | |
| color=LABEL_COLORS_HEX[i], s=8, alpha=0.7) | |
| ax1.set_ylabel("Confidence", fontsize=11) | |
| ax1.set_ylim(0, 1.05) | |
| ax1.legend(loc="upper right", fontsize=8, ncol=4) | |
| ax1.set_title(f"Behavior Timeline — {video_name}", | |
| fontsize=12, fontweight="bold") | |
| ax1.grid(True, alpha=0.3) | |
| dt = (times[1] - times[0]) if len(times) > 1 else 0.1 | |
| for t, lbl, _ in smoothed: | |
| idx = label_to_idx.get(lbl, 7) | |
| ax2.axvspan(t - dt / 2, t + dt / 2, | |
| color=LABEL_COLORS_HEX[idx], alpha=0.9) | |
| patches = [mpatches.Patch(color=LABEL_COLORS_HEX[i], label=l) | |
| for i, l in enumerate(LABEL_NAMES[:7])] | |
| ax2.legend(handles=patches, loc="upper right", fontsize=8, ncol=4) | |
| ax2.set_xlabel("Time (seconds)", fontsize=11) | |
| ax2.set_yticks([]) | |
| plt.tight_layout() | |
| return fig | |
| def save_combined_chart(smooth_cnt, smoothed): | |
| """Save side-by-side pie + timeline strip as a PNG and return its path.""" | |
| times = [t for t, _, _ in smoothed] | |
| if not times: | |
| return None | |
| label_to_idx = {l: i for i, l in enumerate(LABEL_NAMES)} | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 5)) | |
| labels = [l for l in LABEL_NAMES[:7] if smooth_cnt.get(l, 0) > 0] | |
| sizes = [smooth_cnt[l] for l in labels] | |
| colors = [LABEL_COLORS_HEX[LABEL_NAMES.index(l)] for l in labels] | |
| axes[0].pie(sizes, labels=labels, colors=colors, | |
| autopct="%1.1f%%", startangle=90) | |
| axes[0].set_title("Time Budget (Smoothed)") | |
| dt = (times[1] - times[0]) if len(times) > 1 else 0.1 | |
| for t, lbl, _ in smoothed: | |
| idx = label_to_idx.get(lbl, 7) | |
| axes[1].axvspan(t - dt / 2, t + dt / 2, | |
| color=LABEL_COLORS_HEX[idx], alpha=0.9) | |
| axes[1].set_title("Behavior Timeline Strip") | |
| axes[1].set_xlabel("Time (seconds)") | |
| axes[1].set_yticks([]) | |
| plt.tight_layout() | |
| out = os.path.join(tempfile.gettempdir(), "behavior_charts.png") | |
| plt.savefig(out, format="png", dpi=150, bbox_inches="tight") | |
| plt.close(fig) | |
| return out | |
| # ───────────────────────────────────────────────────────────── | |
| # Core Processing (refactored to yield progress updates) | |
| # ───────────────────────────────────────────────────────────── | |
| def _process_video_frames(video_path, species_name, frame_skip, progress): | |
| """ | |
| Runs detection + classification frame by frame. | |
| Calls progress(0..1, desc=...) to stream updates to the UI. | |
| Returns (output_video_path, counts_counter, timeline_list, last_preview_rgb). | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 25 | |
| W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| total_f = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| out_path = os.path.join(tempfile.gettempdir(), "output_annotated.mp4") | |
| # Prefer H.264 for browser-compatible playback; fall back to mp4v | |
| fourcc = cv2.VideoWriter_fourcc(*"avc1") | |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H)) | |
| if not writer.isOpened(): | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H)) | |
| buffers = {} | |
| all_preds = [] | |
| timeline = [] | |
| frame_idx = 0 | |
| last_annotated = None | |
| preview_rgb = None | |
| preview_every = max(1, total_f // 20) | |
| progress(0.0, desc="Starting analysis...") | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if frame_idx % frame_skip == 0: | |
| annotated = frame.copy() | |
| animals = detect_animals(frame, species_name) | |
| timestamp = frame_idx / fps | |
| for idx, (x1, y1, x2, y2, name, _) in enumerate(animals): | |
| crop = frame[y1:y2, x1:x2] | |
| if crop.size == 0: | |
| continue | |
| crop = cv2.resize(crop, (CROP_SIZE, CROP_SIZE)) | |
| if idx not in buffers: | |
| buffers[idx] = deque(maxlen=NUM_FRAMES) | |
| buffers[idx].append(crop) | |
| if len(buffers[idx]) == NUM_FRAMES: | |
| pred_idx, pred_label, probs = classify_behavior(list(buffers[idx])) | |
| conf = float(probs[pred_idx]) | |
| all_preds.append(pred_label) | |
| timeline.append((timestamp, pred_label, conf)) | |
| else: | |
| pred_label = "Buffering..." | |
| pred_idx = 7 | |
| conf = 0.0 | |
| color = LABEL_COLORS_BGR[pred_idx] | |
| cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2) | |
| txt = f"{name}: {pred_label} ({conf:.2f})" | |
| fs = max(0.4, min(0.6, (x2 - x1) / 400)) | |
| (tw, th), _ = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, fs, 2) | |
| cv2.rectangle(annotated, (x1, y1 - th - 8), (x1 + tw + 4, y1), color, -1) | |
| cv2.putText(annotated, txt, (x1 + 2, y1 - 4), | |
| cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 2) | |
| cv2.putText(annotated, | |
| f"t={timestamp:.1f}s | {len(animals)} detected", | |
| (10, 30), cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.65, (255, 255, 255), 2) | |
| last_annotated = annotated | |
| if frame_idx % preview_every == 0: | |
| preview_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) | |
| writer.write(last_annotated if last_annotated is not None else frame) | |
| frame_idx += 1 | |
| progress( | |
| min(frame_idx / total_f, 1.0), | |
| desc=( | |
| f"Frame {frame_idx}/{total_f} " | |
| f"({frame_idx / total_f * 100:.0f}%) — " | |
| f"{len(all_preds)} classifications" | |
| ) | |
| ) | |
| cap.release() | |
| writer.release() | |
| return out_path, Counter(all_preds), timeline, preview_rgb | |
| # ───────────────────────────────────────────────────────────── | |
| # Gradio Event Handler | |
| # ───────────────────────────────────────────────────────────── | |
| def analyze_video(video_file, species_name, frame_skip, progress=gr.Progress()): | |
| """ | |
| Main handler wired to the Analyze button. | |
| Returns 7 outputs: summary_md, fig_raw, fig_smooth, fig_timeline, | |
| video_out_path, charts_png_path, preview_img | |
| """ | |
| if video_file is None: | |
| raise gr.Error("Please upload a video file first.") | |
| video_name = os.path.basename(video_file) | |
| start_time = time.time() | |
| # ── Video metadata ──────────────────────────────────────── | |
| cap = cv2.VideoCapture(video_file) | |
| fps_val = cap.get(cv2.CAP_PROP_FPS) or 25 | |
| total_f = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| dur_sec = total_f / fps_val | |
| cap.release() | |
| # ── Run processing ──────────────────────────────────────── | |
| out_path, counts, timeline, preview_rgb = _process_video_frames( | |
| video_file, species_name, int(frame_skip), progress | |
| ) | |
| elapsed = time.time() - start_time | |
| # ── Guard: no detections ────────────────────────────────── | |
| total_preds = sum(counts.values()) | |
| if total_preds == 0: | |
| raise gr.Error( | |
| "No animals detected in this video. " | |
| "Try a lower frame skip, ensure animals are clearly visible, " | |
| "and use drone footage at 10–50 m altitude." | |
| ) | |
| dominant = max(counts, key=counts.get) | |
| dom_pct = counts[dominant] / total_preds * 100 | |
| # ── Summary markdown ────────────────────────────────────── | |
| info_row = ( | |
| f"📹 **{video_name}** | " | |
| f"⏱ {dur_sec:.1f}s | " | |
| f"🖥 {W}×{H} | " | |
| f"🎞 {fps_val:.0f} fps | " | |
| f"🖼 {total_f:,} frames" | |
| ) | |
| rows = "\n".join( | |
| f"| {lbl} | {counts.get(lbl, 0):,} | " | |
| f"{'█' * int(counts.get(lbl,0)/total_preds*20)}" | |
| f"{'░' * (20 - int(counts.get(lbl,0)/total_preds*20))} | " | |
| f"{counts.get(lbl,0)/total_preds*100:.1f}% |" | |
| for lbl in LABEL_NAMES[:7] | |
| if counts.get(lbl, 0) > 0 | |
| ) | |
| summary_md = f""" | |
| ## ✅ Analysis Complete — processed in {elapsed:.0f}s | |
| {info_row} | |
| --- | |
| | Metric | Value | | |
| |---|---| | |
| | Total Classifications | **{total_preds:,}** | | |
| | Dominant Behavior | **{dominant}** | | |
| | Dominance | **{dom_pct:.1f}%** | | |
| | Device | **{DEVICE}** | | |
| ### Behavior Breakdown | |
| | Behavior | Count | Distribution | % | | |
| |---|---|---|---| | |
| {rows} | |
| """ | |
| # ── Charts ──────────────────────────────────────────────── | |
| smoothed = smooth_predictions(timeline, window=5) | |
| smooth_cnt = Counter(l for _, l, _ in smoothed) | |
| fig_raw = make_time_budget_chart(counts, "Raw Predictions") | |
| fig_smooth = make_time_budget_chart(smooth_cnt, "Smoothed (5-frame vote)") | |
| fig_timeline = make_timeline_chart(smoothed, video_name) | |
| # ── Downloadable combined chart PNG ─────────────────────── | |
| charts_path = save_combined_chart(smooth_cnt, smoothed) | |
| return ( | |
| summary_md, # gr.Markdown — Summary tab | |
| fig_raw, # gr.Plot — Charts tab | |
| fig_smooth, # gr.Plot — Charts tab | |
| fig_timeline, # gr.Plot — Charts tab | |
| out_path, # gr.Video — Annotated Video tab | |
| charts_path, # gr.File — Downloads tab | |
| preview_rgb, # gr.Image — last live preview frame | |
| ) | |
| # ───────────────────────────────────────────────────────────── | |
| # Gradio UI Layout | |
| # ───────────────────────────────────────────────────────────── | |
| CSS = """ | |
| .gr-button-primary { font-size: 1.1rem !important; } | |
| .status-box { background: #f0fdf4; border: 1px solid #86efac; | |
| border-radius: 8px; padding: 12px; } | |
| """ | |
| with gr.Blocks(css=CSS, title="🦒 Wildlife Behavior Analyzer") as demo: | |
| # ── Header ──────────────────────────────────────────────── | |
| gr.Markdown( | |
| """ | |
| # 🦒 Wildlife Behavior Analyzer | |
| **AI-powered animal behavior recognition from drone footage** | |
| Powered by [X3D-KABR](https://huggingface.co/imageomics/x3d-kabr-kinetics) | |
| (WACV 2024) + YOLOv8 | | |
| Detects: **Walk · Trot · Run · Graze · Browse · Head Up · Auto-Groom** | |
| """ | |
| ) | |
| # ── Input Panel ─────────────────────────────────────────── | |
| with gr.Row(): | |
| # Left: settings | |
| with gr.Column(scale=1, min_width=260): | |
| gr.Markdown("### ⚙️ Settings") | |
| species_input = gr.Dropdown( | |
| choices=["Giraffe", "Zebra", "Elephant", | |
| "Deer", "Horse", "Other Wildlife"], | |
| value="Giraffe", | |
| label="Animal species in video" | |
| ) | |
| frame_skip_input = gr.Slider( | |
| minimum=1, maximum=6, value=3, step=1, | |
| label="Frame skip (higher = faster, less detail)", | |
| info="Process every Nth frame. 1 = all frames (slow), 6 = fastest" | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### 📌 About the Model | |
| **X3D-KABR** — trained on the KABR dataset: | |
| 10+ hours of drone footage of Kenyan wildlife. | |
| Published at **WACV 2024**. | |
| ⚠️ Best results with **drone/aerial footage** of ungulates | |
| (giraffe, zebra, deer, horse) at 10–50 m altitude. | |
| --- | |
| ### ⏱️ Estimated Processing Time | |
| - ~2–4 min for a 30-second video | |
| - ~5–8 min for a 1-minute video | |
| """ | |
| ) | |
| # Right: upload + analyze | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 📤 Upload Drone Wildlife Video") | |
| video_input = gr.Video( | |
| label="Upload video (MP4, AVI, MOV, MKV)", | |
| sources=["upload"], | |
| height=320, | |
| ) | |
| analyze_btn = gr.Button( | |
| "🚀 Analyze Behavior", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| gr.Markdown("---") | |
| # ── Output Tabs ─────────────────────────────────────────── | |
| with gr.Tabs(): | |
| with gr.Tab("📋 Summary & Metrics"): | |
| summary_output = gr.Markdown( | |
| "*Upload a video and click **Analyze Behavior** to see results.*" | |
| ) | |
| with gr.Tab("📊 Charts"): | |
| with gr.Row(): | |
| pie_raw = gr.Plot(label="Raw Prediction Budget") | |
| pie_smooth = gr.Plot(label="Smoothed Prediction Budget (5-frame vote)") | |
| timeline_plot = gr.Plot(label="Behavior Timeline") | |
| with gr.Tab("🎬 Annotated Video"): | |
| gr.Markdown("*The annotated video with bounding boxes and behavior labels.*") | |
| video_output = gr.Video(label="Annotated Output", height=480) | |
| with gr.Tab("⬇️ Downloads"): | |
| gr.Markdown("### Download Results") | |
| with gr.Row(): | |
| charts_file = gr.File(label="📈 Behavior Charts PNG") | |
| # Note: the annotated video is also available here as a file | |
| video_file_dl = gr.File(label="🎬 Annotated Video (MP4)") | |
| gr.Markdown( | |
| "_Tip: right-click the annotated video above and choose " | |
| "**Save Video As** for the highest-quality download._" | |
| ) | |
| with gr.Tab("🔍 Live Preview"): | |
| gr.Markdown("*Last frame captured during processing.*") | |
| preview_img = gr.Image( | |
| label="Last Processed Frame", | |
| type="numpy", | |
| height=400, | |
| ) | |
| # ── Wire button → handler ───────────────────────────────── | |
| analyze_btn.click( | |
| fn=analyze_video, | |
| inputs=[video_input, species_input, frame_skip_input], | |
| outputs=[ | |
| summary_output, # Tab 1 | |
| pie_raw, # Tab 2 | |
| pie_smooth, # Tab 2 | |
| timeline_plot, # Tab 2 | |
| video_output, # Tab 3 | |
| charts_file, # Tab 4 (charts PNG) | |
| preview_img, # Tab 5 | |
| ], | |
| show_progress="full", | |
| api_name="analyze", | |
| ) | |
| # Also wire annotated video → download file tab | |
| # (same path fed to both gr.Video and gr.File) | |
| analyze_btn.click( | |
| fn=lambda v, s, f: analyze_video(v, s, f)[4], # index 4 = video path | |
| inputs=[video_input, species_input, frame_skip_input], | |
| outputs=[video_file_dl], | |
| show_progress=False, | |
| api_name=False, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| **Credits:** | |
| [X3D-KABR model](https://huggingface.co/imageomics/x3d-kabr-kinetics) · | |
| [KABR Dataset](https://huggingface.co/datasets/imageomics/KABR) · | |
| YOLOv8 by Ultralytics · Built by rohansingh0 | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", # needed for HuggingFace Spaces | |
| server_port=7860, | |
| show_error=True, | |
| ) | |