LEGIONM36 commited on
Commit
469c325
·
verified ·
1 Parent(s): 3a3de60

Upload 4 files

Browse files
Files changed (4) hide show
  1. Video Swin Transformer Prototype1.pth +3 -0
  2. model.py +15 -0
  3. readme.md +19 -0
  4. train.py +245 -0
Video Swin Transformer Prototype1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1af320e12466d5e9f66e4b8c80bfbe5541e111d0777753b623bbc43098da87eb
3
+ size 126258949
model.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torchvision.models.video as models
3
+
4
+ def build_swin_model():
5
+ print("Initializing Video Swin Transformer...")
6
+ # Using torchvision's Swin3D-T (Tiny)
7
+ # Weights=None for scratch
8
+ model = models.swin3d_t(weights=None)
9
+
10
+ # Modify Head for Binary Classification
11
+ # Original head is model.head (Linear)
12
+ num_features = model.head.in_features
13
+ model.head = nn.Linear(num_features, 2)
14
+
15
+ return model
readme.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Swin Transformer for Video
2
+
3
+ ## Model Architecture
4
+ - **Type**: Video Swin Transformer (Tiny - Swin3D-T)
5
+ - **Source**: Torchvision `models.swin3d_t`.
6
+ - **Modifications**: Classification head (Linear) modified to output 2 classes (Binary).
7
+ - **Features**: Hierarchical transformer with shifted windows, adapted for 3D Video processing.
8
+
9
+ ## Dataset Structure
10
+ Expects `Dataset` folder in parent directory.
11
+ ```
12
+ Dataset/
13
+ ├── violence/
14
+ └── no-violence/
15
+ ```
16
+
17
+ ## How to Run
18
+ 1. Install dependencies: `torch`, `opencv-python`, `scikit-learn`, `numpy`, `torchvision`.
19
+ 2. Run `python train.py`.
train.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.optim as optim
7
+ from torch.utils.data import Dataset, DataLoader
8
+ from sklearn.model_selection import train_test_split
9
+ from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
10
+ import time
11
+ from model import build_swin_model
12
+
13
+ # --- Configuration ---
14
+ BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
15
+ DATASET_DIR = os.path.join(BASE_DIR, "Dataset")
16
+ MODEL_SAVE_PATH = "best_model_swin.pth"
17
+
18
+ # Hyperparameters
19
+ IMG_SIZE = 224 # Swin usually expects 224
20
+ SEQ_LEN = 32
21
+ BATCH_SIZE = 8 # Heavy model
22
+ EPOCHS = 80
23
+ LEARNING_RATE = 1e-4
24
+ PATIENCE = 5
25
+
26
+ # --- Dataset Class ---
27
+ class VideoDataset(Dataset):
28
+ def __init__(self, video_paths, labels, transform=None):
29
+ self.video_paths = video_paths
30
+ self.labels = labels
31
+
32
+ def __len__(self):
33
+ return len(self.video_paths)
34
+
35
+ def __getitem__(self, idx):
36
+ path = self.video_paths[idx]
37
+ label = self.labels[idx]
38
+
39
+ try:
40
+ frames = self._load_video(path)
41
+ except Exception as e:
42
+ print(f"Error loading {path}: {e}")
43
+ frames = np.zeros((3, SEQ_LEN, IMG_SIZE, IMG_SIZE), dtype=np.float32)
44
+
45
+ # Swin3D expects (C, T, H, W) and normalized to [0, 1] usually, handled in load
46
+ return torch.tensor(frames, dtype=torch.float32), label
47
+
48
+ def _load_video(self, path):
49
+ cap = cv2.VideoCapture(path)
50
+ frames = []
51
+ try:
52
+ while True:
53
+ ret, frame = cap.read()
54
+ if not ret:
55
+ break
56
+ frame = cv2.resize(frame, (IMG_SIZE, IMG_SIZE))
57
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
58
+ frames.append(frame)
59
+ finally:
60
+ cap.release()
61
+
62
+ if len(frames) == 0:
63
+ return np.zeros((3, SEQ_LEN, IMG_SIZE, IMG_SIZE), dtype=np.float32)
64
+
65
+ # Temporal Sampling
66
+ if len(frames) < SEQ_LEN:
67
+ while len(frames) < SEQ_LEN:
68
+ frames.append(frames[-1])
69
+ elif len(frames) > SEQ_LEN:
70
+ indices = np.linspace(0, len(frames)-1, SEQ_LEN, dtype=int)
71
+ frames = [frames[i] for i in indices]
72
+
73
+ frames = np.array(frames) # (T, H, W, C)
74
+ frames = frames / 255.0
75
+ frames = frames.transpose(3, 0, 1, 2) # (C, T, H, W)
76
+ return frames
77
+
78
+ # --- Data Preparation ---
79
+ def prepare_data():
80
+ violence_dir = os.path.join(DATASET_DIR, 'violence')
81
+ no_violence_dir = os.path.join(DATASET_DIR, 'no-violence')
82
+
83
+ if not os.path.exists(violence_dir) or not os.path.exists(no_violence_dir):
84
+ raise FileNotFoundError("Dataset directories not found.")
85
+
86
+ violence_files = [os.path.join(violence_dir, f) for f in os.listdir(violence_dir) if f.endswith('.avi') or f.endswith('.mp4')]
87
+ no_violence_files = [os.path.join(no_violence_dir, f) for f in os.listdir(no_violence_dir) if f.endswith('.avi') or f.endswith('.mp4')]
88
+
89
+ X = violence_files + no_violence_files
90
+ y = [1] * len(violence_files) + [0] * len(no_violence_files)
91
+
92
+ X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.30, random_state=42, stratify=y)
93
+ X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.50, random_state=42, stratify=y_temp)
94
+
95
+ return (X_train, y_train), (X_val, y_val), (X_test, y_test)
96
+
97
+ # --- Early Stopping ---
98
+ class EarlyStopping:
99
+ def __init__(self, patience=5, verbose=False, path='checkpoint.pth'):
100
+ self.patience = patience
101
+ self.verbose = verbose
102
+ self.counter = 0
103
+ self.best_score = None
104
+ self.early_stop = False
105
+ self.val_loss_min = np.inf
106
+ self.path = path
107
+
108
+ def __call__(self, val_loss, model):
109
+ score = -val_loss
110
+ if self.best_score is None:
111
+ self.best_score = score
112
+ self.save_checkpoint(val_loss, model)
113
+ elif score < self.best_score:
114
+ self.counter += 1
115
+ if self.verbose:
116
+ print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
117
+ if self.counter >= self.patience:
118
+ self.early_stop = True
119
+ else:
120
+ self.best_score = score
121
+ self.save_checkpoint(val_loss, model)
122
+ self.counter = 0
123
+
124
+ def save_checkpoint(self, val_loss, model):
125
+ if self.verbose:
126
+ print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
127
+ torch.save(model, self.path) # Saving full model
128
+ self.val_loss_min = val_loss
129
+
130
+ if __name__ == "__main__":
131
+ start_time = time.time()
132
+
133
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
134
+ print(f"Using device: {device}")
135
+
136
+ # Prepare Data
137
+ try:
138
+ (X_train, y_train), (X_val, y_val), (X_test, y_test) = prepare_data()
139
+ print(f"Dataset Split Stats:")
140
+ print(f"Train: {len(X_train)} samples")
141
+ print(f"Val: {len(X_val)} samples")
142
+ print(f"Test: {len(X_test)} samples")
143
+ except Exception as e:
144
+ print(f"Data preparation failed: {e}")
145
+ exit(1)
146
+
147
+ train_dataset = VideoDataset(X_train, y_train)
148
+ val_dataset = VideoDataset(X_val, y_val)
149
+ test_dataset = VideoDataset(X_test, y_test)
150
+
151
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
152
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
153
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
154
+
155
+ model = build_swin_model().to(device)
156
+ criterion = nn.CrossEntropyLoss()
157
+ optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
158
+
159
+ early_stopping = EarlyStopping(patience=PATIENCE, verbose=True, path=MODEL_SAVE_PATH)
160
+
161
+ print("\nStarting Swin Transformer Training...")
162
+
163
+ for epoch in range(EPOCHS):
164
+ model.train()
165
+ train_loss = 0.0
166
+ correct = 0
167
+ total = 0
168
+
169
+ for batch_idx, (inputs, labels) in enumerate(train_loader):
170
+ inputs, labels = inputs.to(device), labels.to(device)
171
+
172
+ optimizer.zero_grad()
173
+ outputs = model(inputs)
174
+ loss = criterion(outputs, labels)
175
+ loss.backward()
176
+ optimizer.step()
177
+
178
+ train_loss += loss.item()
179
+ _, predicted = torch.max(outputs.data, 1)
180
+ total += labels.size(0)
181
+ correct += (predicted == labels).sum().item()
182
+
183
+ if batch_idx % 10 == 0:
184
+ print(f"Epoch {epoch+1} Batch {batch_idx}/{len(train_loader)} Loss: {loss.item():.4f}", end='\r')
185
+
186
+ train_acc = 100 * correct / total
187
+ avg_train_loss = train_loss / len(train_loader)
188
+
189
+ # Validation
190
+ model.eval()
191
+ val_loss = 0.0
192
+ correct_val = 0
193
+ total_val = 0
194
+
195
+ with torch.no_grad():
196
+ for inputs, labels in val_loader:
197
+ inputs, labels = inputs.to(device), labels.to(device)
198
+ outputs = model(inputs)
199
+ loss = criterion(outputs, labels)
200
+ val_loss += loss.item()
201
+ _, predicted = torch.max(outputs.data, 1)
202
+ total_val += labels.size(0)
203
+ correct_val += (predicted == labels).sum().item()
204
+
205
+ val_acc = 100 * correct_val / total_val
206
+ avg_val_loss = val_loss / len(val_loader)
207
+
208
+ print(f'\nEpoch [{epoch+1}/{EPOCHS}] '
209
+ f'Train Loss: {avg_train_loss:.4f} Acc: {train_acc:.2f}% '
210
+ f'Val Loss: {avg_val_loss:.4f} Acc: {val_acc:.2f}%')
211
+
212
+ early_stopping(avg_val_loss, model)
213
+ if early_stopping.early_stop:
214
+ print("Early stopping triggered")
215
+ break
216
+
217
+ # Evaluation
218
+ print("\nLoading best Swin model for evaluation...")
219
+ if os.path.exists(MODEL_SAVE_PATH):
220
+ model = torch.load(MODEL_SAVE_PATH)
221
+ else:
222
+ print("Warning: Model file not found, using last epoch model.")
223
+
224
+ model.eval()
225
+ all_preds = []
226
+ all_labels = []
227
+
228
+ print("Evaluating on Test set...")
229
+ with torch.no_grad():
230
+ for inputs, labels in test_loader:
231
+ inputs, labels = inputs.to(device), labels.to(device)
232
+ outputs = model(inputs)
233
+ _, predicted = torch.max(outputs.data, 1)
234
+ all_preds.extend(predicted.cpu().numpy())
235
+ all_labels.extend(labels.cpu().numpy())
236
+
237
+ print("\n=== Swin Transformer Evaluation Report ===")
238
+ print(classification_report(all_labels, all_preds, target_names=['No Violence', 'Violence']))
239
+ print("Confusion Matrix:")
240
+ print(confusion_matrix(all_labels, all_preds))
241
+ acc = accuracy_score(all_labels, all_preds)
242
+ print(f"\nFinal Test Accuracy: {acc*100:.2f}%")
243
+
244
+ elapsed = time.time() - start_time
245
+ print(f"\nTotal execution time: {elapsed/60:.2f} minutes")