AZIIIIIIIIZ commited on
Commit
efb1c7c
·
verified ·
1 Parent(s): 9d10281

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -186
app.py CHANGED
@@ -1,186 +1,66 @@
1
- import os
2
- import gradio as gr
3
- import cv2
4
- import numpy as np
5
- from PIL import Image
6
- import torch
7
- import torchvision.transforms as transforms
8
- import torchvision.models as models
9
-
10
- # Simple video action recognition using pre-trained models
11
- class SimpleVideoAnalyzer:
12
- def __init__(self):
13
- self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
14
- print(f"Using device: {self.device}")
15
-
16
- # Load a pre-trained ResNet model for feature extraction
17
- self.model = models.resnet50(pretrained=True)
18
- self.model.eval()
19
- self.model.to(self.device)
20
-
21
- # Image preprocessing
22
- self.transform = transforms.Compose([
23
- transforms.Resize((224, 224)),
24
- transforms.ToTensor(),
25
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
26
- std=[0.229, 0.224, 0.225])
27
- ])
28
-
29
- # Load labels from Kinetics-400 if available, else fallback
30
- self.action_categories = self.load_kinetics_labels()
31
-
32
- print("✅ Simple video analyzer initialized successfully!")
33
-
34
- def extract_frames(self, video_path, num_frames=8):
35
- """Extract frames from video"""
36
- cap = cv2.VideoCapture(video_path)
37
- frames = []
38
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
39
-
40
- # Sample frames evenly
41
- frame_indices = np.linspace(0, total_frames-1, num_frames, dtype=int)
42
-
43
- for idx in frame_indices:
44
- cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
45
- ret, frame = cap.read()
46
- if ret:
47
- # Convert BGR to RGB
48
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
- frames.append(frame_rgb)
50
-
51
- cap.release()
52
- return frames
53
-
54
- def load_kinetics_labels(self):
55
- """Load Kinetics-400 class labels if available."""
56
- label_file = 'tools/data/kinetics/label_map_k400.txt'
57
- if os.path.exists(label_file):
58
- try:
59
- with open(label_file, 'r') as f:
60
- labels = [line.strip() for line in f.readlines() if line.strip()]
61
- if labels:
62
- print(f"✅ Loaded {len(labels)} Kinetics-400 labels from {label_file}")
63
- return labels
64
- except Exception:
65
- pass
66
- print("⚠️ Kinetics labels not found, using fallback categories")
67
- return [
68
- "walking", "running", "jumping", "sitting", "standing",
69
- "dancing", "cooking", "reading", "writing", "typing",
70
- "clapping", "waving", "pointing", "lifting", "throwing",
71
- "catching", "kicking", "punching", "swimming", "cycling"
72
- ]
73
-
74
- def analyze_frames(self, frames):
75
- """Analyze frames and return predictions"""
76
- features = []
77
-
78
- for frame in frames:
79
- # Convert to PIL Image
80
- pil_image = Image.fromarray(frame)
81
-
82
- # Preprocess
83
- input_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
84
-
85
- # Extract features
86
- with torch.no_grad():
87
- features.append(self.model(input_tensor).cpu().numpy())
88
-
89
- # Average features across frames (not directly used for class mapping here)
90
- _ = np.mean(features, axis=0)
91
-
92
- # Create deterministic-looking output: 1 dominant class with score 1.0 and
93
- # four tiny scores, formatted like the example
94
- num_classes = len(self.action_categories)
95
- num_return = min(5, num_classes)
96
-
97
- # Choose a dominant class index (random for demo)
98
- dominant_idx = np.random.randint(0, num_classes)
99
-
100
- # Pick four other unique indices
101
- candidate_indices = [i for i in range(num_classes) if i != dominant_idx]
102
- np.random.shuffle(candidate_indices)
103
- other_indices = candidate_indices[:max(0, num_return - 1)]
104
-
105
- results = []
106
- # Top-1 with score exactly 1.0
107
- results.append((self.action_categories[dominant_idx], "1.0"))
108
-
109
- # Four tiny scores using scientific notation similar to example
110
- for i in other_indices:
111
- tiny = 10 ** (-(14 + np.random.rand() * 3)) # ~1e-14 to 1e-17
112
- results.append((self.action_categories[i], f"{tiny:.15e}"))
113
-
114
- return results
115
-
116
- def analyze_video(self, video_path):
117
- """Main analysis function"""
118
- try:
119
- if video_path is None:
120
- return "Please upload a video file."
121
-
122
- print(f"Processing video: {video_path}")
123
-
124
- # Extract frames
125
- frames = self.extract_frames(video_path)
126
- if not frames:
127
- return "❌ Could not extract frames from video."
128
-
129
- # Analyze frames
130
- results = self.analyze_frames(frames)
131
-
132
- # Format results to match requested style: "label: score" per line
133
- result_lines = []
134
- for label, score in results:
135
- result_lines.append(f"{label}: {score}")
136
- result_text = "\n".join(result_lines)
137
-
138
- result_text += f"\n📊 Analyzed {len(frames)} frames"
139
- result_text += f"\n🔧 Using: {self.device.upper()}"
140
-
141
- return result_text
142
-
143
- except Exception as e:
144
- return f"❌ Error processing video: {str(e)}"
145
-
146
- # Initialize analyzer
147
- print("🚀 Initializing Simple Video Analyzer...")
148
- analyzer = SimpleVideoAnalyzer()
149
-
150
- # Create Gradio interface
151
- def analyze_video(video):
152
- """Gradio interface function"""
153
- return analyzer.analyze_video(video)
154
-
155
- # Create the interface
156
- demo = gr.Interface(
157
- fn=analyze_video,
158
- inputs=gr.Video(label="Upload Video", height=300),
159
- outputs=gr.Textbox(label="Analysis Results", lines=15),
160
- title="🎬 GenVidBench - Simple Video Action Recognition",
161
- description="""
162
- **Simple Video Action Recognition Demo**
163
-
164
- Upload a video to analyze its content using a simplified approach.
165
- This demo uses pre-trained ResNet features for basic action recognition.
166
-
167
- **Features:**
168
- - 🎥 Multi-frame analysis
169
- - 🧠 Pre-trained ResNet50 features
170
- - ⚡ Fast processing
171
- - 📊 Top-5 predictions
172
-
173
- **Supported formats:** MP4, AVI, MOV, etc.
174
- **Recommended:** Short videos (under 30 seconds) for best performance.
175
- """,
176
- examples=[
177
- ["demo/demo.mp4"] if os.path.exists("demo/demo.mp4") else None
178
- ],
179
- cache_examples=False,
180
- theme=gr.themes.Soft(),
181
- allow_flagging="never"
182
- )
183
-
184
- if __name__ == "__main__":
185
- print("🌟 Starting GenVidBench Simple Demo...")
186
- demo.launch()
 
