AZIIIIIIIIZ commited on
Commit
77b2228
Β·
verified Β·
1 Parent(s): efb1c7c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -66
app.py CHANGED
@@ -1,66 +1,66 @@
1
- import os
2
- from operator import itemgetter
3
- import gradio as gr
4
- from mmaction.apis import init_recognizer, inference_recognizer
5
-
6
- CONFIG_FILE = 'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
7
- CHECKPOINT_FILE = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth'
8
- LABEL_FILE = 'tools/data/kinetics/label_map_k400.txt'
9
-
10
- def load_labels(path):
11
- if os.path.exists(path):
12
- with open(path, 'r') as f:
13
- return [x.strip() for x in f if x.strip()]
14
- return None
15
-
16
- def build_model():
17
- if not os.path.exists(CHECKPOINT_FILE):
18
- raise FileNotFoundError(f'Checkpoint not found at {CHECKPOINT_FILE}')
19
- return init_recognizer(CONFIG_FILE, CHECKPOINT_FILE, device='cpu')
20
-
21
- print('Initializing model...')
22
- try:
23
- model = build_model()
24
- print('βœ… Model loaded successfully!')
25
- except Exception as e:
26
- print(f'❌ Error loading model: {e}')
27
- model = None
28
-
29
- labels = load_labels(LABEL_FILE)
30
-
31
- def analyze_video(video_path: str):
32
- try:
33
- if video_path is None:
34
- return 'Please upload a video file.'
35
- if model is None:
36
- return '⚠️ Model not loaded. Check logs for details.'
37
-
38
- result = inference_recognizer(model, video_path)
39
-
40
- pred_scores = result.pred_score.tolist()
41
- score_sorted = sorted(zip(range(len(pred_scores)), pred_scores), key=itemgetter(1), reverse=True)
42
- top5 = score_sorted[:5]
43
-
44
- lines = []
45
- for idx, score in top5:
46
- name = labels[idx] if labels and idx < len(labels) else f'class_{idx}'
47
- lines.append(f'{name}: {score}')
48
- return '\n'.join(lines)
49
- except Exception as e:
50
- return f'❌ Error processing video: {str(e)}'
51
-
52
- demo = gr.Interface(
53
- fn=analyze_video,
54
- inputs=gr.Video(label='Upload Video', height=300),
55
- outputs=gr.Textbox(label='Analysis Results', lines=12),
56
- title='🎬 GenVidBench - TSN (MMAction2)',
57
- description='Upload a video. Inference uses TSN R50 on Kinetics-400.',
58
- cache_examples=False,
59
- allow_flagging='never'
60
- )
61
-
62
- if __name__ == '__main__':
63
- demo.launch()
64
-
65
-
66
-
 
1
+ import os
2
+ from operator import itemgetter
3
+ import gradio as gr
4
+ from mmaction.apis import init_recognizer, inference_recognizer
5
+
6
+ CONFIG_FILE = 'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
7
+ CHECKPOINT_FILE = 'checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth'
8
+ LABEL_FILE = 'tools/data/kinetics/label_map_k400.txt'
9
+
10
+ def load_labels(path):
11
+ if os.path.exists(path):
12
+ with open(path, 'r') as f:
13
+ return [x.strip() for x in f if x.strip()]
14
+ return None
15
+
16
+ def build_model():
17
+ if not os.path.exists(CHECKPOINT_FILE):
18
+ raise FileNotFoundError(f'Checkpoint not found at {CHECKPOINT_FILE}')
19
+ return init_recognizer(CONFIG_FILE, CHECKPOINT_FILE, device='cpu')
20
+
21
+ print('Initializing model...')
22
+ try:
23
+ model = build_model()
24
+ print('βœ… Model loaded successfully!')
25
+ except Exception as e:
26
+ print(f'❌ Error loading model: {e}')
27
+ model = None
28
+
29
+ labels = load_labels(LABEL_FILE)
30
+
31
+ def analyze_video(video_path: str):
32
+ try:
33
+ if video_path is None:
34
+ return 'Please upload a video file.'
35
+ if model is None:
36
+ return '⚠️ Model not loaded. Check logs for details.'
37
+
38
+ result = inference_recognizer(model, video_path)
39
+
40
+ pred_scores = result.pred_score.tolist()
41
+ score_sorted = sorted(zip(range(len(pred_scores)), pred_scores), key=itemgetter(1), reverse=True)
42
+ top5 = score_sorted[:5]
43
+
44
+ lines = []
45
+ for idx, score in top5:
46
+ name = labels[idx] if labels and idx < len(labels) else f'class_{idx}'
47
+ lines.append(f'{name}: {score}')
48
+ return '\n'.join(lines)
49
+ except Exception as e:
50
+ return f'❌ Error processing video: {str(e)}'
51
+
52
+ demo = gr.Interface(
53
+ fn=analyze_video,
54
+ inputs=gr.Video(label='Upload Video', height=300),
55
+ outputs=gr.Textbox(label='Analysis Results', lines=12),
56
+ title='🎬 GenVidBench - TSN (MMAction2)',
57
+ description='Upload a video. Inference uses TSN R50 on Kinetics-400.',
58
+ cache_examples=False,
59
+ allow_flagging='never'
60
+ )
61
+
62
+ if __name__ == '__main__':
63
+ demo.launch()
64
+
65
+
66
+