JonSnow1512 commited on
Commit
c2d642e
·
verified ·
1 Parent(s): 22d1533

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import cv2
4
+ import numpy as np
5
+ import gradio as gr
6
+ from torchvision import transforms
7
+ from transformers import VideoMAEForVideoClassification
8
+
9
+ # Class mapping
10
+ class_mapping = {
11
+ "Abuse": 0, "Arrest": 1, "Arson": 2, "Assault": 3, "Burglary": 4,
12
+ "Explosion": 5, "Fighting": 6, "Normal Videos": 7, "Road Accidents": 8,
13
+ "Robbery": 9, "Shooting": 10, "Shoplifting": 11, "Stealing": 12, "Vandalism": 13
14
+ }
15
+ reverse_mapping = {v: k for k, v in class_mapping.items()}
16
+
17
+ # Load model
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ model = VideoMAEForVideoClassification.from_pretrained(
20
+ "OPear/videomae-large-finetuned-UCF-Crime",
21
+ label2id=class_mapping,
22
+ id2label=reverse_mapping,
23
+ ignore_mismatched_sizes=True,
24
+ ).to(device)
25
+ model.eval()
26
+
27
+ # Preprocessing function
28
+ def load_video_frames(video_path, num_frames=16, size=(224, 224)):
29
+ cap = cv2.VideoCapture(video_path)
30
+ frames = []
31
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
32
+ frame_indices = np.linspace(0, total_frames - 1, num_frames, dtype=int)
33
+
34
+ for i in range(total_frames):
35
+ ret, frame = cap.read()
36
+ if not ret:
37
+ break
38
+ if i in frame_indices:
39
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
40
+ frame = cv2.resize(frame, size)
41
+ frames.append(frame)
42
+ cap.release()
43
+
44
+ if len(frames) == 0:
45
+ raise ValueError("No frames read from video.")
46
+ if len(frames) < num_frames:
47
+ frames.extend([frames[-1]] * (num_frames - len(frames)))
48
+
49
+ frames = np.stack(frames, axis=0)
50
+ frames = torch.tensor(frames, dtype=torch.float32).permute(0, 3, 1, 2) / 255.0
51
+
52
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
53
+ std=[0.229, 0.224, 0.225])
54
+ frames = torch.stack([normalize(f) for f in frames])
55
+ return frames # [T, 3, H, W]
56
+
57
+ # Prediction function
58
+ def predict_crime(video_file):
59
+ try:
60
+ frames = load_video_frames(video_file)
61
+ input_tensor = frames.permute(1, 0, 2, 3).unsqueeze(0).to(device) # [1, 3, T, H, W]
62
+
63
+ with torch.no_grad():
64
+ outputs = model(input_tensor)
65
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
66
+ pred_id = torch.argmax(probs, dim=-1).item()
67
+ pred_class = reverse_mapping[pred_id]
68
+ confidence = probs[0][pred_id].item()
69
+
70
+ return f"**Predicted Class:** {pred_class}\n**Confidence:** {confidence:.4f}"
71
+
72
+ except Exception as e:
73
+ return f"Error: {str(e)}"
74
+
75
+ # Gradio interface
76
+ interface = gr.Interface(
77
+ fn=predict_crime,
78
+ inputs=gr.Video(label="Upload a Crime-related Video", type="filepath"),
79
+ outputs="markdown",
80
+ title="🎥 Crime Type Classifier",
81
+ description="Upload a video (preferably 5–10s, .mp4 format). The model predicts the crime type using a fine-tuned VideoMAE."
82
+ )
83
+
84
+ if __name__ == "__main__":
85
+ interface.launch()