AZIIIIIIIIZ commited on
Commit
f51238d
Β·
verified Β·
1 Parent(s): 2c03908

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -80
app.py CHANGED
@@ -1,80 +1,65 @@
1
- import os
2
- from operator import itemgetter
3
- import gradio as gr
4
- from mmaction.apis import init_recognizer, inference_recognizer
5
-
6
- CONFIG_FILE = 'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
7
- CHECKPOINT_FILE = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth'
8
- LABEL_FILE = 'tools/data/kinetics/label_map_k400.txt'
9
-
10
- def load_labels(path):
11
- if os.path.exists(path):
12
- with open(path, 'r') as f:
13
- return [x.strip() for x in f if x.strip()]
14
- return None
15
-
16
- def build_model():
17
- if not os.path.exists(CONFIG_FILE):
18
- raise FileNotFoundError(f'Config not found at {CONFIG_FILE}')
19
- if not os.path.exists(CHECKPOINT_FILE):
20
- raise FileNotFoundError(f'Checkpoint not found at {CHECKPOINT_FILE}')
21
- return init_recognizer(CONFIG_FILE, CHECKPOINT_FILE, device='cpu')
22
-
23
- print('Initializing model...')
24
- try:
25
- model = build_model()
26
- print('βœ… Model loaded successfully!')
27
- except Exception as e:
28
- print(f'❌ Error loading model: {e}')
29
- model = None
30
-
31
- labels = load_labels(LABEL_FILE)
32
- def _resolve_video_path(video_input):
33
- if isinstance(video_input, str):
34
- return video_input
35
- if isinstance(video_input, dict):
36
- for key in ('name', 'path', 'video', 'file'):
37
- val = video_input.get(key)
38
- if isinstance(val, str) and os.path.exists(val):
39
- return val
40
- return video_input
41
-
42
- def analyze_video(video_input):
43
- try:
44
- if video_input is None:
45
- return 'Please upload a video file.'
46
- if model is None:
47
- return '⚠️ Model not loaded. Check logs for details.'
48
- video_path = _resolve_video_path(video_input)
49
- if not isinstance(video_path, str) or not os.path.exists(video_path):
50
- return '❌ Could not resolve uploaded video path.'
51
-
52
- result = inference_recognizer(model, video_path)
53
-
54
- pred_scores = result.pred_score.tolist()
55
- score_sorted = sorted(zip(range(len(pred_scores)), pred_scores), key=itemgetter(1), reverse=True)
56
- top5 = score_sorted[:5]
57
-
58
- lines = []
59
- for idx, score in top5:
60
- name = labels[idx] if labels and idx < len(labels) else f'class_{idx}'
61
- lines.append(f'{name}: {score}')
62
- return '\n'.join(lines)
63
- except Exception as e:
64
- return f'❌ Error processing video: {str(e)}'
65
-
66
- demo = gr.Interface(
67
- fn=analyze_video,
68
- inputs=gr.Video(label='Upload Video', height=300),
69
- outputs=gr.Textbox(label='Analysis Results', lines=12),
70
- title='🎬 GenVidBench - TSN (MMAction2)',
71
- description='Upload a video. Inference uses TSN R50 on Kinetics-400.',
72
- cache_examples=False,
73
- flagging_mode='never'
74
- )
75
-
76
- if __name__ == '__main__':
77
- demo.launch()
78
-
79
-
80
-
 
1
+ import os
2
+ import gradio as gr
3
+ from operator import itemgetter
4
+ from mmaction.apis import init_recognizer, inference_recognizer
5
+
6
+ # --- Config & Checkpoint ---
7
+ config_file = "demo/demo_configs/tsn_r50_1x1x8_video_infer.py"
8
+ checkpoint_file = "checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth"
9
+ label_file = "tools/data/kinetics/label_map_k400.txt"
10
+
11
+ # --- Initialize model ---
12
+ device = "cuda" if os.path.exists("/dev/nvidia0") else "cpu"
13
+ print(f"πŸš€ Initializing TSN Model on {device}...")
14
+ model = init_recognizer(config_file, checkpoint_file, device=device)
15
+
16
+ # Load labels
17
+ with open(label_file) as f:
18
+ labels = [x.strip() for x in f.readlines()]
19
+
20
+ def analyze_video(video_path):
21
+ """Run action recognition on uploaded video"""
22
+ try:
23
+ if video_path is None:
24
+ return "❌ Please upload a video file."
25
+
26
+ # Inference
27
+ results = inference_recognizer(model, video_path)
28
+
29
+ # Extract top-5 results
30
+ pred_scores = results.pred_score.tolist()
31
+ score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
32
+ score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
33
+ top5 = score_sorted[:5]
34
+
35
+ # Format results
36
+ lines = []
37
+ for idx, score in top5:
38
+ lines.append(f"{labels[idx]}: {score:.4f}")
39
+
40
+ return "\n".join(lines)
41
+
42
+ except Exception as e:
43
+ return f"❌ Error: {str(e)}"
44
+
45
+ # --- Gradio UI ---
46
+ demo = gr.Interface(
47
+ fn=analyze_video,
48
+ inputs=gr.Video(label="Upload Video", height=300, type="filepath"),
49
+ outputs=gr.Textbox(label="Top-5 Predictions", lines=10),
50
+ title="🎬 GenVidBench - TSN Action Recognition",
51
+ description="""
52
+ Upload a video and run **TSN (Temporal Segment Networks, ResNet-50 backbone)**
53
+ trained on **Kinetics-400**.
54
+ Model: `tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb`
55
+ Benchmark accuracy ~80% (GenVidBench).
56
+ """,
57
+ examples=[["demo/demo.mp4"]] if os.path.exists("demo/demo.mp4") else None,
58
+ cache_examples=False,
59
+ theme=gr.themes.Soft(),
60
+ allow_flagging="never"
61
+ )
62
+
63
+ if __name__ == "__main__":
64
+ print("🌟 Starting GenVidBench TSN Demo...")
65
+ demo.launch()