Spaces:
Runtime error
Runtime error
| import torch | |
| import json | |
| import urllib.request | |
| import gradio as gr | |
| from pytorchvideo.data.encoded_video import EncodedVideo | |
| from pytorchvideo.transforms import ( | |
| ApplyTransformToKey, | |
| ShortSideScale, | |
| UniformTemporalSubsample, | |
| ) | |
| from torchvision.transforms import Compose, Lambda | |
| from torchvision.transforms._transforms_video import ( | |
| CenterCropVideo, | |
| NormalizeVideo, | |
| ) | |
| # Load model | |
| model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True) | |
| model = model.eval() # Set to evaluation mode | |
| # Constants | |
| side_size = 256 | |
| crop_size = 256 | |
| mean = [0.45, 0.45, 0.45] | |
| std = [0.225, 0.225, 0.225] | |
| num_frames = 32 | |
| slowfast_alpha = 4 | |
| clip_duration = 5.0 | |
| # Prepare SlowFast transform | |
| class PackPathway(torch.nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, frames: torch.Tensor): | |
| fast_pathway = frames | |
| slow_pathway = torch.index_select( | |
| frames, | |
| 1, | |
| torch.linspace( | |
| 0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha | |
| ).long(), | |
| ) | |
| return [slow_pathway, fast_pathway] | |
| transform = ApplyTransformToKey( | |
| key="video", | |
| transform=Compose([ | |
| UniformTemporalSubsample(num_frames), | |
| Lambda(lambda x: x / 255.0), | |
| NormalizeVideo(mean, std), | |
| ShortSideScale(size=side_size), | |
| CenterCropVideo(crop_size), | |
| PackPathway(), | |
| ]), | |
| ) | |
| # Load Kinetics-400 class names | |
| json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json" | |
| json_filename = "kinetics_classnames.json" | |
| urllib.request.urlretrieve(json_url, json_filename) | |
| with open(json_filename, "r") as f: | |
| kinetics_classnames = json.load(f) | |
| kinetics_id_to_classname = {v: k.strip('"') for k, v in kinetics_classnames.items()} | |
| def predict_activity(video_path): | |
| video = EncodedVideo.from_path(video_path) | |
| video_data = video.get_clip(start_sec=0, end_sec=clip_duration) | |
| video_data = transform(video_data) | |
| inputs = video_data["video"] | |
| inputs = [i[None, ...] for i in inputs] # Add batch dim | |
| with torch.no_grad(): | |
| preds = model(inputs) | |
| probs = torch.nn.Softmax(dim=1)(preds) | |
| top_class = probs.topk(k=1).indices[0] | |
| class_name = kinetics_id_to_classname[int(top_class)] | |
| return f"Top predicted label: {class_name}" | |
| # Gradio UI | |
| gr.Interface( | |
| fn=predict_activity, | |
| inputs=gr.Video(label="Upload a video"), | |
| outputs=gr.Textbox(label="Predicted Action"), | |
| title="Video Activity Detection with SlowFast" | |
| ).launch() |