aba2 / app.py
rohansingh0's picture
initial commit
7e458d9 verified
"""
Wildlife Behavior Analyzer
Powered by X3D-KABR (WACV 2024) + YOLOv8
Author: rohansingh0 | HuggingFace Spaces
"""
import os
import re
import io
import cv2
import time
import tempfile
import zipfile
import numpy as np
import torch
import gradio as gr
import matplotlib
matplotlib.use("Agg") # headless backend — must be before pyplot import
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from collections import deque, Counter
from huggingface_hub import hf_hub_download
from pytorchvideo.models.x3d import create_x3d
from ultralytics import YOLO
# ─────────────────────────────────────────────────────────────
# Constants
# ─────────────────────────────────────────────────────────────
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLASSES = 8
NUM_FRAMES = 16
CROP_SIZE = 300
MEAN = [0.45, 0.45, 0.45]
STD = [0.225, 0.225, 0.225]
ANIMAL_IDS = set(range(14, 26)) # COCO animal class IDs
LABEL_NAMES = ["Walk", "Trot", "Run", "Graze",
"Browse", "Head Up", "Auto-Groom", "Occluded"]
LABEL_COLORS_HEX = [
"#2196F3", "#9C27B0", "#F44336", "#4CAF50",
"#009688", "#FF9800", "#E91E63", "#9E9E9E"
]
LABEL_COLORS_BGR = [
(243, 150, 33), (176, 39, 156), (54, 67, 244), (80, 175, 76),
(136, 150, 0), (0, 152, 255), (99, 30, 233), (158, 158, 158)
]
# ─────────────────────────────────────────────────────────────
# Model Loading (runs ONCE at server startup — not per session)
# This is the key fix: Streamlit reloads per session; Gradio does not.
# ─────────────────────────────────────────────────────────────
def remap_key(k):
"""Remap SlowFast checkpoint keys → pytorchvideo format."""
k = re.sub(r"^s1\.pathway0_stem\.conv_xy\b", "blocks.0.conv.conv_t", k)
k = re.sub(r"^s1\.pathway0_stem\.conv\b", "blocks.0.conv.conv_xy", k)
k = re.sub(r"^s1\.pathway0_stem\.bn\b", "blocks.0.norm", k)
def stage(m):
return f"blocks.{int(m.group(1))-1}.res_blocks.{int(m.group(2))}."
k = re.sub(r"^s(\d)\.pathway0_res(\d+)\.", stage, k)
k = re.sub(r"\.branch1\.weight$", ".branch1_conv.weight", k)
k = re.sub(r"\.branch1_bn\.", ".branch1_norm.", k)
k = re.sub(r"\.branch2\.a\.weight$", ".branch2.conv_a.weight", k)
k = re.sub(r"\.branch2\.a_bn\.", ".branch2.norm_a.", k)
k = re.sub(r"\.branch2\.b\.weight$", ".branch2.conv_b.weight", k)
k = re.sub(r"\.branch2\.b_bn\.", ".branch2.norm_b.0.", k)
k = re.sub(r"\.branch2\.se\.fc1\.", ".branch2.norm_b.1.block.0.", k)
k = re.sub(r"\.branch2\.se\.fc2\.", ".branch2.norm_b.1.block.2.", k)
k = re.sub(r"\.branch2\.c\.weight$", ".branch2.conv_c.weight", k)
k = re.sub(r"\.branch2\.c_bn\.", ".branch2.norm_c.", k)
k = re.sub(r"^head\.conv_5\.", "blocks.5.pool.pre_conv.", k)
k = re.sub(r"^head\.conv_5_bn\.", "blocks.5.pool.pre_norm.", k)
k = re.sub(r"^head\.lin_5\.", "blocks.5.pool.post_conv.", k)
k = re.sub(r"^head\.projection\.", "blocks.5.proj.", k)
return k
def _load_models():
print(f"[Wildlife Analyzer] Loading models on {DEVICE} ...")
# ── X3D-KABR ──────────────────────────────────────────────
ckpt_zip = hf_hub_download(
repo_id = "imageomics/x3d-kabr-kinetics",
filename = "checkpoint_epoch_00075.pyth.zip",
local_dir = "/tmp"
)
with zipfile.ZipFile(ckpt_zip, "r") as z:
z.extractall("/tmp")
ckpt_path = None
for root, _, files in os.walk("/tmp"):
for f in files:
if f.endswith(".pyth"):
ckpt_path = os.path.join(root, f)
break
model = create_x3d(
input_clip_length=16, input_crop_size=300,
model_num_class=NUM_CLASSES, depth_factor=5.0,
width_factor=2.0, bottleneck_factor=2.25,
dropout_rate=0.5, head_activation=None,
)
checkpoint = torch.load(ckpt_path, map_location="cpu")
state_dict = checkpoint["model_state"]
model_sd = model.state_dict()
model_keys = set(model_sd.keys())
remapped = {}
for ck, cv in state_dict.items():
mk = remap_key(ck)
if mk in model_keys and model_sd[mk].shape == cv.shape:
remapped[mk] = cv
model.load_state_dict(remapped, strict=False)
model = model.to(DEVICE).eval()
# ── YOLOv8 ────────────────────────────────────────────────
yolo = YOLO("yolov8n.pt")
yolo.to("cpu")
print(f"[Wildlife Analyzer] ✅ Models ready on {DEVICE}")
return model, yolo
# Module-level singleton — loaded once, shared across all requests
MODEL, YOLO_MODEL = _load_models()
# ─────────────────────────────────────────────────────────────
# Inference Helpers (identical logic to original)
# ─────────────────────────────────────────────────────────────
def preprocess_clip(frames_bgr):
out = []
for f in frames_bgr:
f = cv2.cvtColor(f, cv2.COLOR_BGR2RGB)
h, w = f.shape[:2]
s = CROP_SIZE / min(h, w)
f = cv2.resize(f, (int(w * s), int(h * s)))
ch = (f.shape[0] - CROP_SIZE) // 2
cw = (f.shape[1] - CROP_SIZE) // 2
f = f[ch:ch + CROP_SIZE, cw:cw + CROP_SIZE]
f = (f.astype(np.float32) / 255. - MEAN) / STD
out.append(f)
clip = np.stack(out).transpose(3, 0, 1, 2)
return torch.tensor(clip, dtype=torch.float32).unsqueeze(0)
def classify_behavior(crop_sequence):
with torch.no_grad():
x = preprocess_clip(crop_sequence).to(DEVICE)
out = MODEL(x)
probs = torch.sigmoid(out).squeeze().cpu().numpy()
pred = int(np.argmax(probs[:7]))
return pred, LABEL_NAMES[pred], probs
def detect_animals(frame_bgr, species_name):
rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
H, W = frame_bgr.shape[:2]
results = YOLO_MODEL(rgb, verbose=False, device="cpu", conf=0.15, iou=0.4)[0]
animals = []
for box in results.boxes:
if int(box.cls[0]) not in ANIMAL_IDS:
continue
x1, y1, x2, y2 = map(int, box.xyxy[0])
x1, y1 = max(0, x1), max(0, y1)
x2, y2 = min(W, x2), min(H, y2)
w_box, h_box = x2 - x1, y2 - y1
area_ratio = (w_box * h_box) / (W * H)
if area_ratio < 0.005 or area_ratio > 0.80:
continue
if w_box < 30 or h_box < 30:
continue
animals.append((x1, y1, x2, y2, species_name, float(box.conf[0])))
return animals
def smooth_predictions(timeline, window=5):
smoothed = []
label_buffer = deque(maxlen=window)
for t, label, conf in timeline:
label_buffer.append(label)
majority = Counter(label_buffer).most_common(1)[0][0]
smoothed.append((t, majority, conf))
return smoothed
# ─────────────────────────────────────────────────────────────
# Visualization Helpers (identical logic to original)
# ─────────────────────────────────────────────────────────────
def make_time_budget_chart(counts, title):
labels = [l for l in LABEL_NAMES[:7] if counts.get(l, 0) > 0]
sizes = [counts[l] for l in labels]
colors = [LABEL_COLORS_HEX[LABEL_NAMES.index(l)] for l in labels]
fig, ax = plt.subplots(figsize=(6, 6))
ax.pie(sizes, labels=labels, colors=colors,
autopct="%1.1f%%", startangle=90,
textprops={"fontsize": 11})
ax.set_title(title, fontsize=13, fontweight="bold", pad=15)
plt.tight_layout()
return fig
def make_timeline_chart(smoothed, video_name):
times = [t for t, _, _ in smoothed]
if not times:
return None
label_to_idx = {l: i for i, l in enumerate(LABEL_NAMES)}
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 5),
gridspec_kw={"height_ratios": [3, 1]})
for i, lbl in enumerate(LABEL_NAMES[:7]):
pts = [(t, c) for t, l, c in smoothed if l == lbl]
if pts:
ts, cs = zip(*pts)
ax1.scatter(ts, cs, label=lbl,
color=LABEL_COLORS_HEX[i], s=8, alpha=0.7)
ax1.set_ylabel("Confidence", fontsize=11)
ax1.set_ylim(0, 1.05)
ax1.legend(loc="upper right", fontsize=8, ncol=4)
ax1.set_title(f"Behavior Timeline — {video_name}",
fontsize=12, fontweight="bold")
ax1.grid(True, alpha=0.3)
dt = (times[1] - times[0]) if len(times) > 1 else 0.1
for t, lbl, _ in smoothed:
idx = label_to_idx.get(lbl, 7)
ax2.axvspan(t - dt / 2, t + dt / 2,
color=LABEL_COLORS_HEX[idx], alpha=0.9)
patches = [mpatches.Patch(color=LABEL_COLORS_HEX[i], label=l)
for i, l in enumerate(LABEL_NAMES[:7])]
ax2.legend(handles=patches, loc="upper right", fontsize=8, ncol=4)
ax2.set_xlabel("Time (seconds)", fontsize=11)
ax2.set_yticks([])
plt.tight_layout()
return fig
def save_combined_chart(smooth_cnt, smoothed):
"""Save side-by-side pie + timeline strip as a PNG and return its path."""
times = [t for t, _, _ in smoothed]
if not times:
return None
label_to_idx = {l: i for i, l in enumerate(LABEL_NAMES)}
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
labels = [l for l in LABEL_NAMES[:7] if smooth_cnt.get(l, 0) > 0]
sizes = [smooth_cnt[l] for l in labels]
colors = [LABEL_COLORS_HEX[LABEL_NAMES.index(l)] for l in labels]
axes[0].pie(sizes, labels=labels, colors=colors,
autopct="%1.1f%%", startangle=90)
axes[0].set_title("Time Budget (Smoothed)")
dt = (times[1] - times[0]) if len(times) > 1 else 0.1
for t, lbl, _ in smoothed:
idx = label_to_idx.get(lbl, 7)
axes[1].axvspan(t - dt / 2, t + dt / 2,
color=LABEL_COLORS_HEX[idx], alpha=0.9)
axes[1].set_title("Behavior Timeline Strip")
axes[1].set_xlabel("Time (seconds)")
axes[1].set_yticks([])
plt.tight_layout()
out = os.path.join(tempfile.gettempdir(), "behavior_charts.png")
plt.savefig(out, format="png", dpi=150, bbox_inches="tight")
plt.close(fig)
return out
# ─────────────────────────────────────────────────────────────
# Core Processing (refactored to yield progress updates)
# ─────────────────────────────────────────────────────────────
def _process_video_frames(video_path, species_name, frame_skip, progress):
"""
Runs detection + classification frame by frame.
Calls progress(0..1, desc=...) to stream updates to the UI.
Returns (output_video_path, counts_counter, timeline_list, last_preview_rgb).
"""
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) or 25
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_f = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
out_path = os.path.join(tempfile.gettempdir(), "output_annotated.mp4")
# Prefer H.264 for browser-compatible playback; fall back to mp4v
fourcc = cv2.VideoWriter_fourcc(*"avc1")
writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H))
if not writer.isOpened():
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
writer = cv2.VideoWriter(out_path, fourcc, fps, (W, H))
buffers = {}
all_preds = []
timeline = []
frame_idx = 0
last_annotated = None
preview_rgb = None
preview_every = max(1, total_f // 20)
progress(0.0, desc="Starting analysis...")
while True:
ret, frame = cap.read()
if not ret:
break
if frame_idx % frame_skip == 0:
annotated = frame.copy()
animals = detect_animals(frame, species_name)
timestamp = frame_idx / fps
for idx, (x1, y1, x2, y2, name, _) in enumerate(animals):
crop = frame[y1:y2, x1:x2]
if crop.size == 0:
continue
crop = cv2.resize(crop, (CROP_SIZE, CROP_SIZE))
if idx not in buffers:
buffers[idx] = deque(maxlen=NUM_FRAMES)
buffers[idx].append(crop)
if len(buffers[idx]) == NUM_FRAMES:
pred_idx, pred_label, probs = classify_behavior(list(buffers[idx]))
conf = float(probs[pred_idx])
all_preds.append(pred_label)
timeline.append((timestamp, pred_label, conf))
else:
pred_label = "Buffering..."
pred_idx = 7
conf = 0.0
color = LABEL_COLORS_BGR[pred_idx]
cv2.rectangle(annotated, (x1, y1), (x2, y2), color, 2)
txt = f"{name}: {pred_label} ({conf:.2f})"
fs = max(0.4, min(0.6, (x2 - x1) / 400))
(tw, th), _ = cv2.getTextSize(txt, cv2.FONT_HERSHEY_SIMPLEX, fs, 2)
cv2.rectangle(annotated, (x1, y1 - th - 8), (x1 + tw + 4, y1), color, -1)
cv2.putText(annotated, txt, (x1 + 2, y1 - 4),
cv2.FONT_HERSHEY_SIMPLEX, fs, (255, 255, 255), 2)
cv2.putText(annotated,
f"t={timestamp:.1f}s | {len(animals)} detected",
(10, 30), cv2.FONT_HERSHEY_SIMPLEX,
0.65, (255, 255, 255), 2)
last_annotated = annotated
if frame_idx % preview_every == 0:
preview_rgb = cv2.cvtColor(annotated, cv2.COLOR_BGR2RGB)
writer.write(last_annotated if last_annotated is not None else frame)
frame_idx += 1
progress(
min(frame_idx / total_f, 1.0),
desc=(
f"Frame {frame_idx}/{total_f} "
f"({frame_idx / total_f * 100:.0f}%) — "
f"{len(all_preds)} classifications"
)
)
cap.release()
writer.release()
return out_path, Counter(all_preds), timeline, preview_rgb
# ─────────────────────────────────────────────────────────────
# Gradio Event Handler
# ─────────────────────────────────────────────────────────────
def analyze_video(video_file, species_name, frame_skip, progress=gr.Progress()):
"""
Main handler wired to the Analyze button.
Returns 7 outputs: summary_md, fig_raw, fig_smooth, fig_timeline,
video_out_path, charts_png_path, preview_img
"""
if video_file is None:
raise gr.Error("Please upload a video file first.")
video_name = os.path.basename(video_file)
start_time = time.time()
# ── Video metadata ────────────────────────────────────────
cap = cv2.VideoCapture(video_file)
fps_val = cap.get(cv2.CAP_PROP_FPS) or 25
total_f = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
dur_sec = total_f / fps_val
cap.release()
# ── Run processing ────────────────────────────────────────
out_path, counts, timeline, preview_rgb = _process_video_frames(
video_file, species_name, int(frame_skip), progress
)
elapsed = time.time() - start_time
# ── Guard: no detections ──────────────────────────────────
total_preds = sum(counts.values())
if total_preds == 0:
raise gr.Error(
"No animals detected in this video. "
"Try a lower frame skip, ensure animals are clearly visible, "
"and use drone footage at 10–50 m altitude."
)
dominant = max(counts, key=counts.get)
dom_pct = counts[dominant] / total_preds * 100
# ── Summary markdown ──────────────────────────────────────
info_row = (
f"📹 **{video_name}** &nbsp;|&nbsp; "
f"⏱ {dur_sec:.1f}s &nbsp;|&nbsp; "
f"🖥 {W}×{H} &nbsp;|&nbsp; "
f"🎞 {fps_val:.0f} fps &nbsp;|&nbsp; "
f"🖼 {total_f:,} frames"
)
rows = "\n".join(
f"| {lbl} | {counts.get(lbl, 0):,} | "
f"{'█' * int(counts.get(lbl,0)/total_preds*20)}"
f"{'░' * (20 - int(counts.get(lbl,0)/total_preds*20))} | "
f"{counts.get(lbl,0)/total_preds*100:.1f}% |"
for lbl in LABEL_NAMES[:7]
if counts.get(lbl, 0) > 0
)
summary_md = f"""
## ✅ Analysis Complete — processed in {elapsed:.0f}s
{info_row}
---
| Metric | Value |
|---|---|
| Total Classifications | **{total_preds:,}** |
| Dominant Behavior | **{dominant}** |
| Dominance | **{dom_pct:.1f}%** |
| Device | **{DEVICE}** |
### Behavior Breakdown
| Behavior | Count | Distribution | % |
|---|---|---|---|
{rows}
"""
# ── Charts ────────────────────────────────────────────────
smoothed = smooth_predictions(timeline, window=5)
smooth_cnt = Counter(l for _, l, _ in smoothed)
fig_raw = make_time_budget_chart(counts, "Raw Predictions")
fig_smooth = make_time_budget_chart(smooth_cnt, "Smoothed (5-frame vote)")
fig_timeline = make_timeline_chart(smoothed, video_name)
# ── Downloadable combined chart PNG ───────────────────────
charts_path = save_combined_chart(smooth_cnt, smoothed)
return (
summary_md, # gr.Markdown — Summary tab
fig_raw, # gr.Plot — Charts tab
fig_smooth, # gr.Plot — Charts tab
fig_timeline, # gr.Plot — Charts tab
out_path, # gr.Video — Annotated Video tab
charts_path, # gr.File — Downloads tab
preview_rgb, # gr.Image — last live preview frame
)
# ─────────────────────────────────────────────────────────────
# Gradio UI Layout
# ─────────────────────────────────────────────────────────────
CSS = """
.gr-button-primary { font-size: 1.1rem !important; }
.status-box { background: #f0fdf4; border: 1px solid #86efac;
border-radius: 8px; padding: 12px; }
"""
with gr.Blocks(css=CSS, title="🦒 Wildlife Behavior Analyzer") as demo:
# ── Header ────────────────────────────────────────────────
gr.Markdown(
"""
# 🦒 Wildlife Behavior Analyzer
**AI-powered animal behavior recognition from drone footage**
Powered by [X3D-KABR](https://huggingface.co/imageomics/x3d-kabr-kinetics)
(WACV 2024) + YOLOv8 &nbsp;|&nbsp;
Detects: **Walk · Trot · Run · Graze · Browse · Head Up · Auto-Groom**
"""
)
# ── Input Panel ───────────────────────────────────────────
with gr.Row():
# Left: settings
with gr.Column(scale=1, min_width=260):
gr.Markdown("### ⚙️ Settings")
species_input = gr.Dropdown(
choices=["Giraffe", "Zebra", "Elephant",
"Deer", "Horse", "Other Wildlife"],
value="Giraffe",
label="Animal species in video"
)
frame_skip_input = gr.Slider(
minimum=1, maximum=6, value=3, step=1,
label="Frame skip (higher = faster, less detail)",
info="Process every Nth frame. 1 = all frames (slow), 6 = fastest"
)
gr.Markdown(
"""
---
### 📌 About the Model
**X3D-KABR** — trained on the KABR dataset:
10+ hours of drone footage of Kenyan wildlife.
Published at **WACV 2024**.
⚠️ Best results with **drone/aerial footage** of ungulates
(giraffe, zebra, deer, horse) at 10–50 m altitude.
---
### ⏱️ Estimated Processing Time
- ~2–4 min for a 30-second video
- ~5–8 min for a 1-minute video
"""
)
# Right: upload + analyze
with gr.Column(scale=2):
gr.Markdown("### 📤 Upload Drone Wildlife Video")
video_input = gr.Video(
label="Upload video (MP4, AVI, MOV, MKV)",
sources=["upload"],
height=320,
)
analyze_btn = gr.Button(
"🚀 Analyze Behavior",
variant="primary",
size="lg"
)
gr.Markdown("---")
# ── Output Tabs ───────────────────────────────────────────
with gr.Tabs():
with gr.Tab("📋 Summary & Metrics"):
summary_output = gr.Markdown(
"*Upload a video and click **Analyze Behavior** to see results.*"
)
with gr.Tab("📊 Charts"):
with gr.Row():
pie_raw = gr.Plot(label="Raw Prediction Budget")
pie_smooth = gr.Plot(label="Smoothed Prediction Budget (5-frame vote)")
timeline_plot = gr.Plot(label="Behavior Timeline")
with gr.Tab("🎬 Annotated Video"):
gr.Markdown("*The annotated video with bounding boxes and behavior labels.*")
video_output = gr.Video(label="Annotated Output", height=480)
with gr.Tab("⬇️ Downloads"):
gr.Markdown("### Download Results")
with gr.Row():
charts_file = gr.File(label="📈 Behavior Charts PNG")
# Note: the annotated video is also available here as a file
video_file_dl = gr.File(label="🎬 Annotated Video (MP4)")
gr.Markdown(
"_Tip: right-click the annotated video above and choose "
"**Save Video As** for the highest-quality download._"
)
with gr.Tab("🔍 Live Preview"):
gr.Markdown("*Last frame captured during processing.*")
preview_img = gr.Image(
label="Last Processed Frame",
type="numpy",
height=400,
)
# ── Wire button → handler ─────────────────────────────────
analyze_btn.click(
fn=analyze_video,
inputs=[video_input, species_input, frame_skip_input],
outputs=[
summary_output, # Tab 1
pie_raw, # Tab 2
pie_smooth, # Tab 2
timeline_plot, # Tab 2
video_output, # Tab 3
charts_file, # Tab 4 (charts PNG)
preview_img, # Tab 5
],
show_progress="full",
api_name="analyze",
)
# Also wire annotated video → download file tab
# (same path fed to both gr.Video and gr.File)
analyze_btn.click(
fn=lambda v, s, f: analyze_video(v, s, f)[4], # index 4 = video path
inputs=[video_input, species_input, frame_skip_input],
outputs=[video_file_dl],
show_progress=False,
api_name=False,
)
gr.Markdown(
"""
---
**Credits:**
[X3D-KABR model](https://huggingface.co/imageomics/x3d-kabr-kinetics) ·
[KABR Dataset](https://huggingface.co/datasets/imageomics/KABR) ·
YOLOv8 by Ultralytics · Built by rohansingh0
"""
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0", # needed for HuggingFace Spaces
server_port=7860,
show_error=True,
)