N00bML's picture
Update app.py
fd036b3 verified
"""
UCF-50 Action Recognition - Gradio App
Deployed on HuggingFace Spaces
"""
import gradio as gr
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.models as models
import cv2
import numpy as np
from PIL import Image
import tempfile
import os
class GRUModel(nn.Module):
"""GRU Model - 97.23% Accuracy"""
def __init__(self, input_dim=2048, hidden_dim=512, num_classes=50, dropout=0.3):
super(GRUModel, self).__init__()
self.hidden_dim = hidden_dim
self.gru = nn.GRU(
input_size=input_dim,
hidden_size=hidden_dim,
num_layers=1,
batch_first=True,
dropout=0 if dropout == 0 else dropout
)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(hidden_dim, num_classes)
def forward(self, x):
out, hidden = self.gru(x)
out = out[:, -1, :]
out = self.dropout(out)
out = self.fc(out)
return out
CLASS_NAMES = [
'BaseballPitch', 'Basketball', 'BenchPress', 'Biking', 'Billiards',
'BreastStroke', 'CleanAndJerk', 'Diving', 'Drumming', 'Fencing',
'GolfSwing', 'HighJump', 'HorseRace', 'HorseRiding', 'HulaHoop',
'JavelinThrow', 'JugglingBalls', 'JumpRope', 'JumpingJack', 'Kayaking',
'Lunges', 'MilitaryParade', 'Mixing', 'Nunchucks', 'PizzaTossing',
'PlayingGuitar', 'PlayingPiano', 'PlayingTabla', 'PlayingViolin', 'PoleVault',
'PommelHorse', 'PullUps', 'Punch', 'PushUps', 'RockClimbingIndoor',
'RopeClimbing', 'Rowing', 'SalsaSpin', 'SkateBoarding', 'Skiing',
'Skijet', 'SoccerJuggling', 'Swing', 'TaiChi', 'TennisSwing',
'ThrowDiscus', 'TrampolineJumping', 'VolleyballSpiking', 'WalkingWithDog', 'YoYo'
]
print("Loading models...")
# Load feature extractor (ResNet50)
resnet = models.resnet50(pretrained=True)
feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
feature_extractor.eval()
# Load action recognition model (GRU)
model = GRUModel(
input_dim=2048,
hidden_dim=512,
num_classes=50,
dropout=0.3
)
# Load trained weights
if os.path.exists('best_model.pth'):
try:
checkpoint = torch.load('best_model.pth', map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
print("✓ Trained model loaded successfully!")
except Exception as e:
print(f" Could not load trained weights: {str(e)}")
else:
print(" No trained model found. Using random initialization.")
model.eval()
print("Models loaded!")
def extract_frames(video_path, num_frames=32):
"""Extract uniformly sampled frames from video"""
cap = cv2.VideoCapture(video_path)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total_frames == 0:
cap.release()
return None
if total_frames < num_frames:
frame_indices = list(range(total_frames)) + [total_frames - 1] * (num_frames - total_frames)
else:
frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
frames = []
for idx in frame_indices:
cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
ret, frame = cap.read()
if ret:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
frames.append(Image.fromarray(frame_rgb))
cap.release()
while len(frames) < num_frames:
frames.append(frames[-1] if frames else Image.new('RGB', (224, 224)))
return frames[:num_frames]
def preprocess_frames(frames):
"""Preprocess frames for model input"""
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
return torch.stack([transform(frame) for frame in frames])
def convert_video_for_web(video_path):
"""Convert video to web-compatible format"""
if video_path is None:
return None
try:
# Create temp file for converted video
temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
# Open original video
cap = cv2.VideoCapture(video_path)
# Get video properties
fps = int(cap.get(cv2.CAP_PROP_FPS))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
# Define codec and create VideoWriter
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height))
# Read and write all frames
while True:
ret, frame = cap.read()
if not ret:
break
out.write(frame)
cap.release()
out.release()
return temp_output
except Exception as e:
print(f"Video conversion failed: {e}")
return video_path # Return original if conversion fails
def predict_action(video_path):
"""Main prediction function"""
if video_path is None:
return (
None,
"Please upload a video first.",
None, None, None, None,
gr.update(visible=False),
None # Add this for converted video
)
try:
# Convert video for web playback
web_video = convert_video_for_web(video_path)
# Extract frames (still use original path for analysis)
frames = extract_frames(video_path, num_frames=32)
if frames is None or len(frames) == 0:
return (
None,
"Error: Could not extract frames from video. Please try another video.",
None, None, None, None,
gr.update(visible=False),
None
)
# Preprocess
frames_tensor = preprocess_frames(frames)
# Extract features
with torch.no_grad():
features = feature_extractor(frames_tensor)
features = features.view(features.size(0), -1)
features = features.unsqueeze(0)
# Predict
outputs = model(features)
probs = F.softmax(outputs, dim=1)
top5_probs, top5_indices = torch.topk(probs, 5)
# Format results
top5_probs = top5_probs[0].numpy()
top5_indices = top5_indices[0].numpy()
# Create prediction dictionary for Gradio
predictions = {
CLASS_NAMES[idx]: float(prob)
for idx, prob in zip(top5_indices, top5_probs)
}
# Create result text
result_text = f"**Predicted Action:** {CLASS_NAMES[top5_indices[0]]}\n\n"
result_text += f"**Confidence:** {top5_probs[0] * 100:.2f}%\n\n"
result_text += "**Top 5 Predictions:**\n\n"
for i, (idx, prob) in enumerate(zip(top5_indices, top5_probs), 1):
result_text += f"{i}. {CLASS_NAMES[idx]}: {prob * 100:.2f}%\n"
# Get sample frames for display
sample_frames = [frames[i] for i in [0, 10, 20, 31]]
return (
predictions,
result_text,
sample_frames[0],
sample_frames[1],
sample_frames[2],
sample_frames[3],
gr.update(visible=True),
web_video # Return converted video
)
except Exception as e:
return (
None,
f"Error processing video: {str(e)}",
None, None, None, None,
gr.update(visible=False),
None
)
# Custom CSS
css = """
.gradio-container {
max-width: 1400px !important;
margin: auto;
}
#upload-zone {
border: 2px dashed #d1d5db;
border-radius: 12px;
padding: 2rem;
background: #f9fafb;
transition: all 0.3s ease;
}
#upload-zone:hover {
border-color: #2563eb;
background: #eff6ff;
}
.primary-button {
background: #2563eb !important;
border: none !important;
font-weight: 600 !important;
font-size: 1.1em !important;
padding: 0.8rem 2rem !important;
}
#title-text {
font-size: 2.5em;
font-weight: 700;
color: #111827;
margin-bottom: 0.3rem;
}
#subtitle-text {
color: #6b7280;
font-size: 1.1em;
margin-bottom: 2rem;
}
.results-container {
background: #f9fafb;
border-radius: 12px;
padding: 1.5rem;
border: 1px solid #e5e7eb;
}
.frame-container img {
border-radius: 8px;
border: 1px solid #e5e7eb;
}
"""
# Create Gradio interface
with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo:
# Header
gr.Markdown("<div id='title-text'>Video Action Recognition</div>")
gr.Markdown("<div id='subtitle-text'>GRU-based sequence model · 97.23% accuracy on UCF-50</div>")
# Model details (collapsed)
with gr.Accordion("Model Details", open=False):
gr.Markdown("""
**Architecture:** ResNet50 feature extractor + GRU sequence model
**Performance:** 97.23% Top-1 accuracy · 99.85% Top-5 accuracy
**Dataset:** UCF-50 (50 human action categories)
**Parameters:** 3.96M trainable parameters
**Supported Actions:** BaseballPitch, Basketball, BenchPress, Biking, Billiards, BreastStroke, CleanAndJerk, Diving, Drumming, Fencing, GolfSwing, HighJump, HorseRace, HorseRiding, HulaHoop, JavelinThrow, JugglingBalls, JumpRope, JumpingJack, Kayaking, Lunges, MilitaryParade, Mixing, Nunchucks, PizzaTossing, PlayingGuitar, PlayingPiano, PlayingTabla, PlayingViolin, PoleVault, PommelHorse, PullUps, Punch, PushUps, RockClimbingIndoor, RopeClimbing, Rowing, SalsaSpin, SkateBoarding, Skiing, Skijet, SoccerJuggling, Swing, TaiChi, TennisSwing, ThrowDiscus, TrampolineJumping, VolleyballSpiking, WalkingWithDog, YoYo
""")
gr.Markdown("---")
# Main interface
with gr.Row():
# Left column - Upload
with gr.Column(scale=1):
gr.Markdown("### Upload Video")
with gr.Group(elem_id="upload-zone"):
video_input = gr.File(
label="Drop video file here or click to upload",
file_types=["video"],
type="filepath"
)
# Add a second video component for playback
video_preview = gr.Video(
label="Video Preview",
visible=False,
interactive=False,
show_label=False
)
predict_button = gr.Button(
"Analyze Video",
variant="primary",
size="lg",
elem_classes="primary-button"
)
gr.Markdown("""
**Requirements:**
- Clear view of human performing action
- 3-10 seconds recommended
- Formats: MP4, AVI, MOV
""")
# Right column - Results
with gr.Column(scale=1):
gr.Markdown("### Results")
with gr.Group(elem_classes="results-container"):
result_text = gr.Markdown("*Upload a video and click 'Analyze Video' to see predictions*")
prediction_chart = gr.Label(
label="Confidence Distribution",
num_top_classes=5,
show_label=True
)
# Frames section (hidden initially)
with gr.Column(visible=False) as frames_container:
gr.Markdown("### Extracted Frames")
gr.Markdown("*Sample frames used for analysis*")
with gr.Row():
frame1 = gr.Image(label="", show_label=False, elem_classes="frame-container")
frame2 = gr.Image(label="", show_label=False, elem_classes="frame-container")
frame3 = gr.Image(label="", show_label=False, elem_classes="frame-container")
frame4 = gr.Image(label="", show_label=False, elem_classes="frame-container")
# Connect prediction function
predict_button.click(
fn=predict_action,
inputs=video_input,
outputs=[
prediction_chart,
result_text,
frame1,
frame2,
frame3,
frame4,
frames_container,
video_preview
]
)
# Footer
gr.Markdown("---")
gr.Markdown("""
<div style='text-align: center; color: #6b7280; font-size: 0.9em; padding: 1rem 0'>
<a href='https://github.com/NoobML/ucf50-action-recognition'
style='color: #2563eb; text-decoration: none; font-weight: 500'>
View Source Code
</a>
<span style='margin: 0 1em; color: #d1d5db'>·</span>
<span>PyTorch · ResNet50 · GRU</span>
</div>
""")
# Launch
if __name__ == "__main__":
demo.launch()