HAT / app.py
Darknsu's picture
Update app.py
adaa321 verified
import gradio as gr
import torch
import os
import tempfile
import numpy as np
from models import Model # Modify based on your actual model class
from dataset import extract_features # Or however you handle input
from eval import predict # Assume this runs inference and returns timestamps
# Load model
def load_model(checkpoint_path='checkpoint/ckp_best.pth.tar'):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
model = Model(**checkpoint['config']) # Adjust depending on how your model is initialized
model.load_state_dict(checkpoint['state_dict'])
model.eval()
return model
model = load_model()
def process_video(video_file):
# Save uploaded file
temp_path = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
with open(temp_path, "wb") as f:
f.write(video_file.read())
# Optional: convert to features using your function
features = extract_features(temp_path) # Modify if needed
# Save to temp .npz file if your pipeline needs it
npz_path = temp_path.replace(".mp4", ".npz")
np.savez(npz_path, features=features)
# Predict
predictions = predict(model, npz_path)
# Format output
results = "\n".join([
f"{label}: {start:.2f}s - {end:.2f}s"
for label, start, end in predictions
])
os.remove(temp_path)
os.remove(npz_path)
return results
demo = gr.Interface(
fn=process_video,
inputs=gr.Video(label="Upload a video"),
outputs=gr.Textbox(label="Detected Actions"),
title="Temporal Action Localization"
)
if __name__ == "__main__":
demo.launch()