Spaces:
Sleeping
Sleeping
| """ | |
| UCF-50 Action Recognition - Gradio App | |
| Deployed on HuggingFace Spaces | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| import torchvision.models as models | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image | |
| import tempfile | |
| import os | |
| class GRUModel(nn.Module): | |
| """GRU Model - 97.23% Accuracy""" | |
| def __init__(self, input_dim=2048, hidden_dim=512, num_classes=50, dropout=0.3): | |
| super(GRUModel, self).__init__() | |
| self.hidden_dim = hidden_dim | |
| self.gru = nn.GRU( | |
| input_size=input_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=1, | |
| batch_first=True, | |
| dropout=0 if dropout == 0 else dropout | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc = nn.Linear(hidden_dim, num_classes) | |
| def forward(self, x): | |
| out, hidden = self.gru(x) | |
| out = out[:, -1, :] | |
| out = self.dropout(out) | |
| out = self.fc(out) | |
| return out | |
| CLASS_NAMES = [ | |
| 'BaseballPitch', 'Basketball', 'BenchPress', 'Biking', 'Billiards', | |
| 'BreastStroke', 'CleanAndJerk', 'Diving', 'Drumming', 'Fencing', | |
| 'GolfSwing', 'HighJump', 'HorseRace', 'HorseRiding', 'HulaHoop', | |
| 'JavelinThrow', 'JugglingBalls', 'JumpRope', 'JumpingJack', 'Kayaking', | |
| 'Lunges', 'MilitaryParade', 'Mixing', 'Nunchucks', 'PizzaTossing', | |
| 'PlayingGuitar', 'PlayingPiano', 'PlayingTabla', 'PlayingViolin', 'PoleVault', | |
| 'PommelHorse', 'PullUps', 'Punch', 'PushUps', 'RockClimbingIndoor', | |
| 'RopeClimbing', 'Rowing', 'SalsaSpin', 'SkateBoarding', 'Skiing', | |
| 'Skijet', 'SoccerJuggling', 'Swing', 'TaiChi', 'TennisSwing', | |
| 'ThrowDiscus', 'TrampolineJumping', 'VolleyballSpiking', 'WalkingWithDog', 'YoYo' | |
| ] | |
| print("Loading models...") | |
| # Load feature extractor (ResNet50) | |
| resnet = models.resnet50(pretrained=True) | |
| feature_extractor = nn.Sequential(*list(resnet.children())[:-1]) | |
| feature_extractor.eval() | |
| # Load action recognition model (GRU) | |
| model = GRUModel( | |
| input_dim=2048, | |
| hidden_dim=512, | |
| num_classes=50, | |
| dropout=0.3 | |
| ) | |
| # Load trained weights | |
| if os.path.exists('best_model.pth'): | |
| try: | |
| checkpoint = torch.load('best_model.pth', map_location='cpu') | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| print("✓ Trained model loaded successfully!") | |
| except Exception as e: | |
| print(f" Could not load trained weights: {str(e)}") | |
| else: | |
| print(" No trained model found. Using random initialization.") | |
| model.eval() | |
| print("Models loaded!") | |
| def extract_frames(video_path, num_frames=32): | |
| """Extract uniformly sampled frames from video""" | |
| cap = cv2.VideoCapture(video_path) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total_frames == 0: | |
| cap.release() | |
| return None | |
| if total_frames < num_frames: | |
| frame_indices = list(range(total_frames)) + [total_frames - 1] * (num_frames - total_frames) | |
| else: | |
| frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) | |
| frames = [] | |
| for idx in frame_indices: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) | |
| ret, frame = cap.read() | |
| if ret: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| cap.release() | |
| while len(frames) < num_frames: | |
| frames.append(frames[-1] if frames else Image.new('RGB', (224, 224))) | |
| return frames[:num_frames] | |
| def preprocess_frames(frames): | |
| """Preprocess frames for model input""" | |
| transform = transforms.Compose([ | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| return torch.stack([transform(frame) for frame in frames]) | |
| def convert_video_for_web(video_path): | |
| """Convert video to web-compatible format""" | |
| if video_path is None: | |
| return None | |
| try: | |
| # Create temp file for converted video | |
| temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name | |
| # Open original video | |
| cap = cv2.VideoCapture(video_path) | |
| # Get video properties | |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # Define codec and create VideoWriter | |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
| out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height)) | |
| # Read and write all frames | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| out.write(frame) | |
| cap.release() | |
| out.release() | |
| return temp_output | |
| except Exception as e: | |
| print(f"Video conversion failed: {e}") | |
| return video_path # Return original if conversion fails | |
| def predict_action(video_path): | |
| """Main prediction function""" | |
| if video_path is None: | |
| return ( | |
| None, | |
| "Please upload a video first.", | |
| None, None, None, None, | |
| gr.update(visible=False), | |
| None # Add this for converted video | |
| ) | |
| try: | |
| # Convert video for web playback | |
| web_video = convert_video_for_web(video_path) | |
| # Extract frames (still use original path for analysis) | |
| frames = extract_frames(video_path, num_frames=32) | |
| if frames is None or len(frames) == 0: | |
| return ( | |
| None, | |
| "Error: Could not extract frames from video. Please try another video.", | |
| None, None, None, None, | |
| gr.update(visible=False), | |
| None | |
| ) | |
| # Preprocess | |
| frames_tensor = preprocess_frames(frames) | |
| # Extract features | |
| with torch.no_grad(): | |
| features = feature_extractor(frames_tensor) | |
| features = features.view(features.size(0), -1) | |
| features = features.unsqueeze(0) | |
| # Predict | |
| outputs = model(features) | |
| probs = F.softmax(outputs, dim=1) | |
| top5_probs, top5_indices = torch.topk(probs, 5) | |
| # Format results | |
| top5_probs = top5_probs[0].numpy() | |
| top5_indices = top5_indices[0].numpy() | |
| # Create prediction dictionary for Gradio | |
| predictions = { | |
| CLASS_NAMES[idx]: float(prob) | |
| for idx, prob in zip(top5_indices, top5_probs) | |
| } | |
| # Create result text | |
| result_text = f"**Predicted Action:** {CLASS_NAMES[top5_indices[0]]}\n\n" | |
| result_text += f"**Confidence:** {top5_probs[0] * 100:.2f}%\n\n" | |
| result_text += "**Top 5 Predictions:**\n\n" | |
| for i, (idx, prob) in enumerate(zip(top5_indices, top5_probs), 1): | |
| result_text += f"{i}. {CLASS_NAMES[idx]}: {prob * 100:.2f}%\n" | |
| # Get sample frames for display | |
| sample_frames = [frames[i] for i in [0, 10, 20, 31]] | |
| return ( | |
| predictions, | |
| result_text, | |
| sample_frames[0], | |
| sample_frames[1], | |
| sample_frames[2], | |
| sample_frames[3], | |
| gr.update(visible=True), | |
| web_video # Return converted video | |
| ) | |
| except Exception as e: | |
| return ( | |
| None, | |
| f"Error processing video: {str(e)}", | |
| None, None, None, None, | |
| gr.update(visible=False), | |
| None | |
| ) | |
| # Custom CSS | |
| css = """ | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: auto; | |
| } | |
| #upload-zone { | |
| border: 2px dashed #d1d5db; | |
| border-radius: 12px; | |
| padding: 2rem; | |
| background: #f9fafb; | |
| transition: all 0.3s ease; | |
| } | |
| #upload-zone:hover { | |
| border-color: #2563eb; | |
| background: #eff6ff; | |
| } | |
| .primary-button { | |
| background: #2563eb !important; | |
| border: none !important; | |
| font-weight: 600 !important; | |
| font-size: 1.1em !important; | |
| padding: 0.8rem 2rem !important; | |
| } | |
| #title-text { | |
| font-size: 2.5em; | |
| font-weight: 700; | |
| color: #111827; | |
| margin-bottom: 0.3rem; | |
| } | |
| #subtitle-text { | |
| color: #6b7280; | |
| font-size: 1.1em; | |
| margin-bottom: 2rem; | |
| } | |
| .results-container { | |
| background: #f9fafb; | |
| border-radius: 12px; | |
| padding: 1.5rem; | |
| border: 1px solid #e5e7eb; | |
| } | |
| .frame-container img { | |
| border-radius: 8px; | |
| border: 1px solid #e5e7eb; | |
| } | |
| """ | |
| # Create Gradio interface | |
| with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: | |
| # Header | |
| gr.Markdown("<div id='title-text'>Video Action Recognition</div>") | |
| gr.Markdown("<div id='subtitle-text'>GRU-based sequence model · 97.23% accuracy on UCF-50</div>") | |
| # Model details (collapsed) | |
| with gr.Accordion("Model Details", open=False): | |
| gr.Markdown(""" | |
| **Architecture:** ResNet50 feature extractor + GRU sequence model | |
| **Performance:** 97.23% Top-1 accuracy · 99.85% Top-5 accuracy | |
| **Dataset:** UCF-50 (50 human action categories) | |
| **Parameters:** 3.96M trainable parameters | |
| **Supported Actions:** BaseballPitch, Basketball, BenchPress, Biking, Billiards, BreastStroke, CleanAndJerk, Diving, Drumming, Fencing, GolfSwing, HighJump, HorseRace, HorseRiding, HulaHoop, JavelinThrow, JugglingBalls, JumpRope, JumpingJack, Kayaking, Lunges, MilitaryParade, Mixing, Nunchucks, PizzaTossing, PlayingGuitar, PlayingPiano, PlayingTabla, PlayingViolin, PoleVault, PommelHorse, PullUps, Punch, PushUps, RockClimbingIndoor, RopeClimbing, Rowing, SalsaSpin, SkateBoarding, Skiing, Skijet, SoccerJuggling, Swing, TaiChi, TennisSwing, ThrowDiscus, TrampolineJumping, VolleyballSpiking, WalkingWithDog, YoYo | |
| """) | |
| gr.Markdown("---") | |
| # Main interface | |
| with gr.Row(): | |
| # Left column - Upload | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Upload Video") | |
| with gr.Group(elem_id="upload-zone"): | |
| video_input = gr.File( | |
| label="Drop video file here or click to upload", | |
| file_types=["video"], | |
| type="filepath" | |
| ) | |
| # Add a second video component for playback | |
| video_preview = gr.Video( | |
| label="Video Preview", | |
| visible=False, | |
| interactive=False, | |
| show_label=False | |
| ) | |
| predict_button = gr.Button( | |
| "Analyze Video", | |
| variant="primary", | |
| size="lg", | |
| elem_classes="primary-button" | |
| ) | |
| gr.Markdown(""" | |
| **Requirements:** | |
| - Clear view of human performing action | |
| - 3-10 seconds recommended | |
| - Formats: MP4, AVI, MOV | |
| """) | |
| # Right column - Results | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Results") | |
| with gr.Group(elem_classes="results-container"): | |
| result_text = gr.Markdown("*Upload a video and click 'Analyze Video' to see predictions*") | |
| prediction_chart = gr.Label( | |
| label="Confidence Distribution", | |
| num_top_classes=5, | |
| show_label=True | |
| ) | |
| # Frames section (hidden initially) | |
| with gr.Column(visible=False) as frames_container: | |
| gr.Markdown("### Extracted Frames") | |
| gr.Markdown("*Sample frames used for analysis*") | |
| with gr.Row(): | |
| frame1 = gr.Image(label="", show_label=False, elem_classes="frame-container") | |
| frame2 = gr.Image(label="", show_label=False, elem_classes="frame-container") | |
| frame3 = gr.Image(label="", show_label=False, elem_classes="frame-container") | |
| frame4 = gr.Image(label="", show_label=False, elem_classes="frame-container") | |
| # Connect prediction function | |
| predict_button.click( | |
| fn=predict_action, | |
| inputs=video_input, | |
| outputs=[ | |
| prediction_chart, | |
| result_text, | |
| frame1, | |
| frame2, | |
| frame3, | |
| frame4, | |
| frames_container, | |
| video_preview | |
| ] | |
| ) | |
| # Footer | |
| gr.Markdown("---") | |
| gr.Markdown(""" | |
| <div style='text-align: center; color: #6b7280; font-size: 0.9em; padding: 1rem 0'> | |
| <a href='https://github.com/NoobML/ucf50-action-recognition' | |
| style='color: #2563eb; text-decoration: none; font-weight: 500'> | |
| View Source Code | |
| </a> | |
| <span style='margin: 0 1em; color: #d1d5db'>·</span> | |
| <span>PyTorch · ResNet50 · GRU</span> | |
| </div> | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch() |