ayeshaishaq004's picture
Create app.py
690f9df verified
import torch
import json
import urllib.request
import gradio as gr
from pytorchvideo.data.encoded_video import EncodedVideo
from pytorchvideo.transforms import (
ApplyTransformToKey,
ShortSideScale,
UniformTemporalSubsample,
)
from torchvision.transforms import Compose, Lambda
from torchvision.transforms._transforms_video import (
CenterCropVideo,
NormalizeVideo,
)
# Load model
model = torch.hub.load('facebookresearch/pytorchvideo', 'slowfast_r50', pretrained=True)
model = model.eval() # Set to evaluation mode
# Constants
side_size = 256
crop_size = 256
mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]
num_frames = 32
slowfast_alpha = 4
clip_duration = 5.0
# Prepare SlowFast transform
class PackPathway(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, frames: torch.Tensor):
fast_pathway = frames
slow_pathway = torch.index_select(
frames,
1,
torch.linspace(
0, frames.shape[1] - 1, frames.shape[1] // slowfast_alpha
).long(),
)
return [slow_pathway, fast_pathway]
transform = ApplyTransformToKey(
key="video",
transform=Compose([
UniformTemporalSubsample(num_frames),
Lambda(lambda x: x / 255.0),
NormalizeVideo(mean, std),
ShortSideScale(size=side_size),
CenterCropVideo(crop_size),
PackPathway(),
]),
)
# Load Kinetics-400 class names
json_url = "https://dl.fbaipublicfiles.com/pyslowfast/dataset/class_names/kinetics_classnames.json"
json_filename = "kinetics_classnames.json"
urllib.request.urlretrieve(json_url, json_filename)
with open(json_filename, "r") as f:
kinetics_classnames = json.load(f)
kinetics_id_to_classname = {v: k.strip('"') for k, v in kinetics_classnames.items()}
def predict_activity(video_path):
video = EncodedVideo.from_path(video_path)
video_data = video.get_clip(start_sec=0, end_sec=clip_duration)
video_data = transform(video_data)
inputs = video_data["video"]
inputs = [i[None, ...] for i in inputs] # Add batch dim
with torch.no_grad():
preds = model(inputs)
probs = torch.nn.Softmax(dim=1)(preds)
top_class = probs.topk(k=1).indices[0]
class_name = kinetics_id_to_classname[int(top_class)]
return f"Top predicted label: {class_name}"
# Gradio UI
gr.Interface(
fn=predict_activity,
inputs=gr.Video(label="Upload a video"),
outputs=gr.Textbox(label="Predicted Action"),
title="Video Activity Detection with SlowFast"
).launch()