santa47 commited on
Commit
d7b868d
Β·
verified Β·
1 Parent(s): e8780a3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -103
app.py CHANGED
@@ -64,19 +64,23 @@ class ViolenceDetector3DCNN(nn.Module):
64
  # ============================================
65
  @st.cache_resource
66
  def load_model():
67
- # Download model from Hugging Face
68
- model_path = hf_hub_download(
69
- repo_id="santa47/violence-detection-3dcnn",
70
- filename="violence_detector.pth"
71
- )
72
-
73
- # Load model
74
- model = ViolenceDetector3DCNN(num_classes=2, dropout=0.5)
75
- checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
76
- model.load_state_dict(checkpoint['model_state_dict'])
77
- model.eval()
78
-
79
- return model
 
 
 
 
80
 
81
 
82
  # ============================================
@@ -84,45 +88,56 @@ def load_model():
84
  # ============================================
85
  def process_video(video_path, num_frames=16, frame_size=(112, 112)):
86
  """Extract and preprocess frames from video"""
87
- cap = cv2.VideoCapture(video_path)
88
- frames = []
89
-
90
- while True:
91
- ret, frame = cap.read()
92
- if not ret:
93
- break
94
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
95
- frame = cv2.resize(frame, frame_size)
96
- frames.append(frame)
97
-
98
- cap.release()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
- if len(frames) == 0:
 
101
  return None
102
-
103
- # Sample frames uniformly
104
- total_frames = len(frames)
105
- if total_frames >= num_frames:
106
- indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
107
- else:
108
- indices = list(range(total_frames)) + [total_frames - 1] * (num_frames - total_frames)
109
-
110
- sampled_frames = [frames[i] for i in indices]
111
-
112
- # Convert to tensor: (T, H, W, C) -> (C, T, H, W)
113
- video_tensor = np.stack(sampled_frames, axis=0)
114
- video_tensor = video_tensor.transpose(3, 0, 1, 2)
115
- video_tensor = video_tensor.astype(np.float32) / 255.0
116
-
117
- # Normalize
118
- mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1, 1)
119
- std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1, 1)
120
- video_tensor = (video_tensor - mean) / std
121
-
122
- # Add batch dimension
123
- video_tensor = torch.from_numpy(video_tensor).unsqueeze(0).float()
124
-
125
- return video_tensor
126
 
127
 
128
  # ============================================
@@ -156,67 +171,90 @@ def main():
156
  # Load model
157
  with st.spinner("Loading model..."):
158
  model = load_model()
 
 
 
 
 
159
  st.success("βœ… Model loaded!")
160
 
161
  # File uploader
162
  st.markdown("### Upload a Video")
163
  uploaded_file = st.file_uploader(
164
- "Choose a video file (AVI, MP4, MKV)",
165
- type=['avi', 'mp4', 'mkv']
166
  )
167
 
168
  if uploaded_file is not None:
169
- # Save uploaded file temporarily
170
- with tempfile.NamedTemporaryFile(delete=False, suffix='.avi') as tmp_file:
171
- tmp_file.write(uploaded_file.read())
172
- tmp_path = tmp_file.name
173
-
174
- # Display video
175
- st.video(uploaded_file)
176
-
177
- # Process and predict
178
- if st.button("πŸ” Analyze Video", type="primary"):
179
- with st.spinner("Processing video..."):
180
- # Process video
181
- video_tensor = process_video(tmp_path)
182
-
183
- if video_tensor is None:
184
- st.error("❌ Could not process video. Please try another file.")
185
- else:
186
- # Predict
187
- pred_class, confidence, probs = predict(model, video_tensor)
188
-
189
- # Display results
190
- st.markdown("---")
191
- st.markdown("### πŸ“Š Results")
192
-
193
- col1, col2 = st.columns(2)
194
-
195
- with col1:
196
- if pred_class == 1:
197
- st.error("⚠️ **VIOLENCE DETECTED**")
198
- else:
199
- st.success("βœ… **NO VIOLENCE**")
200
-
201
- with col2:
202
- st.metric("Confidence", f"{confidence * 100:.1f}%")
203
-
204
- # Probability bars
205
- st.markdown("### Probability Distribution")
206
-
207
- col1, col2 = st.columns(2)
208
- with col1:
209
- st.markdown("**Non-Violence**")
210
- st.progress(float(probs[0]))
211
- st.write(f"{probs[0] * 100:.1f}%")
212
-
213
- with col2:
214
- st.markdown("**Violence**")
215
- st.progress(float(probs[1]))
216
- st.write(f"{probs[1] * 100:.1f}%")
217
 
218
- # Cleanup
219
- os.unlink(tmp_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
  # Footer
222
  st.markdown("---")
@@ -232,4 +270,4 @@ def main():
232
 
233
 
234
  if __name__ == "__main__":
235
- main()
 
64
  # ============================================
65
  @st.cache_resource
66
  def load_model():
67
+ try:
68
+ # Download model from Hugging Face
69
+ model_path = hf_hub_download(
70
+ repo_id="santa47/violence-detection-3dcnn",
71
+ filename="violence_detector.pth"
72
+ )
73
+
74
+ # Load model
75
+ model = ViolenceDetector3DCNN(num_classes=2, dropout=0.5)
76
+ checkpoint = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)
77
+ model.load_state_dict(checkpoint['model_state_dict'])
78
+ model.eval()
79
+
80
+ return model
81
+ except Exception as e:
82
+ st.error(f"Failed to load model: {e}")
83
+ return None
84
 
