AZIIIIIIIIZ commited on
Commit
888460f
Β·
verified Β·
1 Parent(s): f51238d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -43
app.py CHANGED
@@ -1,65 +1,76 @@
1
  import os
2
- import gradio as gr
3
  from operator import itemgetter
 
4
  from mmaction.apis import init_recognizer, inference_recognizer
5
 
6
- # --- Config & Checkpoint ---
7
- config_file = "demo/demo_configs/tsn_r50_1x1x8_video_infer.py"
8
- checkpoint_file = "checkpoints/tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb_20220818-2692d16c.pth"
9
- label_file = "tools/data/kinetics/label_map_k400.txt"
 
 
 
 
 
 
 
 
 
 
10
 
11
- # --- Initialize model ---
12
- device = "cuda" if os.path.exists("/dev/nvidia0") else "cpu"
13
- print(f"πŸš€ Initializing TSN Model on {device}...")
14
- model = init_recognizer(config_file, checkpoint_file, device=device)
 
 
 
15
 
16
- # Load labels
17
- with open(label_file) as f:
18
- labels = [x.strip() for x in f.readlines()]
19
 
20
- def analyze_video(video_path):
21
- """Run action recognition on uploaded video"""
 
 
 
 
 
 
 
 
 
 
 
22
  try:
23
  if video_path is None:
24
- return "❌ Please upload a video file."
 
 
25
 
26
- # Inference
27
- results = inference_recognizer(model, video_path)
28
 
29
- # Extract top-5 results
30
- pred_scores = results.pred_score.tolist()
31
- score_tuples = tuple(zip(range(len(pred_scores)), pred_scores))
32
- score_sorted = sorted(score_tuples, key=itemgetter(1), reverse=True)
33
  top5 = score_sorted[:5]
34
 
35
- # Format results
36
  lines = []
37
  for idx, score in top5:
38
- lines.append(f"{labels[idx]}: {score:.4f}")
39
-
40
- return "\n".join(lines)
41
-
42
  except Exception as e:
43
- return f"❌ Error: {str(e)}"
44
 
45
- # --- Gradio UI ---
46
  demo = gr.Interface(
47
  fn=analyze_video,
48
- inputs=gr.Video(label="Upload Video", height=300, type="filepath"),
49
- outputs=gr.Textbox(label="Top-5 Predictions", lines=10),
50
- title="🎬 GenVidBench - TSN Action Recognition",
51
- description="""
52
- Upload a video and run **TSN (Temporal Segment Networks, ResNet-50 backbone)**
53
- trained on **Kinetics-400**.
54
- Model: `tsn_r50_8xb32-1x1x8-100e_kinetics400-rgb`
55
- Benchmark accuracy ~80% (GenVidBench).
56
- """,
57
- examples=[["demo/demo.mp4"]] if os.path.exists("demo/demo.mp4") else None,
58
  cache_examples=False,
59
- theme=gr.themes.Soft(),
60
- allow_flagging="never"
61
  )
62
 
63
- if __name__ == "__main__":
64
- print("🌟 Starting GenVidBench TSN Demo...")
65
- demo.launch()
 
1
  import os
 
2
  from operator import itemgetter
3
+ import gradio as gr
4
  from mmaction.apis import init_recognizer, inference_recognizer
5
 
6
+ CONFIG_FILE = 'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
7
+ CHECKPOINT_FILE = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth'
8
+ LABEL_FILE = 'tools/data/kinetics/label_map_k400.txt'
9
+
10
+ def load_labels(path):
11
+ if os.path.exists(path):
12
+ with open(path, 'r') as f:
13
+ return [x.strip() for x in f if x.strip()]
14
+ return None
15
+
16
+ def build_model():
17
+ if not os.path.exists(CHECKPOINT_FILE):
18
+ raise FileNotFoundError(f'Checkpoint not found at {CHECKPOINT_FILE}')
19
+ return init_recognizer(CONFIG_FILE, CHECKPOINT_FILE, device='cpu')
20
 
21
+ print('Initializing model...')
22
+ try:
23
+ model = build_model()
24
+ print('βœ… Model loaded successfully!')
25
+ except Exception as e:
26
+ print(f'❌ Error loading model: {e}')
27
+ model = None
28
 
29
+ labels = load_labels(LABEL_FILE)
 
 
30
 
31
+ def _resolve_video_path(video_input):
32
+ """Best-effort extraction of a filesystem path from Gradio Video input."""
33
+ if isinstance(video_input, str):
34
+ return video_input
35
+ if isinstance(video_input, dict):
36
+ for key in ('name', 'path', 'video', 'file'):
37
+ val = video_input.get(key)
38
+ if isinstance(val, str):
39
+ return val
40
+ return video_input
41
+
42
+
43
+ def analyze_video(video_path: str):
44
  try:
45
  if video_path is None:
46
+ return 'Please upload a video file.'
47
+ if model is None:
48
+ return '⚠️ Model not loaded. Check logs for details.'
49
 
50
+ video_fs_path = _resolve_video_path(video_path)
51
+ result = inference_recognizer(model, video_fs_path)
52
 
53
+ pred_scores = result.pred_score.tolist()
54
+ score_sorted = sorted(zip(range(len(pred_scores)), pred_scores), key=itemgetter(1), reverse=True)
 
 
55
  top5 = score_sorted[:5]
56
 
 
57
  lines = []
58
  for idx, score in top5:
59
+ name = labels[idx] if labels and idx < len(labels) else f'class_{idx}'
60
+ lines.append(f'{name}: {score}')
61
+ return '\n'.join(lines)
 
62
  except Exception as e:
63
+ return f'❌ Error processing video: {str(e)}'
64
 
 
65
  demo = gr.Interface(
66
  fn=analyze_video,
67
+ inputs=gr.Video(label='Upload Video', height=300),
68
+ outputs=gr.Textbox(label='Analysis Results', lines=12),
69
+ title='🎬 GenVidBench - TSN (MMAction2)',
70
+ description='Upload a video. Inference uses TSN R50 on Kinetics-400.',
 
 
 
 
 
 
71
  cache_examples=False,
72
+ flagging_mode='never'
 
73
  )
74
 
75
+ if __name__ == '__main__':
76
+ demo.launch()