1
+ import os
2
+ from operator import itemgetter
3
+ import gradio as gr
4
+ from mmaction.apis import init_recognizer, inference_recognizer
5
+
6
+ CONFIG_FILE = 'demo/demo_configs/tsn_r50_1x1x8_video_infer.py'
7
+ CHECKPOINT_FILE = 'checkpoints/tsn_r50_1x1x3_100e_kinetics400_rgb_20200614-e508be42.pth'
8
+ LABEL_FILE = 'tools/data/kinetics/label_map_k400.txt'
9
+
10
+ def load_labels(path):
11
+ if os.path.exists(path):
12
+ with open(path, 'r') as f:
13
+ return [x.strip() for x in f if x.strip()]
14
+ return None
15
+
16
+ def build_model():
17
+ if not os.path.exists(CHECKPOINT_FILE):
18
+ raise FileNotFoundError(f'Checkpoint not found at {CHECKPOINT_FILE}')
19
+ return init_recognizer(CONFIG_FILE, CHECKPOINT_FILE, device='cpu')
20
+
21
+ print('Initializing model...')
22
+ try:
23
+ model = build_model()
24
+ print('✅ Model loaded successfully!')
25
+ except Exception as e:
26
+ print(f'❌ Error loading model: {e}')
27
+ model = None
28
+
29
+ labels = load_labels(LABEL_FILE)
30
+
31
+ def analyze_video(video_path: str):
32
+ try:
33
+ if video_path is None:
34
+ return 'Please upload a video file.'
35
+ if model is None:
36
+ return '⚠️ Model not loaded. Check logs for details.'
37
+
38
+ result = inference_recognizer(model, video_path)
39
+
40
+ pred_scores = result.pred_score.tolist()
41
+ score_sorted = sorted(zip(range(len(pred_scores)), pred_scores), key=itemgetter(1), reverse=True)
42
+ top5 = score_sorted[:5]
43
+
44
+ lines = []
45
+ for idx, score in top5:
46
+ name = labels[idx] if labels and idx < len(labels) else f'class_{idx}'
47
+ lines.append(f'{name}: {score}')
48
+ return '\n'.join(lines)
49
+ except Exception as e:
50
+ return f'❌ Error processing video: {str(e)}'
51
+
52
+ demo = gr.Interface(
53
+ fn=analyze_video,
54
+ inputs=gr.Video(label='Upload Video', height=300),
55
+ outputs=gr.Textbox(label='Analysis Results', lines=12),
56
+ title='🎬 GenVidBench - TSN (MMAction2)',
57
+ description='Upload a video. Inference uses TSN R50 on Kinetics-400.',
58
+ cache_examples=False,
59
+ allow_flagging='never'
60
+ )
61
+
62
+ if __name__ == '__main__':
63
+ demo.launch()
64
+
65
+
66
+