85
 
86
  # ============================================
 
88
  # ============================================
89
  def process_video(video_path, num_frames=16, frame_size=(112, 112)):
90
  """Extract and preprocess frames from video"""
91
+ try:
92
+ cap = cv2.VideoCapture(video_path)
93
+
94
+ if not cap.isOpened():
95
+ st.error(f"❌ Cannot open video file: {video_path}")
96
+ return None
97
+
98
+ frames = []
99
+
100
+ while True:
101
+ ret, frame = cap.read()
102
+ if not ret:
103
+ break
104
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
105
+ frame = cv2.resize(frame, frame_size)
106
+ frames.append(frame)
107
+
108
+ cap.release()
109
+
110
+ if len(frames) == 0:
111
+ st.error("❌ No frames extracted from video")
112
+ return None
113
+
114
+ # Sample frames uniformly
115
+ total_frames = len(frames)
116
+ if total_frames >= num_frames:
117
+ indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
118
+ else:
119
+ indices = list(range(total_frames)) + [total_frames - 1] * (num_frames - total_frames)
120
+
121
+ sampled_frames = [frames[i] for i in indices]
122
+
123
+ # Convert to tensor: (T, H, W, C) -> (C, T, H, W)
124
+ video_tensor = np.stack(sampled_frames, axis=0)
125
+ video_tensor = video_tensor.transpose(3, 0, 1, 2)
126
+ video_tensor = video_tensor.astype(np.float32) / 255.0
127
+
128
+ # Normalize
129
+ mean = np.array([0.485, 0.456, 0.406]).reshape(3, 1, 1, 1)
130
+ std = np.array([0.229, 0.224, 0.225]).reshape(3, 1, 1, 1)
131
+ video_tensor = (video_tensor - mean) / std
132
+
133
+ # Add batch dimension
134
+ video_tensor = torch.from_numpy(video_tensor).unsqueeze(0).float()
135
+
136
+ return video_tensor
137
 
138
+ except Exception as e:
139
+ st.error(f"❌ Error processing video: {e}")
140
  return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  # ============================================
 
171
  # Load model
172
  with st.spinner("Loading model..."):
173
  model = load_model()
174
+
175
+ if model is None:
176
+ st.error("❌ Failed to load model. Please refresh the page.")
177
+ return
178
+
179
  st.success("βœ… Model loaded!")
180
 
181
  # File uploader
182
  st.markdown("### Upload a Video")
183
  uploaded_file = st.file_uploader(
184
+ "Choose a video file (AVI, MP4, MKV, MOV, WEBM)",
185
+ type=['avi', 'mp4', 'mkv', 'mov', 'webm']
186
  )
187
 
188
  if uploaded_file is not None:
189
+ # FIX 1: Get proper file extension from uploaded file
190
+ file_extension = os.path.splitext(uploaded_file.name)[1].lower()
191
+ if not file_extension:
192
+ file_extension = '.mp4'
193
+
194
+ # FIX 2: Read file bytes ONCE and store
195
+ video_bytes = uploaded_file.read()
196
+
197
+ # FIX 3: Save with correct extension
198
+ tmp_path = None
199
+ try:
200
+ with tempfile.NamedTemporaryFile(delete=False, suffix=file_extension) as tmp_file:
201
+ tmp_file.write(video_bytes)
202
+ tmp_path = tmp_file.name
203
+
204
+ # FIX 4: Display video using bytes (not the file object after read)
205
+ st.video(video_bytes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
+ # Process and predict
208
+ if st.button("πŸ” Analyze Video", type="primary"):
209
+ with st.spinner("Processing video..."):
210
+ # Process video
211
+ video_tensor = process_video(tmp_path)
212
+
213
+ if video_tensor is None:
214
+ st.error("❌ Could not process video. Please try another file.")
215
+ else:
216
+ # Predict
217
+ pred_class, confidence, probs = predict(model, video_tensor)
218
+
219
+ # Display results
220
+ st.markdown("---")
221
+ st.markdown("### πŸ“Š Results")
222
+
223
+ col1, col2 = st.columns(2)
224
+
225
+ with col1:
226
+ if pred_class == 1:
227
+ st.error("⚠️ **VIOLENCE DETECTED**")
228
+ else:
229
+ st.success("βœ… **NO VIOLENCE**")
230
+
231
+ with col2:
232
+ st.metric("Confidence", f"{confidence * 100:.1f}%")
233
+
234
+ # Probability bars
235
+ st.markdown("### Probability Distribution")
236
+
237
+ col1, col2 = st.columns(2)
238
+ with col1:
239
+ st.markdown("**Non-Violence**")
240
+ st.progress(float(probs[0]))
241
+ st.write(f"{probs[0] * 100:.1f}%")
242
+
243
+ with col2:
244
+ st.markdown("**Violence**")
245
+ st.progress(float(probs[1]))
246
+ st.write(f"{probs[1] * 100:.1f}%")
247
+
248
+ except Exception as e:
249
+ st.error(f"❌ Error: {e}")
250
+
251
+ finally:
252
+ # FIX 5: Cleanup in finally block (always runs)
253
+ if tmp_path and os.path.exists(tmp_path):
254
+ try:
255
+ os.unlink(tmp_path)
256
+ except:
257
+ pass
258
 
259
  # Footer
260
  st.markdown("---")
 
270
 
271
 
272
  if __name__ == "__main__":
273
+ main()