usman-khn commited on
Commit
4ffed49
·
verified ·
1 Parent(s): e01b129

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +4 -4
  2. app.py +172 -0
  3. best_model.pth +3 -0
  4. gitattributes +35 -0
  5. requirements.txt +9 -0
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Crowd Behavior Detection
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: yellow
6
  sdk: gradio
7
  sdk_version: 6.0.0
8
  app_file: app.py
 
1
  ---
2
+ title: App
3
+ emoji: 🐨
4
+ colorFrom: blue
5
+ colorTo: red
6
  sdk: gradio
7
  sdk_version: 6.0.0
8
  app_file: app.py
app.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ import numpy as np
7
+ import gradio as gr
8
+
9
+ # --- 1. Model Configuration ---
10
+ SEQUENCE_LENGTH = 16
11
+ NUM_CLASSES = 4
12
+ MODEL_PATH = "best_model.pth" # Ensure this file is in the same directory
13
+
14
+ # Device setup for loading the model
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ # Class names mapping (must match the order used in your training code)
18
+ # Based on your classification report: 0=aggressive, 1=idle, 2=panic, 3=normal
19
+ CLASS_NAMES = ["aggressive", "idle", "panic", "normal"]
20
+
21
+ # --- 2. Model Definition (Copied from your notebook) ---
22
+ class CNNLSTM(nn.Module):
23
+ def __init__(self, num_classes):
24
+ super(CNNLSTM, self).__init__()
25
+ self.cnn = nn.Sequential(
26
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
27
+ nn.ReLU(),
28
+ nn.MaxPool2d(2),
29
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
30
+ nn.ReLU(),
31
+ nn.MaxPool2d(2)
32
+ )
33
+ # Input size calculation: 64 channels * (64/2/2) * (64/2/2) = 64 * 16 * 16
34
+ self.lstm = nn.LSTM(input_size=64*16*16, hidden_size=128, batch_first=True)
35
+ self.fc = nn.Linear(128, num_classes)
36
+
37
+ def forward(self, x):
38
+ B, T, C, H, W = x.size() # Batch, Time (Sequence Length), Channel, Height, Width
39
+
40
+ # Apply CNN to each frame
41
+ x = x.view(B * T, C, H, W)
42
+ x = self.cnn(x)
43
+
44
+ # Flatten and reshape for LSTM
45
+ x = x.view(B, T, -1)
46
+
47
+ # Pass through LSTM
48
+ # We only need the output of the last time step
49
+ x, _ = self.lstm(x)
50
+ x = x[:, -1, :]
51
+
52
+ return self.fc(x)
53
+
54
+ # --- 3. Model Loading and Prediction Function ---
55
+ def load_model():
56
+ """Loads the trained model weights."""
57
+ model = CNNLSTM(num_classes=NUM_CLASSES).to(device)
58
+
59
+ if not os.path.exists(MODEL_PATH):
60
+ raise FileNotFoundError(f"Model file not found at {MODEL_PATH}. Please ensure your 'best_model.pth' is uploaded.")
61
+
62
+ # Load state_dict and map to CPU if necessary
63
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
64
+ model.eval()
65
+ return model
66
+
67
+ # Global model instance
68
+ try:
69
+ model = load_model()
70
+ except FileNotFoundError as e:
71
+ print(e)
72
+ # This allows the app to start even if the model file is missing initially,
73
+ # but the prediction function will fail until it's fixed.
74
+ model = None
75
+
76
+ # Transformation pipeline for a single frame
77
+ transform = transforms.Compose([
78
+ transforms.Resize((64, 64)),
79
+ transforms.ToTensor(),
80
+ ])
81
+
82
+ def predict_crowd_behavior(input_images):
83
+ """
84
+ Predicts the crowd behavior from a list of frames.
85
+
86
+ Args:
87
+ input_images (list[PIL.Image]): A list of images (frames).
88
+
89
+ Returns:
90
+ str: The predicted class name.
91
+ """
92
+ if model is None:
93
+ return "ERROR: Model could not be loaded. Check logs for missing file."
94
+
95
+ if not input_images or len(input_images) != SEQUENCE_LENGTH:
96
+ return f"ERROR: Please upload exactly {SEQUENCE_LENGTH} frames."
97
+
98
+ try:
99
+ frames_tensor = []
100
+ for img in input_images:
101
+ # Ensure image is RGB (convert from any format PIL loads)
102
+ if img.mode != 'RGB':
103
+ img = img.convert('RGB')
104
+ frames_tensor.append(transform(img))
105
+
106
+ # Stack the frames and add batch dimension (1, T, C, H, W)
107
+ video_tensor = torch.stack(frames_tensor).unsqueeze(0).to(device)
108
+
109
+ with torch.no_grad():
110
+ output = model(video_tensor)
111
+
112
+ # Get the predicted class index
113
+ predicted_class_idx = torch.argmax(output, dim=1).item()
114
+
115
+ # Map index to class name
116
+ predicted_class_name = CLASS_NAMES[predicted_class_idx]
117
+
118
+ # Return the prediction and all class probabilities
119
+ probabilities = torch.softmax(output, dim=1)[0].cpu().numpy()
120
+
121
+ # Format the output as a dictionary for Gradio to display nicely
122
+ output_data = {
123
+ CLASS_NAMES[i]: probabilities[i] for i in range(len(CLASS_NAMES))
124
+ }
125
+
126
+ return output_data
127
+
128
+ except Exception as e:
129
+ return f"Prediction failed: {e}"
130
+
131
+ # --- 4. Gradio Interface ---
132
+
133
+ # Create an Image component for each frame in the sequence
134
+ image_components = [
135
+ gr.Image(
136
+ label=f"Frame {i+1}",
137
+ type="pil",
138
+ width=100,
139
+ height=100
140
+ )
141
+ for i in range(SEQUENCE_LENGTH)
142
+ ]
143
+
144
+ description = f"""
145
+ # 🧠 CNN-LSTM Crowd Behavior Analysis from Aerial Video
146
+ This model analyzes a sequence of **{SEQUENCE_LENGTH} consecutive frames** extracted from an aerial video (e.g., drone footage) to classify the crowd's behavior.
147
+
148
+ ## 🛠 Instructions
149
+ 1. **Extract Frames:** Use the custom script you have (`extract_frames` from your notebook) or another tool to get **16 consecutive frames** from your video segment.
150
+ 2. **Upload:** Upload each of the 16 frames to the image slots below.
151
+ 3. **Predict:** Click the 'Predict Behavior' button to see the results.
152
+
153
+ The model classifies into one of these behaviors: **aggressive, idle, panic, or normal**.
154
+ """
155
+
156
+ # ... (lines 148-154)
157
+
158
+ iface = gr.Interface(
159
+ fn=predict_crowd_behavior,
160
+ inputs=image_components,
161
+ outputs=gr.Label(num_top_classes=NUM_CLASSES),
162
+ title="Crowd Behavior Classifier (CNN-LSTM Hybrid)",
163
+ description=description,
164
+ live=False,
165
+ # FIX IS HERE: Change 'allow_flagging' to 'flagging_enabled'
166
+ #flagging_enabled=False,
167
+ )
168
+
169
+ if __name__ == "__main__":
170
+ # Gradio will run on localhost when run locally.
171
+ # Hugging Face Spaces will automatically use `iface.launch()` when deploying.
172
+ iface.launch()
best_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2ed13c9cc7796a3d55dd2843c1e06a2664773d1ea0a95db22f698e4b840e9ef8
3
+ size 33904098
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ gradio>=4.0.0 # Force installation of Gradio 4.0 or newer
5
+ tqdm
6
+ torchinfo
7
+ scikit-learn
8
+ numpy
9
+ matplotlib