""" Wildlife Behavior Analyzer Powered by X3D-KABR (WACV 2024) + YOLOv8 Author: rohansingh0 | HuggingFace Spaces """ import os import re import io import cv2 import time import tempfile import zipfile import numpy as np import torch import gradio as gr import matplotlib matplotlib.use("Agg") # headless backend — must be before pyplot import import matplotlib.pyplot as plt import matplotlib.patches as mpatches from collections import deque, Counter from huggingface_hub import hf_hub_download from pytorchvideo.models.x3d import create_x3d from ultralytics import YOLO # ───────────────────────────────────────────────────────────── # Constants # ───────────────────────────────────────────────────────────── DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") NUM_CLASSES = 8 NUM_FRAMES = 16 CROP_SIZE = 300 MEAN = [0.45, 0.45, 0.45] STD = [0.225, 0.225, 0.225] ANIMAL_IDS = set(range(14, 26)) # COCO animal class IDs LABEL_NAMES = ["Walk", "Trot", "Run", "Graze", "Browse", "Head Up", "Auto-Groom", "Occluded"] LABEL_COLORS_HEX = [ "#2196F3", "#9C27B0", "#F44336", "#4CAF50", "#009688", "#FF9800", "#E91E63", "#9E9E9E" ] LABEL_COLORS_BGR = [ (243, 150, 33), (176, 39, 156), (54, 67, 244), (80, 175, 76), (136, 150, 0), (0, 152, 255), (99, 30, 233), (158, 158, 158) ] # ───────────────────────────────────────────────────────────── # Model Loading (runs ONCE at server startup — not per session) # This is the key fix: Streamlit reloads per session; Gradio does not. # ───────────────────────────────────────────────────────────── def remap_key(k): """Remap SlowFast checkpoint keys → pytorchvideo format.""" k = re.sub(r"^s1\.pathway0_stem\.conv_xy\b", "blocks.0.conv.conv_t", k) k = re.sub(r"^s1\.pathway0_stem\.conv\b", "blocks.0.conv.conv_xy", k) k = re.sub(r"^s1\.pathway0_stem\.bn\b", "blocks.0.norm", k) def stage(m): return f"blocks.{int(m.group(1))-1}.res_blocks.{int(m.group(2))}." k = re.sub(r"^s(\d)\.pathway0_res(\d+)\.", stage, k) k = re.sub(r"\.branch1\.weight$", ".branch1_conv.weight", k) k = re.sub(r"\.branch1_bn\.", ".branch1_norm.", k) k = re.sub(r"\.branch2\.a\.weight$", ".branch2.conv_a.weight", k) k = re.sub(r"\.branch2\.a_bn\.", ".branch2.norm_a.", k) k = re.sub(r"\.branch2\.b\.weight$", ".branch2.conv_b.weight", k) k = re.sub(r"\.branch2\.b_bn\.", ".branch2.norm_b.0.", k) k = re.sub(r"\.branch2\.se\.fc1\.", ".branch2.norm_b.1.block.0.", k) k = re.sub(r"\.branch2\.se\.fc2\.", ".branch2.norm_b.1.block.2.", k) k = re.sub(r"\.branch2\.c\.weight$", ".branch2.conv_c.weight", k) k = re.sub(r"\.branch2\.c_bn\.", ".branch2.norm_c.", k) k = re.sub(r"^head\.conv_5\.", "blocks.5.pool.pre_conv.", k) k = re.sub(r"^head\.conv_5_bn\.", "blocks.5.pool.pre_norm.", k) k = re.sub(r"^head\.lin_5\.", "blocks.5.pool.post_conv.", k) k = re.sub(r"^head\.projection\.", "blocks.5.proj.", k) return k def _load_models(): print(f"[Wildlife Analyzer] Loading models on {DEVICE} ...") # ── X3D-KABR ────────────────────────────────────────────── ckpt_zip = hf_hub_download( repo_id = "imageomics/x3d-kabr-kinetics", filename = "checkpoint_epoch_00075.pyth.zip", local_dir = "/tmp" ) with zipfile.ZipFile(ckpt_zip, "r") as z: z.extractall("/tmp") ckpt_path = None for root, _, files in os.walk("/tmp"): for f in files: if f.endswith(".pyth"): ckpt_path = os.path.join(root, f) break model = create_x3d( input_clip_length=16, input_crop_size=300, model_num_class=NUM_CLASSES, depth_factor=5.0, width_factor=2.0, bottleneck_factor=2.25, dropout_rate=0.5, head_activation=None, ) checkpoint = torch.load(ckpt_path, map_location="cpu") state_dict = checkpoint["model_state"] model_sd = model.state_dict() model_keys = set(model_sd.keys()) remapped = {} for ck, cv in state_dict.items(): mk = remap_key(ck) if mk in model_keys and model_sd[mk].shape == cv.shape: remapped[mk] = cv model.load_state_dict(remapped, strict=False) model = model.to(DEVICE).eval() # ── YOLOv8 ──────────────────────────────────────────────── yolo = YOLO("yolov8n.pt") yolo.to("cpu") print(f"[Wildlife Analyzer] ✅ Models ready on {DEVICE}") return model, yolo # Module-level singleton — loaded once, shared across all requests MODEL, YOLO_MODEL = _load_models() # ───────────────────────────────────────────────────────────── # Inference Helpers (identical logic to original) # ───────────────────────────────────────────────────────────── def preprocess_clip(frames_bgr): out = [] for f in frames_bgr: f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB) h, w = f.shape[:2] s = CROP_SIZE / min(h, w) f = cv2.resize(f, (int(w * s), int(h * s))) ch = (f.shape[0] - CROP_SIZE) // 2 cw = (f.shape[1] - CROP_SIZE) // 2 f = f[ch:ch + CROP_SIZE, cw:cw + CROP_SIZE] f = (f.astype(np.float32) / 255. - MEAN) / STD out.append(f) clip = np.stack(out).transpose(3, 0, 1, 2) return torch.tensor(clip, dtype=torch.float32).unsqueeze(0) def classify_behavior(crop_sequence): with torch.no_grad(): x = preprocess_clip(crop_sequence).to(DEVICE) out = MODEL(x) probs = torch.sigmoid(out).squeeze().cpu().numpy() pred = int(np.argmax(probs[:7])) return pred, LABEL_NAMES[pred], probs def detect_animals(frame_bgr, species_name): rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) H, W = frame_bgr.shape[:2] results = YOLO_MODEL(rgb, verbose=False, device="cpu", conf=0.15, iou=0.4)[0] animals = [] for box in results.boxes: if int(box.cls[0]) not in ANIMAL_IDS: continue x1, y1, x2, y2 = map(int, box.xyxy[0]) x1, y1 = max(0, x1), max(0, y1) x2, y2 = min(W, x2), min(H, y2) w_box, h_box = x2 - x1, y2 - y1 area_ratio = (w_box * h_box) / (W * H) if area_ratio < 0.005 or area_ratio > 0.80: continue if w_box < 30 or h_box < 30: continue animals.append((x1, y1, x2, y2, species_name, float(box.conf[0]))) return animals def smooth_predictions(timeline, window=5): smoothed = [] label_buffer = deque(maxlen=window) for t, label, conf in timeline: label_buffer.append(label) majority = Counter(label_buffer).most_common(1)[0][0] smoothed.append((t, majority, conf)) return smoothed # ───────────────────────────────────────────────────────────── # Visualization Helpers (identical logic to original) # ───────────────────────────────────────────────────────────── def make_time_budget_chart(counts, title): labels = [l for l in LABEL_NAMES[:7] if counts.get(l, 0) > 0] sizes = [counts[l] for l in labels] colors = [LABEL_COLORS_HEX[LABEL_NAMES.index(l)] for l in labels] fig, ax = plt.subplots(figsize=(6, 6)) ax.pie(sizes, labels=labels, colors=colors, autopct="%1.1f%%", startangle=90, textprops={"fontsize": 11}) ax.set_title(title, fontsize=13, fontweight="bold", pad=15) plt.tight_layout() return fig def make_timeline_chart(smoothed, video_name): times = [t for t, _, _ in smoothed] if not times: return None label_to_idx = {l: i for i, l in enumerate(LABEL_NAMES)} fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 5), gridspec_kw={"height_ratios": [3, 1]}) for i, lbl in enumerate(LABEL_NAMES[:7]): pts = [(t, c) for t, l, c in smoothed if l == lbl] if pts: ts, cs = zip(*pts) ax1.scatter(ts, cs, label=lbl, color=LABEL_COLORS_HEX[i], s=8, alpha=0.7) ax1.set_ylabel("Confidence", fontsize=11) ax1.set_ylim(0, 1.05) ax1.legend(loc="upper right", fontsize=8, ncol=4) ax1.set_title(f"Behavior Timeline — {video_name}", fontsize=12, fontweight="bold") ax1.grid(True, alpha=0.3) dt = (times[1] - times[0]) if len(times) > 1 else 0.1 for t, lbl, _ in smoothed: idx = label_to_idx.get(lbl, 7) ax2.axvspan(t - dt / 2, t + dt / 2, color=LABEL_COLORS_HEX[idx], alpha=0.9) patches = [mpatches.Patch(color=LABEL_COLORS_HEX[i], label=l) for i, l in enumerate(LABEL_NAMES[:7])] ax2.legend(handles=patches, loc="upper right", fontsize=8, ncol=4) ax2.set_xlabel("Time (seconds)", fontsize=11) ax2.set_yticks([]) plt.tight_layout() return fig def save_combined_chart(smooth_cnt, smoothed): """Save side-by-side pie + timeline strip as a PNG and return its path.""" times = [t for t, _, _ in smoothed] if not times: return None label_to_idx = {l: i for i, l in enumerate(LABEL_NAMES)} fig, axes = plt.subplots(1, 2, figsize=(14, 5)) labels = [l for l in LABEL_NAMES[:7] if smooth_cnt.get(l, 0) > 0] sizes = [smooth_cnt[l] for l in labels] colors = [LABEL_COLORS_HEX[LABEL_NAMES.index(l)] for l in labels] axes[0].pie(sizes, labels=labels, colors=colors, autopct="%1.1f%%", startangle=90) axes[0].set_title("Time Budget (Smoothed)") dt = (times[1] - times[0]) if len(times) > 1 else 0.1 for t, lbl, _ in smoothed: idx = label_to_idx.get(lbl, 7) axes[1].axvspan(t - dt / 2, t + dt / 2, color=LABEL_COLORS_HEX[idx], alpha=0.9) axes[1].set_title("Behavior Timeline Strip") axes[1].set_xlabel("Time (seconds)") axes[1].set_yticks([]) plt.tight_layout() out = os.path.join(tempfile.gettempdir(), "behavior_charts.png") plt.savefig(out, format="png", dpi=150, bbox_inches="tight") plt.close(fig) return out # ───────────────────────────────────────────────────────────── # Core Processing (refactored to yield progress updates) # ───────────────────────────────────────────────────────────── def _process_video_frames(video_path, species_name, frame_skip, progress): """ Runs detection + classification frame by frame. Calls progress(0..1, desc=...) to stream updates to the UI. Returns (output_video_path, counts_counter, timeline_list, last_preview_rgb). """ cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) or 25 W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_f = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) out_path = os.path.join(tempfile.gettempdir(), "output_annotated.mp4") # Prefer H.264 for browser-compatible playback; fall back to mp4v fourcc = cv2.VideoWriter_fourcc(*"avc1") writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H)) if not writer.isOpened(): fourcc = cv2.VideoWriter_fourcc(*"mp4v") writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H)) buffers = {} all_preds = [] timeline = [] frame_idx = 0 last_annotated = None preview_rgb = None preview_every = max(1, total_f // 20) progress(0.0, desc="Starting analysis...") while True: ret, frame = cap.read() if not ret: break if frame_idx % frame_skip == 0: annotated = frame.copy() animals = detect_animals(frame, species_name) timestamp = frame_idx / fps for idx, (x1, y1, x2, y2, name, _) in enumerate(animals): crop = frame[y1:y2, x1:x2] if crop.size == 0: continue crop = cv2.resize(crop, (CROP_SIZE, CROP_SIZE)) if idx not in buffers: buffers[idx] = deque(maxlen=NUM_FRAMES) buffers[idx].append(crop) if len(buffers[idx]) == NUM_FRAMES: pred_idx, pred_label, probs = classify_behavior(list(buffers[idx])) conf = float(probs[pred_idx]) all_preds.append(pred_label) timeline.append((timestamp, pred_label, conf)) else: pred_label = "Buffering..." pred_idx = 7 conf = 0.0 color = LABEL_COLORS_BGR[pred_idx] cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2) txt = f"{name}: {pred_label} ({conf:.2f})" fs = max(0.4, min(0.6, (x2 - x1) / 400)) (tw, th), _ = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, fs, 2) cv2.rectangle(annotated, (x1, y1 - th - 8), (x1 + tw + 4, y1), color, -1) cv2.putText(annotated, txt, (x1 + 2, y1 - 4), cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 2) cv2.putText(annotated, f"t={timestamp:.1f}s | {len(animals)} detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.65, (255, 255, 255), 2) last_annotated = annotated if frame_idx % preview_every == 0: preview_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB) writer.write(last_annotated if last_annotated is not None else frame) frame_idx += 1 progress( min(frame_idx / total_f, 1.0), desc=( f"Frame {frame_idx}/{total_f} " f"({frame_idx / total_f * 100:.0f}%) — " f"{len(all_preds)} classifications" ) ) cap.release() writer.release() return out_path, Counter(all_preds), timeline, preview_rgb # ───────────────────────────────────────────────────────────── # Gradio Event Handler # ───────────────────────────────────────────────────────────── def analyze_video(video_file, species_name, frame_skip, progress=gr.Progress()): """ Main handler wired to the Analyze button. Returns 7 outputs: summary_md, fig_raw, fig_smooth, fig_timeline, video_out_path, charts_png_path, preview_img """ if video_file is None: raise gr.Error("Please upload a video file first.") video_name = os.path.basename(video_file) start_time = time.time() # ── Video metadata ──────────────────────────────────────── cap = cv2.VideoCapture(video_file) fps_val = cap.get(cv2.CAP_PROP_FPS) or 25 total_f = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) dur_sec = total_f / fps_val cap.release() # ── Run processing ──────────────────────────────────────── out_path, counts, timeline, preview_rgb = _process_video_frames( video_file, species_name, int(frame_skip), progress ) elapsed = time.time() - start_time # ── Guard: no detections ────────────────────────────────── total_preds = sum(counts.values()) if total_preds == 0: raise gr.Error( "No animals detected in this video. " "Try a lower frame skip, ensure animals are clearly visible, " "and use drone footage at 10–50 m altitude." ) dominant = max(counts, key=counts.get) dom_pct = counts[dominant] / total_preds * 100 # ── Summary markdown ────────────────────────────────────── info_row = ( f"📹 **{video_name}**  |  " f"⏱ {dur_sec:.1f}s  |  " f"🖥 {W}×{H}  |  " f"🎞 {fps_val:.0f} fps  |  " f"🖼 {total_f:,} frames" ) rows = "\n".join( f"| {lbl} | {counts.get(lbl, 0):,} | " f"{'█' * int(counts.get(lbl,0)/total_preds*20)}" f"{'░' * (20 - int(counts.get(lbl,0)/total_preds*20))} | " f"{counts.get(lbl,0)/total_preds*100:.1f}% |" for lbl in LABEL_NAMES[:7] if counts.get(lbl, 0) > 0 ) summary_md = f""" ## ✅ Analysis Complete — processed in {elapsed:.0f}s {info_row} --- | Metric | Value | |---|---| | Total Classifications | **{total_preds:,}** | | Dominant Behavior | **{dominant}** | | Dominance | **{dom_pct:.1f}%** | | Device | **{DEVICE}** | ### Behavior Breakdown | Behavior | Count | Distribution | % | |---|---|---|---| {rows} """ # ── Charts ──────────────────────────────────────────────── smoothed = smooth_predictions(timeline, window=5) smooth_cnt = Counter(l for _, l, _ in smoothed) fig_raw = make_time_budget_chart(counts, "Raw Predictions") fig_smooth = make_time_budget_chart(smooth_cnt, "Smoothed (5-frame vote)") fig_timeline = make_timeline_chart(smoothed, video_name) # ── Downloadable combined chart PNG ─────────────────────── charts_path = save_combined_chart(smooth_cnt, smoothed) return ( summary_md, # gr.Markdown — Summary tab fig_raw, # gr.Plot — Charts tab fig_smooth, # gr.Plot — Charts tab fig_timeline, # gr.Plot — Charts tab out_path, # gr.Video — Annotated Video tab charts_path, # gr.File — Downloads tab preview_rgb, # gr.Image — last live preview frame ) # ───────────────────────────────────────────────────────────── # Gradio UI Layout # ───────────────────────────────────────────────────────────── CSS = """ .gr-button-primary { font-size: 1.1rem !important; } .status-box { background: #f0fdf4; border: 1px solid #86efac; border-radius: 8px; padding: 12px; } """ with gr.Blocks(css=CSS, title="🦒 Wildlife Behavior Analyzer") as demo: # ── Header ──────────────────────────────────────────────── gr.Markdown( """ # 🦒 Wildlife Behavior Analyzer **AI-powered animal behavior recognition from drone footage** Powered by [X3D-KABR](https://huggingface.co/imageomics/x3d-kabr-kinetics) (WACV 2024) + YOLOv8  |  Detects: **Walk · Trot · Run · Graze · Browse · Head Up · Auto-Groom** """ ) # ── Input Panel ─────────────────────────────────────────── with gr.Row(): # Left: settings with gr.Column(scale=1, min_width=260): gr.Markdown("### ⚙️ Settings") species_input = gr.Dropdown( choices=["Giraffe", "Zebra", "Elephant", "Deer", "Horse", "Other Wildlife"], value="Giraffe", label="Animal species in video" ) frame_skip_input = gr.Slider( minimum=1, maximum=6, value=3, step=1, label="Frame skip (higher = faster, less detail)", info="Process every Nth frame. 1 = all frames (slow), 6 = fastest" ) gr.Markdown( """ --- ### 📌 About the Model **X3D-KABR** — trained on the KABR dataset: 10+ hours of drone footage of Kenyan wildlife. Published at **WACV 2024**. ⚠️ Best results with **drone/aerial footage** of ungulates (giraffe, zebra, deer, horse) at 10–50 m altitude. --- ### ⏱️ Estimated Processing Time - ~2–4 min for a 30-second video - ~5–8 min for a 1-minute video """ ) # Right: upload + analyze with gr.Column(scale=2): gr.Markdown("### 📤 Upload Drone Wildlife Video") video_input = gr.Video( label="Upload video (MP4, AVI, MOV, MKV)", sources=["upload"], height=320, ) analyze_btn = gr.Button( "🚀 Analyze Behavior", variant="primary", size="lg" ) gr.Markdown("---") # ── Output Tabs ─────────────────────────────────────────── with gr.Tabs(): with gr.Tab("📋 Summary & Metrics"): summary_output = gr.Markdown( "*Upload a video and click **Analyze Behavior** to see results.*" ) with gr.Tab("📊 Charts"): with gr.Row(): pie_raw = gr.Plot(label="Raw Prediction Budget") pie_smooth = gr.Plot(label="Smoothed Prediction Budget (5-frame vote)") timeline_plot = gr.Plot(label="Behavior Timeline") with gr.Tab("🎬 Annotated Video"): gr.Markdown("*The annotated video with bounding boxes and behavior labels.*") video_output = gr.Video(label="Annotated Output", height=480) with gr.Tab("⬇️ Downloads"): gr.Markdown("### Download Results") with gr.Row(): charts_file = gr.File(label="📈 Behavior Charts PNG") # Note: the annotated video is also available here as a file video_file_dl = gr.File(label="🎬 Annotated Video (MP4)") gr.Markdown( "_Tip: right-click the annotated video above and choose " "**Save Video As** for the highest-quality download._" ) with gr.Tab("🔍 Live Preview"): gr.Markdown("*Last frame captured during processing.*") preview_img = gr.Image( label="Last Processed Frame", type="numpy", height=400, ) # ── Wire button → handler ───────────────────────────────── analyze_btn.click( fn=analyze_video, inputs=[video_input, species_input, frame_skip_input], outputs=[ summary_output, # Tab 1 pie_raw, # Tab 2 pie_smooth, # Tab 2 timeline_plot, # Tab 2 video_output, # Tab 3 charts_file, # Tab 4 (charts PNG) preview_img, # Tab 5 ], show_progress="full", api_name="analyze", ) # Also wire annotated video → download file tab # (same path fed to both gr.Video and gr.File) analyze_btn.click( fn=lambda v, s, f: analyze_video(v, s, f)[4], # index 4 = video path inputs=[video_input, species_input, frame_skip_input], outputs=[video_file_dl], show_progress=False, api_name=False, ) gr.Markdown( """ --- **Credits:** [X3D-KABR model](https://huggingface.co/imageomics/x3d-kabr-kinetics) · [KABR Dataset](https://huggingface.co/datasets/imageomics/KABR) · YOLOv8 by Ultralytics · Built by rohansingh0 """ ) if __name__ == "__main__": demo.launch( server_name="0.0.0.0", # needed for HuggingFace Spaces server_port=7860, show_error=True, )