AZIIIIIIIIZ's picture
upload model name in app
02022d1 verified
import os
from operator import itemgetter
import gradio as gr
from mmaction.apis import init_recognizer, inference_recognizer
CONFIG_FILE = 'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
CHECKPOINT_FILE = 'checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth'
LABEL_FILE = 'tools/data/kinetics/label_map_k400.txt'
def load_labels(path):
if os.path.exists(path):
with open(path, 'r') as f:
return [x.strip() for x in f if x.strip()]
return None
def build_model():
if not os.path.exists(CHECKPOINT_FILE):
raise FileNotFoundError(f'Checkpoint not found at {CHECKPOINT_FILE}')
return init_recognizer(CONFIG_FILE, CHECKPOINT_FILE, device='cpu')
print('Initializing model...')
try:
model = build_model()
print('βœ… Model loaded successfully!')
except Exception as e:
print(f'❌ Error loading model: {e}')
model = None
labels = load_labels(LABEL_FILE)
def _resolve_video_path(video_input):
"""Best-effort extraction of a filesystem path from Gradio Video input."""
if isinstance(video_input, str):
return video_input
if isinstance(video_input, dict):
for key in ('name', 'path', 'video', 'file'):
val = video_input.get(key)
if isinstance(val, str):
return val
return video_input
def analyze_video(video_path: str):
try:
if video_path is None:
return 'Please upload a video file.'
if model is None:
return '⚠️ Model not loaded. Check logs for details.'
video_fs_path = _resolve_video_path(video_path)
result = inference_recognizer(model, video_fs_path)
pred_scores = result.pred_score.tolist()
score_sorted = sorted(zip(range(len(pred_scores)), pred_scores), key=itemgetter(1), reverse=True)
top5 = score_sorted[:5]
lines = []
for idx, score in top5:
name = labels[idx] if labels and idx < len(labels) else f'class_{idx}'
lines.append(f'{name}: {score}')
return '\n'.join(lines)
except Exception as e:
return f'❌ Error processing video: {str(e)}'
demo = gr.Interface(
fn=analyze_video,
inputs=gr.Video(label='Upload Video', height=300),
outputs=gr.Textbox(label='Analysis Results', lines=12),
title='🎬 GenVidBench - TSN (MMAction2)',
description='Upload a video. Inference uses TSN R50 on Kinetics-400.',
cache_examples=False,
flagging_mode='never'
)
if __name__ == '__main__':
demo.launch()