|
|
import gradio as gr |
|
|
import torch |
|
|
import os |
|
|
import tempfile |
|
|
import numpy as np |
|
|
from models import Model |
|
|
from dataset import extract_features |
|
|
from eval import predict |
|
|
|
|
|
|
|
|
def load_model(checkpoint_path='checkpoint/ckp_best.pth.tar'): |
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
model = Model(**checkpoint['config']) |
|
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
model = load_model() |
|
|
|
|
|
def process_video(video_file): |
|
|
|
|
|
temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name |
|
|
with open(temp_path, "wb") as f: |
|
|
f.write(video_file.read()) |
|
|
|
|
|
|
|
|
features = extract_features(temp_path) |
|
|
|
|
|
npz_path = temp_path.replace(".mp4", ".npz") |
|
|
np.savez(npz_path, features=features) |
|
|
|
|
|
|
|
|
predictions = predict(model, npz_path) |
|
|
|
|
|
|
|
|
results = "\n".join([ |
|
|
f"{label}: {start:.2f}s - {end:.2f}s" |
|
|
for label, start, end in predictions |
|
|
]) |
|
|
|
|
|
os.remove(temp_path) |
|
|
os.remove(npz_path) |
|
|
return results |
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=process_video, |
|
|
inputs=gr.Video(label="Upload a video"), |
|
|
outputs=gr.Textbox(label="Detected Actions"), |
|
|
title="Temporal Action Localization" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|