AZIIIIIIIIZ's picture
Update app.py
77b2228 verified
raw
history blame
2.08 kB
import os
from operator import itemgetter
import gradio as gr
from mmaction.apis import init_recognizer, inference_recognizer
CONFIG_FILE = 'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
CHECKPOINT_FILE = 'checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth'
LABEL_FILE = 'tools/data/kinetics/label_map_k400.txt'
def load_labels(path):
if os.path.exists(path):
with open(path, 'r') as f:
return [x.strip() for x in f if x.strip()]
return None
def build_model():
if not os.path.exists(CHECKPOINT_FILE):
raise FileNotFoundError(f'Checkpoint not found at {CHECKPOINT_FILE}')
return init_recognizer(CONFIG_FILE, CHECKPOINT_FILE, device='cpu')
print('Initializing model...')
try:
model = build_model()
print('βœ… Model loaded successfully!')
except Exception as e:
print(f'❌ Error loading model: {e}')
model = None
labels = load_labels(LABEL_FILE)
def analyze_video(video_path: str):
try:
if video_path is None:
return 'Please upload a video file.'
if model is None:
return '⚠️ Model not loaded. Check logs for details.'
result = inference_recognizer(model, video_path)
pred_scores = result.pred_score.tolist()
score_sorted = sorted(zip(range(len(pred_scores)), pred_scores), key=itemgetter(1), reverse=True)
top5 = score_sorted[:5]
lines = []
for idx, score in top5:
name = labels[idx] if labels and idx < len(labels) else f'class_{idx}'
lines.append(f'{name}: {score}')
return '\n'.join(lines)
except Exception as e:
return f'❌ Error processing video: {str(e)}'
demo = gr.Interface(
fn=analyze_video,
inputs=gr.Video(label='Upload Video', height=300),
outputs=gr.Textbox(label='Analysis Results', lines=12),
title='🎬 GenVidBench - TSN (MMAction2)',
description='Upload a video. Inference uses TSN R50 on Kinetics-400.',
cache_examples=False,
allow_flagging='never'
)
if __name__ == '__main__':
demo.launch()