guard2PFE commited on
Commit
1ac3d64
·
verified ·
1 Parent(s): e5e4b3d

Upload 22 files

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ lists/
2
+ output/
3
+ venv/
4
+ *.csv
5
+ *.json
6
+ __pycache__/
7
+ data/
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) [year] [fullname]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,265 @@
1
- ---
2
- title: DeepFakeDetector Demo
3
- emoji: 🐠
4
- colorFrom: green
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 6.5.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎯 TALL-SWIN Deepfake Detection – Video-Level Pipeline
2
+
3
+ Transformer-based deepfake detection system with custom dataset preparation, evaluation metrics, and batch inference utilities.
4
+
5
+ ---
6
+
7
+ # 📖 1. Project Overview
8
+
9
+ This project implements a **video-level deepfake detection pipeline** based on the **TALL-SWIN Vision Transformer architecture**.
10
+
11
+ It extends the original TALL/DeiT repository with:
12
+
13
+ - Custom dataset preparation scripts
14
+ - Frame cleaning and renaming utilities
15
+ - Automatic train/test split generation
16
+ - Advanced evaluation metrics (ROC, PR, confusion matrix)
17
+ - Threshold-controlled classification
18
+ - Batch inference from video lists
19
+ - Reproducible environment setup
20
+
21
+ Designed primarily for FaceForensics++ (FFPP)-style datasets.
22
+
23
+ ---
24
+
25
+ # 🏗 2. Pipeline Architecture
26
+
27
+ ```
28
+ Raw Videos
29
+
30
+ Frame Extraction
31
+
32
+ Frame Cleaning / Renaming
33
+
34
+ Train/Test Split Generation (.txt)
35
+
36
+ Training (TALL-SWIN)
37
+
38
+ Video-Level Aggregation
39
+
40
+ Evaluation (ROC / PR / Metrics)
41
+
42
+ Batch Inference
43
+ ```
44
+
45
+ ---
46
+
47
+ # 📂 3. Project Structure
48
+
49
+ ```
50
+ .
51
+ ├── main.py # Training script (base repository)
52
+ ├── engine.py # Training loop
53
+ ├── models.py # DeiT models
54
+ ├── utils.py # Distributed + checkpoint utilities
55
+ ├── video_dataset.py # Dataset loader
56
+ ├── video_dataset_aug.py # Augmentations
57
+ ├── video_dataset_config.py # Dataset config
58
+ ├── video_transforms.py # Group transforms
59
+
60
+ ├── eval_direct.py # Custom evaluation with plots
61
+ ├── test_new.py # Threshold-controlled evaluation
62
+ ├── infer_videos_txt.py # Batch video inference
63
+ ├── make_tall_txt.py # Dataset split generator
64
+ ├── make_tall_txt_count.py # Alternative split generator
65
+ ├── renumber_frames_for_tall.py # Frame renaming tool
66
+ ├── keep_only_numbered.py # Frame cleaning tool
67
+ ├── find_tall_model.py # Debug utility
68
+
69
+ ├── requirements.txt
70
+ ├── requirements-torch.txt
71
+ └── README.md
72
+ ```
73
+
74
+ ---
75
+
76
+ # ⚙ 4. Installation
77
+
78
+ ## 4.1 Clone Repository
79
+
80
+ ```bash
81
+ git clone <your_repository_url>
82
+ cd TALL4Deepfake
83
+ ```
84
+
85
+ ## 4.2 Create Virtual Environment
86
+
87
+ **Windows (PowerShell)**
88
+
89
+ ```powershell
90
+ python -m venv venv
91
+ venv\Scripts\Activate.ps1
92
+ ```
93
+
94
+ **Linux / macOS**
95
+
96
+ ```bash
97
+ python -m venv venv
98
+ source venv/bin/activate
99
+ ```
100
+
101
+ ## 4.3 Install PyTorch (CUDA Build)
102
+
103
+ PyTorch CUDA wheels are not hosted on default PyPI.
104
+
105
+ ```bash
106
+ pip install -r requirements-torch.txt
107
+ ```
108
+
109
+ ## 4.4 Install Remaining Dependencies
110
+
111
+ ```bash
112
+ pip install -r requirements.txt
113
+ ```
114
+
115
+ ## 4.5 Sanity Check
116
+
117
+ ```bash
118
+ python -c "import torch, cv2, numpy as np; print(torch.__version__); print(cv2.__version__); print(np.__version__)"
119
+ ```
120
+
121
+ ---
122
+
123
+ # 📁 5. Dataset Preparation
124
+
125
+ Expected structure:
126
+
127
+ ```
128
+ data/
129
+ ├── real/
130
+ │ ├── video_001/
131
+ │ │ ├── 0001.jpg
132
+ │ │ ├── 0002.jpg
133
+ │ │ └── ...
134
+ └── fake/
135
+ ├── video_002/
136
+ ```
137
+
138
+ ## 5.1 Renumber Frames
139
+
140
+ Ensures frame names follow `0001.jpg` format.
141
+
142
+ ```bash
143
+ python renumber_frames_for_tall.py --root data --digits 4 --copy
144
+ ```
145
+
146
+ ## 5.2 Remove Non-numbered Frames
147
+
148
+ ```bash
149
+ python keep_only_numbered.py --root data
150
+ ```
151
+
152
+ Dry run (no deletion):
153
+
154
+ ```bash
155
+ python keep_only_numbered.py --root data --dry_run
156
+ ```
157
+
158
+ ## 5.3 Generate Train/Test Split
159
+
160
+ **Full Split Generator**
161
+
162
+ ```bash
163
+ python make_tall_txt.py --root data --out lists --train_ratio 0.8
164
+ ```
165
+
166
+ Outputs:
167
+
168
+ - `lists/cdf_train_fold.txt`
169
+ - `lists/cdf_test_fold.txt`
170
+
171
+ **Alternative Split (Count-Based)**
172
+
173
+ ```bash
174
+ python make_tall_txt_count.py --root data --out lists
175
+ ```
176
+
177
+ ---
178
+
179
+ # 🚀 6. Training
180
+
181
+ ```bash
182
+ python main.py --dataset ffpp --data_dir [data_dir] --data_txt_dir [data_txt_dir] --input-size 112 --num_clips 8 --output_dir [outout_dir] --opt adamw --lr 1.5e-5 --warmup-lr 1.5e-8 --min-lr 1.5e-7 --epochs 10 --sched cosine --duration 4 --batch-size 2 --thumbnail_rows 2 --disable_scaleup --cutout True --pretrained --warmup-epochs 1 --no-amp --model TALL_SWIN --hpe_to_token --num_workers 0
183
+ ```
184
+
185
+ ---
186
+
187
+ # 📊 7. Evaluation
188
+
189
+ ## 7.1 Custom Evaluation with Plots
190
+
191
+ ```bash
192
+ python test_new.py --dataset ffpp --data_dir [your_data_dir] --data_txt_dir [your_data_dir_txt] --num_clips 8 --duration 4 --thumbnail_rows 2 --batch-size 1 --num_workers 0 --initial_checkpoint [your_.pth_dir] --output_dir [your_out_dir] --save_plots
193
+ ```
194
+
195
+ ---
196
+
197
+ # 🎥 8. Batch Video Inference
198
+
199
+ Create a file `videos.txt`:
200
+
201
+ ```
202
+ C:/.../video1.mp4
203
+ C:/.../video2.mp4
204
+ ```
205
+
206
+ Run:
207
+
208
+ ```bash
209
+ python infer_videos_txt.py --video_list videos.txt --initial_checkpoint [your_.pth_dir] --dataset ffpp --duration 4 --num_clips 8
210
+ ```
211
+
212
+ Outputs:
213
+
214
+ - `results.json`
215
+ - `results.csv`
216
+
217
+ ---
218
+
219
+ # 📈 9. Implemented Metrics
220
+
221
+ - Accuracy
222
+ - Balanced Accuracy
223
+ - Precision
224
+ - Recall
225
+ - F1-score
226
+ - ROC-AUC
227
+ - PR-AUC
228
+ - Confusion Matrix
229
+ - Classification Report
230
+
231
+ ---
232
+
233
+ # 🧠 10. Video-Level Aggregation Strategy
234
+
235
+ When multiple clips are extracted per video:
236
+
237
+ ```
238
+ logits → softmax → mean over clips → threshold decision
239
+ ```
240
+
241
+ If logits shape is `[B*K, 2]`, they are reshaped into:
242
+
243
+ ```
244
+ [B, K, 2] → mean(dim=1)
245
+ ```
246
+
247
+ Ensures video-level classification rather than frame-level.
248
+
249
+ ---
250
+
251
+ # 🖥 11. Recommended Environment
252
+
253
+ - Python 3.10
254
+ - CUDA-compatible GPU
255
+ - PyTorch 1.13.1 (cu117 build)
256
+ - NumPy 2.x (if using OpenCV ≥4.13)
257
+
258
+ ---
259
+
260
+ # 👨‍💻 Author
261
+
262
+ **Vinícius Passos Castilho Pinto**
263
+ Double-degree Engineering Student
264
+ Industrial & Automation Systems
265
+ France / Brazil
check_data.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Diagnose data directory structure and list file format
3
+ """
4
+ import os
5
+
6
+ def check_data_structure():
7
+ data_dir = r"C:\Users\vinip\Desktop\TALL4Deepfake\data"
8
+ list_file = r"C:\Users\vinip\Desktop\TALL4Deepfake\lists\cdf_test_fold.txt"
9
+
10
+ print("=" * 60)
11
+ print("Data Structure Diagnostic")
12
+ print("=" * 60)
13
+
14
+ # Check data directory
15
+ print(f"\n1. Checking data directory: {data_dir}")
16
+ if os.path.exists(data_dir):
17
+ print(" ✓ Directory exists")
18
+
19
+ # List subdirectories
20
+ subdirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
21
+ print(f"\n Subdirectories found: {subdirs}")
22
+
23
+ # Check a few subdirs
24
+ for subdir in subdirs[:3]:
25
+ subdir_path = os.path.join(data_dir, subdir)
26
+ video_folders = os.listdir(subdir_path)[:3]
27
+ print(f"\n {subdir}/ contains: {video_folders}")
28
+
29
+ # Check first video folder
30
+ if video_folders:
31
+ first_video = os.path.join(subdir_path, video_folders[0])
32
+ if os.path.isdir(first_video):
33
+ frames = os.listdir(first_video)[:5]
34
+ print(f" {video_folders[0]}/ contains: {frames}")
35
+ else:
36
+ print(" ✗ Directory not found!")
37
+ return
38
+
39
+ # Check list file
40
+ print(f"\n2. Checking list file: {list_file}")
41
+ if os.path.exists(list_file):
42
+ print(" ✓ File exists")
43
+
44
+ with open(list_file, 'r') as f:
45
+ lines = f.readlines()
46
+
47
+ print(f" Total lines: {len(lines)}")
48
+ print("\n First 5 lines:")
49
+ for i, line in enumerate(lines[:5], 1):
50
+ print(f" {i}: {repr(line.strip())}")
51
+
52
+ # Analyze path format
53
+ print("\n3. Path format analysis:")
54
+ first_line = lines[0].strip()
55
+ parts = first_line.split()
56
+ if parts:
57
+ video_path = parts[0]
58
+ print(f" Video path: {repr(video_path)}")
59
+ print(f" Uses forward slashes: {'/' in video_path}")
60
+ print(f" Uses backslashes: {chr(92) in video_path}")
61
+
62
+ # Try to construct full path
63
+ full_path = os.path.join(data_dir, video_path)
64
+ print(f"\n Constructed path: {full_path}")
65
+ print(f" Path exists: {os.path.exists(full_path)}")
66
+
67
+ # Try with normalization
68
+ normalized = os.path.normpath(os.path.join(data_dir, video_path))
69
+ print(f"\n Normalized path: {normalized}")
70
+ print(f" Path exists: {os.path.exists(normalized)}")
71
+
72
+ # Check if video folder exists
73
+ if not os.path.exists(normalized):
74
+ # Try to find similar paths
75
+ print("\n ⚠ Path doesn't exist. Looking for similar paths...")
76
+ video_name = os.path.basename(video_path)
77
+ category = os.path.dirname(video_path).replace('/', os.sep).replace('\\', os.sep)
78
+
79
+ category_path = os.path.join(data_dir, category)
80
+ if os.path.exists(category_path):
81
+ contents = os.listdir(category_path)
82
+ print(f"\n Directory '{category}' contains: {contents[:10]}")
83
+ else:
84
+ print(" ✗ File not found!")
85
+
86
+ print("\n" + "=" * 60)
87
+
88
+ if __name__ == "__main__":
89
+ check_data_structure()
engine.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ """
8
+ Train and eval functions used in main.py
9
+ """
10
+ from typing import Iterable, Optional
11
+ from einops import rearrange
12
+ import torch
13
+ import numpy
14
+ from timm.data import Mixup
15
+ from timm.utils import accuracy, ModelEma
16
+ import utils
17
+ from sklearn.metrics import roc_auc_score
18
+
19
+ def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
20
+ data_loader: Iterable, num_cilps:int, optimizer: torch.optim.Optimizer,
21
+ device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
22
+ model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
23
+ world_size: int = 1, distributed: bool = True, amp=True,
24
+ contrastive_nomixup=False, hard_contrastive=False,
25
+ finetune=False
26
+ ):
27
+ # TODO fix this for finetuning
28
+ if finetune:
29
+ model.train(not finetune)
30
+ else:
31
+ model.train()
32
+ #criterion.train()
33
+ metric_logger = utils.MetricLogger(delimiter=" ")
34
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.8f}'))
35
+ header = 'Epoch: [{}]'.format(epoch)
36
+ print_freq = 50
37
+
38
+ for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
39
+
40
+ batch_size = targets.size(0)
41
+
42
+ samples = samples.to(device, non_blocking=True)
43
+ targets = targets.to(device, non_blocking=True)
44
+
45
+ if mixup_fn is not None:
46
+ # batch size has to be an even number
47
+ if batch_size == 1:
48
+ continue
49
+ if batch_size % 2 != 0:
50
+ samples, targets = samples[:-1], targets[:-1]
51
+ samples, targets = mixup_fn(samples, targets)
52
+
53
+ with torch.cuda.amp.autocast(enabled=amp):
54
+
55
+ outputs = model(samples)
56
+ outputs = outputs.reshape(batch_size, num_cilps, -1).mean(dim=1)
57
+
58
+
59
+ loss = criterion(outputs, targets)
60
+
61
+
62
+ loss_value = loss.item()
63
+
64
+ optimizer.zero_grad()
65
+
66
+ # this attribute is added by timm on one optimizer (adahessian)
67
+ is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
68
+
69
+ if amp:
70
+ loss_scaler(loss, optimizer, clip_grad=max_norm,
71
+ parameters=model.parameters(), create_graph=is_second_order)
72
+ else:
73
+ loss.backward(create_graph=is_second_order)
74
+ if max_norm is not None and max_norm != 0.0:
75
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
76
+ optimizer.step()
77
+
78
+ torch.cuda.synchronize()
79
+ if model_ema is not None:
80
+ model_ema.update(model)
81
+
82
+ metric_logger.update(loss=loss_value)
83
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
84
+ # gather the stats from all processes
85
+ metric_logger.synchronize_between_processes()
86
+ print("Averaged stats:", metric_logger)
87
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
88
+
89
+
90
+ @torch.no_grad()
91
+ def evaluate(data_loader, model, device, world_size, distributed=True, amp=False, num_crops=1, num_clips=1):
92
+ criterion = torch.nn.CrossEntropyLoss()
93
+ to_np = lambda x: x.data.cpu().numpy()
94
+ metric_logger = utils.MetricLogger(delimiter=" ")
95
+ header = 'Test:'
96
+
97
+ # switch to evaluation mode
98
+ model.eval()
99
+
100
+ outputs = []
101
+ targets = []
102
+ logits = []
103
+ binary_label = []
104
+ for images, target in metric_logger.log_every(data_loader, 10, header):
105
+
106
+ images = images.to(device, non_blocking=True)
107
+ target = target.to(device, non_blocking=True)
108
+ # compute output
109
+ batch_size = images.shape[0]
110
+
111
+ with torch.cuda.amp.autocast(enabled=amp):
112
+
113
+ output = model(images)
114
+
115
+ output = output.reshape(batch_size, num_crops * num_clips, -1).mean(dim=1)
116
+ output_np = to_np(output[:,1])
117
+
118
+
119
+ if distributed:
120
+ outputs.append(concat_all_gather(output))
121
+ targets.append(concat_all_gather(target))
122
+ output_ = concat_all_gather(output)
123
+ target_ = concat_all_gather(target)
124
+ output_np_ = to_np(output_[:,1])
125
+ logits.append(output_np_)
126
+ binary_label.append(target_.detach().cpu())
127
+ else:
128
+ outputs.append(output)
129
+ targets.append(target)
130
+ logits.append(output_np)
131
+ binary_label.append(target.detach().cpu())
132
+ batch_size = images.shape[0]
133
+
134
+ acc1 = accuracy(output, target, topk=(1,))[0]
135
+ metric_logger.meters['acc1'].update(acc1.item(), images.size(0))
136
+
137
+ # import pdb;pdb.set_trace()
138
+
139
+ acc_outputs = numpy.stack(logits,0).reshape(-1,1)
140
+ acc_label = numpy.stack(binary_label,0).reshape(-1,1)
141
+
142
+ outputs = torch.cat(outputs, dim=0)
143
+ targets = torch.cat(targets, dim=0)
144
+
145
+ auc_score = roc_auc_score(acc_label, acc_outputs)
146
+
147
+ real_loss = criterion(outputs, targets)
148
+ metric_logger.update(loss=real_loss.item())
149
+
150
+ print('* Acc@1 {top1.global_avg:.3f} AUC {auc} loss {losses.global_avg:.3f}'
151
+ .format(top1=metric_logger.acc1,auc=auc_score,losses=metric_logger.loss))
152
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
153
+
154
+
155
+ @torch.no_grad()
156
+ def concat_all_gather(tensor):
157
+ """
158
+ Performs all_gather operation on the provided tensors.
159
+ *** Warning ***: torch.distributed.all_gather has no gradient.
160
+ """
161
+ tensors_gather = [torch.ones_like(tensor)
162
+ for _ in range(torch.distributed.get_world_size())]
163
+ torch.distributed.all_gather(tensors_gather, tensor.contiguous(), async_op=False)
164
+
165
+ #output = torch.cat(tensors_gather, dim=0)
166
+ if tensor.dim() == 1:
167
+ output = rearrange(tensors_gather, 'n b -> (b n)')
168
+ else:
169
+ output = rearrange(tensors_gather, 'n b c -> (b n) c')
170
+
171
+ return output
find_tall_model.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Find the TALL_SWIN model definition in the codebase
3
+ """
4
+ import os
5
+ import re
6
+
7
+ def search_for_tall_model(root_dir):
8
+ """Search for TALL_SWIN class definition"""
9
+ print("=" * 60)
10
+ print("Searching for TALL_SWIN model definition...")
11
+ print("=" * 60)
12
+
13
+ patterns = [
14
+ r'class\s+TALL_SWIN',
15
+ r'class\s+TallSwin',
16
+ r'class\s+TALLSwin',
17
+ r'def\s+TALL_SWIN',
18
+ r'@register_model.*TALL',
19
+ ]
20
+
21
+ found_files = []
22
+
23
+ # Search in models directory and subdirectories
24
+ for dirpath, dirnames, filenames in os.walk(root_dir):
25
+ # Skip venv and common ignore dirs
26
+ dirnames[:] = [d for d in dirnames if d not in ['venv', '.git', '__pycache__', 'node_modules']]
27
+
28
+ for filename in filenames:
29
+ if filename.endswith('.py'):
30
+ filepath = os.path.join(dirpath, filename)
31
+ try:
32
+ with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
33
+ content = f.read()
34
+
35
+ for pattern in patterns:
36
+ matches = re.findall(pattern, content, re.IGNORECASE)
37
+ if matches:
38
+ found_files.append(filepath)
39
+ print(f"\n✓ Found in: {filepath}")
40
+ print(f" Pattern matched: {pattern}")
41
+
42
+ # Show context
43
+ lines = content.split('\n')
44
+ for i, line in enumerate(lines):
45
+ if re.search(pattern, line, re.IGNORECASE):
46
+ start = max(0, i-2)
47
+ end = min(len(lines), i+10)
48
+ print(f"\n Context (lines {start+1}-{end+1}):")
49
+ print(" " + "-"*50)
50
+ for j in range(start, end):
51
+ marker = ">>>" if j == i else " "
52
+ print(f" {marker} {j+1:4d}: {lines[j]}")
53
+ print(" " + "-"*50)
54
+ break
55
+
56
+ except Exception as e:
57
+ pass
58
+
59
+ if not found_files:
60
+ print("\n✗ No TALL_SWIN model definition found!")
61
+ print("\nSearching for any Swin-related files...")
62
+
63
+ # Broader search
64
+ for dirpath, dirnames, filenames in os.walk(root_dir):
65
+ dirnames[:] = [d for d in dirnames if d not in ['venv', '.git', '__pycache__']]
66
+
67
+ for filename in filenames:
68
+ if 'swin' in filename.lower() or 'tall' in filename.lower():
69
+ print(f" Found file: {os.path.join(dirpath, filename)}")
70
+
71
+ return found_files
72
+
73
+ if __name__ == "__main__":
74
+ import sys
75
+
76
+ # Get the repo directory
77
+ if len(sys.argv) > 1:
78
+ repo_dir = sys.argv[1]
79
+ else:
80
+ repo_dir = os.getcwd()
81
+
82
+ print(f"Searching in: {repo_dir}\n")
83
+
84
+ found = search_for_tall_model(repo_dir)
85
+
86
+ print("\n" + "=" * 60)
87
+ if found:
88
+ print(f"Found {len(found)} file(s) with TALL_SWIN definition")
89
+ print("\nNext steps:")
90
+ print("1. Check the file(s) above for the TALL_SWIN class")
91
+ print("2. Ensure this file is imported in models/__init__.py")
92
+ print("3. Example fix for models/__init__.py:")
93
+ print("\n from .tall_swin import TALL_SWIN")
94
+ print(" # or")
95
+ print(" from .tall_swin import *")
96
+ else:
97
+ print("No TALL_SWIN definition found!")
98
+ print("\nPossible issues:")
99
+ print("1. The model file is in a different location")
100
+ print("2. The model has a different class name")
101
+ print("3. The model code is missing from the repository")
102
+ print("=" * 60)
infer_videos_txt.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # infer_videos_txt.py
2
+ # -------------------------------------------
3
+ # Inference on a list of videos (one path per line in a .txt).
4
+ # For each video, frames are extracted to a temporary folder and then fed to the model
5
+ # using the same VideoDataSet pipeline as training/evaluation.
6
+ #
7
+ # PowerShell usage:
8
+ # python infer_videos_txt.py `
9
+ # --video_list "C:\path\videos.txt" `
10
+ # --initial_checkpoint "C:\path\model_best.pth" `
11
+ # --dataset ffpp `
12
+ # --threshold 0.7 `
13
+ # --num_clips 8 --duration 4 --thumbnail_rows 2 `
14
+ # --num_workers 0 `
15
+ # --output_json "C:\path\results.json" `
16
+ # --output_csv "C:\path\results.csv"
17
+ # -------------------------------------------
18
+
19
+ import os
20
+ import csv
21
+ import json
22
+ import argparse
23
+ import shutil
24
+ import tempfile
25
+ from typing import List, Dict, Any
26
+
27
+ import cv2
28
+ import numpy as np
29
+ import torch
30
+ import torch.backends.cudnn as cudnn
31
+ from timm.models import create_model
32
+
33
+ import my_models # registers TALL_SWIN
34
+ import utils
35
+
36
+ from video_dataset import VideoDataSet
37
+ from video_dataset_aug import get_augmentor, build_dataflow
38
+ from video_dataset_config import get_dataset_config
39
+
40
+
41
+ # -----------------------------
42
+ # Frame extraction (OpenCV)
43
+ # -----------------------------
44
+ def extract_frames_opencv(video_path: str, out_dir: str, digits: int = 5, max_frames: int = 0) -> int:
45
+ """Extract all frames from a video to out_dir as {00001}.jpg, {00002}.jpg, ...
46
+ Returns the number of frames extracted.
47
+ """
48
+ cap = cv2.VideoCapture(video_path)
49
+ if not cap.isOpened():
50
+ raise RuntimeError(f"Could not open video: {video_path}")
51
+
52
+ idx = 1
53
+ fmt = f"{{:0{digits}d}}.jpg"
54
+ while True:
55
+ ok, frame = cap.read()
56
+ if not ok:
57
+ break
58
+ out_file = os.path.join(out_dir, fmt.format(idx))
59
+ cv2.imwrite(out_file, frame)
60
+ idx += 1
61
+ if max_frames > 0 and (idx - 1) >= max_frames:
62
+ break
63
+
64
+ cap.release()
65
+ return idx - 1
66
+
67
+
68
+ def detect_digits_for_tmpl(image_tmpl: str) -> int:
69
+ """Infer digits from an image template like '{:05d}.jpg' -> 5."""
70
+ import re
71
+ m = re.search(r"\{:\s*0?(\d+)d\}", image_tmpl)
72
+ if m:
73
+ return int(m.group(1))
74
+ return 5
75
+
76
+
77
+ # -----------------------------
78
+ # IO helpers
79
+ # -----------------------------
80
+ def read_video_list_txt(path: str) -> List[str]:
81
+ """Read a text file with one video path per line; returns unique absolute paths."""
82
+ videos: List[str] = []
83
+ with open(path, "r", encoding="utf-8") as f:
84
+ for line in f:
85
+ p = line.strip().strip('"').strip("'")
86
+ if not p:
87
+ continue
88
+ videos.append(os.path.abspath(p))
89
+
90
+ # De-duplicate while keeping order
91
+ seen = set()
92
+ uniq: List[str] = []
93
+ for p in videos:
94
+ if p not in seen:
95
+ uniq.append(p)
96
+ seen.add(p)
97
+ return uniq
98
+
99
+
100
+ def write_csv(path: str, rows: List[Dict[str, Any]]) -> None:
101
+ """Write per-video results to CSV."""
102
+ if os.path.dirname(path):
103
+ os.makedirs(os.path.dirname(path), exist_ok=True)
104
+
105
+ fieldnames = [
106
+ "video",
107
+ "pred_name",
108
+ "pred",
109
+ "threshold",
110
+ "p_real",
111
+ "p_fake",
112
+ "n_frames",
113
+ "n_preds",
114
+ "status",
115
+ "error",
116
+ ]
117
+ with open(path, "w", newline="", encoding="utf-8") as f:
118
+ w = csv.DictWriter(f, fieldnames=fieldnames)
119
+ w.writeheader()
120
+ for r in rows:
121
+ w.writerow({k: r.get(k, "") for k in fieldnames})
122
+
123
+
124
+ def write_json(path: str, payload: Dict[str, Any]) -> None:
125
+ """Write results summary to JSON."""
126
+ if os.path.dirname(path):
127
+ os.makedirs(os.path.dirname(path), exist_ok=True)
128
+ with open(path, "w", encoding="utf-8") as f:
129
+ json.dump(payload, f, indent=2, ensure_ascii=False)
130
+
131
+
132
+ # -----------------------------
133
+ # Temporary dataset builder
134
+ # -----------------------------
135
+ def build_tmp_dataset_from_video(video_path: str, tmp_dir: str, image_tmpl: str) -> Dict[str, Any]:
136
+ """Create a temporary folder with extracted frames and a one-line list file compatible with VideoDataSet."""
137
+ video_id = "video"
138
+ video_folder = os.path.join(tmp_dir, video_id)
139
+ os.makedirs(video_folder, exist_ok=True)
140
+
141
+ digits = detect_digits_for_tmpl(image_tmpl)
142
+ n = extract_frames_opencv(video_path, video_folder, digits=digits, max_frames=0)
143
+
144
+ if n < 4:
145
+ raise RuntimeError(f"Video has too few frames ({n}). Need >= 4.")
146
+
147
+ # VideoDataSet expects a list file with: <video_id> <start_idx> <num_frames> <label>
148
+ list_file_abs = os.path.join(tmp_dir, "one.txt")
149
+ with open(list_file_abs, "w", encoding="utf-8") as f:
150
+ f.write(f"{video_id} 1 {n} 0\n") # dummy label (0), unused during inference
151
+
152
+ return {"list_rel": "one.txt", "nframes": n, "video_id": video_id}
153
+
154
+
155
+ # -----------------------------
156
+ # Model + augmentor builder
157
+ # -----------------------------
158
+ def build_model_and_augmentor(args):
159
+ """Build model, load checkpoint, and create the same evaluation augmentor."""
160
+ utils.init_distributed_mode(args) # will set args.distributed, etc.
161
+
162
+ device = torch.device(args.device)
163
+ cudnn.benchmark = True
164
+
165
+ # Get dataset config for: num_classes, separator, image_tmpl, filter_video, etc.
166
+ num_classes, _, _, _, filename_seperator, image_tmpl, filter_video, _ = get_dataset_config(
167
+ args.dataset, args.use_lmdb
168
+ )
169
+ args.num_classes = num_classes
170
+
171
+ print(f"Creating model: {args.model}")
172
+ model = create_model(
173
+ args.model,
174
+ pretrained=False,
175
+ duration=args.duration,
176
+ hpe_to_token=args.hpe_to_token,
177
+ rel_pos=args.rel_pos,
178
+ window_size=args.window_size,
179
+ thumbnail_rows=args.thumbnail_rows,
180
+ token_mask=not args.no_token_mask,
181
+ online_learning=False,
182
+ num_classes=args.num_classes,
183
+ drop_rate=args.drop,
184
+ drop_path_rate=args.drop_path,
185
+ drop_block_rate=args.drop_block,
186
+ use_checkpoint=args.use_checkpoint,
187
+ ).to(device)
188
+ model.eval()
189
+
190
+ ckpt = torch.load(args.initial_checkpoint, map_location="cpu")
191
+ if isinstance(ckpt, dict) and "model" in ckpt:
192
+ utils.load_checkpoint(model, ckpt["model"])
193
+ else:
194
+ # If the checkpoint is a raw state_dict
195
+ model.load_state_dict(ckpt, strict=False)
196
+
197
+ mean = (0.5, 0.5, 0.5) if "mean" not in model.default_cfg else model.default_cfg["mean"]
198
+ std = (0.5, 0.5, 0.5) if "std" not in model.default_cfg else model.default_cfg["std"]
199
+
200
+ augmentor = get_augmentor(
201
+ False, # is_train
202
+ args.input_size, # input_size
203
+ mean,
204
+ std,
205
+ args.disable_scaleup,
206
+ threed_data=args.threed_data,
207
+ version=args.augmentor_ver,
208
+ scale_range=args.scale_range,
209
+ num_clips=args.num_clips,
210
+ num_crops=args.num_crops,
211
+ dataset=args.dataset
212
+ )
213
+
214
+
215
+ meta = {
216
+ "device": device,
217
+ "model": model,
218
+ "augmentor": augmentor,
219
+ "num_classes": num_classes,
220
+ "filename_seperator": filename_seperator,
221
+ "image_tmpl": image_tmpl,
222
+ "filter_video": filter_video,
223
+ }
224
+ return meta
225
+
226
+
227
+ # -----------------------------
228
+ # Inference
229
+ # -----------------------------
230
+ @torch.no_grad()
231
+ def infer_one_video_from_tmp(args, meta: Dict[str, Any], tmp_root: str, list_rel_path: str, image_tmpl: str) -> Dict[str, Any]:
232
+ """Run inference on a single temporary dataset with 1 video."""
233
+ device = meta["device"]
234
+ model = meta["model"]
235
+ augmentor = meta["augmentor"]
236
+
237
+ dataset = VideoDataSet(
238
+ root_path=tmp_root,
239
+ list_file=list_rel_path, # relative to root_path for VideoDataSet
240
+ num_groups=args.duration,
241
+ frames_per_group=args.frames_per_group,
242
+ sample_offset=0,
243
+ num_clips=args.num_clips,
244
+ modality=args.modality,
245
+ dense_sampling=args.dense_sampling,
246
+ fixed_offset=True,
247
+ image_tmpl=image_tmpl, # enforce correct template (e.g., {:05d}.jpg)
248
+ transform=augmentor,
249
+ is_train=False,
250
+ test_mode=False,
251
+ seperator=meta["filename_seperator"],
252
+ filter_video=meta["filter_video"],
253
+ num_classes=meta["num_classes"],
254
+ whole_video=False,
255
+ )
256
+
257
+ loader = build_dataflow(
258
+ dataset, is_train=False, batch_size=1,
259
+ workers=args.num_workers, is_distributed=False
260
+ )
261
+
262
+ logits_all = []
263
+ for samples, _targets in loader:
264
+ samples = samples.to(device, non_blocking=True)
265
+ logits = model(samples) # shape [1,2] typically (or [K,2] depending on pipeline)
266
+ logits_all.append(logits.detach().cpu())
267
+
268
+ logits_all = torch.cat(logits_all, dim=0) # [n_preds, 2]
269
+ logits_mean = logits_all.mean(dim=0, keepdim=True) # [1,2]
270
+ probs = torch.softmax(logits_mean, dim=1).numpy()[0]
271
+
272
+ # Threshold-based decision (class 1 = FAKE)
273
+ thr = float(args.threshold)
274
+ pred = int(probs[1] >= thr)
275
+
276
+ return {
277
+ "threshold": thr,
278
+ "p_real": float(probs[0]),
279
+ "p_fake": float(probs[1]),
280
+ "pred": pred,
281
+ "pred_name": "FAKE" if pred == 1 else "REAL",
282
+ "n_preds": int(logits_all.shape[0]),
283
+ }
284
+
285
+
286
+ # -----------------------------
287
+ # CLI
288
+ # -----------------------------
289
+ def get_args():
290
+ ap = argparse.ArgumentParser("Infer TALL_SWIN from a list of videos (txt)")
291
+
292
+ ap.add_argument("--video_list", required=True, help="Text file with one video path per line")
293
+ ap.add_argument("--initial_checkpoint", required=True)
294
+ ap.add_argument("--dataset", default="ffpp")
295
+ ap.add_argument("--model", default="TALL_SWIN")
296
+ ap.add_argument("--device", default="cuda")
297
+ ap.add_argument("--num_workers", type=int, default=0)
298
+
299
+ ap.add_argument("--duration", type=int, default=4)
300
+ ap.add_argument("--frames_per_group", type=int, default=1)
301
+ ap.add_argument("--num_clips", type=int, default=8)
302
+ ap.add_argument("--num_crops", type=int, default=1)
303
+ ap.add_argument("--thumbnail_rows", type=int, default=2)
304
+ ap.add_argument("--input_size", type=int, default=224)
305
+
306
+ ap.add_argument("--threshold", type=float, default=0.5,
307
+ help="Decision threshold for FAKE (pred=1 if p_fake >= threshold)")
308
+
309
+ ap.add_argument("--disable_scaleup", action="store_true")
310
+ ap.add_argument("--threed_data", default=False)
311
+ ap.add_argument("--dense_sampling", default=True)
312
+ ap.add_argument("--augmentor_ver", default="v1")
313
+ ap.add_argument("--scale_range", default=[256, 320], type=int, nargs="+")
314
+ ap.add_argument("--modality", default="rgb")
315
+ ap.add_argument("--use_lmdb", default=False)
316
+
317
+ ap.add_argument("--hpe_to_token", action="store_true")
318
+ ap.add_argument("--rel_pos", action="store_true")
319
+ ap.add_argument("--window_size", type=int, default=7)
320
+ ap.add_argument("--no_token_mask", action="store_true")
321
+
322
+ ap.add_argument("--drop", type=float, default=0.0)
323
+ ap.add_argument("--drop_path", type=float, default=0.1)
324
+ ap.add_argument("--drop_block", default=None)
325
+ ap.add_argument("--use_checkpoint", default=False)
326
+
327
+ ap.add_argument("--dist_url", default="env://")
328
+ ap.add_argument("--world_size", default=1, type=int)
329
+ ap.add_argument("--local_rank", default=None, type=int)
330
+
331
+ ap.add_argument("--output_json", default="", help="Optional path to save results JSON")
332
+ ap.add_argument("--output_csv", default="", help="Optional path to save results CSV")
333
+
334
+ return ap.parse_args()
335
+
336
+
337
+ # -----------------------------
338
+ # Main
339
+ # -----------------------------
340
+ def main():
341
+ args = get_args()
342
+
343
+ if not os.path.isfile(args.video_list):
344
+ raise FileNotFoundError(args.video_list)
345
+ if not os.path.isfile(args.initial_checkpoint):
346
+ raise FileNotFoundError(args.initial_checkpoint)
347
+
348
+ videos = read_video_list_txt(args.video_list)
349
+ if len(videos) == 0:
350
+ raise RuntimeError("video_list is empty.")
351
+
352
+ # Build model + augmentor once
353
+ meta = build_model_and_augmentor(args)
354
+
355
+ results_rows: List[Dict[str, Any]] = []
356
+ ok_count = 0
357
+
358
+ print(f"\nVideos to process: {len(videos)}")
359
+ print(f"Checkpoint: {args.initial_checkpoint}")
360
+ print(f"Threshold: {args.threshold}\n")
361
+
362
+ for i, video_path in enumerate(videos, 1):
363
+ row: Dict[str, Any] = {
364
+ "video": video_path,
365
+ "status": "ok",
366
+ "error": "",
367
+ "pred": "",
368
+ "pred_name": "",
369
+ "threshold": float(args.threshold),
370
+ "p_real": "",
371
+ "p_fake": "",
372
+ "n_frames": "",
373
+ "n_preds": "",
374
+ }
375
+
376
+ if not os.path.isfile(video_path):
377
+ row["status"] = "skip"
378
+ row["error"] = "file_not_found"
379
+ results_rows.append(row)
380
+ print(f"[{i}/{len(videos)}] [SKIP] Not found: {video_path}")
381
+ continue
382
+
383
+ tmp_dir = tempfile.mkdtemp(prefix="tall_infer_")
384
+ try:
385
+ tmp_info = build_tmp_dataset_from_video(video_path, tmp_dir, image_tmpl=meta["image_tmpl"])
386
+ row["n_frames"] = int(tmp_info["nframes"])
387
+
388
+ out = infer_one_video_from_tmp(
389
+ args, meta, tmp_dir, tmp_info["list_rel"], image_tmpl=meta["image_tmpl"]
390
+ )
391
+ row.update(out)
392
+ ok_count += 1
393
+
394
+ print(f"[{i}/{len(videos)}] [{row['pred_name']}] {video_path} | p_fake={float(row['p_fake']):.4f}")
395
+
396
+ except Exception as e:
397
+ row["status"] = "error"
398
+ row["error"] = str(e)
399
+ print(f"[{i}/{len(videos)}] [ERROR] {video_path}\n -> {e}")
400
+
401
+ finally:
402
+ shutil.rmtree(tmp_dir, ignore_errors=True)
403
+
404
+ results_rows.append(row)
405
+
406
+ summary = {
407
+ "video_list": os.path.abspath(args.video_list),
408
+ "checkpoint": os.path.abspath(args.initial_checkpoint),
409
+ "dataset": args.dataset,
410
+ "model": args.model,
411
+ "threshold": float(args.threshold),
412
+ "num_videos": len(videos),
413
+ "num_ok": ok_count,
414
+ "results": results_rows,
415
+ }
416
+
417
+ print("\n=== SUMMARY ===")
418
+ print(f"ok: {ok_count}/{len(videos)}")
419
+ if ok_count > 0:
420
+ pf = [float(r["p_fake"]) for r in results_rows if r.get("status") == "ok"]
421
+ print(f"avg p_fake (ok only): {sum(pf)/len(pf):.4f}")
422
+
423
+ if args.output_json:
424
+ write_json(args.output_json, summary)
425
+ print(f"Saved JSON: {os.path.abspath(args.output_json)}")
426
+
427
+ if args.output_csv:
428
+ write_csv(args.output_csv, results_rows)
429
+ print(f"Saved CSV: {os.path.abspath(args.output_csv)}")
430
+
431
+
432
+ if __name__ == "__main__":
433
+ main()
keep_only_numbered.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, re
2
+ from pathlib import Path
3
+
4
+ IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
5
+ NUM_RE = re.compile(r"^\d+$") # stem must be only digits
6
+
7
+ def main():
8
+ ap = argparse.ArgumentParser()
9
+ ap.add_argument("--root", required=True)
10
+ ap.add_argument("--dry_run", action="store_true", help="Only print what would be deleted")
11
+ args = ap.parse_args()
12
+
13
+ root = Path(args.root)
14
+ to_delete = []
15
+
16
+ for cls in ["real", "fake"]:
17
+ cls_dir = root / cls
18
+ if not cls_dir.exists():
19
+ continue
20
+ for vid_dir in cls_dir.iterdir():
21
+ if not vid_dir.is_dir():
22
+ continue
23
+ for p in vid_dir.iterdir():
24
+ if not p.is_file():
25
+ continue
26
+ if p.suffix.lower() not in IMG_EXTS:
27
+ continue
28
+ if not NUM_RE.match(p.stem):
29
+ to_delete.append(p)
30
+
31
+ print(f"Found {len(to_delete)} non-numbered frames to delete.")
32
+ for p in to_delete[:30]:
33
+ print("DEL", p)
34
+ if len(to_delete) > 30:
35
+ print("...")
36
+
37
+ if args.dry_run:
38
+ print("Dry-run only. No files deleted.")
39
+ return
40
+
41
+ for p in to_delete:
42
+ try:
43
+ p.unlink()
44
+ except Exception as e:
45
+ print("FAILED", p, e)
46
+
47
+ print("Done.")
48
+
49
+ if __name__ == "__main__":
50
+ main()
main.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import numpy as np
4
+ import time
5
+ import torch
6
+ import torch.backends.cudnn as cudnn
7
+ import json
8
+ import os
9
+ import warnings
10
+
11
+ from pathlib import Path
12
+
13
+ from timm.data import Mixup
14
+ from timm.models import create_model
15
+ from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
16
+ from timm.scheduler import create_scheduler
17
+ from timm.optim import create_optimizer
18
+ from timm.utils import NativeScaler, get_state_dict, ModelEma
19
+
20
+ #from datasets import build_dataset
21
+ from engine import train_one_epoch, evaluate
22
+ import models
23
+ import my_models
24
+ import torch.nn as nn
25
+
26
+ import utils
27
+
28
+ from video_dataset import VideoDataSet
29
+ from video_dataset_aug import get_augmentor, build_dataflow
30
+ from video_dataset_config import get_dataset_config, DATASET_CONFIG
31
+
32
+ warnings.filterwarnings("ignore", category=UserWarning)
33
+
34
+ def get_args_parser():
35
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
36
+ parser.add_argument('--model_name',default="TALL_SWIN")
37
+ parser.add_argument('--batch-size', default=2, type=int)
38
+ parser.add_argument('--epochs', default=30, type=int)
39
+
40
+ # Dataset parameters
41
+ parser.add_argument('--data_txt_dir', type=str,default='##path_for_dataset_txt##', help='path to text of dataset')
42
+ parser.add_argument('--data_dir', type=str,default="##path_for_dataset##", help='path to dataset')
43
+ parser.add_argument('--dataset', default='ffpp_train',
44
+ choices=list(DATASET_CONFIG.keys()), help='path to dataset file list')
45
+ parser.add_argument('--duration', default=8, type=int, help='number of frames')
46
+ parser.add_argument('--frames_per_group', default=1, type=int,
47
+ help='[uniform sampling] number of frames per group; '
48
+ '[dense sampling]: sampling frequency')
49
+ parser.add_argument('--threed_data', default=False, help='load data in the layout for 3D conv')
50
+ parser.add_argument('--input_size', default=224, type=int, metavar='N', help='input image size')
51
+ parser.add_argument('--disable_scaleup', action='store_true',
52
+ help='do not scale up and then crop a small region, directly crop the input_size')
53
+ parser.add_argument('--random_sampling', action='store_true',
54
+ help='perform determinstic sampling for data loader')
55
+ parser.add_argument('--dense_sampling', default=True,
56
+ help='perform dense sampling for data loader')
57
+ parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2'],
58
+ help='[v1] TSN data argmentation, [v2] resize the shorter side to `scale_range`')
59
+ parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+",
60
+ metavar='scale_range', help='scale range for augmentor v2')
61
+ parser.add_argument('--modality', default='rgb', type=str, help='rgb or flow')
62
+ parser.add_argument('--use_lmdb', default=False, help='use lmdb instead of jpeg.')
63
+ parser.add_argument('--use_pyav', default=False, help='use video directly.')
64
+
65
+ # temporal module
66
+ parser.add_argument('--pretrained', action='store_true', default=False,
67
+ help='Start with pretrained version of specified network (if avail)')
68
+ parser.add_argument('--temporal_module_name', default=None, type=str, metavar='TEM', choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA'],
69
+ help='temporal module applied. [TAM]')
70
+ parser.add_argument('--temporal_attention_only', action='store_true', default=False,
71
+ help='use attention only in temporal module]')
72
+ parser.add_argument('--no_token_mask', action='store_true', default=False, help='do not apply token mask')
73
+ parser.add_argument('--temporal_heads_scale', default=1.0, type=float, help='scale of the number of spatial heads')
74
+ parser.add_argument('--temporal_mlp_scale', default=1.0, type=float, help='scale of spatial mlp')
75
+ parser.add_argument('--rel_pos', action='store_true', default=False,
76
+ help='use relative positioning in temporal module]')
77
+ parser.add_argument('--temporal_pooling', type=str, default=None, choices=['avg', 'max', 'conv', 'depthconv'],
78
+ help='perform temporal pooling]')
79
+ parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw'],
80
+ help='use depth-wise bottleneck in temporal attention')
81
+
82
+ parser.add_argument('--window_size', default=14, type=int, help='number of frames')
83
+ parser.add_argument('--thumbnail_rows', default=4, type=int, help='number of frames per row')
84
+
85
+ parser.add_argument('--hpe_to_token', default=False, action='store_true',
86
+ help='add hub position embedding to image tokens')
87
+ # Model parameters
88
+ parser.add_argument('--model', default='TALL_SWIN', type=str, metavar='MODEL',
89
+ help='Name of model to train')
90
+ parser.add_argument('--input-size', default=224, type=int, help='images input size')
91
+
92
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
93
+ help='Dropout rate (default: 0.)')
94
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
95
+ help='Drop path rate (default: 0.1)')
96
+ parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
97
+ help='Drop block rate (default: None)')
98
+
99
+ parser.add_argument('--model-ema', action='store_true')
100
+ parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
101
+ parser.set_defaults(model_ema=True)
102
+ parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
103
+ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
104
+
105
+ # Optimizer parameters
106
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
107
+ help='Optimizer (default: "adamw"')
108
+ parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
109
+ help='Optimizer Epsilon (default: 1e-8)')
110
+ parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
111
+ help='Optimizer Betas (default: None, use opt default)')
112
+ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
113
+ help='Clip gradient norm (default: None, no clipping)')
114
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
115
+ help='SGD momentum (default: 0.9)')
116
+ parser.add_argument('--weight-decay', type=float, default=1e-5,
117
+ help='weight decay (default: 0.05)')
118
+ # Learning rate schedule parameters
119
+ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
120
+ help='LR scheduler (default: "cosine"')
121
+ parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
122
+ help='learning rate (default: 5e-4)')
123
+ parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
124
+ help='learning rate noise on/off epoch percentages')
125
+ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
126
+ help='learning rate noise limit percent (default: 0.67)')
127
+ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
128
+ help='learning rate noise std-dev (default: 1.0)')
129
+ parser.add_argument('--warmup-lr', type=float, default=1e-8, metavar='LR',
130
+ help='warmup learning rate (default: 1e-6)')
131
+ parser.add_argument('--min-lr', type=float, default=1e-7, metavar='LR',
132
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
133
+
134
+ parser.add_argument('--decay-epochs', type=float, default=10, metavar='N',
135
+ help='epoch interval to decay LR')
136
+ parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
137
+ help='epochs to warmup LR, if scheduler supports')
138
+ parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
139
+ help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
140
+ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
141
+ help='patience epochs for Plateau LR scheduler (default: 10')
142
+ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
143
+ help='LR decay rate (default: 0.1)')
144
+
145
+ # Augmentation parameters
146
+ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
147
+ help='Color jitter factor (default: 0.4)')
148
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5', metavar='NAME',
149
+ help='Use AutoAugment policy. "v0" or "original". " + \
150
+ "(default: rand-m9-mstd0.5-inc1)'),
151
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
152
+ parser.add_argument('--train-interpolation', type=str, default='bicubic',
153
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
154
+
155
+ parser.add_argument('--repeated-aug', action='store_true')
156
+ parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
157
+ parser.set_defaults(repeated_aug=False)
158
+
159
+ # * Random Erase params
160
+ parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT',
161
+ help='Random erase prob (default: 0.25)')
162
+ parser.add_argument('--remode', type=str, default='pixel',
163
+ help='Random erase mode (default: "pixel")')
164
+ parser.add_argument('--recount', type=int, default=1,
165
+ help='Random erase count (default: 1)')
166
+ parser.add_argument('--resplit', action='store_true', default=False,
167
+ help='Do not random erase first (clean) augmentation split')
168
+
169
+ # * Mixup params
170
+ parser.add_argument('--cutout',default=True)
171
+ parser.add_argument('--mixup', type=float, default=0,
172
+ help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
173
+ parser.add_argument('--cutmix', type=float, default=0,
174
+ help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
175
+ parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
176
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
177
+ parser.add_argument('--mixup-prob', type=float, default=1.0,
178
+ help='Probability of performing mixup or cutmix when either/both is enabled')
179
+ parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
180
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
181
+ parser.add_argument('--mixup-mode', type=str, default='batch',
182
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
183
+
184
+ # Dataset parameters
185
+
186
+ parser.add_argument('--output_dir', default="",
187
+ help='path where to save, empty for no saving')
188
+ parser.add_argument('--device', default='cuda',
189
+ help='device to use for training / testing')
190
+ parser.add_argument('--seed', default=42, type=int)
191
+ parser.add_argument('--resume', default="", help='resume from checkpoint')
192
+ parser.add_argument('--no-resume-loss-scaler', action='store_false', dest='resume_loss_scaler')
193
+ parser.add_argument('--no-amp', action='store_false', dest='amp', help='disable amp')
194
+ parser.add_argument('--use_checkpoint', default=False, help='use checkpoint to save memory')
195
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
196
+ help='start epoch')
197
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
198
+ parser.add_argument('--num_workers', default=8, type=int)
199
+ parser.add_argument('--pin-mem', action='store_true',
200
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
201
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
202
+ help='')
203
+ parser.set_defaults(pin_mem=True)
204
+
205
+ # for testing and validation
206
+ parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10])
207
+ parser.add_argument('--num_clips', default=1, type=int)
208
+
209
+ # distributed training parameters
210
+ parser.add_argument('--world_size', default=1, type=int,
211
+ help='number of distributed processes')
212
+ parser.add_argument("--local_rank", type=int)
213
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
214
+
215
+
216
+ parser.add_argument('--auto-resume', default=True, help='auto resume')
217
+ # exp
218
+ # parser.add_argument('--simclr_w', type=float, default=0., help='weights for simclr loss')
219
+ parser.add_argument('--contrastive_nomixup', action='store_true', help='do not involve mixup in contrastive learning')
220
+ parser.add_argument('--finetune', default=False, help='finetune model')
221
+ parser.add_argument('--initial_checkpoint', type=str, default='', help='path to the pretrained model')
222
+
223
+ parser.add_argument('--hard_contrastive', action='store_true', help='use HEXA')
224
+ # parser.add_argument('--selfdis_w', type=float, default=0., help='enable self distillation')
225
+
226
+ return parser
227
+
228
+
229
+ def main(args):
230
+ utils.init_distributed_mode(args)
231
+ print(args)
232
+ # Patch
233
+ if not hasattr(args, 'hard_contrastive'):
234
+ args.hard_contrastive = False
235
+ if not hasattr(args, 'selfdis_w'):
236
+ args.selfdis_w = 0.0
237
+
238
+ #is_imnet21k = args.data_set == 'IMNET21K'
239
+
240
+ device = torch.device(args.device)
241
+
242
+ # fix the seed for reproducibility
243
+ seed = args.seed + utils.get_rank()
244
+ torch.manual_seed(seed)
245
+ np.random.seed(seed)
246
+ # random.seed(seed)
247
+
248
+ cudnn.benchmark = True
249
+
250
+ num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(
251
+ args.dataset, args.use_lmdb)
252
+
253
+ args.num_classes = num_classes
254
+ if args.modality == 'rgb':
255
+ args.input_channels = 3
256
+ elif args.modality == 'flow':
257
+ args.input_channels = 2 * 5
258
+
259
+
260
+ mixup_fn = None
261
+ mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
262
+ if mixup_active:
263
+ mixup_fn = Mixup(
264
+ mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
265
+ prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
266
+ label_smoothing=args.smoothing, num_classes=args.num_classes)
267
+
268
+ print(f"Creating model: {args.model}")
269
+
270
+ model =create_model(
271
+ args.model,
272
+ pretrained=args.pretrained,
273
+ duration=args.duration,
274
+ hpe_to_token = args.hpe_to_token,
275
+ rel_pos = args.rel_pos,
276
+ window_size=args.window_size,
277
+ thumbnail_rows = args.thumbnail_rows,
278
+ token_mask=not args.no_token_mask,
279
+ online_learning = False,
280
+ num_classes=args.num_classes,
281
+ drop_rate=args.drop,
282
+ drop_path_rate=args.drop_path,
283
+ drop_block_rate=args.drop_block,
284
+ use_checkpoint=args.use_checkpoint
285
+ )
286
+
287
+ # TODO: finetuning
288
+
289
+ model.to(device)
290
+
291
+ model_ema = None
292
+ if args.model_ema:
293
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
294
+ model_ema = ModelEma(
295
+ model,
296
+ decay=args.model_ema_decay,
297
+ device='cpu' if args.model_ema_force_cpu else '',
298
+ resume=args.resume)
299
+
300
+ model_without_ddp = model
301
+ if args.distributed:
302
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
303
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
304
+ model_without_ddp = model.module
305
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
306
+ print('number of params:', n_parameters)
307
+
308
+
309
+ optimizer = create_optimizer(args, model)
310
+ loss_scaler = NativeScaler()
311
+ #print(f"Scaled learning rate (batch size: {args.batch_size * utils.get_world_size()}): {linear_scaled_lr}")
312
+ lr_scheduler, _ = create_scheduler(args, optimizer)
313
+
314
+ criterion = LabelSmoothingCrossEntropy()
315
+
316
+ if args.mixup > 0.:
317
+ # smoothing is handled with mixup label transform
318
+ criterion = SoftTargetCrossEntropy()
319
+ elif args.smoothing:
320
+ criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
321
+ else:
322
+ criterion = torch.nn.CrossEntropyLoss()
323
+
324
+ if args.distributed:
325
+ mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean']
326
+ std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std']
327
+ else:
328
+ mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean']
329
+ std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std']
330
+ # dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
331
+ # create data loaders w/ augmentation pipeiine
332
+ video_data_cls = VideoDataSet
333
+ train_list = os.path.join(args.data_txt_dir, train_list_name)
334
+
335
+ train_augmentor = get_augmentor(True, args.input_size, mean, std, threed_data=False,
336
+ version=args.augmentor_ver, scale_range=args.scale_range, cut_out = args.cutout,dataset=args.dataset)
337
+ dataset_train = video_data_cls(args.data_dir, train_list, args.duration, args.frames_per_group,
338
+ num_clips=args.num_clips,
339
+ modality=args.modality, image_tmpl=image_tmpl,
340
+ dense_sampling=args.dense_sampling,
341
+ transform=train_augmentor, is_train=True, test_mode=False,
342
+ seperator=filename_seperator, filter_video=filter_video)
343
+
344
+ num_tasks = utils.get_world_size()
345
+ data_loader_train = build_dataflow(dataset_train, is_train=True, batch_size=args.batch_size,
346
+ workers=args.num_workers, is_distributed=args.distributed)
347
+
348
+ val_list = os.path.join(args.data_txt_dir, val_list_name)
349
+ val_augmentor = get_augmentor(False, args.input_size, mean, std, args.disable_scaleup,
350
+ threed_data=args.threed_data, version=args.augmentor_ver,
351
+ scale_range=args.scale_range, num_clips=args.num_clips, num_crops=args.num_crops,cut_out = False, dataset=args.dataset)
352
+ dataset_val = video_data_cls(args.data_dir, val_list, args.duration, args.frames_per_group,
353
+ num_clips=args.num_clips,
354
+ modality=args.modality, image_tmpl=image_tmpl,
355
+ dense_sampling=args.dense_sampling,
356
+ transform=val_augmentor, is_train=False, test_mode=False,
357
+ seperator=filename_seperator, filter_video=filter_video)
358
+
359
+ data_loader_val = build_dataflow(dataset_val, is_train=False, batch_size=args.batch_size,
360
+ workers=args.num_workers, is_distributed=args.distributed)
361
+
362
+
363
+ max_accuracy = 0.0
364
+ output_dir = Path(args.output_dir)
365
+
366
+ if args.initial_checkpoint:
367
+ checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
368
+ utils.load_checkpoint(model, checkpoint['model'])
369
+
370
+ if args.auto_resume:
371
+ if args.resume == '':
372
+ args.resume = str(output_dir / "checkpoint.pth")
373
+ if not os.path.exists(args.resume):
374
+ args.resume = ''
375
+
376
+ if args.resume:
377
+ if args.resume.startswith('https'):
378
+ checkpoint = torch.hub.load_state_dict_from_url(
379
+ args.resume, map_location='cpu', check_hash=True)
380
+ else:
381
+ checkpoint = torch.load(args.resume, map_location='cpu')
382
+ utils.load_checkpoint(model, checkpoint['model'])
383
+ if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
384
+ optimizer.load_state_dict(checkpoint['optimizer'])
385
+ lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
386
+ args.start_epoch = checkpoint['epoch'] + 1
387
+ if 'scaler' in checkpoint and args.resume_loss_scaler:
388
+ print("Resume with previous loss scaler state")
389
+ loss_scaler.load_state_dict(checkpoint['scaler'])
390
+ if args.model_ema:
391
+ utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
392
+ max_accuracy = checkpoint['max_accuracy']
393
+
394
+ if args.eval:
395
+ test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp, num_crops=args.num_crops, num_clips=args.num_clips)
396
+ print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
397
+ return
398
+
399
+ print(f"Start training, currnet max acc is {max_accuracy:.2f}")
400
+ start_time = time.time()
401
+ for epoch in range(args.start_epoch, args.epochs):
402
+
403
+ if args.distributed:
404
+ data_loader_train.sampler.set_epoch(epoch)
405
+
406
+ train_stats = train_one_epoch(
407
+ model, criterion, data_loader_train,args.num_clips,
408
+ optimizer, device, epoch, loss_scaler,
409
+ args.clip_grad, model_ema, mixup_fn, num_tasks, True,
410
+ amp=args.amp,
411
+ contrastive_nomixup=args.contrastive_nomixup,
412
+ hard_contrastive=args.hard_contrastive,
413
+ finetune=args.finetune
414
+ )
415
+
416
+ lr_scheduler.step(epoch)
417
+
418
+ test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp, num_crops=args.num_crops, num_clips=args.num_clips)
419
+ print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
420
+
421
+ max_accuracy = max(max_accuracy, test_stats["acc1"])
422
+ print(f'Max accuracy: {max_accuracy:.2f}%')
423
+
424
+ if args.output_dir:
425
+ checkpoint_paths = [output_dir / 'checkpoint{}.pth'.format(epoch)]
426
+ if test_stats["acc1"] == max_accuracy:
427
+ checkpoint_paths.append(output_dir / 'model_best.pth')
428
+ for checkpoint_path in checkpoint_paths:
429
+ state_dict = {
430
+ 'model': model_without_ddp.state_dict(),
431
+ 'optimizer': optimizer.state_dict(),
432
+ 'lr_scheduler': lr_scheduler.state_dict(),
433
+ 'epoch': epoch,
434
+ 'args': args,
435
+ 'scaler': loss_scaler.state_dict(),
436
+ 'max_accuracy': max_accuracy
437
+ }
438
+ if args.model_ema:
439
+ state_dict['model_ema'] = get_state_dict(model_ema)
440
+ utils.save_on_master(state_dict, checkpoint_path)
441
+
442
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
443
+ **{f'test_{k}': v for k, v in test_stats.items()},
444
+ 'epoch': epoch,
445
+ 'n_parameters': n_parameters}
446
+
447
+ if args.output_dir and utils.is_main_process():
448
+ with (output_dir / "log.txt").open("a") as f:
449
+ f.write(json.dumps(log_stats) + "\n")
450
+
451
+
452
+ total_time = time.time() - start_time
453
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
454
+ print('Training time {}'.format(total_time_str))
455
+
456
+
457
+ if __name__ == '__main__':
458
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
459
+ args = parser.parse_args()
460
+ if args.output_dir:
461
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
462
+ main(args)
make_tall_txt.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # make_tall_txt.py
2
+ # Usage:
3
+ # python make_tall_txt.py --root frames_root --out lists --train_ratio 0.8 --seed 42
4
+ #
5
+ # Expected structure:
6
+ # root/
7
+ # real/<video_id>/*.jpg|png...
8
+ # fake/<video_id>/*.jpg|png...
9
+ #
10
+ # Output:
11
+ # lists/train.txt
12
+ # lists/test.txt
13
+ #
14
+ # Each line:
15
+ # relative_path start_frame end_frame label
16
+
17
+ import os
18
+ import re
19
+ import argparse
20
+ import random
21
+ from pathlib import Path
22
+
23
+ IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
24
+
25
+ # Try to parse frame index from filenames like:
26
+ # frame_000012_f000036.jpg -> uses the "f000036" part
27
+ # 000036.png -> uses the numeric stem
28
+ RE_FPART = re.compile(r"_f(\d+)", re.IGNORECASE)
29
+ RE_NUM = re.compile(r"(\d+)$")
30
+
31
+ def list_frames(video_dir: Path):
32
+ files = [p for p in video_dir.iterdir() if p.is_file() and p.suffix.lower() in IMG_EXTS]
33
+ files.sort()
34
+ return files
35
+
36
+ def parse_frame_idx(p: Path):
37
+ s = p.stem
38
+ m = RE_FPART.search(s)
39
+ if m:
40
+ return int(m.group(1))
41
+ m = RE_NUM.search(s)
42
+ if m:
43
+ return int(m.group(1))
44
+ return None # unknown
45
+
46
+ def video_start_end(video_dir: Path):
47
+ frames = list_frames(video_dir)
48
+ if not frames:
49
+ return None, None, 0
50
+
51
+ idxs = [parse_frame_idx(p) for p in frames]
52
+ idxs = [i for i in idxs if i is not None]
53
+
54
+ # Fallback: if can't parse indices, use 1..N
55
+ if not idxs:
56
+ return 1, len(frames), len(frames)
57
+
58
+ return min(idxs), max(idxs), len(frames)
59
+
60
+ def collect_videos(root: Path, class_name: str):
61
+ class_dir = root / class_name
62
+ if not class_dir.exists():
63
+ return []
64
+
65
+ videos = []
66
+ for vd in class_dir.iterdir():
67
+ if vd.is_dir():
68
+ start, end, n = video_start_end(vd)
69
+ if n > 0:
70
+ videos.append((vd, start, end, n))
71
+ videos.sort(key=lambda x: x[0].name)
72
+ return videos
73
+
74
+ def write_list(items, out_path: Path, root: Path, label_map):
75
+ with open(out_path, "w", encoding="utf-8") as f:
76
+ for vd, start, end, _n, label in items:
77
+ rel = vd.relative_to(root).as_posix()
78
+ f.write(f"{rel} {start} {end} {label}\n")
79
+
80
+ def main():
81
+ ap = argparse.ArgumentParser()
82
+ ap.add_argument("--root", required=True, help="Root with real/ and fake/ video folders")
83
+ ap.add_argument("--out", default="lists", help="Output folder for txt files")
84
+ ap.add_argument("--train_ratio", type=float, default=0.8)
85
+ ap.add_argument("--seed", type=int, default=42)
86
+ ap.add_argument("--label_real", type=int, default=0, help="Label for real")
87
+ ap.add_argument("--label_fake", type=int, default=1, help="Label for fake")
88
+ args = ap.parse_args()
89
+
90
+ if not (0.0 < args.train_ratio < 1.0):
91
+ raise SystemExit("--train_ratio must be between 0 and 1.")
92
+
93
+ root = Path(args.root)
94
+ out = Path(args.out)
95
+ out.mkdir(parents=True, exist_ok=True)
96
+
97
+ # Collect
98
+ real_videos = collect_videos(root, "real")
99
+ fake_videos = collect_videos(root, "fake")
100
+
101
+ if not real_videos:
102
+ raise SystemExit(f"No videos found under: {root/'real'}")
103
+ if not fake_videos:
104
+ raise SystemExit(f"No videos found under: {root/'fake'}")
105
+
106
+ # Build items list (per video)
107
+ label_map = {"real": args.label_real, "fake": args.label_fake}
108
+ items = []
109
+ for vd, start, end, n in real_videos:
110
+ items.append((vd, start, end, n, label_map["real"]))
111
+ for vd, start, end, n in fake_videos:
112
+ items.append((vd, start, end, n, label_map["fake"]))
113
+
114
+ # Split by class (keeps balance more stable)
115
+ rng = random.Random(args.seed)
116
+
117
+ def split_class(videos, label):
118
+ vids = [(vd, s, e, n, label) for vd, s, e, n in videos]
119
+ rng.shuffle(vids)
120
+ k = int(round(len(vids) * args.train_ratio))
121
+ return vids[:k], vids[k:]
122
+
123
+ real_train, real_test = split_class(real_videos, label_map["real"])
124
+ fake_train, fake_test = split_class(fake_videos, label_map["fake"])
125
+
126
+ train_items = real_train + fake_train
127
+ test_items = real_test + fake_test
128
+
129
+ rng.shuffle(train_items)
130
+ rng.shuffle(test_items)
131
+
132
+ # Write
133
+ train_path = out / "cdf_train_fold.txt"
134
+ test_path = out / "cdf_test_fold.txt"
135
+ write_list(train_items, train_path, root, label_map)
136
+ write_list(test_items, test_path, root, label_map)
137
+
138
+ print("DONE")
139
+ print(f"Train videos: {len(train_items)} (real {len(real_train)}, fake {len(fake_train)})")
140
+ print(f"Test videos: {len(test_items)} (real {len(real_test)}, fake {len(fake_test)})")
141
+ print("Saved:")
142
+ print(" ", train_path.resolve())
143
+ print(" ", test_path.resolve())
144
+
145
+ if __name__ == "__main__":
146
+ main()
make_tall_txt_count.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse, random
2
+ from pathlib import Path
3
+
4
+ IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
5
+
6
+ def count_frames(video_dir: Path) -> int:
7
+ return sum(1 for p in video_dir.iterdir() if p.is_file() and p.suffix.lower() in IMG_EXTS)
8
+
9
+ def collect(root: Path, cls: str, label: int):
10
+ base = root / cls
11
+ out = []
12
+ for vd in sorted([p for p in base.iterdir() if p.is_dir()]):
13
+ n = count_frames(vd)
14
+ if n >= 4: # precisa >3 frames
15
+ rel = vd.relative_to(root).as_posix()
16
+ out.append((rel, 1, n, label))
17
+ return out
18
+
19
+ def split(items, train_ratio, seed):
20
+ rng = random.Random(seed)
21
+ rng.shuffle(items)
22
+ k = int(round(len(items)*train_ratio))
23
+ return items[:k], items[k:]
24
+
25
+ def write_list(path: Path, items):
26
+ with open(path, "w", encoding="utf-8") as f:
27
+ for rel, s, e, lab in items:
28
+ f.write(f"{rel} {s} {e} {lab}\n")
29
+
30
+ def main():
31
+ ap = argparse.ArgumentParser()
32
+ ap.add_argument("--root", required=True)
33
+ ap.add_argument("--out", default="lists")
34
+ ap.add_argument("--train_ratio", type=float, default=0.8)
35
+ ap.add_argument("--seed", type=int, default=42)
36
+ ap.add_argument("--label_real", type=int, default=0)
37
+ ap.add_argument("--label_fake", type=int, default=1)
38
+ args = ap.parse_args()
39
+
40
+ root = Path(args.root)
41
+ out = Path(args.out); out.mkdir(parents=True, exist_ok=True)
42
+
43
+ real = collect(root, "real", args.label_real)
44
+ fake = collect(root, "fake", args.label_fake)
45
+
46
+ r_tr, r_te = split(real, args.train_ratio, args.seed)
47
+ f_tr, f_te = split(fake, args.train_ratio, args.seed)
48
+
49
+ train = r_tr + f_tr
50
+ test = r_te + f_te
51
+
52
+ rng = random.Random(args.seed)
53
+ rng.shuffle(train); rng.shuffle(test)
54
+
55
+ write_list(out/"train.txt", train)
56
+ write_list(out/"test.txt", test)
57
+
58
+ write_list(out/"cdf_train_fold.txt", train)
59
+ write_list(out/"cdf_test_fold.txt", test)
60
+
61
+ print("OK")
62
+ print(f"train videos: {len(train)} | test videos: {len(test)}")
63
+
64
+ if __name__ == "__main__":
65
+ main()
models.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ import torch
8
+ import torch.nn as nn
9
+ from functools import partial
10
+
11
+ from timm.models.vision_transformer import VisionTransformer, _cfg
12
+ from timm.models.registry import register_model
13
+
14
+
15
+ @register_model
16
+ def deit_tiny_patch8_224(pretrained=False, **kwargs):
17
+ model = VisionTransformer(
18
+ patch_size=8, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
19
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
20
+ model.default_cfg = _cfg()
21
+ if pretrained:
22
+ checkpoint = torch.hub.load_state_dict_from_url(
23
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
24
+ map_location="cpu", check_hash=True
25
+ )
26
+ model.load_state_dict(checkpoint["model"])
27
+ return model
28
+
29
+
30
+ @register_model
31
+ def deit_tiny_patch16_224(pretrained=False, **kwargs):
32
+ model = VisionTransformer(
33
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
34
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
35
+ model.default_cfg = _cfg()
36
+ if pretrained:
37
+ checkpoint = torch.hub.load_state_dict_from_url(
38
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
39
+ map_location="cpu", check_hash=True
40
+ )
41
+ model.load_state_dict(checkpoint["model"])
42
+ return model
43
+
44
+
45
+ @register_model
46
+ def deit_tiny_patch16_d_6_224(pretrained=False, **kwargs):
47
+ model = VisionTransformer(
48
+ patch_size=16, embed_dim=192, depth=6, num_heads=3, mlp_ratio=4, qkv_bias=True,
49
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
50
+ model.default_cfg = _cfg()
51
+ if pretrained:
52
+ checkpoint = torch.hub.load_state_dict_from_url(
53
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
54
+ map_location="cpu", check_hash=True
55
+ )
56
+ model.load_state_dict(checkpoint["model"])
57
+ return model
58
+
59
+
60
+ @register_model
61
+ def deit_tiny_patch32_224(pretrained=False, **kwargs):
62
+ model = VisionTransformer(
63
+ patch_size=32, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
64
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
65
+ model.default_cfg = _cfg()
66
+ if pretrained:
67
+ checkpoint = torch.hub.load_state_dict_from_url(
68
+ url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
69
+ map_location="cpu", check_hash=True
70
+ )
71
+ model.load_state_dict(checkpoint["model"])
72
+ return model
73
+
74
+
75
+ @register_model
76
+ def deit_small_patch8_224(pretrained=False, **kwargs):
77
+ model = VisionTransformer(
78
+ patch_size=8, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
79
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
80
+ model.default_cfg = _cfg()
81
+ if pretrained:
82
+ checkpoint = torch.hub.load_state_dict_from_url(
83
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
84
+ map_location="cpu", check_hash=True
85
+ )
86
+ model.load_state_dict(checkpoint["model"])
87
+ return model
88
+
89
+
90
+ @register_model
91
+ def deit_small_patch16_224(pretrained=False, **kwargs):
92
+ model = VisionTransformer(
93
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
94
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
95
+ model.default_cfg = _cfg()
96
+ if pretrained:
97
+ checkpoint = torch.hub.load_state_dict_from_url(
98
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
99
+ map_location="cpu", check_hash=True
100
+ )
101
+ model.load_state_dict(checkpoint["model"])
102
+ return model
103
+
104
+
105
+ @register_model
106
+ def deit_small_patch16_d_6_224(pretrained=False, **kwargs):
107
+ model = VisionTransformer(
108
+ patch_size=16, embed_dim=384, depth=6, num_heads=6, mlp_ratio=4, qkv_bias=True,
109
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
110
+ model.default_cfg = _cfg()
111
+ if pretrained:
112
+ checkpoint = torch.hub.load_state_dict_from_url(
113
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
114
+ map_location="cpu", check_hash=True
115
+ )
116
+ model.load_state_dict(checkpoint["model"])
117
+ return model
118
+
119
+
120
+ @register_model
121
+ def deit_small_patch32_224(pretrained=False, **kwargs):
122
+ model = VisionTransformer(
123
+ patch_size=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
124
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
125
+ model.default_cfg = _cfg()
126
+ if pretrained:
127
+ checkpoint = torch.hub.load_state_dict_from_url(
128
+ url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
129
+ map_location="cpu", check_hash=True
130
+ )
131
+ model.load_state_dict(checkpoint["model"])
132
+ return model
133
+
134
+
135
+ @register_model
136
+ def deit_base_patch8_224(pretrained=False, **kwargs):
137
+ model = VisionTransformer(
138
+ patch_size=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
139
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
140
+ model.default_cfg = _cfg()
141
+ if pretrained:
142
+ checkpoint = torch.hub.load_state_dict_from_url(
143
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
144
+ map_location="cpu", check_hash=True
145
+ )
146
+ model.load_state_dict(checkpoint["model"])
147
+ return model
148
+
149
+
150
+ @register_model
151
+ def deit_base_patch16_224(pretrained=False, **kwargs):
152
+ model = VisionTransformer(
153
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
154
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
155
+ model.default_cfg = _cfg()
156
+ if pretrained:
157
+ checkpoint = torch.hub.load_state_dict_from_url(
158
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
159
+ map_location="cpu", check_hash=True
160
+ )
161
+ model.load_state_dict(checkpoint["model"])
162
+ return model
163
+
164
+
165
+ @register_model
166
+ def deit_base_patch16_ft_224(pretrained=False, **kwargs):
167
+ model = VisionTransformer(
168
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
169
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
170
+ model.default_cfg = _cfg()
171
+ if pretrained:
172
+ checkpoint = torch.hub.load_state_dict_from_url(
173
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
174
+ map_location="cpu", check_hash=True
175
+ )
176
+ model.load_state_dict(checkpoint["model"])
177
+
178
+ for m in model.parameters():
179
+ m.requires_grad = False
180
+
181
+ for m in model.head.parameters():
182
+ m.requires_grad = True
183
+
184
+ return model
185
+
186
+
187
+
188
+ @register_model
189
+ def deit_base24_patch16_224(pretrained=False, **kwargs):
190
+ model = VisionTransformer(
191
+ patch_size=16, embed_dim=768, depth=24, num_heads=12, mlp_ratio=4, qkv_bias=True,
192
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
193
+ model.default_cfg = _cfg()
194
+ if pretrained:
195
+ checkpoint = torch.hub.load_state_dict_from_url(
196
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
197
+ map_location="cpu", check_hash=True
198
+ )
199
+ model.load_state_dict(checkpoint["model"])
200
+ return model
201
+
202
+
203
+ @register_model
204
+ def deit_base16_patch16_224(pretrained=False, **kwargs):
205
+ model = VisionTransformer(
206
+ patch_size=16, embed_dim=768, depth=16, num_heads=12, mlp_ratio=4, qkv_bias=True,
207
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
208
+ model.default_cfg = _cfg()
209
+ if pretrained:
210
+ checkpoint = torch.hub.load_state_dict_from_url(
211
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
212
+ map_location="cpu", check_hash=True
213
+ )
214
+ model.load_state_dict(checkpoint["model"])
215
+ return model
216
+
217
+
218
+ @register_model
219
+ def deit_base_patch16_384(pretrained=False, **kwargs):
220
+ model = VisionTransformer(img_size=384,
221
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
222
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
223
+ model.default_cfg = _cfg()
224
+ if pretrained:
225
+ checkpoint = torch.hub.load_state_dict_from_url(
226
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
227
+ map_location="cpu", check_hash=True
228
+ )
229
+ model.load_state_dict(checkpoint["model"])
230
+ return model
231
+
232
+
233
+ @register_model
234
+ def deit_base_patch32_224(pretrained=False, **kwargs):
235
+ model = VisionTransformer(
236
+ patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
237
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
238
+ model.default_cfg = _cfg()
239
+ if pretrained:
240
+ checkpoint = torch.hub.load_state_dict_from_url(
241
+ url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
242
+ map_location="cpu", check_hash=True
243
+ )
244
+ model.load_state_dict(checkpoint["model"])
245
+ return model
renumber_frames_for_tall.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # renumber_frames_for_tall.py
2
+ # Usage:
3
+ # python renumber_frames_for_tall.py --root "C:\...\TALL4Deepfake\data" --ext .jpg --digits 4 --copy
4
+
5
+ import argparse
6
+ from pathlib import Path
7
+ import shutil
8
+
9
+ IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
10
+
11
+ def main():
12
+ ap = argparse.ArgumentParser()
13
+ ap.add_argument("--root", required=True, help="Root that contains real/ and fake/")
14
+ ap.add_argument("--digits", type=int, default=4, help="Digits for numbering (4 -> 0001.jpg)")
15
+ ap.add_argument("--ext", default=".jpg", help="Target extension for output filenames, e.g. .jpg")
16
+ ap.add_argument("--copy", action="store_true", help="Copy instead of rename (safer)")
17
+ args = ap.parse_args()
18
+
19
+ root = Path(args.root)
20
+ for cls in ["real", "fake"]:
21
+ cls_dir = root / cls
22
+ if not cls_dir.exists():
23
+ print(f"[skip] {cls_dir} not found")
24
+ continue
25
+
26
+ for vid_dir in sorted([p for p in cls_dir.iterdir() if p.is_dir()]):
27
+ frames = [p for p in vid_dir.iterdir() if p.is_file() and p.suffix.lower() in IMG_EXTS]
28
+ frames.sort(key=lambda p: p.name)
29
+
30
+ if not frames:
31
+ print(f"[empty] {vid_dir}")
32
+ continue
33
+
34
+ tmp_dir = vid_dir / "_tall_tmp"
35
+ tmp_dir.mkdir(exist_ok=True)
36
+
37
+ for i, src in enumerate(frames, start=1):
38
+ dst_name = f"{i:0{args.digits}d}{args.ext}"
39
+ dst = tmp_dir / dst_name
40
+ if args.copy:
41
+ shutil.copy2(src, dst)
42
+ else:
43
+ shutil.move(src, dst)
44
+
45
+ # move tmp content back
46
+ for f in tmp_dir.iterdir():
47
+ shutil.move(str(f), str(vid_dir / f.name))
48
+ tmp_dir.rmdir()
49
+
50
+ print(f"[ok] {vid_dir.name}: {len(frames)} frames -> renamed to 1..{len(frames)}")
51
+
52
+ if __name__ == "__main__":
53
+ main()
requirements-torch.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ --index-url https://download.pytorch.org/whl/cu117
2
+ torch==1.13.1+cu117
3
+ torchvision==0.14.1+cu117
4
+ opencv-python==4.13.0.92
requirements.txt ADDED
Binary file (2.37 kB). View file
 
test.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ import torch.backends.cudnn as cudnn
5
+ import os
6
+ import warnings
7
+
8
+ from pathlib import Path
9
+
10
+ from timm.models import create_model
11
+ from timm.utils import ModelEma
12
+
13
+ #from datasets import build_dataset
14
+ import my_models
15
+ from engine import evaluate
16
+ #import simclr
17
+ import utils
18
+
19
+ from video_dataset import VideoDataSet
20
+ from video_dataset_aug import get_augmentor, build_dataflow
21
+ from video_dataset_config import get_dataset_config, DATASET_CONFIG
22
+
23
+ warnings.filterwarnings("ignore", category=UserWarning)
24
+ #torch.multiprocessing.set_start_method('spawn', force=True)
25
+
26
+ def get_args_parser():
27
+ parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
28
+ parser.add_argument('--model_name',default="TALL_SWIN")
29
+ parser.add_argument('--batch-size', default=2, type=int)
30
+ parser.add_argument('--epochs', default=30, type=int)
31
+
32
+ # Dataset parameters
33
+ parser.add_argument('--data_txt_dir', type=str,default='##path_for_dataset_txt##', help='path to text of dataset')
34
+ parser.add_argument('--data_dir', type=str,default="##path_for_dataset##", help='path to dataset')
35
+ parser.add_argument('--dataset', default='ffpp',
36
+ choices=list(DATASET_CONFIG.keys()), help='path to dataset file list')
37
+ parser.add_argument('--duration', default=1, type=int, help='number of frames')
38
+ parser.add_argument('--frames_per_group', default=1, type=int,
39
+ help='[uniform sampling] number of frames per group; '
40
+ '[dense sampling]: sampling frequency')
41
+ parser.add_argument('--threed_data', default=False, help='load data in the layout for 3D conv')
42
+ parser.add_argument('--input_size', default=224, type=int, metavar='N', help='input image size')
43
+ parser.add_argument('--disable_scaleup', action='store_true',
44
+ help='do not scale up and then crop a small region, directly crop the input_size')
45
+ parser.add_argument('--random_sampling', action='store_true',
46
+ help='perform determinstic sampling for data loader')
47
+ parser.add_argument('--dense_sampling', default=True,
48
+ help='perform dense sampling for data loader')
49
+ parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2'],
50
+ help='[v1] TSN data argmentation, [v2] resize the shorter side to `scale_range`')
51
+ parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+",
52
+ metavar='scale_range', help='scale range for augmentor v2')
53
+ parser.add_argument('--modality', default='rgb', type=str, help='rgb or flow')
54
+ parser.add_argument('--use_lmdb', default=False, help='use lmdb instead of jpeg.')
55
+ parser.add_argument('--use_pyav', default=False, help='use video directly.')
56
+
57
+ # temporal module
58
+ parser.add_argument('--pretrained', action='store_true', default=False,
59
+ help='Start with pretrained version of specified network (if avail)')
60
+ parser.add_argument('--temporal_module_name', default=None, type=str, metavar='TEM', choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA'],
61
+ help='temporal module applied. [TAM]')
62
+ parser.add_argument('--temporal_attention_only', action='store_true', default=False,
63
+ help='use attention only in temporal module]')
64
+ parser.add_argument('--no_token_mask', action='store_true', default=False, help='do not apply token mask')
65
+ parser.add_argument('--temporal_heads_scale', default=1.0, type=float, help='scale of the number of spatial heads')
66
+ parser.add_argument('--temporal_mlp_scale', default=1.0, type=float, help='scale of spatial mlp')
67
+ parser.add_argument('--rel_pos', action='store_true', default=False,
68
+ help='use relative positioning in temporal module]')
69
+ parser.add_argument('--temporal_pooling', type=str, default=None, choices=['avg', 'max', 'conv', 'depthconv'],
70
+ help='perform temporal pooling]')
71
+ parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw'],
72
+ help='use depth-wise bottleneck in temporal attention')
73
+
74
+ parser.add_argument('--window_size', default=7, type=int, help='number of frames')
75
+ parser.add_argument('--thumbnail_rows', default=3, type=int, help='number of frames per row')
76
+
77
+ parser.add_argument('--hpe_to_token', default=False, action='store_true',
78
+ help='add hub position embedding to image tokens')
79
+ # Model parameters
80
+ parser.add_argument('--model', default='TALL_SWIN', type=str, metavar='MODEL',
81
+ help='Name of model to train')
82
+ # parser.add_argument('--input-size', default=224, type=int, help='images input size')
83
+
84
+ parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
85
+ help='Dropout rate (default: 0.)')
86
+ parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
87
+ help='Drop path rate (default: 0.1)')
88
+ parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
89
+ help='Drop block rate (default: None)')
90
+
91
+ parser.add_argument('--model-ema', action='store_true')
92
+ parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
93
+ parser.set_defaults(model_ema=True)
94
+ parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
95
+ parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
96
+
97
+ # Optimizer parameters
98
+ parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
99
+ help='Optimizer (default: "adamw"')
100
+ parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
101
+ help='Optimizer Epsilon (default: 1e-8)')
102
+ parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
103
+ help='Optimizer Betas (default: None, use opt default)')
104
+ parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
105
+ help='Clip gradient norm (default: None, no clipping)')
106
+ parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
107
+ help='SGD momentum (default: 0.9)')
108
+ parser.add_argument('--weight-decay', type=float, default=1e-5,
109
+ help='weight decay (default: 0.05)')
110
+ # Learning rate schedule parameters
111
+ parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
112
+ help='LR scheduler (default: "cosine"')
113
+ parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
114
+ help='learning rate (default: 5e-4)')
115
+ parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
116
+ help='learning rate noise on/off epoch percentages')
117
+ parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
118
+ help='learning rate noise limit percent (default: 0.67)')
119
+ parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
120
+ help='learning rate noise std-dev (default: 1.0)')
121
+ parser.add_argument('--warmup-lr', type=float, default=1e-7, metavar='LR',
122
+ help='warmup learning rate (default: 1e-6)')
123
+ parser.add_argument('--min-lr', type=float, default=2e-6, metavar='LR',
124
+ help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
125
+
126
+ parser.add_argument('--decay-epochs', type=float, default=10, metavar='N',
127
+ help='epoch interval to decay LR')
128
+ parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
129
+ help='epochs to warmup LR, if scheduler supports')
130
+ parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
131
+ help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
132
+ parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
133
+ help='patience epochs for Plateau LR scheduler (default: 10')
134
+ parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
135
+ help='LR decay rate (default: 0.1)')
136
+
137
+ # Augmentation parameters
138
+ parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
139
+ help='Color jitter factor (default: 0.4)')
140
+ parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
141
+ help='Use AutoAugment policy. "v0" or "original". " + \
142
+ "(default: rand-m9-mstd0.5-inc1)'),
143
+ parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
144
+ parser.add_argument('--train-interpolation', type=str, default='bicubic',
145
+ help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
146
+
147
+ parser.add_argument('--repeated-aug', action='store_true')
148
+ parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
149
+ parser.set_defaults(repeated_aug=False)
150
+
151
+ # * Random Erase params
152
+ parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT',
153
+ help='Random erase prob (default: 0.25)')
154
+ parser.add_argument('--remode', type=str, default='pixel',
155
+ help='Random erase mode (default: "pixel")')
156
+ parser.add_argument('--recount', type=int, default=1,
157
+ help='Random erase count (default: 1)')
158
+ parser.add_argument('--resplit', action='store_true', default=False,
159
+ help='Do not random erase first (clean) augmentation split')
160
+
161
+ # * Mixup params
162
+ parser.add_argument('--mixup', type=float, default=0,
163
+ help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
164
+ parser.add_argument('--cutmix', type=float, default=0,
165
+ help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
166
+ parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
167
+ help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
168
+ parser.add_argument('--mixup-prob', type=float, default=1.0,
169
+ help='Probability of performing mixup or cutmix when either/both is enabled')
170
+ parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
171
+ help='Probability of switching to cutmix when both mixup and cutmix enabled')
172
+ parser.add_argument('--mixup-mode', type=str, default='batch',
173
+ help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
174
+
175
+ # Dataset parameters
176
+
177
+ parser.add_argument('--output_dir', default="./output",
178
+ help='path where to save, empty for no saving')
179
+ parser.add_argument('--device', default='cuda',
180
+ help='device to use for training / testing')
181
+ parser.add_argument('--seed', default=42, type=int)
182
+ parser.add_argument('--resume', default='', help='resume from checkpoint')
183
+ parser.add_argument('--no-resume-loss-scaler', action='store_false', dest='resume_loss_scaler')
184
+ parser.add_argument('--no-amp', action='store_false', dest='amp', help='disable amp')
185
+ parser.add_argument('--use_checkpoint', default=False, help='use checkpoint to save memory')
186
+ parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
187
+ help='start epoch')
188
+ parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
189
+ parser.add_argument('--num_workers', default=8, type=int)
190
+ parser.add_argument('--pin-mem', action='store_true',
191
+ help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
192
+ parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
193
+ help='')
194
+ parser.set_defaults(pin_mem=True)
195
+
196
+ # for testing and validation
197
+ parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10])
198
+ parser.add_argument('--num_clips', default=3, type=int)
199
+
200
+ # distributed training parameters
201
+ parser.add_argument('--world_size', default=1, type=int,
202
+ help='number of distributed processes')
203
+ parser.add_argument("--local_rank", type=int)
204
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
205
+
206
+
207
+ parser.add_argument('--auto-resume', default=True, help='auto resume')
208
+ # exp
209
+ # parser.add_argument('--simclr_w', type=float, default=0., help='weights for simclr loss')
210
+ parser.add_argument('--contrastive_nomixup', action='store_true', help='do not involve mixup in contrastive learning')
211
+ parser.add_argument('--finetune', default=False, help='finetune model')
212
+ parser.add_argument('--initial_checkpoint', type=str, default='', help='path to the pretrained model')
213
+
214
+ parser.add_argument('--hard_contrastive', action='store_true', help='use HEXA')
215
+ # parser.add_argument('--selfdis_w', type=float, default=0., help='enable self distillation')
216
+
217
+ return parser
218
+
219
+
220
+ def main(args):
221
+ utils.init_distributed_mode(args)
222
+ print(args)
223
+ # Patch
224
+ if not hasattr(args, 'hard_contrastive'):
225
+ args.hard_contrastive = False
226
+ if not hasattr(args, 'selfdis_w'):
227
+ args.selfdis_w = 0.0
228
+
229
+ #is_imnet21k = args.data_set == 'IMNET21K'
230
+
231
+ device = torch.device(args.device)
232
+
233
+ # fix the seed for reproducibility
234
+ seed = args.seed + utils.get_rank()
235
+ torch.manual_seed(seed)
236
+ np.random.seed(seed)
237
+ # random.seed(seed)
238
+
239
+ cudnn.benchmark = True
240
+
241
+ num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(
242
+ args.dataset, args.use_lmdb)
243
+
244
+ args.num_classes = num_classes
245
+ if args.modality == 'rgb':
246
+ args.input_channels = 3
247
+ elif args.modality == 'flow':
248
+ args.input_channels = 2 * 5
249
+
250
+
251
+ print(f"Creating model: {args.model}")
252
+
253
+ model = create_model(
254
+ args.model,
255
+ pretrained=args.pretrained,
256
+ duration=args.duration,
257
+ hpe_to_token = args.hpe_to_token,
258
+ rel_pos = args.rel_pos,
259
+ window_size=args.window_size,
260
+ thumbnail_rows = args.thumbnail_rows,
261
+ token_mask=not args.no_token_mask,
262
+ online_learning = False,
263
+ num_classes=args.num_classes,
264
+ drop_rate=args.drop,
265
+ drop_path_rate=args.drop_path,
266
+ drop_block_rate=args.drop_block,
267
+ use_checkpoint=args.use_checkpoint
268
+ )
269
+
270
+ # TODO: finetuning
271
+
272
+ model.to(device)
273
+
274
+ model_ema = None
275
+ if args.model_ema:
276
+ # Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
277
+ model_ema = ModelEma(
278
+ model,
279
+ decay=args.model_ema_decay,
280
+ device='cpu' if args.model_ema_force_cpu else '',
281
+ resume=args.resume)
282
+
283
+ model_without_ddp = model
284
+ if args.distributed:
285
+ #model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
286
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
287
+ model_without_ddp = model.module
288
+ n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
289
+ print('number of params:', n_parameters)
290
+
291
+
292
+ if args.distributed:
293
+ mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean']
294
+ std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std']
295
+ else:
296
+ mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean']
297
+ std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std']
298
+ # dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
299
+ # create data loaders w/ augmentation pipeiine
300
+ video_data_cls = VideoDataSet
301
+
302
+ num_tasks = utils.get_world_size()
303
+
304
+ val_list = os.path.join(args.data_txt_dir, val_list_name)
305
+ val_augmentor = get_augmentor(False, args.input_size, mean, std, args.disable_scaleup,
306
+ threed_data=args.threed_data, version=args.augmentor_ver,
307
+ scale_range=args.scale_range, num_clips=args.num_clips, num_crops=args.num_crops, dataset=args.dataset)
308
+ dataset_val = video_data_cls(args.data_dir, val_list, args.duration, args.frames_per_group,
309
+ num_clips=args.num_clips,
310
+ modality=args.modality,
311
+ dense_sampling=args.dense_sampling,
312
+ image_tmpl=image_tmpl,
313
+ transform=val_augmentor,
314
+ is_train=False, test_mode=False,
315
+ seperator=filename_seperator, filter_video=filter_video)
316
+
317
+ data_loader_val = build_dataflow(dataset_val, is_train=False, batch_size=args.batch_size,
318
+ workers=args.num_workers, is_distributed=args.distributed)
319
+
320
+
321
+ if args.initial_checkpoint:
322
+ checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
323
+ utils.load_checkpoint(model, checkpoint['model'])
324
+
325
+ state = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp, num_crops=args.num_crops, num_clips=args.num_clips)
326
+ print(f"Accuracy of the network on the {len(dataset_val)} test images: {state['acc1']:.1f}%")
327
+
328
+ if __name__ == '__main__':
329
+ parser = argparse.ArgumentParser('DeiT evaluation script', parents=[get_args_parser()])
330
+ args = parser.parse_args()
331
+ if args.output_dir:
332
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
333
+ main(args)
test_new.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import torch
4
+ import torch.backends.cudnn as cudnn
5
+ import os
6
+ import warnings
7
+ import json
8
+ from pathlib import Path
9
+
10
+ from timm.models import create_model
11
+
12
+ import my_models # registra TALL_SWIN
13
+ import utils
14
+
15
+ from video_dataset import VideoDataSet
16
+ from video_dataset_aug import get_augmentor, build_dataflow
17
+ from video_dataset_config import get_dataset_config, DATASET_CONFIG
18
+
19
+ from sklearn.metrics import (
20
+ accuracy_score, balanced_accuracy_score,
21
+ precision_recall_fscore_support,
22
+ confusion_matrix, classification_report,
23
+ roc_auc_score, roc_curve,
24
+ average_precision_score, precision_recall_curve
25
+ )
26
+ import matplotlib.pyplot as plt
27
+
28
+ warnings.filterwarnings("ignore", category=UserWarning)
29
+
30
+
31
+ def get_args_parser():
32
+ parser = argparse.ArgumentParser('DeiT evaluation script', add_help=False)
33
+
34
+ parser.add_argument('--model', default='TALL_SWIN', type=str)
35
+ parser.add_argument('--model_name', default="TALL_SWIN")
36
+ parser.add_argument('--batch-size', default=2, type=int)
37
+
38
+ # Dataset parameters
39
+ parser.add_argument('--data_txt_dir', type=str, default='##path_for_dataset_txt##')
40
+ parser.add_argument('--data_dir', type=str, default="##path_for_dataset##")
41
+ parser.add_argument('--dataset', default='ffpp', choices=list(DATASET_CONFIG.keys()))
42
+ parser.add_argument('--duration', default=1, type=int)
43
+ parser.add_argument('--frames_per_group', default=1, type=int)
44
+ parser.add_argument('--threed_data', default=False)
45
+ parser.add_argument('--input_size', default=224, type=int)
46
+ parser.add_argument('--disable_scaleup', action='store_true')
47
+ parser.add_argument('--random_sampling', action='store_true')
48
+ parser.add_argument('--dense_sampling', default=True)
49
+ parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2'])
50
+ parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+")
51
+ parser.add_argument('--modality', default='rgb', type=str)
52
+ parser.add_argument('--use_lmdb', default=False)
53
+ parser.add_argument('--use_pyav', default=False)
54
+
55
+ # temporal module / model params
56
+ parser.add_argument('--pretrained', action='store_true', default=False)
57
+ parser.add_argument('--temporal_module_name', default=None, type=str,
58
+ choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA'])
59
+ parser.add_argument('--temporal_attention_only', action='store_true', default=False)
60
+ parser.add_argument('--no_token_mask', action='store_true', default=False)
61
+ parser.add_argument('--temporal_heads_scale', default=1.0, type=float)
62
+ parser.add_argument('--temporal_mlp_scale', default=1.0, type=float)
63
+ parser.add_argument('--rel_pos', action='store_true', default=False)
64
+ parser.add_argument('--temporal_pooling', type=str, default=None,
65
+ choices=['avg', 'max', 'conv', 'depthconv'])
66
+ parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw'])
67
+
68
+ parser.add_argument('--window_size', default=7, type=int)
69
+ parser.add_argument('--thumbnail_rows', default=3, type=int)
70
+ parser.add_argument('--hpe_to_token', default=False, action='store_true')
71
+
72
+ parser.add_argument('--drop', type=float, default=0.0)
73
+ parser.add_argument('--drop-path', type=float, default=0.1)
74
+ parser.add_argument('--drop-block', type=float, default=None)
75
+
76
+ # runtime
77
+ parser.add_argument('--output_dir', default="./output")
78
+ parser.add_argument('--device', default='cuda')
79
+ parser.add_argument('--seed', default=42, type=int)
80
+ parser.add_argument('--num_workers', default=8, type=int)
81
+
82
+ parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10])
83
+ parser.add_argument('--num_clips', default=3, type=int)
84
+
85
+ parser.add_argument('--world_size', default=1, type=int)
86
+ parser.add_argument("--local_rank", type=int)
87
+ parser.add_argument('--dist_url', default='env://')
88
+
89
+ # checkpoint
90
+ parser.add_argument('--initial_checkpoint', type=str, default='',
91
+ help='path do .pth/.pth.tar com checkpoint (espera key "model")')
92
+
93
+ parser.add_argument('--threshold', type=float, default=0.5,
94
+ help='threshold para decidir classe 1 (fake) a partir de prob[:,1]')
95
+ parser.add_argument('--metrics_out', default='', type=str,
96
+ help='pasta pra salvar metrics.json e plots (default: output_dir)')
97
+ parser.add_argument('--save_plots', action='store_true',
98
+ help='salvar cm.png / roc.png / pr.png')
99
+
100
+ return parser
101
+
102
+
103
+ @torch.no_grad()
104
+ def eval_with_outputs(data_loader, model, device, threshold: float = 0.5):
105
+ model.eval()
106
+ y_true, y_score, y_pred = [], [], []
107
+
108
+ thr = float(threshold)
109
+
110
+ for samples, targets in data_loader:
111
+ samples = samples.to(device, non_blocking=True)
112
+ targets = targets.to(device, non_blocking=True)
113
+
114
+ logits = model(samples) # [B,2] ou [B*K,2]
115
+
116
+ # se logits veio por-clip, agrega por vídeo
117
+ B = targets.shape[0]
118
+ if logits.shape[0] != B:
119
+ if logits.shape[0] % B != 0:
120
+ raise RuntimeError(
121
+ f"logits batch ({logits.shape[0]}) não é múltiplo do target batch ({B})."
122
+ )
123
+ K = logits.shape[0] // B
124
+ logits = logits.view(B, K, -1).mean(dim=1) # [B,2]
125
+
126
+ probs = torch.softmax(logits, dim=1) # [B,2]
127
+ p1 = probs[:, 1] # score da classe 1 (fake)
128
+
129
+ # >>> AQUI é o THRESHOLD <<<
130
+ hat = (p1 >= thr).long()
131
+
132
+ y_true.append(targets.detach().cpu().numpy())
133
+ y_score.append(p1.detach().cpu().numpy())
134
+ y_pred.append(hat.detach().cpu().numpy())
135
+
136
+ y_true = np.concatenate(y_true).astype(int)
137
+ y_score = np.concatenate(y_score).astype(float)
138
+ y_pred = np.concatenate(y_pred).astype(int)
139
+ return y_true, y_score, y_pred
140
+
141
+
142
+ def plot_confusion(cm, out_path):
143
+ plt.figure(figsize=(6, 5))
144
+ plt.imshow(cm)
145
+ plt.title("Confusion Matrix")
146
+ plt.xlabel("Predicted")
147
+ plt.ylabel("True")
148
+ for (i, j), v in np.ndenumerate(cm):
149
+ plt.text(j, i, str(v), ha="center", va="center")
150
+ plt.tight_layout()
151
+ plt.savefig(out_path, dpi=200)
152
+ plt.close()
153
+
154
+
155
+ def plot_roc(y, scores, out_path):
156
+ fpr, tpr, _ = roc_curve(y, scores)
157
+ auc = roc_auc_score(y, scores)
158
+ plt.figure(figsize=(7, 6))
159
+ plt.plot(fpr, tpr, label=f"AUC={auc:.4f}")
160
+ plt.plot([0, 1], [0, 1], "--", label="Chance")
161
+ plt.xlabel("FPR")
162
+ plt.ylabel("TPR")
163
+ plt.legend(loc="best")
164
+ plt.tight_layout()
165
+ plt.savefig(out_path, dpi=200)
166
+ plt.close()
167
+
168
+
169
+ def plot_pr(y, scores, out_path):
170
+ p, r, _ = precision_recall_curve(y, scores)
171
+ ap = average_precision_score(y, scores)
172
+ plt.figure(figsize=(7, 6))
173
+ plt.plot(r, p, label=f"AP={ap:.4f}")
174
+ plt.xlabel("Recall")
175
+ plt.ylabel("Precision")
176
+ plt.legend(loc="best")
177
+ plt.tight_layout()
178
+ plt.savefig(out_path, dpi=200)
179
+ plt.close()
180
+
181
+
182
+ def main(args):
183
+ utils.init_distributed_mode(args)
184
+ print(args)
185
+
186
+ device = torch.device(args.device)
187
+
188
+ seed = args.seed + utils.get_rank()
189
+ torch.manual_seed(seed)
190
+ np.random.seed(seed)
191
+ cudnn.benchmark = True
192
+
193
+ num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = \
194
+ get_dataset_config(args.dataset, args.use_lmdb)
195
+
196
+ args.num_classes = num_classes
197
+ args.input_channels = 3 if args.modality == 'rgb' else 2 * 5
198
+
199
+ print(f"Creating model: {args.model}")
200
+ model = create_model(
201
+ args.model,
202
+ pretrained=args.pretrained,
203
+ duration=args.duration,
204
+ hpe_to_token=args.hpe_to_token,
205
+ rel_pos=args.rel_pos,
206
+ window_size=args.window_size,
207
+ thumbnail_rows=args.thumbnail_rows,
208
+ token_mask=not args.no_token_mask,
209
+ online_learning=False,
210
+ num_classes=args.num_classes,
211
+ drop_rate=args.drop,
212
+ drop_path_rate=args.drop_path,
213
+ drop_block_rate=args.drop_block,
214
+ use_checkpoint=False
215
+ )
216
+ model.to(device)
217
+
218
+ # mean/std
219
+ if args.distributed:
220
+ mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean']
221
+ std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std']
222
+ else:
223
+ mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean']
224
+ std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std']
225
+
226
+ # dataset (val list)
227
+ video_data_cls = VideoDataSet
228
+ val_list = os.path.join(args.data_txt_dir, val_list_name)
229
+
230
+ val_augmentor = get_augmentor(
231
+ False, args.input_size, mean, std, args.disable_scaleup,
232
+ threed_data=args.threed_data, version=args.augmentor_ver,
233
+ scale_range=args.scale_range, num_clips=args.num_clips,
234
+ num_crops=args.num_crops, dataset=args.dataset
235
+ )
236
+
237
+ dataset_val = video_data_cls(
238
+ args.data_dir, val_list,
239
+ args.duration, args.frames_per_group,
240
+ num_clips=args.num_clips,
241
+ modality=args.modality,
242
+ dense_sampling=args.dense_sampling,
243
+ image_tmpl=image_tmpl,
244
+ transform=val_augmentor,
245
+ is_train=False, test_mode=False,
246
+ seperator=filename_seperator, filter_video=filter_video
247
+ )
248
+
249
+ data_loader_val = build_dataflow(
250
+ dataset_val, is_train=False, batch_size=args.batch_size,
251
+ workers=args.num_workers, is_distributed=args.distributed
252
+ )
253
+
254
+ if not args.initial_checkpoint:
255
+ raise RuntimeError("Passe --initial_checkpoint apontando pro checkpoint do modelo.")
256
+
257
+ checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
258
+ # muitos checkpoints vêm como {"model": state_dict, ...}
259
+ if isinstance(checkpoint, dict) and "model" in checkpoint:
260
+ utils.load_checkpoint(model, checkpoint["model"])
261
+ else:
262
+ # se for state_dict direto
263
+ model.load_state_dict(checkpoint, strict=False)
264
+
265
+ # eval
266
+ y_true, y_score, y_pred = eval_with_outputs(
267
+ data_loader_val, model, device, threshold=args.threshold
268
+ )
269
+
270
+ acc = accuracy_score(y_true, y_pred)
271
+ bacc = balanced_accuracy_score(y_true, y_pred)
272
+ prec, rec, f1, _ = precision_recall_fscore_support(
273
+ y_true, y_pred, average="binary", zero_division=0
274
+ )
275
+ cm = confusion_matrix(y_true, y_pred)
276
+
277
+ roc_auc = roc_auc_score(y_true, y_score)
278
+ pr_auc = average_precision_score(y_true, y_score)
279
+
280
+ print(f"\nN={len(y_true)} | thr={args.threshold:.3f}")
281
+ print(f"acc={acc:.4f} | bacc={bacc:.4f} | prec={prec:.4f} | rec={rec:.4f} | f1={f1:.4f} | roc_auc={roc_auc:.4f} | pr_auc={pr_auc:.4f}")
282
+ print(classification_report(y_true, y_pred, digits=4, zero_division=0))
283
+
284
+ outdir = args.metrics_out.strip() if args.metrics_out else args.output_dir
285
+ os.makedirs(outdir, exist_ok=True)
286
+
287
+ out_json = {
288
+ "threshold": float(args.threshold),
289
+ "acc": float(acc),
290
+ "balanced_acc": float(bacc),
291
+ "precision": float(prec),
292
+ "recall": float(rec),
293
+ "f1": float(f1),
294
+ "roc_auc": float(roc_auc),
295
+ "pr_auc": float(pr_auc),
296
+ "confusion_matrix": cm.tolist(),
297
+ "n": int(len(y_true)),
298
+ }
299
+ with open(os.path.join(outdir, "metrics.json"), "w", encoding="utf-8") as f:
300
+ json.dump(out_json, f, indent=2)
301
+
302
+ np.savez(os.path.join(outdir, "eval_outputs.npz"),
303
+ y_true=y_true, y_score=y_score, y_pred=y_pred)
304
+
305
+ if args.save_plots:
306
+ plot_confusion(cm, os.path.join(outdir, "cm.png"))
307
+ plot_roc(y_true, y_score, os.path.join(outdir, "roc.png"))
308
+ plot_pr(y_true, y_score, os.path.join(outdir, "pr.png"))
309
+ print(f"\n✔ Plots + metrics saved in: {os.path.abspath(outdir)}")
310
+ else:
311
+ print(f"\n✔ Metrics saved in: {os.path.abspath(os.path.join(outdir, 'metrics.json'))}")
312
+
313
+
314
+ if __name__ == '__main__':
315
+ parser = argparse.ArgumentParser('DeiT evaluation script', parents=[get_args_parser()])
316
+ args = parser.parse_args()
317
+ if args.output_dir:
318
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
319
+ main(args)
utils.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2015-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the CC-by-NC license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ """
8
+ Misc functions, including distributed helpers.
9
+
10
+ Mostly copy-paste from torchvision references.
11
+ """
12
+ import io
13
+ import os
14
+ import time
15
+ from collections import defaultdict, deque
16
+ import datetime
17
+ import tempfile
18
+
19
+ import torch
20
+ import torch.distributed as dist
21
+ from fvcore.common.checkpoint import Checkpointer
22
+
23
+
24
+ class SmoothedValue(object):
25
+ """Track a series of values and provide access to smoothed values over a
26
+ window or the global series average.
27
+ """
28
+
29
+ def __init__(self, window_size=20, fmt=None):
30
+ if fmt is None:
31
+ fmt = "{median:.4f} ({global_avg:.4f})"
32
+ self.deque = deque(maxlen=window_size)
33
+ self.total = 0.0
34
+ self.count = 0
35
+ self.fmt = fmt
36
+
37
+ def update(self, value, n=1):
38
+ self.deque.append(value)
39
+ self.count += n
40
+ self.total += value * n
41
+
42
+ def synchronize_between_processes(self):
43
+ """
44
+ Warning: does not synchronize the deque!
45
+ """
46
+ if not is_dist_avail_and_initialized():
47
+ return
48
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
49
+ dist.barrier()
50
+ dist.all_reduce(t)
51
+ t = t.tolist()
52
+ self.count = int(t[0])
53
+ self.total = t[1]
54
+
55
+ @property
56
+ def median(self):
57
+ d = torch.tensor(list(self.deque))
58
+ return d.median().item()
59
+
60
+ @property
61
+ def avg(self):
62
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
63
+ return d.mean().item()
64
+
65
+ @property
66
+ def global_avg(self):
67
+ return self.total / self.count
68
+
69
+ @property
70
+ def max(self):
71
+ return max(self.deque)
72
+
73
+ @property
74
+ def value(self):
75
+ return self.deque[-1]
76
+
77
+ def __str__(self):
78
+ return self.fmt.format(
79
+ median=self.median,
80
+ avg=self.avg,
81
+ global_avg=self.global_avg,
82
+ max=self.max,
83
+ value=self.value)
84
+
85
+
86
+ class MetricLogger(object):
87
+ def __init__(self, delimiter="\t"):
88
+ self.meters = defaultdict(SmoothedValue)
89
+ self.delimiter = delimiter
90
+
91
+ def update(self, **kwargs):
92
+ for k, v in kwargs.items():
93
+ if isinstance(v, torch.Tensor):
94
+ v = v.item()
95
+ assert isinstance(v, (float, int))
96
+ self.meters[k].update(v)
97
+
98
+ def __getattr__(self, attr):
99
+ if attr in self.meters:
100
+ return self.meters[attr]
101
+ if attr in self.__dict__:
102
+ return self.__dict__[attr]
103
+ raise AttributeError("'{}' object has no attribute '{}'".format(
104
+ type(self).__name__, attr))
105
+
106
+ def __str__(self):
107
+ loss_str = []
108
+ for name, meter in self.meters.items():
109
+ loss_str.append(
110
+ "{}: {}".format(name, str(meter))
111
+ )
112
+ return self.delimiter.join(loss_str)
113
+
114
+ def synchronize_between_processes(self):
115
+ for meter in self.meters.values():
116
+ meter.synchronize_between_processes()
117
+
118
+ def add_meter(self, name, meter):
119
+ self.meters[name] = meter
120
+
121
+ def log_every(self, iterable, print_freq, header=None):
122
+ i = 0
123
+ if not header:
124
+ header = ''
125
+ start_time = time.time()
126
+ end = time.time()
127
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
128
+ data_time = SmoothedValue(fmt='{avg:.4f}')
129
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
130
+ log_msg = [
131
+ header,
132
+ '[{0' + space_fmt + '}/{1}]',
133
+ 'eta: {eta}',
134
+ '{meters}',
135
+ 'time: {time}',
136
+ 'data: {data}'
137
+ ]
138
+ if torch.cuda.is_available():
139
+ log_msg.append('max mem: {memory:.0f}')
140
+ log_msg = self.delimiter.join(log_msg)
141
+ MB = 1024.0 * 1024.0
142
+ for obj in iterable:
143
+ data_time.update(time.time() - end)
144
+ yield obj
145
+ iter_time.update(time.time() - end)
146
+ if i % print_freq == 0 or i == len(iterable) - 1:
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ print(log_msg.format(
151
+ i, len(iterable), eta=eta_string,
152
+ meters=str(self),
153
+ time=str(iter_time), data=str(data_time),
154
+ memory=torch.cuda.max_memory_allocated() / MB))
155
+ else:
156
+ print(log_msg.format(
157
+ i, len(iterable), eta=eta_string,
158
+ meters=str(self),
159
+ time=str(iter_time), data=str(data_time)))
160
+ i += 1
161
+ end = time.time()
162
+ total_time = time.time() - start_time
163
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
164
+ print('{} Total time: {} ({:.4f} s / it)'.format(
165
+ header, total_time_str, total_time / len(iterable)))
166
+
167
+
168
+ def _load_checkpoint_for_ema(model_ema, checkpoint):
169
+ """
170
+ Workaround for ModelEma._load_checkpoint to accept an already-loaded object
171
+ """
172
+ mem_file = io.BytesIO()
173
+ torch.save(checkpoint, mem_file)
174
+ mem_file.seek(0)
175
+ model_ema._load_checkpoint(mem_file)
176
+
177
+ """
178
+ def load_checkpoint(model, state_dict, mode=None):
179
+
180
+ # reuse Checkpointer in fvcore to support flexible loading
181
+ ckpt = Checkpointer(model, save_to_disk=False)
182
+ # since Checkpointer requires the weight to be put under `model` field, we need to save it to disk
183
+ tmp_path = tempfile.NamedTemporaryFile('w+b')
184
+ torch.save({'model': state_dict}, tmp_path.name)
185
+ ckpt.load(tmp_path.name)
186
+ """
187
+ def load_checkpoint(model, state_dict):
188
+ # Load checkpoint directly (avoid writing temp files on Windows)
189
+ if isinstance(state_dict, dict) and 'state_dict' in state_dict:
190
+ state_dict = state_dict['state_dict']
191
+
192
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
193
+
194
+ if len(missing) > 0:
195
+ print(f"[load_checkpoint] Missing keys: {len(missing)}")
196
+ if len(unexpected) > 0:
197
+ print(f"[load_checkpoint] Unexpected keys: {len(unexpected)}")
198
+
199
+ def setup_for_distributed(is_master):
200
+ """
201
+ This function disables printing when not in master process
202
+ """
203
+ import builtins as __builtin__
204
+ builtin_print = __builtin__.print
205
+
206
+ def print(*args, **kwargs):
207
+ force = kwargs.pop('force', False)
208
+ if is_master or force:
209
+ builtin_print(*args, **kwargs)
210
+
211
+ __builtin__.print = print
212
+
213
+
214
+ def is_dist_avail_and_initialized():
215
+ if not dist.is_available():
216
+ return False
217
+ if not dist.is_initialized():
218
+ return False
219
+ return True
220
+
221
+
222
+ def get_world_size():
223
+ if not is_dist_avail_and_initialized():
224
+ return 1
225
+ return dist.get_world_size()
226
+
227
+
228
+ def get_rank():
229
+ if not is_dist_avail_and_initialized():
230
+ return 0
231
+ return dist.get_rank()
232
+
233
+
234
+ def is_main_process():
235
+ return get_rank() == 0
236
+
237
+
238
+ def save_on_master(*args, **kwargs):
239
+ if is_main_process():
240
+ torch.save(*args, **kwargs)
241
+
242
+
243
+ def init_distributed_mode(args):
244
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
245
+ args.rank = int(os.environ["RANK"])
246
+ args.world_size = int(os.environ['WORLD_SIZE'])
247
+ args.gpu = int(os.environ['LOCAL_RANK'])
248
+ elif 'SLURM_PROCID' in os.environ:
249
+ args.rank = int(os.environ['SLURM_PROCID'])
250
+ args.gpu = args.rank % torch.cuda.device_count()
251
+ else:
252
+ print('Not using distributed mode')
253
+ args.distributed = False
254
+ return
255
+
256
+ args.distributed = True
257
+
258
+ torch.cuda.set_device(args.gpu)
259
+ args.dist_backend = 'nccl'
260
+ print('| distributed init (rank {}): {}'.format(
261
+ args.rank, args.dist_url), flush=True)
262
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
263
+ world_size=args.world_size, rank=args.rank)
264
+ torch.distributed.barrier()
265
+ setup_for_distributed(args.rank == 0)
video_dataset.py ADDED
@@ -0,0 +1,1228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import six
3
+ from typing import Union
4
+ import random
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+ import torch.utils.data as data
9
+
10
+ try:
11
+ import lmdb
12
+ import pyarrow as pa
13
+ _HAS_LMDB = True
14
+ except ImportError as e:
15
+ _HAS_LMDB = False
16
+ _LMDB_ERROR_MSG = e
17
+
18
+ try:
19
+ import av
20
+ _HAS_PYAV = True
21
+ except ImportError as e:
22
+ _HAS_PYAV = False
23
+ _PYAV_ERROR_MSG = e
24
+
25
+
26
+ def random_clip(video_frames, sampling_rate, frames_per_clip, fixed_offset=False, start_frame_idx=0, end_frame_idx=None):
27
+ """
28
+
29
+ Args:
30
+ video_frames (int): total frame number of a video
31
+ sampling_rate (int): sampling rate for clip, pick one every k frames
32
+ frames_per_clip (int): number of frames of a clip
33
+ fixed_offset (bool): used with sample offset to decide the offset value deterministically.
34
+
35
+ Returns:
36
+ list[int]: frame indices (started from zero)
37
+ """
38
+ new_sampling_rate = sampling_rate
39
+ highest_idx = video_frames - new_sampling_rate * frames_per_clip if end_frame_idx is None else end_frame_idx
40
+ if highest_idx <= 0:
41
+ random_offset = 0
42
+ else:
43
+ if fixed_offset:
44
+ random_offset = (video_frames - new_sampling_rate * frames_per_clip) // 2
45
+ else:
46
+ random_offset = int(np.random.randint(start_frame_idx, highest_idx, 1))
47
+ # print(start_frame_idx, highest_idx, random_offset)
48
+ frame_idx = [int(random_offset + i * sampling_rate) % video_frames for i in range(frames_per_clip)]
49
+ return frame_idx
50
+
51
+
52
+ def compute_img_diff(image_1, image_2, bound=255.0):
53
+ image_diff = np.asarray(image_1, dtype=np.float) - np.asarray(image_2, dtype=np.float)
54
+ image_diff += bound
55
+ image_diff *= (255.0 / float(2 * bound))
56
+ image_diff = image_diff.astype(np.uint8)
57
+ image_diff = Image.fromarray(image_diff)
58
+ return image_diff
59
+
60
+
61
+ def load_image(root_path, directory, image_tmpl, idx, modality):
62
+ """
63
+
64
+ :param root_path:
65
+ :param directory:
66
+ :param image_tmpl:
67
+ :param idx: if it is a list, load a batch of images
68
+ :param modality:
69
+ :return:
70
+ """
71
+
72
+ def _safe_load_image(img_path):
73
+ img = None
74
+ num_try = 0
75
+ while num_try < 10:
76
+ try:
77
+ img_tmp = Image.open(img_path)
78
+ img = img_tmp.copy()
79
+ img_tmp.close()
80
+ break
81
+ except Exception as e:
82
+ print('[Will try load again] error loading image: {}, '
83
+ 'error: {}'.format(img_path, str(e)))
84
+ num_try += 1
85
+ if img is None:
86
+ raise ValueError('[Fail 10 times] error loading image: {}'.format(img_path))
87
+ return img
88
+
89
+ if not isinstance(idx, list):
90
+ idx = [idx]
91
+ out = []
92
+ if modality == 'rgb':
93
+ for i in idx:
94
+ image_path_file = os.path.join(root_path, directory, image_tmpl.format(i))
95
+ out.append(_safe_load_image(image_path_file))
96
+ elif modality == 'rgbdiff':
97
+ tmp = {}
98
+ new_idx = np.unique(np.concatenate((np.asarray(idx), np.asarray(idx) + 1)))
99
+ for i in new_idx:
100
+ image_path_file = os.path.join(root_path, directory, image_tmpl.format(i))
101
+ tmp[i] = _safe_load_image(image_path_file)
102
+ for k in idx:
103
+ img_ = compute_img_diff(tmp[k + 1], tmp[k])
104
+ out.append(img_)
105
+ del tmp
106
+ elif modality == 'flow':
107
+ for i in idx:
108
+ flow_x_name = os.path.join(root_path, directory, "x_" + image_tmpl.format(i))
109
+ flow_y_name = os.path.join(root_path, directory, "y_" + image_tmpl.format(i))
110
+ out.extend([_safe_load_image(flow_x_name), _safe_load_image(flow_y_name)])
111
+
112
+ return out
113
+
114
+
115
+ def load_sound(data_dir, record, idx, fps, audio_length, resampling_rate,
116
+ window_size=10, step_size=5, eps=1e-6):
117
+ import librosa
118
+ """idx must be the center frame of a clip"""
119
+ centre_sec = (record.start_frame + idx) / fps
120
+ left_sec = centre_sec - (audio_length / 2.0)
121
+ right_sec = centre_sec + (audio_length / 2.0)
122
+ audio_fname = os.path.join(data_dir, record.path)
123
+ # TODO: generate 0s if the audio file does not exist.
124
+ if not os.path.exists(audio_fname):
125
+ return [Image.fromarray(np.zeros((256, 256 * int(audio_length / 1.28))))]
126
+ samples, sr = librosa.core.load(audio_fname, sr=None, mono=True)
127
+ duration = samples.shape[0] / float(resampling_rate)
128
+
129
+ left_sample = int(round(left_sec * resampling_rate))
130
+ right_sample = int(round(right_sec * resampling_rate))
131
+
132
+ required_samples = int(round(resampling_rate * audio_length))
133
+
134
+ if left_sec < 0:
135
+ samples = samples[:required_samples]
136
+ elif right_sec > duration:
137
+ samples = samples[-required_samples:]
138
+ else:
139
+ samples = samples[left_sample:right_sample]
140
+
141
+ # TODO: is the size of spec is fixed if number of samples are different?
142
+ # if the samples is not long enough, repeat the waveform
143
+ if len(samples) < required_samples:
144
+ multiplies = required_samples / len(samples)
145
+ samples = np.tile(samples, int(multiplies + 0.5) + 1)
146
+ samples = samples[:required_samples]
147
+
148
+ # log sepcgram
149
+ nperseg = int(round(window_size * resampling_rate / 1e3))
150
+ noverlap = int(round(step_size * resampling_rate / 1e3))
151
+ spec = librosa.stft(samples, n_fft=511, window='hann', hop_length=noverlap,
152
+ win_length=nperseg, pad_mode='constant')
153
+ spec = np.log(np.real(spec * np.conj(spec)) + eps)
154
+ img = Image.fromarray(spec)
155
+ return [img]
156
+
157
+
158
+ def load_data_lmdb(videos, idx, modality):
159
+ def _convert_buffer_to_PIL(tmp_buf, is_flow=False):
160
+ data = six.BytesIO()
161
+ data.write(tmp_buf)
162
+ data.seek(0)
163
+ img_tmp = Image.open(data).convert('RGB' if not is_flow else 'L')
164
+ img_ = img_tmp.copy()
165
+ img_tmp.close()
166
+ return img_
167
+
168
+ img = []
169
+ if modality == 'rgb':
170
+ buf = [videos[i] for i in idx]
171
+ for x in buf:
172
+ img_ = _convert_buffer_to_PIL(x)
173
+ img.append(img_)
174
+ elif modality == 'flow':
175
+ new_idx = np.asarray(idx) * 2 - 1
176
+ buf = [[videos[i], videos[i + 1]] for i in new_idx]
177
+ for x in buf:
178
+ flow_x = _convert_buffer_to_PIL(x[0], True)
179
+ flow_y = _convert_buffer_to_PIL(x[1], True)
180
+ img.extend([flow_x, flow_y])
181
+ elif modality == 'rgbdiff':
182
+ tmp = {}
183
+ new_idx = np.unique(np.concatenate((np.asarray(idx), np.asarray(idx) + 1)))
184
+ for i in new_idx:
185
+ tmp[i] = _convert_buffer_to_PIL(videos[i])
186
+ for k in idx:
187
+ img_ = compute_img_diff(tmp[k + 1], tmp[k])
188
+ img.append(img_)
189
+ del tmp
190
+ return img
191
+
192
+
193
+ def sample_train_clip(video_length, num_consecutive_frames, num_frames, sample_freq, dense_sampling, num_clips=1):
194
+ max_frame_idx = max(1, video_length - num_consecutive_frames + 1)
195
+ if dense_sampling:
196
+ frame_idx = np.zeros((num_clips, num_frames), dtype=int)
197
+ if num_clips == 1: # backward compatibility
198
+ frame_idx[0] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False))
199
+ else:
200
+ max_start_frame_idx = max_frame_idx - sample_freq * num_frames
201
+ frames_per_segment = max_start_frame_idx // num_clips
202
+ for i in range(num_clips):
203
+ if frames_per_segment <= 0:
204
+ frame_idx[i] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False))
205
+ #frame_idx[i] = [frame_idx[i][2],frame_idx[i][0],frame_idx[i][1],frame_idx[i][3]]
206
+ #frame_idx[i] = [frame_idx[i][3],frame_idx[i][2],frame_idx[i][1],frame_idx[i][0]]
207
+ else:
208
+ frame_idx[i] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False, i * frames_per_segment, (i + 1) * frames_per_segment))
209
+ #1423
210
+ #frame_idx[i] = [frame_idx[i][2],frame_idx[i][0],frame_idx[i][1],frame_idx[i][3]]
211
+ #frame_idx[i] = [frame_idx[i][3],frame_idx[i][2],frame_idx[i][1],frame_idx[i][0]]
212
+ frame_idx = frame_idx.flatten()
213
+ """
214
+ def _check_interval_overlapped(int_1, int_2):
215
+ if int_1[0] < int_2[0]:
216
+ int_l, int_r = int_1, int_2
217
+ else:
218
+ int_l, int_r = int_2, int_1
219
+
220
+ return True if int_l[-1] > int_r[0] else False
221
+
222
+ clips = 0
223
+ num_tries = 0
224
+ #all_frame_idx = np.arange(max_frame_idx - sample_freq * num_frames)
225
+ while clips < num_clips and num_tries < 1000:
226
+ curr_clips = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames))
227
+ overlap = False
228
+ for i in range(clips):
229
+ overlap = _check_interval_overlapped((frame_idx[i][0], frame_idx[i][-1]), (curr_clips[0], curr_clips[-1]) )
230
+ if overlap:
231
+ break
232
+ if overlap:
233
+ num_tries += 1
234
+ continue
235
+ else:
236
+ frame_idx[clips] = curr_clips
237
+ clips += 1
238
+ for i in range(clips, num_clips):
239
+ frame_idx[i] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames))
240
+
241
+ # sort the intervals
242
+ frame_idx = frame_idx[np.argsort(frame_idx[:, 0]), ...]
243
+ frame_idx = frame_idx.flatten()
244
+ """
245
+
246
+ else: # uniform sampling
247
+ # import pdb;pdb.set_trace()
248
+ total_frames = num_frames * sample_freq
249
+ ave_frames_per_group = max_frame_idx // num_frames
250
+ if ave_frames_per_group >= sample_freq:
251
+ # randomly sample f images per segement
252
+ frame_idx = np.arange(0, num_frames) * ave_frames_per_group
253
+ frame_idx = np.repeat(frame_idx, repeats=sample_freq)
254
+ offsets = np.random.choice(ave_frames_per_group, sample_freq, replace=False)
255
+ offsets = np.tile(offsets, num_frames)
256
+ frame_idx = frame_idx + offsets
257
+ elif max_frame_idx < total_frames:
258
+ # need to sample the same images
259
+ frame_idx = np.random.choice(max_frame_idx, total_frames)
260
+ else:
261
+ # sample cross all images
262
+ frame_idx = np.random.choice(max_frame_idx, total_frames, replace=False)
263
+ frame_idx = np.sort(frame_idx)
264
+ # print(frame_idx)
265
+ frame_idx = frame_idx + 1
266
+ # random.shuffle(frame_idx)
267
+ return frame_idx
268
+
269
+
270
+ def sample_val_test_clip(video_length, num_consecutive_frames, num_frames, sample_freq, dense_sampling,
271
+ fixed_offset, num_clips, whole_video):
272
+ max_frame_idx = max(1, video_length - num_consecutive_frames + 1)
273
+ # import pdb;pdb.set_trace()
274
+ if whole_video:
275
+ return np.arange(1, max_frame_idx, step=sample_freq, dtype=int)
276
+ if dense_sampling:
277
+ if fixed_offset:
278
+ sample_pos = max(1, 1 + max_frame_idx - sample_freq * num_frames)
279
+ t_stride = sample_freq
280
+ start_list = np.linspace(0, sample_pos - 1, num=num_clips, dtype=int)
281
+ frame_idx = []
282
+ for start_idx in start_list.tolist():
283
+ frame_idx += [(idx * t_stride + start_idx) % max_frame_idx for idx in
284
+ range(num_frames)]
285
+ else:
286
+ frame_idx = []
287
+ for i in range(num_clips):
288
+ frame_idx.extend(random_clip(max_frame_idx, sample_freq, num_frames))
289
+ frame_idx = np.asarray(frame_idx) + 1
290
+ else: # uniform sampling
291
+ if fixed_offset:
292
+ frame_idices = []
293
+ sample_offsets = list(range(-num_clips // 2 + 1, num_clips // 2 + 1))
294
+ for sample_offset in sample_offsets:
295
+ if max_frame_idx > num_frames:
296
+ tick = max_frame_idx / float(num_frames)
297
+ curr_sample_offset = sample_offset
298
+ if curr_sample_offset >= tick / 2.0:
299
+ curr_sample_offset = tick / 2.0 - 1e-4
300
+ elif curr_sample_offset < -tick / 2.0:
301
+ curr_sample_offset = -tick / 2.0
302
+ frame_idx = np.array([int(tick / 2.0 + curr_sample_offset + tick * x) for x in
303
+ range(num_frames)])
304
+ else:
305
+ np.random.seed(sample_offset - (-num_clips // 2 + 1))
306
+ frame_idx = np.random.choice(max_frame_idx, num_frames)
307
+ frame_idx = np.sort(frame_idx)
308
+ frame_idices.extend(frame_idx.tolist())
309
+ else:
310
+ frame_idices = []
311
+ for i in range(num_clips):
312
+ total_frames = num_frames * sample_freq
313
+ ave_frames_per_group = max_frame_idx // num_frames
314
+ if ave_frames_per_group >= sample_freq:
315
+ # randomly sample f images per segment
316
+ frame_idx = np.arange(0, num_frames) * ave_frames_per_group
317
+ frame_idx = np.repeat(frame_idx, repeats=sample_freq)
318
+ offsets = np.random.choice(ave_frames_per_group, sample_freq,
319
+ replace=False)
320
+ offsets = np.tile(offsets, num_frames)
321
+ frame_idx = frame_idx + offsets
322
+ elif max_frame_idx < total_frames:
323
+ # need to sample the same images
324
+ np.random.seed(i)
325
+ frame_idx = np.random.choice(max_frame_idx, total_frames)
326
+ else:
327
+ # sample cross all images
328
+ np.random.seed(i)
329
+ frame_idx = np.random.choice(max_frame_idx, total_frames, replace=False)
330
+ frame_idx = np.sort(frame_idx)
331
+ frame_idices.extend(frame_idx.tolist())
332
+ frame_idx = np.asarray(frame_idices) + 1
333
+ return frame_idx
334
+
335
+
336
+ class VideoRecord(object):
337
+ def __init__(self, path, start_frame, end_frame, label, reverse=False):
338
+ self.path = path
339
+ self.video_id = os.path.basename(path)
340
+ self.start_frame = start_frame
341
+ self.end_frame = end_frame
342
+ self.label = label
343
+ self.reverse = reverse
344
+
345
+ @property
346
+ def num_frames(self):
347
+ return self.end_frame - self.start_frame + 1
348
+
349
+ def __str__(self):
350
+ return self.path
351
+
352
+
353
+ class VideoDataSet(data.Dataset):
354
+
355
+ def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1,
356
+ modality='rgb', dense_sampling=True, fixed_offset=True,
357
+ image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
358
+ filter_video=0, num_classes=None, whole_video=False,
359
+ fps=29.97, audio_length=1.28, resampling_rate=24000):
360
+ """
361
+
362
+ Arguments have different meaning when dense_sampling is True:
363
+ - num_groups ==> number of frames
364
+ - frames_per_group ==> sample every K frame
365
+ - sample_offset ==> number of clips used in validation or test mode
366
+
367
+ Args:
368
+ root_path (str): the file path to the root of video folder
369
+ list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
370
+ num_groups (int): number of frames per data sample
371
+ frames_per_group (int): number of frames within one group
372
+ sample_offset (int): used in validation/test, the offset when sampling frames from a group
373
+ modality (str): rgb or flow
374
+ dense_sampling (bool): dense sampling in I3D
375
+ fixed_offset (bool): used for generating the same videos used in TSM
376
+ image_tmpl (str): template of image ids
377
+ transform: the transformer for preprocessing
378
+ is_train (bool): shuffle the video but keep the causality
379
+ test_mode (bool): testing mode, no label
380
+ whole_video (bool): take whole video
381
+ fps (float): frame rate per second, used to localize sound when frame idx is selected.
382
+ audio_length (float): the time window to extract audio feature.
383
+ resampling_rate (int): used to resampling audio extracted from wav
384
+ """
385
+ if modality not in ['flow', 'rgb', 'rgbdiff', 'sound']:
386
+ raise ValueError("modality should be 'flow' or 'rgb' or 'rgbdiff' or 'sound'.")
387
+
388
+ self.root_path = root_path
389
+ self.list_file = os.path.join(root_path, list_file)
390
+ self.num_groups = num_groups
391
+ self.num_frames = num_groups
392
+ self.frames_per_group = frames_per_group
393
+ self.sample_freq = frames_per_group
394
+ self.num_clips = num_clips
395
+ self.sample_offset = sample_offset
396
+ self.fixed_offset = fixed_offset
397
+ self.dense_sampling = dense_sampling
398
+ self.modality = modality.lower()
399
+ self.image_tmpl = image_tmpl
400
+ self.transform = transform
401
+ self.is_train = is_train
402
+ self.test_mode = test_mode
403
+ self.separator = seperator
404
+ self.filter_video = filter_video
405
+ self.whole_video = whole_video
406
+ self.fps = fps
407
+ self.audio_length = audio_length
408
+ self.resampling_rate = resampling_rate
409
+ self.video_length = (self.num_frames * self.sample_freq) / self.fps
410
+
411
+ if self.modality in ['flow', 'rgbdiff']:
412
+ self.num_consecutive_frames = 5
413
+ else:
414
+ self.num_consecutive_frames = 1
415
+
416
+ self.video_list, self.multi_label = self._parse_list()
417
+ self.num_classes = num_classes
418
+
419
+ def _parse_list(self):
420
+ # usually it is [video_id, num_frames, class_idx]
421
+ # or [video_id, start_frame, end_frame, list of class_idx]
422
+ tmp = []
423
+ original_video_numbers = 0
424
+ for x in open(self.list_file):
425
+ elements = x.strip().split(self.separator)
426
+ start_frame = int(elements[1])
427
+ end_frame = int(elements[2])
428
+ total_frame = end_frame - start_frame + 1
429
+ original_video_numbers += 1
430
+ if self.test_mode:
431
+ tmp.append(elements)
432
+ else:
433
+ if total_frame >= self.filter_video:
434
+ tmp.append(elements)
435
+
436
+ num = len(tmp)
437
+ print("The number of videos is {} (with more than {} frames) "
438
+ "(original: {})".format(num, self.filter_video, original_video_numbers), flush=True)
439
+ assert (num > 0)
440
+ # TODO: a better way to check if multi-label or not
441
+ multi_label = np.mean(np.asarray([len(x) for x in tmp])) > 4.0
442
+ file_list = []
443
+ for item in tmp:
444
+ if self.test_mode:
445
+ file_list.append([item[0], int(item[1]), int(item[2]), -1])
446
+ else:
447
+ labels = []
448
+ for i in range(3, len(item)):
449
+ labels.append(float(item[i]))
450
+ if not multi_label:
451
+ labels = labels[0] if len(labels) == 1 else labels
452
+ file_list.append([item[0], int(item[1]), int(item[2]), labels])
453
+
454
+ video_list = [VideoRecord(item[0], item[1], item[2], item[3]) for item in file_list]
455
+ # flow model has one frame less
456
+ if self.modality in ['rgbdiff']:
457
+ for i in range(len(video_list)):
458
+ video_list[i].end_frame -= 1
459
+
460
+ #if self.is_train:
461
+ # video_list = video_list[:50000]
462
+
463
+ return video_list, multi_label
464
+
465
+ def remove_data(self, idx):
466
+ original_video_num = len(self.video_list)
467
+ self.video_list = [v for i, v in enumerate(self.video_list) if i not in idx]
468
+ print("Original videos: {}\t remove {} videos, remaining {} videos".format(original_video_num, len(idx), len(self.video_list)))
469
+
470
+ def _sample_indices(self, record):
471
+ return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
472
+ self.sample_freq, self.dense_sampling, self.num_clips)
473
+
474
+ def _get_val_indices(self, record):
475
+ return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
476
+ self.sample_freq, self.dense_sampling, self.fixed_offset,
477
+ self.num_clips, self.whole_video)
478
+
479
+ def __getitem__(self, index):
480
+ """
481
+ Returns:
482
+ torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group.
483
+ torch.FloatTensor: the label
484
+ """
485
+ record = self.video_list[index]
486
+ # check this is a legit video folder
487
+ indices = self._sample_indices(record) if self.is_train else self._get_val_indices(record)
488
+ images = self.get_data(record, indices)
489
+ images = self.transform(images)
490
+ label = self.get_label(record)
491
+
492
+ # re-order data to targeted format.
493
+ return images, label
494
+
495
+ def get_data(self, record, indices):
496
+ images = []
497
+ if self.whole_video:
498
+ tmp = len(indices) % self.num_frames
499
+ if tmp != 0:
500
+ indices = indices[:-tmp]
501
+ num_clips = len(indices) // self.num_frames
502
+ # print(tmp, indices, self.num_frames, num_clips)
503
+ else:
504
+ num_clips = self.num_clips
505
+ if self.modality == 'sound':
506
+ new_indices = [indices[i * self.num_frames: (i + 1) * self.num_frames]
507
+ for i in range(num_clips)]
508
+ for curr_indiecs in new_indices:
509
+ center_idx = (curr_indiecs[self.num_frames // 2 - 1] + curr_indiecs[self.num_frames // 2]) // 2 \
510
+ if self.num_frames % 2 == 0 else curr_indiecs[self.num_frames // 2]
511
+ center_idx = min(record.num_frames, center_idx)
512
+ seg_imgs = load_sound(self.root_path, record, center_idx,
513
+ self.fps, self.audio_length, self.resampling_rate)
514
+ images.extend(seg_imgs)
515
+ else:
516
+ images = []
517
+ for seg_ind in indices:
518
+ new_seg_ind = [min(seg_ind + record.start_frame - 1 + i, record.num_frames)
519
+ for i in range(self.num_consecutive_frames)]
520
+ seg_imgs = load_image(self.root_path, record.path, self.image_tmpl,
521
+ new_seg_ind, self.modality)
522
+ images.extend(seg_imgs)
523
+ return images
524
+
525
+ def get_label(self, record):
526
+ if self.test_mode:
527
+ # in test mode, return the video id as label
528
+ label = record.video_id
529
+ else:
530
+ if not self.multi_label:
531
+ label = int(record.label)
532
+ else:
533
+ # create a binary vector.
534
+ label = torch.zeros(self.num_classes, dtype=torch.float)
535
+ for x in record.label:
536
+ label[int(x)] = 1.0
537
+ return label
538
+
539
+ def __len__(self):
540
+ return len(self.video_list)
541
+
542
+
543
+ class VideoDataSetLMDB(data.Dataset):
544
+ # do not support sound
545
+ def __init__(self, datadir, db_name, num_groups=16, frames_per_group=1, sample_offset=0, num_clips=1,
546
+ modality='rgb', dense_sampling=False, fixed_offset=True,
547
+ image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False,
548
+ seperator=' ', filter_video=0, num_classes=None, whole_video=False,
549
+ fps=29.97, audio_length=1.28, resampling_rate=24000):
550
+ """
551
+
552
+ Arguments have different meaning when dense_sampling is True:
553
+ - num_groups ==> number of frames
554
+ - frames_per_group ==> sample every K frame
555
+ - sample_offset ==> number of clips used in validation or test mode
556
+
557
+ Args:
558
+ db_path (str): the file path to the root of video folder
559
+ num_groups (int): number of frames per data sample
560
+ frames_per_group (int): number of frames within one group
561
+ sample_offset (int): used in validation/test, the offset when sampling frames from a group
562
+ modality (str): rgb or flow
563
+ dense_sampling (bool): dense sampling in I3D
564
+ fixed_offset (bool): used for generating the same videos used in TSM
565
+ image_tmpl (str): template of image ids
566
+ transform: the transformer for preprocessing
567
+ is_train (bool): shuffle the video but keep the causality
568
+ test_mode (bool): testing mode, no label
569
+ """
570
+ # TODO: handle multi-label?
571
+ # TODO: flow data?
572
+
573
+ if not _HAS_LMDB:
574
+ raise ValueError(_LMDB_ERROR_MSG)
575
+
576
+ if modality not in ['flow', 'rgb', 'rgbdiff']:
577
+ raise ValueError("modality should be 'flow' or 'rgb'.")
578
+
579
+ self.db_path = os.path.join(datadir, db_name)
580
+
581
+ self.num_groups = num_groups
582
+ self.num_frames = num_groups
583
+ self.frames_per_group = frames_per_group
584
+ self.sample_freq = frames_per_group
585
+ self.num_clips = num_clips
586
+ self.sample_offset = sample_offset
587
+ self.fixed_offset = fixed_offset
588
+ self.dense_sampling = dense_sampling
589
+ self.modality = modality.lower()
590
+ self.image_tmpl = image_tmpl
591
+ self.transform = transform
592
+ self.is_train = is_train
593
+ self.test_mode = test_mode
594
+ self.seperator = seperator
595
+ self.filter_video = filter_video
596
+ self.whole_video = whole_video
597
+ self.fps = fps
598
+ self.audio_length = audio_length
599
+ self.resampling_rate = resampling_rate
600
+ self.video_length = (self.num_frames * self.sample_freq) / self.fps
601
+
602
+ if self.modality in ['flow', 'rgbdiff']:
603
+ self.num_consecutive_frames = 5
604
+ else:
605
+ self.num_consecutive_frames = 1
606
+
607
+ self.multi_label = None
608
+
609
+ self.db = None
610
+ db = lmdb.open(self.db_path, max_readers=1, subdir=os.path.isdir(self.db_path),
611
+ readonly=True, lock=False, readahead=False, meminit=False)
612
+ with db.begin(write=False) as txn:
613
+ self.length = pa.deserialize(txn.get(b'__len__'))
614
+ self.keys = pa.deserialize(txn.get(b'__keys__'))
615
+ db.close()
616
+
617
+ # TODO: a hack way to filter video
618
+ self.list_file = self.db_path.replace(".lmdb", ".txt")
619
+
620
+ valid_video_numbers = self.length
621
+ invalid_video_ids = []
622
+ if self.filter_video > 0:
623
+ valid_video_numbers = 0
624
+ invalid_video_ids = []
625
+ for x in open(self.list_file):
626
+ elements = x.strip().split(self.seperator)
627
+ start_frame = int(elements[1])
628
+ end_frame = int(elements[2])
629
+ total_frame = end_frame - start_frame + 1
630
+ if self.test_mode:
631
+ valid_video_numbers += 1
632
+ else:
633
+ if total_frame >= self.filter_video:
634
+ valid_video_numbers += 1
635
+ else:
636
+ name = u'{}'.format(elements[0].split("/")[-1]).encode('ascii')
637
+ invalid_video_ids.append(name)
638
+
639
+ print("The number of videos is {} (with more than {} frames) "
640
+ "(original: {})".format(valid_video_numbers, self.filter_video, self.length),
641
+ flush=True)
642
+
643
+ # remove keys and update length
644
+ self.length = valid_video_numbers
645
+ self.keys = [k for k in self.keys if k not in invalid_video_ids]
646
+
647
+ if self.length != len(self.keys):
648
+ raise ValueError("Do not filter video correctly.")
649
+
650
+ self.num_classes = num_classes
651
+ self.unpacked_video = None
652
+
653
+ def remove_data(self, idx):
654
+ original_video_num = self.length
655
+ self.keys = [v for i, v in enumerate(self.keys) if i not in idx]
656
+ self.length -= len(idx)
657
+ print("Original videos: {}\t remove {} videos, remaining {} videos".format(original_video_num, len(idx), self.length))
658
+
659
+ def _sample_indices(self, record):
660
+ return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
661
+ self.sample_freq, self.dense_sampling, self.num_clips)
662
+
663
+ def _get_val_indices(self, record):
664
+ return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
665
+ self.sample_freq, self.dense_sampling, self.fixed_offset,
666
+ self.num_clips, self.whole_video)
667
+
668
+ def __getitem__(self, index):
669
+ unpacked_video = self.maybe_open_and_get_buffer(index)
670
+ num_frames = unpacked_video[0] - 1 if self.modality == 'rgbdiff' else unpacked_video[0]
671
+ record = VideoRecord(self.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
672
+ indices = self._sample_indices(record) if self.is_train else self._get_val_indices(record)
673
+ images = self.get_data(record, indices, unpacked_video)
674
+ images = self.transform(images)
675
+ label = self.get_label(record)
676
+ self.unpacked_video = None
677
+ # re-order data to targeted format.
678
+ return images, label
679
+
680
+ def maybe_open_and_get_buffer(self, index):
681
+ if self.db is None:
682
+ self.db = lmdb.open(self.db_path, max_readers=1, subdir=os.path.isdir(self.db_path),
683
+ readonly=True, lock=False, readahead=False, meminit=False)
684
+
685
+ with self.db.begin(write=False) as txn:
686
+ byteflow = txn.get(self.keys[index])
687
+ try:
688
+ unpacked_video = pa.deserialize(byteflow)
689
+ except Exception as e:
690
+ with self.db.begin(write=False) as txn:
691
+ byteflow = txn.get(self.keys[0])
692
+ unpacked_video = pa.deserialize(byteflow)
693
+ print(self.keys[index], e, flush=True)
694
+
695
+ self.unpacked_video = unpacked_video
696
+ return unpacked_video
697
+
698
+ def get_data(self, record, indices, unpacked_video):
699
+ images = []
700
+ for seg_ind in indices:
701
+ new_seg_ind = [min(seg_ind + record.start_frame - 1 + i, record.num_frames)
702
+ for i in range(self.num_consecutive_frames)]
703
+ img = load_data_lmdb(unpacked_video, new_seg_ind, self.modality)
704
+ images.extend(img)
705
+ return images
706
+
707
+ def get_label(self, record):
708
+ if self.test_mode:
709
+ # in test mode, return the video id as label
710
+ label = record.video_id
711
+ else:
712
+ if not self.multi_label:
713
+ label = int(record.label)
714
+ else:
715
+ # create a binary vector.
716
+ label = torch.zeros(self.num_classes, dtype=torch.float)
717
+ for x in record.label:
718
+ label[int(x)] = 1.0
719
+ return label
720
+
721
+ def __len__(self):
722
+ return self.length
723
+
724
+
725
+ class MultiVideoDataSet(data.Dataset):
726
+
727
+ def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1,
728
+ modality='rgb', dense_sampling=False, fixed_offset=True,
729
+ image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
730
+ filter_video=0, num_classes=None, whole_video=False,
731
+ fps=29.97, audio_length=1.28, resampling_rate=24000):
732
+ """
733
+ # root_path, modality and transform become list, each for one modality
734
+
735
+ Argments have different meaning when dense_sampling is True:
736
+ - num_groups ==> number of frames
737
+ - frames_per_group ==> sample every K frame
738
+ - sample_offset ==> number of clips used in validation or test mode
739
+
740
+ Args:
741
+ root_path (str): the file path to the root of video folder
742
+ list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
743
+ num_groups (int): number of frames per data sample
744
+ frames_per_group (int): number of frames within one group
745
+ sample_offset (int): used in validation/test, the offset when sampling frames from a group
746
+ modality (str): rgb or flow
747
+ dense_sampling (bool): dense sampling in I3D
748
+ fixed_offset (bool): used for generating the same videos used in TSM
749
+ image_tmpl (str): template of image ids
750
+ transform: the transformer for preprocessing
751
+ is_train (bool): shuffle the video but keep the causality
752
+ test_mode (bool): testing mode, no label
753
+ """
754
+
755
+ video_datasets = []
756
+ for i in range(len(modality)):
757
+ tmp = VideoDataSet(root_path[i], os.path.join(root_path[i], list_file),
758
+ num_groups, frames_per_group, sample_offset,
759
+ num_clips, modality[i], dense_sampling, fixed_offset,
760
+ image_tmpl, transform[i], is_train, test_mode, seperator,
761
+ filter_video, num_classes, whole_video, fps, audio_length, resampling_rate)
762
+ video_datasets.append(tmp)
763
+
764
+ self.video_datasets = video_datasets
765
+ self.is_train = is_train
766
+ self.test_mode = test_mode
767
+ self.num_frames = num_groups
768
+ self.sample_freq = frames_per_group
769
+ self.dense_sampling = dense_sampling
770
+ self.num_clips = num_clips
771
+ self.fixed_offset = fixed_offset
772
+ self.modality = modality
773
+ self.num_classes = num_classes
774
+ self.whole_video = whole_video
775
+
776
+ self.video_list = video_datasets[0].video_list
777
+ self.num_consecutive_frames = max([x.num_consecutive_frames for x in self.video_datasets])
778
+
779
+ def _sample_indices(self, record):
780
+ return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
781
+ self.sample_freq, self.dense_sampling, self.num_clips)
782
+
783
+ def _get_val_indices(self, record):
784
+ return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
785
+ self.sample_freq, self.dense_sampling, self.fixed_offset,
786
+ self.num_clips, self.whole_video)
787
+
788
+ def remove_data(self, idx):
789
+ for i in range(len(self.video_datasets)):
790
+ self.video_datasets[i].remove_data(idx)
791
+ self.video_list = self.video_datasets[0].video_list
792
+
793
+ def __getitem__(self, index):
794
+ """
795
+ Returns:
796
+ torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group.
797
+ torch.FloatTensor: the label
798
+ """
799
+
800
+ record = self.video_list[index]
801
+ if self.is_train:
802
+ indices = self._sample_indices(record)
803
+ else:
804
+ indices = self._get_val_indices(record)
805
+
806
+ multi_modalities = []
807
+ for modality, video_dataset in zip(self.modality, self.video_datasets):
808
+ record = video_dataset.video_list[index]
809
+ images = video_dataset.get_data(record, indices)
810
+ images = video_dataset.transform(images)
811
+ label = video_dataset.get_label(record)
812
+ multi_modalities.append((images, label))
813
+
814
+ return [x for x, y in multi_modalities], multi_modalities[0][1]
815
+
816
+ def __len__(self):
817
+ return len(self.video_list)
818
+
819
+
820
+ class MultiVideoDataSetLMDB(data.Dataset):
821
+
822
+ def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1,
823
+ modality='rgb', dense_sampling=False, fixed_offset=True,
824
+ image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
825
+ filter_video=0, num_classes=None, whole_video=False,
826
+ fps=29.97, audio_length=1.28, resampling_rate=24000):
827
+ """
828
+ # root_path, modality and transform become list, each for one modality
829
+
830
+ Argments have different meaning when dense_sampling is True:
831
+ - num_groups ==> number of frames
832
+ - frames_per_group ==> sample every K frame
833
+ - sample_offset ==> number of clips used in validation or test mode
834
+
835
+ Args:
836
+ root_path (str): the file path to the root of video folder
837
+ list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
838
+ num_groups (int): number of frames per data sample
839
+ frames_per_group (int): number of frames within one group
840
+ sample_offset (int): used in validation/test, the offset when sampling frames from a group
841
+ modality (str): rgb or flow
842
+ dense_sampling (bool): dense sampling in I3D
843
+ fixed_offset (bool): used for generating the same videos used in TSM
844
+ image_tmpl (str): template of image ids
845
+ transform: the transformer for preprocessing
846
+ is_train (bool): shuffle the video but keep the causality
847
+ test_mode (bool): testing mode, no label
848
+ """
849
+
850
+ video_datasets = []
851
+ for i in range(len(modality)):
852
+ if modality[i] == 'sound':
853
+ list_file_ = list_file.replace(".lmdb", ".txt")
854
+ tmp = VideoDataSet(root_path[i], os.path.join(root_path[i], list_file_),
855
+ num_groups, frames_per_group, sample_offset,
856
+ num_clips, modality[i], dense_sampling, fixed_offset,
857
+ image_tmpl, transform[i], is_train, test_mode, seperator,
858
+ filter_video, num_classes, whole_video, fps, audio_length, resampling_rate)
859
+ else:
860
+ tmp = VideoDataSetLMDB(root_path[i], list_file, num_groups, frames_per_group,
861
+ sample_offset, num_clips, modality[i], dense_sampling,
862
+ fixed_offset, image_tmpl, transform[i], is_train, test_mode,
863
+ seperator, filter_video, num_classes, whole_video, fps, audio_length,
864
+ resampling_rate)
865
+ video_datasets.append(tmp)
866
+
867
+ self.video_datasets = video_datasets
868
+ self.is_train = is_train
869
+ self.test_mode = test_mode
870
+ self.num_frames = num_groups
871
+ self.sample_freq = frames_per_group
872
+ self.dense_sampling = dense_sampling
873
+ self.num_clips = num_clips
874
+ self.fixed_offset = fixed_offset
875
+ self.modality = modality
876
+ self.num_classes = num_classes
877
+ self.whole_video = whole_video
878
+
879
+ self.num_consecutive_frames = max([x.num_consecutive_frames for x in self.video_datasets])
880
+
881
+ def _sample_indices(self, record):
882
+ return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
883
+ self.sample_freq, self.dense_sampling, self.num_clips)
884
+
885
+ def _get_val_indices(self, record):
886
+ return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
887
+ self.sample_freq, self.dense_sampling, self.fixed_offset,
888
+ self.num_clips, self.whole_video)
889
+
890
+ def remove_data(self, idx):
891
+ for i in range(len(self.video_datasets)):
892
+ self.video_datasets[i].remove_data(idx)
893
+
894
+ def __getitem__(self, index):
895
+ """
896
+ Returns:
897
+ torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group.
898
+ torch.FloatTensor: the label
899
+ """
900
+ multi_modalities = []
901
+ indices = None
902
+ for modality, video_dataset in zip(self.modality, self.video_datasets):
903
+ if indices is None:
904
+ if modality == 'sound':
905
+ record = video_dataset.video_list[index]
906
+ else:
907
+ unpacked_video = video_dataset.maybe_open_and_get_buffer(index)
908
+ num_frames = unpacked_video[0] - 1 if modality == 'rgbdiff' else unpacked_video[0]
909
+ record = VideoRecord(video_dataset.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
910
+ indices = video_dataset._sample_indices(record) if video_dataset.is_train else video_dataset._get_val_indices(record)
911
+
912
+ if modality == 'sound':
913
+ record = video_dataset.video_list[index]
914
+ images = video_dataset.get_data(record, indices)
915
+ else:
916
+ if video_dataset.unpacked_video is None:
917
+ video_dataset.maybe_open_and_get_buffer(index)
918
+ unpacked_video = video_dataset.unpacked_video
919
+ num_frames = unpacked_video[0] - 1 if modality == 'rgbdiff' else unpacked_video[0]
920
+ record = VideoRecord(video_dataset.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
921
+
922
+ images = video_dataset.get_data(record, indices, video_dataset.unpacked_video)
923
+ video_dataset.unpacked_video = None
924
+
925
+ images = video_dataset.transform(images)
926
+ label = video_dataset.get_label(record)
927
+ multi_modalities.append((images, label))
928
+
929
+ return [x for x, y in multi_modalities], multi_modalities[0][1]
930
+
931
+ def __len__(self):
932
+ return len(self.video_datasets[0])
933
+
934
+
935
+ class VideoDataSetOnline(VideoDataSet):
936
+
937
+ def __init__(self, root_path, list_file, num_groups=8, frames_per_group=1, sample_offset=0,
938
+ num_clips=1, modality='rgb', dense_sampling=False, fixed_offset=True,
939
+ image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
940
+ filter_video=0, num_classes=None, whole_video=False,
941
+ fps=29.97, audio_length=1.28, resampling_rate=24000):
942
+ """
943
+
944
+ Arguments have different meaning when dense_sampling is True:
945
+ - num_groups ==> number of frames
946
+ - frames_per_group ==> sample every K frame
947
+ - sample_offset ==> number of clips used in validation or test mode
948
+
949
+ Args:
950
+ root_path (str): the file path to the root of video folder
951
+ list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
952
+ num_groups (int): number of frames per data sample
953
+ frames_per_group (int): number of frames within one group
954
+ sample_offset (int): used in validation/test, the offset when sampling frames from a group
955
+ modality (str): rgb or flow
956
+ dense_sampling (bool): dense sampling in I3D
957
+ fixed_offset (bool): used for generating the same videos used in TSM
958
+ image_tmpl (str): template of image ids
959
+ transform: the transformer for preprocessing
960
+ is_train (bool): shuffle the video but keep the causality
961
+ test_mode (bool): testing mode, no label
962
+ fps (float): frame rate per second, used to localize sound when frame idx is selected.
963
+ audio_length (float): the time window to extract audio feature.
964
+ resampling_rate (int): used to resampling audio extracted from wav
965
+ """
966
+
967
+ if not _HAS_PYAV:
968
+ raise ValueError(_PYAV_ERROR_MSG)
969
+ if modality not in ['rgb', 'rgbdiff']:
970
+ raise ValueError("modality should be 'rgb' or 'rgbdiff'.")
971
+
972
+ super().__init__(root_path, list_file, num_groups, frames_per_group, sample_offset,
973
+ num_clips, modality, dense_sampling, fixed_offset,
974
+ image_tmpl, transform, is_train, test_mode, seperator,
975
+ filter_video, num_classes, whole_video, fps, audio_length, resampling_rate)
976
+
977
+ def remove_data(self, idx):
978
+ original_video_num = len(self.video_list)
979
+ self.video_list = [v for i, v in enumerate(self.video_list) if i not in idx]
980
+ print("Original videos: {}\t remove {} videos, remaining {} videos".format(original_video_num, len(idx), len(self.video_list)))
981
+
982
+ def get_data(self, record, indices):
983
+ indices = indices - 1
984
+ container = av.open(os.path.join(self.root_path, record.path))
985
+ container.streams.video[0].thread_type = "AUTO"
986
+ frames_length = container.streams.video[0].frames
987
+ duration = container.streams.video[0].duration
988
+ if duration is None or frames_length == 0:
989
+ # If failed to fetch the decoding information, decode the entire video.
990
+ # video_start_pts, video_end_pts = 0, math.inf
991
+ decode_all = True
992
+ else:
993
+ # Perform selective decoding.
994
+ if frames_length != record.num_frames:
995
+ # remap the index
996
+ length_ratio = frames_length / record.num_frames
997
+ indices = np.around(indices * length_ratio).astype(int)
998
+ start_idx, end_idx = min(indices), max(indices)
999
+ # if self.modality == 'rgbdiff':
1000
+ # end_idx += (self.num_consecutive_frames + 1)
1001
+ timebase = duration / frames_length
1002
+ video_start_pts = int(start_idx * timebase)
1003
+ video_end_pts = int(end_idx * timebase)
1004
+ decode_all = False
1005
+
1006
+ def _selective_decoding(container, index, timebase):
1007
+ margin = 1024
1008
+ start_idx, end_idx = min(index), max(index)
1009
+ video_start_pts = int(start_idx * timebase)
1010
+ video_end_pts = int(end_idx * timebase)
1011
+ seek_offset = max(video_start_pts - margin, 0)
1012
+ container.seek(seek_offset, any_frame=False, backward=True,
1013
+ stream=container.streams.video[0])
1014
+ success = True
1015
+ video_frames = None
1016
+ try:
1017
+ frames = {}
1018
+ for frame in container.decode({'video': 0}):
1019
+ if frame.pts < video_start_pts:
1020
+ continue
1021
+ if frame.pts <= video_end_pts:
1022
+ frames[frame.pts] = frame
1023
+ else:
1024
+ break
1025
+ # the decoded frames is a whole region but we might subsample it
1026
+ video_frames = np.asarray([frames[pts].to_rgb().to_ndarray() for pts in sorted(frames)])
1027
+ index = np.linspace(0, max(0, len(video_frames) - 1), num=self.num_frames, dtype=int)
1028
+ if len(video_frames) == 0: # somehow decoding is wrong
1029
+ success = False
1030
+ else:
1031
+ video_frames = video_frames[index, ...]
1032
+ except Exception as e:
1033
+ success = False
1034
+
1035
+ return video_frames, success
1036
+
1037
+ # If video stream was found, fetch video frames from the video.
1038
+ # Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a
1039
+ # margin pts.
1040
+ if not decode_all:
1041
+ timebase = duration / frames_length
1042
+ video_frames = None
1043
+ for i in range(self.num_clips):
1044
+ curr_index = indices[(i) * self.num_frames: (i + 1) * self.num_frames]
1045
+ curr_video_frames, success = _selective_decoding(container, curr_index, timebase)
1046
+ if not success:
1047
+ decode_all = True
1048
+ break
1049
+ if video_frames is not None:
1050
+ video_frames = np.concatenate((video_frames, curr_video_frames), axis=0)
1051
+ else:
1052
+ video_frames = curr_video_frames
1053
+ if decode_all:
1054
+ container.seek(0, any_frame=False, backward=True, stream=container.streams.video[0])
1055
+ frames = {}
1056
+ for frame in container.decode({'video': 0}):
1057
+ frames[frame.pts] = frame
1058
+ video_frames = np.asarray([frames[pts].to_rgb().to_ndarray() for pts in sorted(frames)])
1059
+ total_frames = len(video_frames)
1060
+ if total_frames != record.num_frames:
1061
+ # remap the index
1062
+ length_ratio = total_frames / record.num_frames
1063
+ indices = np.around(indices * length_ratio).astype(int)
1064
+ video_frames = video_frames[indices, ...]
1065
+
1066
+ """
1067
+ if self.modality == 'rgbdiff':
1068
+ video_diff = np.asarray(video_frames[1:, ...].copy(), dtype=np.float) - np.asarray(video_frames[:-1, ...].copy(), dtype=np.float)
1069
+ video_diff += 255.0
1070
+ video_diff *= (255.0 / float(2 * 255.0))
1071
+ video_diff = video_diff.astype(np.uint8)
1072
+ for seg_ind in indices:
1073
+ new_seg_ind = [min(seg_ind + i, total_frames - 1)
1074
+ for i in range(self.num_consecutive_frames) ]
1075
+
1076
+ video_frames = video_diff
1077
+ else:
1078
+ """
1079
+ images = [Image.fromarray(frame) for frame in video_frames]
1080
+ # TODO: support rgb diff, calculate end_pts differently.
1081
+ container.close()
1082
+ return images
1083
+
1084
+
1085
+ class MultiVideoDataSetOnline(data.Dataset):
1086
+
1087
+ def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1,
1088
+ modality='rgb', dense_sampling=False, fixed_offset=True,
1089
+ image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
1090
+ filter_video=0, num_classes=None, whole_video=False,
1091
+ fps=29.97, audio_length=1.28, resampling_rate=24000):
1092
+ """
1093
+ # root_path, modality and transform become list, each for one modality
1094
+
1095
+ Argments have different meaning when dense_sampling is True:
1096
+ - num_groups ==> number of frames
1097
+ - frames_per_group ==> sample every K frame
1098
+ - sample_offset ==> number of clips used in validation or test mode
1099
+
1100
+ Args:
1101
+ root_path (str): the file path to the root of video folder
1102
+ list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
1103
+ num_groups (int): number of frames per data sample
1104
+ frames_per_group (int): number of frames within one group
1105
+ sample_offset (int): used in validation/test, the offset when sampling frames from a group
1106
+ modality (str): rgb or flow
1107
+ dense_sampling (bool): dense sampling in I3D
1108
+ fixed_offset (bool): used for generating the same videos used in TSM
1109
+ image_tmpl (str): template of image ids
1110
+ transform: the transformer for preprocessing
1111
+ is_train (bool): shuffle the video but keep the causality
1112
+ test_mode (bool): testing mode, no label
1113
+ """
1114
+
1115
+ # TODO: support mixed LMDB, pyAV, etc.
1116
+
1117
+ video_datasets = []
1118
+ for i in range(len(modality)):
1119
+ if modality[i] == 'rgb' or modality[i] == 'rgbdiff':
1120
+ video_dataset_cls = VideoDataSetOnline
1121
+ list_file_ = list_file
1122
+ elif modality[i] == 'sound':
1123
+ video_dataset_cls = VideoDataSet
1124
+ list_file_ = list_file
1125
+ elif modality[i] == 'flow':
1126
+ video_dataset_cls = VideoDataSetLMDB
1127
+ list_file_ = list_file.replace(".txt", ".lmdb")
1128
+
1129
+ tmp = video_dataset_cls(root_path[i], list_file_,
1130
+ num_groups, frames_per_group, sample_offset,
1131
+ num_clips, modality[i], dense_sampling, fixed_offset,
1132
+ image_tmpl, transform[i], is_train, test_mode, seperator,
1133
+ filter_video, num_classes, whole_video, fps, audio_length, resampling_rate)
1134
+ video_datasets.append(tmp)
1135
+
1136
+ self.video_datasets = video_datasets
1137
+ self.is_train = is_train
1138
+ self.test_mode = test_mode
1139
+ self.num_frames = num_groups
1140
+ self.sample_freq = frames_per_group
1141
+ self.dense_sampling = dense_sampling
1142
+ self.num_clips = num_clips
1143
+ self.fixed_offset = fixed_offset
1144
+ self.modality = modality
1145
+ self.num_classes = num_classes
1146
+ self.whole_video = whole_video
1147
+
1148
+ self.video_list = video_datasets[0].video_list
1149
+ self.num_consecutive_frames = max([x.num_consecutive_frames for x in self.video_datasets])
1150
+
1151
+ def _sample_indices(self, record):
1152
+ return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
1153
+ self.sample_freq, self.dense_sampling, self.num_clips)
1154
+
1155
+ def _get_val_indices(self, record):
1156
+ return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
1157
+ self.sample_freq, self.dense_sampling, self.fixed_offset,
1158
+ self.num_clips, self.whole_video)
1159
+
1160
+ def remove_data(self, idx):
1161
+ for i in range(len(self.video_datasets)):
1162
+ self.video_datasets[i].remove_data(idx)
1163
+ self.video_list = self.video_datasets[0].video_list
1164
+
1165
+ def __getitem__(self, index):
1166
+ """
1167
+ Returns:
1168
+ torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group.
1169
+ torch.FloatTensor: the label
1170
+ """
1171
+ multi_modalities = []
1172
+ indices = None
1173
+ for modality, video_dataset in zip(self.modality, self.video_datasets):
1174
+ if indices is None:
1175
+ if modality != 'flow':
1176
+ record = video_dataset.video_list[index]
1177
+ else:
1178
+ unpacked_video = video_dataset.maybe_open_and_get_buffer(index)
1179
+ num_frames = unpacked_video[0]
1180
+ record = VideoRecord(video_dataset.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
1181
+ indices = video_dataset._sample_indices(record) if video_dataset.is_train else video_dataset._get_val_indices(record)
1182
+
1183
+ if modality != 'flow':
1184
+ record = video_dataset.video_list[index]
1185
+ images = video_dataset.get_data(record, indices)
1186
+ else:
1187
+ if video_dataset.unpacked_video is None:
1188
+ video_dataset.maybe_open_and_get_buffer(index)
1189
+ unpacked_video = video_dataset.unpacked_video
1190
+ num_frames = unpacked_video[0]
1191
+ record = VideoRecord(video_dataset.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
1192
+
1193
+ images = video_dataset.get_data(record, indices, video_dataset.unpacked_video)
1194
+ video_dataset.unpacked_video = None
1195
+
1196
+ images = video_dataset.transform(images)
1197
+ label = video_dataset.get_label(record)
1198
+ multi_modalities.append((images, label))
1199
+
1200
+ return [x for x, y in multi_modalities], multi_modalities[0][1]
1201
+
1202
+ def __len__(self):
1203
+ return len(self.video_datasets[0])
1204
+
1205
+
1206
+ def get_dataloader(loader_type, *args, **kwargs) -> \
1207
+ Union[VideoDataSetLMDB, VideoDataSetOnline, VideoDataSet]:
1208
+ if loader_type == 'lmdb':
1209
+ return VideoDataSetLMDB(*args, **kwargs)
1210
+ elif loader_type == 'pyav':
1211
+ return VideoDataSetOnline(*args, **kwargs)
1212
+ elif loader_type == 'jpeg':
1213
+ return VideoDataSet(*args, **kwargs)
1214
+ else:
1215
+ raise ValueError(f'Unknown dataloader type: {loader_type}')
1216
+
1217
+
1218
+ def get_multimodality_dataloader(loader_type, *args, **kwargs) -> \
1219
+ Union[MultiVideoDataSetLMDB, MultiVideoDataSetOnline, MultiVideoDataSet]:
1220
+ if loader_type == 'lmdb':
1221
+ return MultiVideoDataSetLMDB(*args, **kwargs)
1222
+ elif loader_type == 'pyav':
1223
+ return MultiVideoDataSetOnline(*args, **kwargs)
1224
+ elif loader_type == 'jpeg':
1225
+ return MultiVideoDataSet(*args, **kwargs)
1226
+ else:
1227
+ raise ValueError(f'Unknown dataloader type: {loader_type}')
1228
+
video_dataset_aug.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import multiprocessing
2
+ from typing import List
3
+
4
+ import torch
5
+ import torch.nn.parallel
6
+ import torch.optim
7
+ import torch.utils.data
8
+ import torch.utils.data.distributed
9
+ import torchvision.transforms as transforms
10
+ from video_transforms import (GroupRandomHorizontalFlip, GroupOverSample,
11
+ GroupMultiScaleCrop, GroupScale, GroupCenterCrop, GroupRandomCrop,
12
+ GroupNormalize, Stack, ToTorchFormatTensor, GroupRandomScale,GroupCutout)
13
+
14
+ def get_augmentor(is_train: bool, image_size: int, mean: List[float] = None,
15
+ std: List[float] = None, disable_scaleup: bool = False,
16
+ threed_data: bool = False, version: str = 'v1', scale_range: [int] = None,
17
+ modality: str = 'rgb', num_clips: int = 1, num_crops: int = 1, cut_out=True,dataset: str = ''):
18
+
19
+ mean = [0.485, 0.456, 0.406] if mean is None else mean
20
+ std = [0.229, 0.224, 0.225] if std is None else std
21
+ scale_range = [256, 320] if scale_range is None else scale_range
22
+
23
+
24
+ augments = []
25
+ if is_train:
26
+ if version == 'v1':
27
+ augments += [
28
+ GroupMultiScaleCrop(image_size, [1, .875, .75, .66])
29
+ ]
30
+ elif version == 'v2':
31
+ augments += [
32
+ GroupRandomScale(scale_range),
33
+ GroupRandomCrop(image_size),
34
+ ]
35
+ if not (dataset.startswith('ststv') or 'jester' in dataset or 'mini_ststv' in dataset):
36
+ augments += [GroupRandomHorizontalFlip(is_flow=(modality == 'flow'))]
37
+ else:
38
+ scaled_size = image_size if disable_scaleup else int(image_size / 0.875 + 0.5)
39
+ if num_crops == 1:
40
+ augments += [
41
+ GroupScale(scaled_size),
42
+ GroupCenterCrop(image_size)
43
+ ]
44
+ else:
45
+ flip = True if num_crops == 10 else False
46
+ augments += [
47
+ GroupOverSample(image_size, scaled_size, num_crops=num_crops, flip=flip),
48
+ ]
49
+ augments += [
50
+ Stack(threed_data=threed_data),
51
+ ToTorchFormatTensor(num_clips_crops=num_clips * num_crops),
52
+ GroupNormalize(mean=mean, std=std, threed_data=threed_data)
53
+ ]
54
+ if cut_out:
55
+ augments += [GroupCutout(n_holes=1,length=16)]
56
+
57
+ augmentor = transforms.Compose(augments)
58
+ return augmentor
59
+
60
+
61
+ def build_dataflow(dataset, is_train, batch_size, workers=36, is_distributed=False):
62
+ workers = min(workers, multiprocessing.cpu_count())
63
+ shuffle = False
64
+
65
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
66
+ if is_train:
67
+ shuffle = sampler is None
68
+
69
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
70
+ num_workers=workers, drop_last = True,pin_memory=True, sampler=sampler)
71
+
72
+ return data_loader
73
+
video_dataset_config.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ DATASET_CONFIG = {
4
+ 'ffpp': {
5
+ 'num_classes': 2,
6
+ 'train_list_name': 'cdf_test_fold.txt',
7
+ 'val_list_name': 'cdf_test_fold.txt',
8
+ 'test_list_name': 'cdf_test_fold.txt',
9
+ 'filename_seperator': " ",
10
+ 'image_tmpl': '{:04d}.jpg',
11
+ 'filter_video': 3,
12
+ }
13
+ }
14
+
15
+
16
+ def get_dataset_config(dataset, use_lmdb=False):
17
+ ret = DATASET_CONFIG[dataset]
18
+ num_classes = ret['num_classes']
19
+ train_list_name = ret['train_list_name'].replace("txt", "lmdb") if use_lmdb \
20
+ else ret['train_list_name']
21
+ val_list_name = ret['val_list_name'].replace("txt", "lmdb") if use_lmdb \
22
+ else ret['val_list_name']
23
+ test_list_name = ret['test_list_name'].replace("txt", "lmdb") if use_lmdb \
24
+ else ret['test_list_name']
25
+ filename_seperator = ret['filename_seperator']
26
+ image_tmpl = ret['image_tmpl']
27
+ filter_video = ret.get('filter_video', 0)
28
+ label_file = ret.get('label_file', None)
29
+
30
+ return num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, \
31
+ image_tmpl, filter_video, label_file
video_transforms.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchvision
2
+ import random
3
+ from PIL import Image, ImageOps
4
+ import numbers
5
+ import torch
6
+ import numpy as np
7
+ import math
8
+
9
+ class GroupRandomCrop(object):
10
+ def __init__(self, size):
11
+ if isinstance(size, numbers.Number):
12
+ self.size = (int(size), int(size))
13
+ else:
14
+ self.size = size
15
+
16
+ def __call__(self, img_group):
17
+
18
+ w, h = img_group[0].size
19
+ th, tw = self.size
20
+
21
+ out_images = list()
22
+
23
+ x1 = random.randint(0, w - tw)
24
+ y1 = random.randint(0, h - th)
25
+
26
+ for img in img_group:
27
+ assert(img.size[0] == w and img.size[1] == h)
28
+ if w == tw and h == th:
29
+ out_images.append(img)
30
+ else:
31
+ out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
32
+
33
+ return out_images
34
+
35
+
36
+ class GroupCenterCrop(object):
37
+ def __init__(self, size):
38
+ self.worker = torchvision.transforms.CenterCrop(size)
39
+
40
+ def __call__(self, img_group):
41
+ return [self.worker(img) for img in img_group]
42
+
43
+
44
+ class GroupRandomHorizontalFlip(object):
45
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
46
+ """
47
+ def __init__(self, is_flow=False):
48
+ self.is_flow = is_flow
49
+
50
+ def __call__(self, img_group, is_flow=False):
51
+ v = random.random()
52
+ if v < 0.5:
53
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
54
+ if self.is_flow:
55
+ for i in range(0, len(ret), 2):
56
+ ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping
57
+ return ret
58
+ else:
59
+ return img_group
60
+
61
+
62
+
63
+ class GroupNormalize(object):
64
+ def __init__(self, mean, std, threed_data=False):
65
+ self.threed_data = threed_data
66
+ if self.threed_data:
67
+ # convert to the proper format
68
+ self.mean = torch.FloatTensor(mean).view(len(mean), 1, 1, 1)
69
+ self.std = torch.FloatTensor(std).view(len(std), 1, 1, 1)
70
+ else:
71
+ self.mean = mean
72
+ self.std = std
73
+
74
+ def __call__(self, tensor):
75
+
76
+ if self.threed_data:
77
+ tensor.sub_(self.mean).div_(self.std)
78
+ else:
79
+ rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
80
+ rep_std = self.std * (tensor.size()[0] // len(self.std))
81
+
82
+ # TODO: make efficient
83
+ for t, m, s in zip(tensor, rep_mean, rep_std):
84
+ t.sub_(m).div_(s)
85
+
86
+ return tensor
87
+
88
+ class GroupCutout(object):
89
+ """Randomly mask out one or more patches from an image.
90
+ Args:
91
+ n_holes (int): Number of patches to cut out of each image.
92
+ length (int): The length (in pixels) of each square patch.
93
+ """
94
+ def __init__(self, n_holes, length):
95
+ self.n_holes = n_holes
96
+ self.length = length
97
+
98
+ def __call__(self, imgs):
99
+ """
100
+ Args:
101
+ img (Tensor): Tensor image of size (C, H, W).
102
+ Returns:
103
+ Tensor: Image with n_holes of dimension length x length cut out of it.
104
+ """
105
+ new_imgs = []
106
+ # import pdb;pdb.set_trace()
107
+ C,W,H = imgs.shape #72,224,224
108
+ # print(C,W,H)
109
+ # imgs = imgs.reshape(-1,3,H,W)
110
+ y = np.random.randint(H)
111
+ x = np.random.randint(W)
112
+ for i in range(0,imgs.shape[0],3):
113
+ h = W
114
+ w = H
115
+
116
+ mask = np.ones((h, w), np.float32)
117
+
118
+ for n in range(self.n_holes):
119
+
120
+ y1 = np.clip(y - self.length // 2, 0, h)
121
+ y2 = np.clip(y + self.length // 2, 0, h)
122
+ x1 = np.clip(x - self.length // 2, 0, w)
123
+ x2 = np.clip(x + self.length // 2, 0, w)
124
+
125
+ mask[y1: y2, x1: x2] = 0.
126
+
127
+ mask = torch.from_numpy(mask)
128
+ mask = mask.expand_as(imgs[i:i+3])
129
+ img = imgs[i:i+3] * mask
130
+ new_imgs.append(img)
131
+
132
+ # import pdb;pdb.set_trace()
133
+ new_imgs = torch.stack(new_imgs,0).reshape(C,H,W)
134
+ # print(new_imgs.shape)
135
+ return new_imgs
136
+
137
+ class GroupScale(object):
138
+ """ Rescales the input PIL.Image to the given 'size'.
139
+ 'size' will be the size of the smaller edge.
140
+ For example, if height > width, then image will be
141
+ rescaled to (size * height / width, size)
142
+ size: size of the smaller edge
143
+ interpolation: Default: PIL.Image.BILINEAR
144
+ """
145
+
146
+ def __init__(self, size, interpolation=Image.BILINEAR):
147
+ self.worker = torchvision.transforms.Resize(size, interpolation)
148
+
149
+ def __call__(self, img_group):
150
+ return [self.worker(img) for img in img_group]
151
+
152
+ class GroupRandomScale(object):
153
+ """ Rescales the input PIL.Image to the given 'size'.
154
+ 'size' will be the size of the smaller edge.
155
+ For example, if height > width, then image will be
156
+ rescaled to (size * height / width, size)
157
+ size: size of the smaller edge
158
+ interpolation: Default: PIL.Image.BILINEAR
159
+
160
+ Randomly select the smaller edge from the range of 'size'.
161
+ """
162
+ def __init__(self, size, interpolation=Image.BILINEAR):
163
+ self.size = size
164
+ self.interpolation = interpolation
165
+
166
+ def __call__(self, img_group):
167
+ selected_size = np.random.randint(low=self.size[0], high=self.size[1] + 1, dtype=int)
168
+ scale = GroupScale(selected_size, interpolation=self.interpolation)
169
+ return scale(img_group)
170
+
171
+ class GroupOverSample(object):
172
+ def __init__(self, crop_size, scale_size=None, num_crops=5, flip=False):
173
+ self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
174
+
175
+ if scale_size is not None:
176
+ self.scale_worker = GroupScale(scale_size)
177
+ else:
178
+ self.scale_worker = None
179
+
180
+ if num_crops not in [1, 3, 5, 10]:
181
+ raise ValueError("num_crops should be in [1, 3, 5, 10] but ({})".format(num_crops))
182
+ self.num_crops = num_crops
183
+
184
+ self.flip = flip
185
+
186
+ def __call__(self, img_group):
187
+
188
+ if self.scale_worker is not None:
189
+ img_group = self.scale_worker(img_group)
190
+
191
+ image_w, image_h = img_group[0].size
192
+ crop_w, crop_h = self.crop_size
193
+
194
+ if self.num_crops == 3:
195
+ w_step = (image_w - crop_w) // 4
196
+ h_step = (image_h - crop_h) // 4
197
+ offsets = list()
198
+ if image_w != crop_w and image_h != crop_h:
199
+ offsets.append((0 * w_step, 0 * h_step)) # top
200
+ offsets.append((4 * w_step, 4 * h_step)) # bottom
201
+ offsets.append((2 * w_step, 2 * h_step)) # center
202
+ else:
203
+ if image_w < image_h:
204
+ offsets.append((2 * w_step, 0 * h_step)) # top
205
+ offsets.append((2 * w_step, 4 * h_step)) # bottom
206
+ offsets.append((2 * w_step, 2 * h_step)) # center
207
+ else:
208
+ offsets.append((0 * w_step, 2 * h_step)) # left
209
+ offsets.append((4 * w_step, 2 * h_step)) # right
210
+ offsets.append((2 * w_step, 2 * h_step)) # center
211
+
212
+ else:
213
+ offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
214
+
215
+ oversample_group = list()
216
+ for o_w, o_h in offsets:
217
+ normal_group = list()
218
+ flip_group = list()
219
+ for i, img in enumerate(img_group):
220
+ crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
221
+ normal_group.append(crop)
222
+ if self.flip:
223
+ flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
224
+
225
+ if img.mode == 'L' and i % 2 == 0:
226
+ flip_group.append(ImageOps.invert(flip_crop))
227
+ else:
228
+ flip_group.append(flip_crop)
229
+
230
+ oversample_group.extend(normal_group)
231
+ if self.flip:
232
+ oversample_group.extend(flip_group)
233
+ return oversample_group
234
+
235
+
236
+ class GroupMultiScaleCrop(object):
237
+
238
+ def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
239
+ self.scales = scales if scales is not None else [1, 875, .75, .66]
240
+ self.max_distort = max_distort
241
+ self.fix_crop = fix_crop
242
+ self.more_fix_crop = more_fix_crop
243
+ self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
244
+ self.interpolation = Image.BILINEAR
245
+
246
+ def __call__(self, img_group):
247
+
248
+ im_size = img_group[0].size
249
+
250
+ crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
251
+ crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
252
+ ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
253
+ for img in crop_img_group]
254
+ return ret_img_group
255
+
256
+ def _sample_crop_size(self, im_size):
257
+ image_w, image_h = im_size[0], im_size[1]
258
+
259
+ # find a crop size
260
+ base_size = min(image_w, image_h)
261
+ crop_sizes = [int(base_size * x) for x in self.scales]
262
+ crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
263
+ crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
264
+
265
+ pairs = []
266
+ for i, h in enumerate(crop_h):
267
+ for j, w in enumerate(crop_w):
268
+ if abs(i - j) <= self.max_distort:
269
+ pairs.append((w, h))
270
+
271
+ crop_pair = random.choice(pairs)
272
+ if not self.fix_crop:
273
+ w_offset = random.randint(0, image_w - crop_pair[0])
274
+ h_offset = random.randint(0, image_h - crop_pair[1])
275
+ else:
276
+ w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
277
+
278
+ return crop_pair[0], crop_pair[1], w_offset, h_offset
279
+
280
+ def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
281
+ offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
282
+ return random.choice(offsets)
283
+
284
+ @staticmethod
285
+ def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
286
+ w_step = (image_w - crop_w) // 4
287
+ h_step = (image_h - crop_h) // 4
288
+
289
+ ret = list()
290
+ ret.append((0, 0)) # upper left
291
+ ret.append((4 * w_step, 0)) # upper right
292
+ ret.append((0, 4 * h_step)) # lower left
293
+ ret.append((4 * w_step, 4 * h_step)) # lower right
294
+ ret.append((2 * w_step, 2 * h_step)) # center
295
+
296
+ if more_fix_crop:
297
+ ret.append((0, 2 * h_step)) # center left
298
+ ret.append((4 * w_step, 2 * h_step)) # center right
299
+ ret.append((2 * w_step, 4 * h_step)) # lower center
300
+ ret.append((2 * w_step, 0 * h_step)) # upper center
301
+
302
+ ret.append((1 * w_step, 1 * h_step)) # upper left quarter
303
+ ret.append((3 * w_step, 1 * h_step)) # upper right quarter
304
+ ret.append((1 * w_step, 3 * h_step)) # lower left quarter
305
+ ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
306
+
307
+ return ret
308
+
309
+
310
+ class GroupRandomSizedCrop(object):
311
+ """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
312
+ and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
313
+ This is popularly used to train the Inception networks
314
+ size: size of the smaller edge
315
+ interpolation: Default: PIL.Image.BILINEAR
316
+ """
317
+ def __init__(self, size, interpolation=Image.BILINEAR):
318
+ self.size = size
319
+ self.interpolation = interpolation
320
+
321
+ def __call__(self, img_group):
322
+ for attempt in range(10):
323
+ area = img_group[0].size[0] * img_group[0].size[1]
324
+ target_area = random.uniform(0.08, 1.0) * area
325
+ aspect_ratio = random.uniform(3. / 4, 4. / 3)
326
+
327
+ w = int(round(math.sqrt(target_area * aspect_ratio)))
328
+ h = int(round(math.sqrt(target_area / aspect_ratio)))
329
+
330
+ if random.random() < 0.5:
331
+ w, h = h, w
332
+
333
+ if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
334
+ x1 = random.randint(0, img_group[0].size[0] - w)
335
+ y1 = random.randint(0, img_group[0].size[1] - h)
336
+ found = True
337
+ break
338
+ else:
339
+ found = False
340
+ x1 = 0
341
+ y1 = 0
342
+
343
+ if found:
344
+ out_group = list()
345
+ for img in img_group:
346
+ img = img.crop((x1, y1, x1 + w, y1 + h))
347
+ assert(img.size == (w, h))
348
+ out_group.append(img.resize((self.size, self.size), self.interpolation))
349
+ return out_group
350
+ else:
351
+ # Fallback
352
+ scale = GroupScale(self.size, interpolation=self.interpolation)
353
+ crop = GroupRandomCrop(self.size)
354
+ return crop(scale(img_group))
355
+
356
+
357
+ class Stack(object):
358
+
359
+ def __init__(self, roll=False, threed_data=False):
360
+ self.roll = roll
361
+ self.threed_data = threed_data
362
+
363
+ def __call__(self, img_group):
364
+ if img_group[0].mode == 'L':
365
+ return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2)
366
+ elif img_group[0].mode == 'RGB':
367
+ if self.threed_data:
368
+ return np.stack(img_group, axis=0)
369
+ else:
370
+ if self.roll:
371
+ return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
372
+ else:
373
+ return np.concatenate(img_group, axis=2)
374
+
375
+
376
+ class ToTorchFormatTensor(object):
377
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
378
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
379
+ def __init__(self, div=True, num_clips_crops=1):
380
+ self.div = div
381
+ self.num_clips_crops = num_clips_crops
382
+
383
+ def __call__(self, pic):
384
+ if isinstance(pic, np.ndarray):
385
+ # handle numpy array
386
+ if len(pic.shape) == 4:
387
+ # ((NF)xCxHxW) --> (Cx(NF)xHxW)
388
+ img = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous()
389
+ else: # data is HW(FC)
390
+ img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
391
+ else:
392
+ # handle PIL Image
393
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
394
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
395
+ # put it from HWC to CHW format
396
+ # yikes, this transpose takes 80% of the loading time/CPU
397
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
398
+ return img.float().div(255) if self.div else img.float()
399
+
400
+
401
+ class IdentityTransform(object):
402
+
403
+ def __call__(self, data):
404
+ return data
405
+
406
+
407
+ if __name__ == "__main__":
408
+ trans = torchvision.transforms.Compose([
409
+ GroupScale(256),
410
+ GroupRandomCrop(224),
411
+ GroupOverSample(224, 224, num_crops=3, flip=False),
412
+ Stack(),
413
+ ToTorchFormatTensor(num_clips_crops=9),
414
+ GroupNormalize(
415
+ mean=[.485, .456, .406],
416
+ std=[.229, .224, .225]
417
+ )]
418
+ )
419
+
420
+