Mordehay's picture
init2
20a9b6d
import gradio as gr
import torch
import torchvision
import torch.nn as nn
import cv2
import numpy as np
import json
import random
import math
import os
# --- 1. HELPER FUNCTIONS ---
def resize_keep_ar(frame, target_size=112):
h, w = frame.shape[:2]
scale = target_size / max(h, w)
nh, nw = int(h * scale), int(w * scale)
frame = cv2.resize(frame, (nw, nh))
# pad to square
top = (target_size - nh) // 2
bottom = target_size - nh - top
left = (target_size - nw) // 2
right = target_size - nw - left
frame = cv2.copyMakeBorder(frame, top, bottom, left, right,
cv2.BORDER_CONSTANT, value=0)
return frame
def load_video_frames(path, max_duration=8, target_fps=5):
target_size = 200
cap = cv2.VideoCapture(path)
original_fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if original_fps <= 0 or math.isnan(original_fps):
original_fps = 30.0
# --- RANDOM SAMPLING LOGIC START ---
start_frame = 0
end_frame = frame_count # Default to reading the whole video
if max_duration is not None:
max_frames = int(max_duration * original_fps)
# If video is longer than max_duration, pick a random start point
if frame_count > max_frames:
# Random integer between 0 and (total - duration)
start_frame = random.randint(0, frame_count - max_frames)
end_frame = start_frame + max_frames
# OPTIMIZATION: Jump directly to the start frame
cap.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
# --- RANDOM SAMPLING LOGIC END ---
step = original_fps / target_fps
frames = []
# Since we jumped (or started at 0), our current read position corresponds to start_frame
current_frame_idx = start_frame
frames_collected = 0
while True:
ret, frame = cap.read()
if not ret:
break
# Stop reading if we've reached the end of our sampled clip
if current_frame_idx >= end_frame:
break
# Calculate the target frame index we need next
# It is relative to where we started: start_frame + (n * step)
target_idx_absolute = start_frame + (frames_collected * step)
if current_frame_idx >= target_idx_absolute:
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
frame = resize_keep_ar(frame, target_size)
# Normalize and add channel dim
frame = torch.tensor(frame, dtype=torch.float32).unsqueeze(0) / 255.0
frames.append(frame)
frames_collected += 1
current_frame_idx += 1
cap.release()
if not frames:
# Return a placeholder if video failed
return torch.zeros((1, 1, target_size, target_size)), target_fps
return torch.stack(frames, dim=1) # (C, T, H, W)
# --- 2. MODEL SETUP ---
# Load Mappings
with open("class_mappings.json", "r") as f:
data = json.load(f)
idx_to_class = data["idx_to_class"]
num_classes = len(idx_to_class)
# Define Model Architecture
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torchvision.models.video.r2plus1d_18(weights=None) # No need for internet weights, we load ours
# Modify first layer for Grayscale
old_layer = model.stem[0]
new_layer = nn.Conv3d(
in_channels=1,
out_channels=old_layer.out_channels,
kernel_size=old_layer.kernel_size,
stride=old_layer.stride,
padding=old_layer.padding,
bias=old_layer.bias
)
model.stem[0] = new_layer
# Modify final layer
model.fc = nn.Linear(model.fc.in_features, num_classes)
# Load Weights
if os.path.exists("r2plus1d_18_best_model.pt"):
checkpoint = torch.load("r2plus1d_18_best_model.pt", map_location=device)
# Check if checkpoint is full dict or state_dict
if "model_state" in checkpoint:
model.load_state_dict(checkpoint["model_state"])
else:
model.load_state_dict(checkpoint)
else:
print("Warning: Model weights not found.")
model = model.to(device)
model.eval()
# --- 3. INFERENCE FUNCTION ---
def predict(video_path):
if video_path is None:
return None
# Process Video
video_tensor = load_video_frames(video_path)
video_tensor = video_tensor.to(device)
with torch.no_grad():
# Add batch dimension: (1, C, T, H, W)
output = model(video_tensor.unsqueeze(0))
probabilities = torch.nn.functional.softmax(output[0], dim=0)
# Create a dictionary for Gradio Label {Class: Probability}
confidences = {idx_to_class[str(i)]: float(probabilities[i]) for i in range(len(probabilities))}
return confidences
# --- 4. GRADIO INTERFACE ---
interface = gr.Interface(
fn=predict,
inputs=gr.Video(
label="Upload or Record Video",
sources=["upload", "webcam"], # <--- THIS ENABLES THE CAMERA
format="mp4" # Ensures the recorded video is saved in a format OpenCV likes
),
outputs=gr.Label(num_top_classes=3, label="Predictions"),
title="Video Action Classifier",
description="Upload a video or record yourself to classify the action."
)
if __name__ == "__main__":
interface.launch()