""" UCF-50 Action Recognition - Gradio App Deployed on HuggingFace Spaces """ import gradio as gr import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms import torchvision.models as models import cv2 import numpy as np from PIL import Image import tempfile import os class GRUModel(nn.Module): """GRU Model - 97.23% Accuracy""" def __init__(self, input_dim=2048, hidden_dim=512, num_classes=50, dropout=0.3): super(GRUModel, self).__init__() self.hidden_dim = hidden_dim self.gru = nn.GRU( input_size=input_dim, hidden_size=hidden_dim, num_layers=1, batch_first=True, dropout=0 if dropout == 0 else dropout ) self.dropout = nn.Dropout(dropout) self.fc = nn.Linear(hidden_dim, num_classes) def forward(self, x): out, hidden = self.gru(x) out = out[:, -1, :] out = self.dropout(out) out = self.fc(out) return out CLASS_NAMES = [ 'BaseballPitch', 'Basketball', 'BenchPress', 'Biking', 'Billiards', 'BreastStroke', 'CleanAndJerk', 'Diving', 'Drumming', 'Fencing', 'GolfSwing', 'HighJump', 'HorseRace', 'HorseRiding', 'HulaHoop', 'JavelinThrow', 'JugglingBalls', 'JumpRope', 'JumpingJack', 'Kayaking', 'Lunges', 'MilitaryParade', 'Mixing', 'Nunchucks', 'PizzaTossing', 'PlayingGuitar', 'PlayingPiano', 'PlayingTabla', 'PlayingViolin', 'PoleVault', 'PommelHorse', 'PullUps', 'Punch', 'PushUps', 'RockClimbingIndoor', 'RopeClimbing', 'Rowing', 'SalsaSpin', 'SkateBoarding', 'Skiing', 'Skijet', 'SoccerJuggling', 'Swing', 'TaiChi', 'TennisSwing', 'ThrowDiscus', 'TrampolineJumping', 'VolleyballSpiking', 'WalkingWithDog', 'YoYo' ] print("Loading models...") # Load feature extractor (ResNet50) resnet = models.resnet50(pretrained=True) feature_extractor = nn.Sequential(*list(resnet.children())[:-1]) feature_extractor.eval() # Load action recognition model (GRU) model = GRUModel( input_dim=2048, hidden_dim=512, num_classes=50, dropout=0.3 ) # Load trained weights if os.path.exists('best_model.pth'): try: checkpoint = torch.load('best_model.pth', map_location='cpu') model.load_state_dict(checkpoint['model_state_dict']) print("✓ Trained model loaded successfully!") except Exception as e: print(f" Could not load trained weights: {str(e)}") else: print(" No trained model found. Using random initialization.") model.eval() print("Models loaded!") def extract_frames(video_path, num_frames=32): """Extract uniformly sampled frames from video""" cap = cv2.VideoCapture(video_path) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total_frames == 0: cap.release() return None if total_frames < num_frames: frame_indices = list(range(total_frames)) + [total_frames - 1] * (num_frames - total_frames) else: frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int) frames = [] for idx in frame_indices: cap.set(cv2.CAP_PROP_POS_FRAMES, idx) ret, frame = cap.read() if ret: frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) frames.append(Image.fromarray(frame_rgb)) cap.release() while len(frames) < num_frames: frames.append(frames[-1] if frames else Image.new('RGB', (224, 224))) return frames[:num_frames] def preprocess_frames(frames): """Preprocess frames for model input""" transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return torch.stack([transform(frame) for frame in frames]) def convert_video_for_web(video_path): """Convert video to web-compatible format""" if video_path is None: return None try: # Create temp file for converted video temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name # Open original video cap = cv2.VideoCapture(video_path) # Get video properties fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # Define codec and create VideoWriter fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_output, fourcc, fps, (width, height)) # Read and write all frames while True: ret, frame = cap.read() if not ret: break out.write(frame) cap.release() out.release() return temp_output except Exception as e: print(f"Video conversion failed: {e}") return video_path # Return original if conversion fails def predict_action(video_path): """Main prediction function""" if video_path is None: return ( None, "Please upload a video first.", None, None, None, None, gr.update(visible=False), None # Add this for converted video ) try: # Convert video for web playback web_video = convert_video_for_web(video_path) # Extract frames (still use original path for analysis) frames = extract_frames(video_path, num_frames=32) if frames is None or len(frames) == 0: return ( None, "Error: Could not extract frames from video. Please try another video.", None, None, None, None, gr.update(visible=False), None ) # Preprocess frames_tensor = preprocess_frames(frames) # Extract features with torch.no_grad(): features = feature_extractor(frames_tensor) features = features.view(features.size(0), -1) features = features.unsqueeze(0) # Predict outputs = model(features) probs = F.softmax(outputs, dim=1) top5_probs, top5_indices = torch.topk(probs, 5) # Format results top5_probs = top5_probs[0].numpy() top5_indices = top5_indices[0].numpy() # Create prediction dictionary for Gradio predictions = { CLASS_NAMES[idx]: float(prob) for idx, prob in zip(top5_indices, top5_probs) } # Create result text result_text = f"**Predicted Action:** {CLASS_NAMES[top5_indices[0]]}\n\n" result_text += f"**Confidence:** {top5_probs[0] * 100:.2f}%\n\n" result_text += "**Top 5 Predictions:**\n\n" for i, (idx, prob) in enumerate(zip(top5_indices, top5_probs), 1): result_text += f"{i}. {CLASS_NAMES[idx]}: {prob * 100:.2f}%\n" # Get sample frames for display sample_frames = [frames[i] for i in [0, 10, 20, 31]] return ( predictions, result_text, sample_frames[0], sample_frames[1], sample_frames[2], sample_frames[3], gr.update(visible=True), web_video # Return converted video ) except Exception as e: return ( None, f"Error processing video: {str(e)}", None, None, None, None, gr.update(visible=False), None ) # Custom CSS css = """ .gradio-container { max-width: 1400px !important; margin: auto; } #upload-zone { border: 2px dashed #d1d5db; border-radius: 12px; padding: 2rem; background: #f9fafb; transition: all 0.3s ease; } #upload-zone:hover { border-color: #2563eb; background: #eff6ff; } .primary-button { background: #2563eb !important; border: none !important; font-weight: 600 !important; font-size: 1.1em !important; padding: 0.8rem 2rem !important; } #title-text { font-size: 2.5em; font-weight: 700; color: #111827; margin-bottom: 0.3rem; } #subtitle-text { color: #6b7280; font-size: 1.1em; margin-bottom: 2rem; } .results-container { background: #f9fafb; border-radius: 12px; padding: 1.5rem; border: 1px solid #e5e7eb; } .frame-container img { border-radius: 8px; border: 1px solid #e5e7eb; } """ # Create Gradio interface with gr.Blocks(css=css, theme=gr.themes.Soft()) as demo: # Header gr.Markdown("