Spaces:
Running
Running
Upload 22 files
Browse files- .gitignore +7 -0
- LICENSE +21 -0
- README.md +265 -12
- check_data.py +89 -0
- engine.py +171 -0
- find_tall_model.py +102 -0
- infer_videos_txt.py +433 -0
- keep_only_numbered.py +50 -0
- main.py +462 -0
- make_tall_txt.py +146 -0
- make_tall_txt_count.py +65 -0
- models.py +245 -0
- renumber_frames_for_tall.py +53 -0
- requirements-torch.txt +4 -0
- requirements.txt +0 -0
- test.py +333 -0
- test_new.py +319 -0
- utils.py +265 -0
- video_dataset.py +1228 -0
- video_dataset_aug.py +73 -0
- video_dataset_config.py +31 -0
- video_transforms.py +420 -0
.gitignore
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
lists/
|
| 2 |
+
output/
|
| 3 |
+
venv/
|
| 4 |
+
*.csv
|
| 5 |
+
*.json
|
| 6 |
+
__pycache__/
|
| 7 |
+
data/
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) [year] [fullname]
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,265 @@
|
|
| 1 |
-
--
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎯 TALL-SWIN Deepfake Detection – Video-Level Pipeline
|
| 2 |
+
|
| 3 |
+
Transformer-based deepfake detection system with custom dataset preparation, evaluation metrics, and batch inference utilities.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
# 📖 1. Project Overview
|
| 8 |
+
|
| 9 |
+
This project implements a **video-level deepfake detection pipeline** based on the **TALL-SWIN Vision Transformer architecture**.
|
| 10 |
+
|
| 11 |
+
It extends the original TALL/DeiT repository with:
|
| 12 |
+
|
| 13 |
+
- Custom dataset preparation scripts
|
| 14 |
+
- Frame cleaning and renaming utilities
|
| 15 |
+
- Automatic train/test split generation
|
| 16 |
+
- Advanced evaluation metrics (ROC, PR, confusion matrix)
|
| 17 |
+
- Threshold-controlled classification
|
| 18 |
+
- Batch inference from video lists
|
| 19 |
+
- Reproducible environment setup
|
| 20 |
+
|
| 21 |
+
Designed primarily for FaceForensics++ (FFPP)-style datasets.
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
# 🏗 2. Pipeline Architecture
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
Raw Videos
|
| 29 |
+
↓
|
| 30 |
+
Frame Extraction
|
| 31 |
+
↓
|
| 32 |
+
Frame Cleaning / Renaming
|
| 33 |
+
↓
|
| 34 |
+
Train/Test Split Generation (.txt)
|
| 35 |
+
↓
|
| 36 |
+
Training (TALL-SWIN)
|
| 37 |
+
↓
|
| 38 |
+
Video-Level Aggregation
|
| 39 |
+
↓
|
| 40 |
+
Evaluation (ROC / PR / Metrics)
|
| 41 |
+
↓
|
| 42 |
+
Batch Inference
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
# 📂 3. Project Structure
|
| 48 |
+
|
| 49 |
+
```
|
| 50 |
+
.
|
| 51 |
+
├── main.py # Training script (base repository)
|
| 52 |
+
├── engine.py # Training loop
|
| 53 |
+
├── models.py # DeiT models
|
| 54 |
+
├── utils.py # Distributed + checkpoint utilities
|
| 55 |
+
├── video_dataset.py # Dataset loader
|
| 56 |
+
├── video_dataset_aug.py # Augmentations
|
| 57 |
+
├── video_dataset_config.py # Dataset config
|
| 58 |
+
├── video_transforms.py # Group transforms
|
| 59 |
+
│
|
| 60 |
+
├── eval_direct.py # Custom evaluation with plots
|
| 61 |
+
├── test_new.py # Threshold-controlled evaluation
|
| 62 |
+
├── infer_videos_txt.py # Batch video inference
|
| 63 |
+
├── make_tall_txt.py # Dataset split generator
|
| 64 |
+
├── make_tall_txt_count.py # Alternative split generator
|
| 65 |
+
├── renumber_frames_for_tall.py # Frame renaming tool
|
| 66 |
+
├── keep_only_numbered.py # Frame cleaning tool
|
| 67 |
+
├── find_tall_model.py # Debug utility
|
| 68 |
+
│
|
| 69 |
+
├── requirements.txt
|
| 70 |
+
├── requirements-torch.txt
|
| 71 |
+
└── README.md
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
---
|
| 75 |
+
|
| 76 |
+
# ⚙ 4. Installation
|
| 77 |
+
|
| 78 |
+
## 4.1 Clone Repository
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
git clone <your_repository_url>
|
| 82 |
+
cd TALL4Deepfake
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## 4.2 Create Virtual Environment
|
| 86 |
+
|
| 87 |
+
**Windows (PowerShell)**
|
| 88 |
+
|
| 89 |
+
```powershell
|
| 90 |
+
python -m venv venv
|
| 91 |
+
venv\Scripts\Activate.ps1
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
**Linux / macOS**
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
python -m venv venv
|
| 98 |
+
source venv/bin/activate
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
## 4.3 Install PyTorch (CUDA Build)
|
| 102 |
+
|
| 103 |
+
PyTorch CUDA wheels are not hosted on default PyPI.
|
| 104 |
+
|
| 105 |
+
```bash
|
| 106 |
+
pip install -r requirements-torch.txt
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
## 4.4 Install Remaining Dependencies
|
| 110 |
+
|
| 111 |
+
```bash
|
| 112 |
+
pip install -r requirements.txt
|
| 113 |
+
```
|
| 114 |
+
|
| 115 |
+
## 4.5 Sanity Check
|
| 116 |
+
|
| 117 |
+
```bash
|
| 118 |
+
python -c "import torch, cv2, numpy as np; print(torch.__version__); print(cv2.__version__); print(np.__version__)"
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
# 📁 5. Dataset Preparation
|
| 124 |
+
|
| 125 |
+
Expected structure:
|
| 126 |
+
|
| 127 |
+
```
|
| 128 |
+
data/
|
| 129 |
+
├── real/
|
| 130 |
+
│ ├── video_001/
|
| 131 |
+
│ │ ├── 0001.jpg
|
| 132 |
+
│ │ ├── 0002.jpg
|
| 133 |
+
│ │ └── ...
|
| 134 |
+
└── fake/
|
| 135 |
+
├── video_002/
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
## 5.1 Renumber Frames
|
| 139 |
+
|
| 140 |
+
Ensures frame names follow `0001.jpg` format.
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
python renumber_frames_for_tall.py --root data --digits 4 --copy
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
## 5.2 Remove Non-numbered Frames
|
| 147 |
+
|
| 148 |
+
```bash
|
| 149 |
+
python keep_only_numbered.py --root data
|
| 150 |
+
```
|
| 151 |
+
|
| 152 |
+
Dry run (no deletion):
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
python keep_only_numbered.py --root data --dry_run
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## 5.3 Generate Train/Test Split
|
| 159 |
+
|
| 160 |
+
**Full Split Generator**
|
| 161 |
+
|
| 162 |
+
```bash
|
| 163 |
+
python make_tall_txt.py --root data --out lists --train_ratio 0.8
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
Outputs:
|
| 167 |
+
|
| 168 |
+
- `lists/cdf_train_fold.txt`
|
| 169 |
+
- `lists/cdf_test_fold.txt`
|
| 170 |
+
|
| 171 |
+
**Alternative Split (Count-Based)**
|
| 172 |
+
|
| 173 |
+
```bash
|
| 174 |
+
python make_tall_txt_count.py --root data --out lists
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
---
|
| 178 |
+
|
| 179 |
+
# 🚀 6. Training
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
python main.py --dataset ffpp --data_dir [data_dir] --data_txt_dir [data_txt_dir] --input-size 112 --num_clips 8 --output_dir [outout_dir] --opt adamw --lr 1.5e-5 --warmup-lr 1.5e-8 --min-lr 1.5e-7 --epochs 10 --sched cosine --duration 4 --batch-size 2 --thumbnail_rows 2 --disable_scaleup --cutout True --pretrained --warmup-epochs 1 --no-amp --model TALL_SWIN --hpe_to_token --num_workers 0
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
---
|
| 186 |
+
|
| 187 |
+
# 📊 7. Evaluation
|
| 188 |
+
|
| 189 |
+
## 7.1 Custom Evaluation with Plots
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
python test_new.py --dataset ffpp --data_dir [your_data_dir] --data_txt_dir [your_data_dir_txt] --num_clips 8 --duration 4 --thumbnail_rows 2 --batch-size 1 --num_workers 0 --initial_checkpoint [your_.pth_dir] --output_dir [your_out_dir] --save_plots
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
---
|
| 196 |
+
|
| 197 |
+
# 🎥 8. Batch Video Inference
|
| 198 |
+
|
| 199 |
+
Create a file `videos.txt`:
|
| 200 |
+
|
| 201 |
+
```
|
| 202 |
+
C:/.../video1.mp4
|
| 203 |
+
C:/.../video2.mp4
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
Run:
|
| 207 |
+
|
| 208 |
+
```bash
|
| 209 |
+
python infer_videos_txt.py --video_list videos.txt --initial_checkpoint [your_.pth_dir] --dataset ffpp --duration 4 --num_clips 8
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
Outputs:
|
| 213 |
+
|
| 214 |
+
- `results.json`
|
| 215 |
+
- `results.csv`
|
| 216 |
+
|
| 217 |
+
---
|
| 218 |
+
|
| 219 |
+
# 📈 9. Implemented Metrics
|
| 220 |
+
|
| 221 |
+
- Accuracy
|
| 222 |
+
- Balanced Accuracy
|
| 223 |
+
- Precision
|
| 224 |
+
- Recall
|
| 225 |
+
- F1-score
|
| 226 |
+
- ROC-AUC
|
| 227 |
+
- PR-AUC
|
| 228 |
+
- Confusion Matrix
|
| 229 |
+
- Classification Report
|
| 230 |
+
|
| 231 |
+
---
|
| 232 |
+
|
| 233 |
+
# 🧠 10. Video-Level Aggregation Strategy
|
| 234 |
+
|
| 235 |
+
When multiple clips are extracted per video:
|
| 236 |
+
|
| 237 |
+
```
|
| 238 |
+
logits → softmax → mean over clips → threshold decision
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
If logits shape is `[B*K, 2]`, they are reshaped into:
|
| 242 |
+
|
| 243 |
+
```
|
| 244 |
+
[B, K, 2] → mean(dim=1)
|
| 245 |
+
```
|
| 246 |
+
|
| 247 |
+
Ensures video-level classification rather than frame-level.
|
| 248 |
+
|
| 249 |
+
---
|
| 250 |
+
|
| 251 |
+
# 🖥 11. Recommended Environment
|
| 252 |
+
|
| 253 |
+
- Python 3.10
|
| 254 |
+
- CUDA-compatible GPU
|
| 255 |
+
- PyTorch 1.13.1 (cu117 build)
|
| 256 |
+
- NumPy 2.x (if using OpenCV ≥4.13)
|
| 257 |
+
|
| 258 |
+
---
|
| 259 |
+
|
| 260 |
+
# 👨💻 Author
|
| 261 |
+
|
| 262 |
+
**Vinícius Passos Castilho Pinto**
|
| 263 |
+
Double-degree Engineering Student
|
| 264 |
+
Industrial & Automation Systems
|
| 265 |
+
France / Brazil
|
check_data.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Diagnose data directory structure and list file format
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
def check_data_structure():
|
| 7 |
+
data_dir = r"C:\Users\vinip\Desktop\TALL4Deepfake\data"
|
| 8 |
+
list_file = r"C:\Users\vinip\Desktop\TALL4Deepfake\lists\cdf_test_fold.txt"
|
| 9 |
+
|
| 10 |
+
print("=" * 60)
|
| 11 |
+
print("Data Structure Diagnostic")
|
| 12 |
+
print("=" * 60)
|
| 13 |
+
|
| 14 |
+
# Check data directory
|
| 15 |
+
print(f"\n1. Checking data directory: {data_dir}")
|
| 16 |
+
if os.path.exists(data_dir):
|
| 17 |
+
print(" ✓ Directory exists")
|
| 18 |
+
|
| 19 |
+
# List subdirectories
|
| 20 |
+
subdirs = [d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))]
|
| 21 |
+
print(f"\n Subdirectories found: {subdirs}")
|
| 22 |
+
|
| 23 |
+
# Check a few subdirs
|
| 24 |
+
for subdir in subdirs[:3]:
|
| 25 |
+
subdir_path = os.path.join(data_dir, subdir)
|
| 26 |
+
video_folders = os.listdir(subdir_path)[:3]
|
| 27 |
+
print(f"\n {subdir}/ contains: {video_folders}")
|
| 28 |
+
|
| 29 |
+
# Check first video folder
|
| 30 |
+
if video_folders:
|
| 31 |
+
first_video = os.path.join(subdir_path, video_folders[0])
|
| 32 |
+
if os.path.isdir(first_video):
|
| 33 |
+
frames = os.listdir(first_video)[:5]
|
| 34 |
+
print(f" {video_folders[0]}/ contains: {frames}")
|
| 35 |
+
else:
|
| 36 |
+
print(" ✗ Directory not found!")
|
| 37 |
+
return
|
| 38 |
+
|
| 39 |
+
# Check list file
|
| 40 |
+
print(f"\n2. Checking list file: {list_file}")
|
| 41 |
+
if os.path.exists(list_file):
|
| 42 |
+
print(" ✓ File exists")
|
| 43 |
+
|
| 44 |
+
with open(list_file, 'r') as f:
|
| 45 |
+
lines = f.readlines()
|
| 46 |
+
|
| 47 |
+
print(f" Total lines: {len(lines)}")
|
| 48 |
+
print("\n First 5 lines:")
|
| 49 |
+
for i, line in enumerate(lines[:5], 1):
|
| 50 |
+
print(f" {i}: {repr(line.strip())}")
|
| 51 |
+
|
| 52 |
+
# Analyze path format
|
| 53 |
+
print("\n3. Path format analysis:")
|
| 54 |
+
first_line = lines[0].strip()
|
| 55 |
+
parts = first_line.split()
|
| 56 |
+
if parts:
|
| 57 |
+
video_path = parts[0]
|
| 58 |
+
print(f" Video path: {repr(video_path)}")
|
| 59 |
+
print(f" Uses forward slashes: {'/' in video_path}")
|
| 60 |
+
print(f" Uses backslashes: {chr(92) in video_path}")
|
| 61 |
+
|
| 62 |
+
# Try to construct full path
|
| 63 |
+
full_path = os.path.join(data_dir, video_path)
|
| 64 |
+
print(f"\n Constructed path: {full_path}")
|
| 65 |
+
print(f" Path exists: {os.path.exists(full_path)}")
|
| 66 |
+
|
| 67 |
+
# Try with normalization
|
| 68 |
+
normalized = os.path.normpath(os.path.join(data_dir, video_path))
|
| 69 |
+
print(f"\n Normalized path: {normalized}")
|
| 70 |
+
print(f" Path exists: {os.path.exists(normalized)}")
|
| 71 |
+
|
| 72 |
+
# Check if video folder exists
|
| 73 |
+
if not os.path.exists(normalized):
|
| 74 |
+
# Try to find similar paths
|
| 75 |
+
print("\n ⚠ Path doesn't exist. Looking for similar paths...")
|
| 76 |
+
video_name = os.path.basename(video_path)
|
| 77 |
+
category = os.path.dirname(video_path).replace('/', os.sep).replace('\\', os.sep)
|
| 78 |
+
|
| 79 |
+
category_path = os.path.join(data_dir, category)
|
| 80 |
+
if os.path.exists(category_path):
|
| 81 |
+
contents = os.listdir(category_path)
|
| 82 |
+
print(f"\n Directory '{category}' contains: {contents[:10]}")
|
| 83 |
+
else:
|
| 84 |
+
print(" ✗ File not found!")
|
| 85 |
+
|
| 86 |
+
print("\n" + "=" * 60)
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
check_data_structure()
|
engine.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
"""
|
| 8 |
+
Train and eval functions used in main.py
|
| 9 |
+
"""
|
| 10 |
+
from typing import Iterable, Optional
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import torch
|
| 13 |
+
import numpy
|
| 14 |
+
from timm.data import Mixup
|
| 15 |
+
from timm.utils import accuracy, ModelEma
|
| 16 |
+
import utils
|
| 17 |
+
from sklearn.metrics import roc_auc_score
|
| 18 |
+
|
| 19 |
+
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
|
| 20 |
+
data_loader: Iterable, num_cilps:int, optimizer: torch.optim.Optimizer,
|
| 21 |
+
device: torch.device, epoch: int, loss_scaler, max_norm: float = 0,
|
| 22 |
+
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
|
| 23 |
+
world_size: int = 1, distributed: bool = True, amp=True,
|
| 24 |
+
contrastive_nomixup=False, hard_contrastive=False,
|
| 25 |
+
finetune=False
|
| 26 |
+
):
|
| 27 |
+
# TODO fix this for finetuning
|
| 28 |
+
if finetune:
|
| 29 |
+
model.train(not finetune)
|
| 30 |
+
else:
|
| 31 |
+
model.train()
|
| 32 |
+
#criterion.train()
|
| 33 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 34 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.8f}'))
|
| 35 |
+
header = 'Epoch: [{}]'.format(epoch)
|
| 36 |
+
print_freq = 50
|
| 37 |
+
|
| 38 |
+
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
|
| 39 |
+
|
| 40 |
+
batch_size = targets.size(0)
|
| 41 |
+
|
| 42 |
+
samples = samples.to(device, non_blocking=True)
|
| 43 |
+
targets = targets.to(device, non_blocking=True)
|
| 44 |
+
|
| 45 |
+
if mixup_fn is not None:
|
| 46 |
+
# batch size has to be an even number
|
| 47 |
+
if batch_size == 1:
|
| 48 |
+
continue
|
| 49 |
+
if batch_size % 2 != 0:
|
| 50 |
+
samples, targets = samples[:-1], targets[:-1]
|
| 51 |
+
samples, targets = mixup_fn(samples, targets)
|
| 52 |
+
|
| 53 |
+
with torch.cuda.amp.autocast(enabled=amp):
|
| 54 |
+
|
| 55 |
+
outputs = model(samples)
|
| 56 |
+
outputs = outputs.reshape(batch_size, num_cilps, -1).mean(dim=1)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
loss = criterion(outputs, targets)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
loss_value = loss.item()
|
| 63 |
+
|
| 64 |
+
optimizer.zero_grad()
|
| 65 |
+
|
| 66 |
+
# this attribute is added by timm on one optimizer (adahessian)
|
| 67 |
+
is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
|
| 68 |
+
|
| 69 |
+
if amp:
|
| 70 |
+
loss_scaler(loss, optimizer, clip_grad=max_norm,
|
| 71 |
+
parameters=model.parameters(), create_graph=is_second_order)
|
| 72 |
+
else:
|
| 73 |
+
loss.backward(create_graph=is_second_order)
|
| 74 |
+
if max_norm is not None and max_norm != 0.0:
|
| 75 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
|
| 76 |
+
optimizer.step()
|
| 77 |
+
|
| 78 |
+
torch.cuda.synchronize()
|
| 79 |
+
if model_ema is not None:
|
| 80 |
+
model_ema.update(model)
|
| 81 |
+
|
| 82 |
+
metric_logger.update(loss=loss_value)
|
| 83 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
| 84 |
+
# gather the stats from all processes
|
| 85 |
+
metric_logger.synchronize_between_processes()
|
| 86 |
+
print("Averaged stats:", metric_logger)
|
| 87 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@torch.no_grad()
|
| 91 |
+
def evaluate(data_loader, model, device, world_size, distributed=True, amp=False, num_crops=1, num_clips=1):
|
| 92 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 93 |
+
to_np = lambda x: x.data.cpu().numpy()
|
| 94 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
| 95 |
+
header = 'Test:'
|
| 96 |
+
|
| 97 |
+
# switch to evaluation mode
|
| 98 |
+
model.eval()
|
| 99 |
+
|
| 100 |
+
outputs = []
|
| 101 |
+
targets = []
|
| 102 |
+
logits = []
|
| 103 |
+
binary_label = []
|
| 104 |
+
for images, target in metric_logger.log_every(data_loader, 10, header):
|
| 105 |
+
|
| 106 |
+
images = images.to(device, non_blocking=True)
|
| 107 |
+
target = target.to(device, non_blocking=True)
|
| 108 |
+
# compute output
|
| 109 |
+
batch_size = images.shape[0]
|
| 110 |
+
|
| 111 |
+
with torch.cuda.amp.autocast(enabled=amp):
|
| 112 |
+
|
| 113 |
+
output = model(images)
|
| 114 |
+
|
| 115 |
+
output = output.reshape(batch_size, num_crops * num_clips, -1).mean(dim=1)
|
| 116 |
+
output_np = to_np(output[:,1])
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
if distributed:
|
| 120 |
+
outputs.append(concat_all_gather(output))
|
| 121 |
+
targets.append(concat_all_gather(target))
|
| 122 |
+
output_ = concat_all_gather(output)
|
| 123 |
+
target_ = concat_all_gather(target)
|
| 124 |
+
output_np_ = to_np(output_[:,1])
|
| 125 |
+
logits.append(output_np_)
|
| 126 |
+
binary_label.append(target_.detach().cpu())
|
| 127 |
+
else:
|
| 128 |
+
outputs.append(output)
|
| 129 |
+
targets.append(target)
|
| 130 |
+
logits.append(output_np)
|
| 131 |
+
binary_label.append(target.detach().cpu())
|
| 132 |
+
batch_size = images.shape[0]
|
| 133 |
+
|
| 134 |
+
acc1 = accuracy(output, target, topk=(1,))[0]
|
| 135 |
+
metric_logger.meters['acc1'].update(acc1.item(), images.size(0))
|
| 136 |
+
|
| 137 |
+
# import pdb;pdb.set_trace()
|
| 138 |
+
|
| 139 |
+
acc_outputs = numpy.stack(logits,0).reshape(-1,1)
|
| 140 |
+
acc_label = numpy.stack(binary_label,0).reshape(-1,1)
|
| 141 |
+
|
| 142 |
+
outputs = torch.cat(outputs, dim=0)
|
| 143 |
+
targets = torch.cat(targets, dim=0)
|
| 144 |
+
|
| 145 |
+
auc_score = roc_auc_score(acc_label, acc_outputs)
|
| 146 |
+
|
| 147 |
+
real_loss = criterion(outputs, targets)
|
| 148 |
+
metric_logger.update(loss=real_loss.item())
|
| 149 |
+
|
| 150 |
+
print('* Acc@1 {top1.global_avg:.3f} AUC {auc} loss {losses.global_avg:.3f}'
|
| 151 |
+
.format(top1=metric_logger.acc1,auc=auc_score,losses=metric_logger.loss))
|
| 152 |
+
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@torch.no_grad()
|
| 156 |
+
def concat_all_gather(tensor):
|
| 157 |
+
"""
|
| 158 |
+
Performs all_gather operation on the provided tensors.
|
| 159 |
+
*** Warning ***: torch.distributed.all_gather has no gradient.
|
| 160 |
+
"""
|
| 161 |
+
tensors_gather = [torch.ones_like(tensor)
|
| 162 |
+
for _ in range(torch.distributed.get_world_size())]
|
| 163 |
+
torch.distributed.all_gather(tensors_gather, tensor.contiguous(), async_op=False)
|
| 164 |
+
|
| 165 |
+
#output = torch.cat(tensors_gather, dim=0)
|
| 166 |
+
if tensor.dim() == 1:
|
| 167 |
+
output = rearrange(tensors_gather, 'n b -> (b n)')
|
| 168 |
+
else:
|
| 169 |
+
output = rearrange(tensors_gather, 'n b c -> (b n) c')
|
| 170 |
+
|
| 171 |
+
return output
|
find_tall_model.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Find the TALL_SWIN model definition in the codebase
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
def search_for_tall_model(root_dir):
|
| 8 |
+
"""Search for TALL_SWIN class definition"""
|
| 9 |
+
print("=" * 60)
|
| 10 |
+
print("Searching for TALL_SWIN model definition...")
|
| 11 |
+
print("=" * 60)
|
| 12 |
+
|
| 13 |
+
patterns = [
|
| 14 |
+
r'class\s+TALL_SWIN',
|
| 15 |
+
r'class\s+TallSwin',
|
| 16 |
+
r'class\s+TALLSwin',
|
| 17 |
+
r'def\s+TALL_SWIN',
|
| 18 |
+
r'@register_model.*TALL',
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
found_files = []
|
| 22 |
+
|
| 23 |
+
# Search in models directory and subdirectories
|
| 24 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 25 |
+
# Skip venv and common ignore dirs
|
| 26 |
+
dirnames[:] = [d for d in dirnames if d not in ['venv', '.git', '__pycache__', 'node_modules']]
|
| 27 |
+
|
| 28 |
+
for filename in filenames:
|
| 29 |
+
if filename.endswith('.py'):
|
| 30 |
+
filepath = os.path.join(dirpath, filename)
|
| 31 |
+
try:
|
| 32 |
+
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
| 33 |
+
content = f.read()
|
| 34 |
+
|
| 35 |
+
for pattern in patterns:
|
| 36 |
+
matches = re.findall(pattern, content, re.IGNORECASE)
|
| 37 |
+
if matches:
|
| 38 |
+
found_files.append(filepath)
|
| 39 |
+
print(f"\n✓ Found in: {filepath}")
|
| 40 |
+
print(f" Pattern matched: {pattern}")
|
| 41 |
+
|
| 42 |
+
# Show context
|
| 43 |
+
lines = content.split('\n')
|
| 44 |
+
for i, line in enumerate(lines):
|
| 45 |
+
if re.search(pattern, line, re.IGNORECASE):
|
| 46 |
+
start = max(0, i-2)
|
| 47 |
+
end = min(len(lines), i+10)
|
| 48 |
+
print(f"\n Context (lines {start+1}-{end+1}):")
|
| 49 |
+
print(" " + "-"*50)
|
| 50 |
+
for j in range(start, end):
|
| 51 |
+
marker = ">>>" if j == i else " "
|
| 52 |
+
print(f" {marker} {j+1:4d}: {lines[j]}")
|
| 53 |
+
print(" " + "-"*50)
|
| 54 |
+
break
|
| 55 |
+
|
| 56 |
+
except Exception as e:
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
if not found_files:
|
| 60 |
+
print("\n✗ No TALL_SWIN model definition found!")
|
| 61 |
+
print("\nSearching for any Swin-related files...")
|
| 62 |
+
|
| 63 |
+
# Broader search
|
| 64 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 65 |
+
dirnames[:] = [d for d in dirnames if d not in ['venv', '.git', '__pycache__']]
|
| 66 |
+
|
| 67 |
+
for filename in filenames:
|
| 68 |
+
if 'swin' in filename.lower() or 'tall' in filename.lower():
|
| 69 |
+
print(f" Found file: {os.path.join(dirpath, filename)}")
|
| 70 |
+
|
| 71 |
+
return found_files
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
import sys
|
| 75 |
+
|
| 76 |
+
# Get the repo directory
|
| 77 |
+
if len(sys.argv) > 1:
|
| 78 |
+
repo_dir = sys.argv[1]
|
| 79 |
+
else:
|
| 80 |
+
repo_dir = os.getcwd()
|
| 81 |
+
|
| 82 |
+
print(f"Searching in: {repo_dir}\n")
|
| 83 |
+
|
| 84 |
+
found = search_for_tall_model(repo_dir)
|
| 85 |
+
|
| 86 |
+
print("\n" + "=" * 60)
|
| 87 |
+
if found:
|
| 88 |
+
print(f"Found {len(found)} file(s) with TALL_SWIN definition")
|
| 89 |
+
print("\nNext steps:")
|
| 90 |
+
print("1. Check the file(s) above for the TALL_SWIN class")
|
| 91 |
+
print("2. Ensure this file is imported in models/__init__.py")
|
| 92 |
+
print("3. Example fix for models/__init__.py:")
|
| 93 |
+
print("\n from .tall_swin import TALL_SWIN")
|
| 94 |
+
print(" # or")
|
| 95 |
+
print(" from .tall_swin import *")
|
| 96 |
+
else:
|
| 97 |
+
print("No TALL_SWIN definition found!")
|
| 98 |
+
print("\nPossible issues:")
|
| 99 |
+
print("1. The model file is in a different location")
|
| 100 |
+
print("2. The model has a different class name")
|
| 101 |
+
print("3. The model code is missing from the repository")
|
| 102 |
+
print("=" * 60)
|
infer_videos_txt.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# infer_videos_txt.py
|
| 2 |
+
# -------------------------------------------
|
| 3 |
+
# Inference on a list of videos (one path per line in a .txt).
|
| 4 |
+
# For each video, frames are extracted to a temporary folder and then fed to the model
|
| 5 |
+
# using the same VideoDataSet pipeline as training/evaluation.
|
| 6 |
+
#
|
| 7 |
+
# PowerShell usage:
|
| 8 |
+
# python infer_videos_txt.py `
|
| 9 |
+
# --video_list "C:\path\videos.txt" `
|
| 10 |
+
# --initial_checkpoint "C:\path\model_best.pth" `
|
| 11 |
+
# --dataset ffpp `
|
| 12 |
+
# --threshold 0.7 `
|
| 13 |
+
# --num_clips 8 --duration 4 --thumbnail_rows 2 `
|
| 14 |
+
# --num_workers 0 `
|
| 15 |
+
# --output_json "C:\path\results.json" `
|
| 16 |
+
# --output_csv "C:\path\results.csv"
|
| 17 |
+
# -------------------------------------------
|
| 18 |
+
|
| 19 |
+
import os
|
| 20 |
+
import csv
|
| 21 |
+
import json
|
| 22 |
+
import argparse
|
| 23 |
+
import shutil
|
| 24 |
+
import tempfile
|
| 25 |
+
from typing import List, Dict, Any
|
| 26 |
+
|
| 27 |
+
import cv2
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
import torch.backends.cudnn as cudnn
|
| 31 |
+
from timm.models import create_model
|
| 32 |
+
|
| 33 |
+
import my_models # registers TALL_SWIN
|
| 34 |
+
import utils
|
| 35 |
+
|
| 36 |
+
from video_dataset import VideoDataSet
|
| 37 |
+
from video_dataset_aug import get_augmentor, build_dataflow
|
| 38 |
+
from video_dataset_config import get_dataset_config
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# -----------------------------
|
| 42 |
+
# Frame extraction (OpenCV)
|
| 43 |
+
# -----------------------------
|
| 44 |
+
def extract_frames_opencv(video_path: str, out_dir: str, digits: int = 5, max_frames: int = 0) -> int:
|
| 45 |
+
"""Extract all frames from a video to out_dir as {00001}.jpg, {00002}.jpg, ...
|
| 46 |
+
Returns the number of frames extracted.
|
| 47 |
+
"""
|
| 48 |
+
cap = cv2.VideoCapture(video_path)
|
| 49 |
+
if not cap.isOpened():
|
| 50 |
+
raise RuntimeError(f"Could not open video: {video_path}")
|
| 51 |
+
|
| 52 |
+
idx = 1
|
| 53 |
+
fmt = f"{{:0{digits}d}}.jpg"
|
| 54 |
+
while True:
|
| 55 |
+
ok, frame = cap.read()
|
| 56 |
+
if not ok:
|
| 57 |
+
break
|
| 58 |
+
out_file = os.path.join(out_dir, fmt.format(idx))
|
| 59 |
+
cv2.imwrite(out_file, frame)
|
| 60 |
+
idx += 1
|
| 61 |
+
if max_frames > 0 and (idx - 1) >= max_frames:
|
| 62 |
+
break
|
| 63 |
+
|
| 64 |
+
cap.release()
|
| 65 |
+
return idx - 1
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def detect_digits_for_tmpl(image_tmpl: str) -> int:
|
| 69 |
+
"""Infer digits from an image template like '{:05d}.jpg' -> 5."""
|
| 70 |
+
import re
|
| 71 |
+
m = re.search(r"\{:\s*0?(\d+)d\}", image_tmpl)
|
| 72 |
+
if m:
|
| 73 |
+
return int(m.group(1))
|
| 74 |
+
return 5
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# -----------------------------
|
| 78 |
+
# IO helpers
|
| 79 |
+
# -----------------------------
|
| 80 |
+
def read_video_list_txt(path: str) -> List[str]:
|
| 81 |
+
"""Read a text file with one video path per line; returns unique absolute paths."""
|
| 82 |
+
videos: List[str] = []
|
| 83 |
+
with open(path, "r", encoding="utf-8") as f:
|
| 84 |
+
for line in f:
|
| 85 |
+
p = line.strip().strip('"').strip("'")
|
| 86 |
+
if not p:
|
| 87 |
+
continue
|
| 88 |
+
videos.append(os.path.abspath(p))
|
| 89 |
+
|
| 90 |
+
# De-duplicate while keeping order
|
| 91 |
+
seen = set()
|
| 92 |
+
uniq: List[str] = []
|
| 93 |
+
for p in videos:
|
| 94 |
+
if p not in seen:
|
| 95 |
+
uniq.append(p)
|
| 96 |
+
seen.add(p)
|
| 97 |
+
return uniq
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def write_csv(path: str, rows: List[Dict[str, Any]]) -> None:
|
| 101 |
+
"""Write per-video results to CSV."""
|
| 102 |
+
if os.path.dirname(path):
|
| 103 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 104 |
+
|
| 105 |
+
fieldnames = [
|
| 106 |
+
"video",
|
| 107 |
+
"pred_name",
|
| 108 |
+
"pred",
|
| 109 |
+
"threshold",
|
| 110 |
+
"p_real",
|
| 111 |
+
"p_fake",
|
| 112 |
+
"n_frames",
|
| 113 |
+
"n_preds",
|
| 114 |
+
"status",
|
| 115 |
+
"error",
|
| 116 |
+
]
|
| 117 |
+
with open(path, "w", newline="", encoding="utf-8") as f:
|
| 118 |
+
w = csv.DictWriter(f, fieldnames=fieldnames)
|
| 119 |
+
w.writeheader()
|
| 120 |
+
for r in rows:
|
| 121 |
+
w.writerow({k: r.get(k, "") for k in fieldnames})
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def write_json(path: str, payload: Dict[str, Any]) -> None:
|
| 125 |
+
"""Write results summary to JSON."""
|
| 126 |
+
if os.path.dirname(path):
|
| 127 |
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
| 128 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 129 |
+
json.dump(payload, f, indent=2, ensure_ascii=False)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# -----------------------------
|
| 133 |
+
# Temporary dataset builder
|
| 134 |
+
# -----------------------------
|
| 135 |
+
def build_tmp_dataset_from_video(video_path: str, tmp_dir: str, image_tmpl: str) -> Dict[str, Any]:
|
| 136 |
+
"""Create a temporary folder with extracted frames and a one-line list file compatible with VideoDataSet."""
|
| 137 |
+
video_id = "video"
|
| 138 |
+
video_folder = os.path.join(tmp_dir, video_id)
|
| 139 |
+
os.makedirs(video_folder, exist_ok=True)
|
| 140 |
+
|
| 141 |
+
digits = detect_digits_for_tmpl(image_tmpl)
|
| 142 |
+
n = extract_frames_opencv(video_path, video_folder, digits=digits, max_frames=0)
|
| 143 |
+
|
| 144 |
+
if n < 4:
|
| 145 |
+
raise RuntimeError(f"Video has too few frames ({n}). Need >= 4.")
|
| 146 |
+
|
| 147 |
+
# VideoDataSet expects a list file with: <video_id> <start_idx> <num_frames> <label>
|
| 148 |
+
list_file_abs = os.path.join(tmp_dir, "one.txt")
|
| 149 |
+
with open(list_file_abs, "w", encoding="utf-8") as f:
|
| 150 |
+
f.write(f"{video_id} 1 {n} 0\n") # dummy label (0), unused during inference
|
| 151 |
+
|
| 152 |
+
return {"list_rel": "one.txt", "nframes": n, "video_id": video_id}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# -----------------------------
|
| 156 |
+
# Model + augmentor builder
|
| 157 |
+
# -----------------------------
|
| 158 |
+
def build_model_and_augmentor(args):
|
| 159 |
+
"""Build model, load checkpoint, and create the same evaluation augmentor."""
|
| 160 |
+
utils.init_distributed_mode(args) # will set args.distributed, etc.
|
| 161 |
+
|
| 162 |
+
device = torch.device(args.device)
|
| 163 |
+
cudnn.benchmark = True
|
| 164 |
+
|
| 165 |
+
# Get dataset config for: num_classes, separator, image_tmpl, filter_video, etc.
|
| 166 |
+
num_classes, _, _, _, filename_seperator, image_tmpl, filter_video, _ = get_dataset_config(
|
| 167 |
+
args.dataset, args.use_lmdb
|
| 168 |
+
)
|
| 169 |
+
args.num_classes = num_classes
|
| 170 |
+
|
| 171 |
+
print(f"Creating model: {args.model}")
|
| 172 |
+
model = create_model(
|
| 173 |
+
args.model,
|
| 174 |
+
pretrained=False,
|
| 175 |
+
duration=args.duration,
|
| 176 |
+
hpe_to_token=args.hpe_to_token,
|
| 177 |
+
rel_pos=args.rel_pos,
|
| 178 |
+
window_size=args.window_size,
|
| 179 |
+
thumbnail_rows=args.thumbnail_rows,
|
| 180 |
+
token_mask=not args.no_token_mask,
|
| 181 |
+
online_learning=False,
|
| 182 |
+
num_classes=args.num_classes,
|
| 183 |
+
drop_rate=args.drop,
|
| 184 |
+
drop_path_rate=args.drop_path,
|
| 185 |
+
drop_block_rate=args.drop_block,
|
| 186 |
+
use_checkpoint=args.use_checkpoint,
|
| 187 |
+
).to(device)
|
| 188 |
+
model.eval()
|
| 189 |
+
|
| 190 |
+
ckpt = torch.load(args.initial_checkpoint, map_location="cpu")
|
| 191 |
+
if isinstance(ckpt, dict) and "model" in ckpt:
|
| 192 |
+
utils.load_checkpoint(model, ckpt["model"])
|
| 193 |
+
else:
|
| 194 |
+
# If the checkpoint is a raw state_dict
|
| 195 |
+
model.load_state_dict(ckpt, strict=False)
|
| 196 |
+
|
| 197 |
+
mean = (0.5, 0.5, 0.5) if "mean" not in model.default_cfg else model.default_cfg["mean"]
|
| 198 |
+
std = (0.5, 0.5, 0.5) if "std" not in model.default_cfg else model.default_cfg["std"]
|
| 199 |
+
|
| 200 |
+
augmentor = get_augmentor(
|
| 201 |
+
False, # is_train
|
| 202 |
+
args.input_size, # input_size
|
| 203 |
+
mean,
|
| 204 |
+
std,
|
| 205 |
+
args.disable_scaleup,
|
| 206 |
+
threed_data=args.threed_data,
|
| 207 |
+
version=args.augmentor_ver,
|
| 208 |
+
scale_range=args.scale_range,
|
| 209 |
+
num_clips=args.num_clips,
|
| 210 |
+
num_crops=args.num_crops,
|
| 211 |
+
dataset=args.dataset
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
meta = {
|
| 216 |
+
"device": device,
|
| 217 |
+
"model": model,
|
| 218 |
+
"augmentor": augmentor,
|
| 219 |
+
"num_classes": num_classes,
|
| 220 |
+
"filename_seperator": filename_seperator,
|
| 221 |
+
"image_tmpl": image_tmpl,
|
| 222 |
+
"filter_video": filter_video,
|
| 223 |
+
}
|
| 224 |
+
return meta
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
# -----------------------------
|
| 228 |
+
# Inference
|
| 229 |
+
# -----------------------------
|
| 230 |
+
@torch.no_grad()
|
| 231 |
+
def infer_one_video_from_tmp(args, meta: Dict[str, Any], tmp_root: str, list_rel_path: str, image_tmpl: str) -> Dict[str, Any]:
|
| 232 |
+
"""Run inference on a single temporary dataset with 1 video."""
|
| 233 |
+
device = meta["device"]
|
| 234 |
+
model = meta["model"]
|
| 235 |
+
augmentor = meta["augmentor"]
|
| 236 |
+
|
| 237 |
+
dataset = VideoDataSet(
|
| 238 |
+
root_path=tmp_root,
|
| 239 |
+
list_file=list_rel_path, # relative to root_path for VideoDataSet
|
| 240 |
+
num_groups=args.duration,
|
| 241 |
+
frames_per_group=args.frames_per_group,
|
| 242 |
+
sample_offset=0,
|
| 243 |
+
num_clips=args.num_clips,
|
| 244 |
+
modality=args.modality,
|
| 245 |
+
dense_sampling=args.dense_sampling,
|
| 246 |
+
fixed_offset=True,
|
| 247 |
+
image_tmpl=image_tmpl, # enforce correct template (e.g., {:05d}.jpg)
|
| 248 |
+
transform=augmentor,
|
| 249 |
+
is_train=False,
|
| 250 |
+
test_mode=False,
|
| 251 |
+
seperator=meta["filename_seperator"],
|
| 252 |
+
filter_video=meta["filter_video"],
|
| 253 |
+
num_classes=meta["num_classes"],
|
| 254 |
+
whole_video=False,
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
loader = build_dataflow(
|
| 258 |
+
dataset, is_train=False, batch_size=1,
|
| 259 |
+
workers=args.num_workers, is_distributed=False
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
logits_all = []
|
| 263 |
+
for samples, _targets in loader:
|
| 264 |
+
samples = samples.to(device, non_blocking=True)
|
| 265 |
+
logits = model(samples) # shape [1,2] typically (or [K,2] depending on pipeline)
|
| 266 |
+
logits_all.append(logits.detach().cpu())
|
| 267 |
+
|
| 268 |
+
logits_all = torch.cat(logits_all, dim=0) # [n_preds, 2]
|
| 269 |
+
logits_mean = logits_all.mean(dim=0, keepdim=True) # [1,2]
|
| 270 |
+
probs = torch.softmax(logits_mean, dim=1).numpy()[0]
|
| 271 |
+
|
| 272 |
+
# Threshold-based decision (class 1 = FAKE)
|
| 273 |
+
thr = float(args.threshold)
|
| 274 |
+
pred = int(probs[1] >= thr)
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
"threshold": thr,
|
| 278 |
+
"p_real": float(probs[0]),
|
| 279 |
+
"p_fake": float(probs[1]),
|
| 280 |
+
"pred": pred,
|
| 281 |
+
"pred_name": "FAKE" if pred == 1 else "REAL",
|
| 282 |
+
"n_preds": int(logits_all.shape[0]),
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
# -----------------------------
|
| 287 |
+
# CLI
|
| 288 |
+
# -----------------------------
|
| 289 |
+
def get_args():
|
| 290 |
+
ap = argparse.ArgumentParser("Infer TALL_SWIN from a list of videos (txt)")
|
| 291 |
+
|
| 292 |
+
ap.add_argument("--video_list", required=True, help="Text file with one video path per line")
|
| 293 |
+
ap.add_argument("--initial_checkpoint", required=True)
|
| 294 |
+
ap.add_argument("--dataset", default="ffpp")
|
| 295 |
+
ap.add_argument("--model", default="TALL_SWIN")
|
| 296 |
+
ap.add_argument("--device", default="cuda")
|
| 297 |
+
ap.add_argument("--num_workers", type=int, default=0)
|
| 298 |
+
|
| 299 |
+
ap.add_argument("--duration", type=int, default=4)
|
| 300 |
+
ap.add_argument("--frames_per_group", type=int, default=1)
|
| 301 |
+
ap.add_argument("--num_clips", type=int, default=8)
|
| 302 |
+
ap.add_argument("--num_crops", type=int, default=1)
|
| 303 |
+
ap.add_argument("--thumbnail_rows", type=int, default=2)
|
| 304 |
+
ap.add_argument("--input_size", type=int, default=224)
|
| 305 |
+
|
| 306 |
+
ap.add_argument("--threshold", type=float, default=0.5,
|
| 307 |
+
help="Decision threshold for FAKE (pred=1 if p_fake >= threshold)")
|
| 308 |
+
|
| 309 |
+
ap.add_argument("--disable_scaleup", action="store_true")
|
| 310 |
+
ap.add_argument("--threed_data", default=False)
|
| 311 |
+
ap.add_argument("--dense_sampling", default=True)
|
| 312 |
+
ap.add_argument("--augmentor_ver", default="v1")
|
| 313 |
+
ap.add_argument("--scale_range", default=[256, 320], type=int, nargs="+")
|
| 314 |
+
ap.add_argument("--modality", default="rgb")
|
| 315 |
+
ap.add_argument("--use_lmdb", default=False)
|
| 316 |
+
|
| 317 |
+
ap.add_argument("--hpe_to_token", action="store_true")
|
| 318 |
+
ap.add_argument("--rel_pos", action="store_true")
|
| 319 |
+
ap.add_argument("--window_size", type=int, default=7)
|
| 320 |
+
ap.add_argument("--no_token_mask", action="store_true")
|
| 321 |
+
|
| 322 |
+
ap.add_argument("--drop", type=float, default=0.0)
|
| 323 |
+
ap.add_argument("--drop_path", type=float, default=0.1)
|
| 324 |
+
ap.add_argument("--drop_block", default=None)
|
| 325 |
+
ap.add_argument("--use_checkpoint", default=False)
|
| 326 |
+
|
| 327 |
+
ap.add_argument("--dist_url", default="env://")
|
| 328 |
+
ap.add_argument("--world_size", default=1, type=int)
|
| 329 |
+
ap.add_argument("--local_rank", default=None, type=int)
|
| 330 |
+
|
| 331 |
+
ap.add_argument("--output_json", default="", help="Optional path to save results JSON")
|
| 332 |
+
ap.add_argument("--output_csv", default="", help="Optional path to save results CSV")
|
| 333 |
+
|
| 334 |
+
return ap.parse_args()
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
# -----------------------------
|
| 338 |
+
# Main
|
| 339 |
+
# -----------------------------
|
| 340 |
+
def main():
|
| 341 |
+
args = get_args()
|
| 342 |
+
|
| 343 |
+
if not os.path.isfile(args.video_list):
|
| 344 |
+
raise FileNotFoundError(args.video_list)
|
| 345 |
+
if not os.path.isfile(args.initial_checkpoint):
|
| 346 |
+
raise FileNotFoundError(args.initial_checkpoint)
|
| 347 |
+
|
| 348 |
+
videos = read_video_list_txt(args.video_list)
|
| 349 |
+
if len(videos) == 0:
|
| 350 |
+
raise RuntimeError("video_list is empty.")
|
| 351 |
+
|
| 352 |
+
# Build model + augmentor once
|
| 353 |
+
meta = build_model_and_augmentor(args)
|
| 354 |
+
|
| 355 |
+
results_rows: List[Dict[str, Any]] = []
|
| 356 |
+
ok_count = 0
|
| 357 |
+
|
| 358 |
+
print(f"\nVideos to process: {len(videos)}")
|
| 359 |
+
print(f"Checkpoint: {args.initial_checkpoint}")
|
| 360 |
+
print(f"Threshold: {args.threshold}\n")
|
| 361 |
+
|
| 362 |
+
for i, video_path in enumerate(videos, 1):
|
| 363 |
+
row: Dict[str, Any] = {
|
| 364 |
+
"video": video_path,
|
| 365 |
+
"status": "ok",
|
| 366 |
+
"error": "",
|
| 367 |
+
"pred": "",
|
| 368 |
+
"pred_name": "",
|
| 369 |
+
"threshold": float(args.threshold),
|
| 370 |
+
"p_real": "",
|
| 371 |
+
"p_fake": "",
|
| 372 |
+
"n_frames": "",
|
| 373 |
+
"n_preds": "",
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
if not os.path.isfile(video_path):
|
| 377 |
+
row["status"] = "skip"
|
| 378 |
+
row["error"] = "file_not_found"
|
| 379 |
+
results_rows.append(row)
|
| 380 |
+
print(f"[{i}/{len(videos)}] [SKIP] Not found: {video_path}")
|
| 381 |
+
continue
|
| 382 |
+
|
| 383 |
+
tmp_dir = tempfile.mkdtemp(prefix="tall_infer_")
|
| 384 |
+
try:
|
| 385 |
+
tmp_info = build_tmp_dataset_from_video(video_path, tmp_dir, image_tmpl=meta["image_tmpl"])
|
| 386 |
+
row["n_frames"] = int(tmp_info["nframes"])
|
| 387 |
+
|
| 388 |
+
out = infer_one_video_from_tmp(
|
| 389 |
+
args, meta, tmp_dir, tmp_info["list_rel"], image_tmpl=meta["image_tmpl"]
|
| 390 |
+
)
|
| 391 |
+
row.update(out)
|
| 392 |
+
ok_count += 1
|
| 393 |
+
|
| 394 |
+
print(f"[{i}/{len(videos)}] [{row['pred_name']}] {video_path} | p_fake={float(row['p_fake']):.4f}")
|
| 395 |
+
|
| 396 |
+
except Exception as e:
|
| 397 |
+
row["status"] = "error"
|
| 398 |
+
row["error"] = str(e)
|
| 399 |
+
print(f"[{i}/{len(videos)}] [ERROR] {video_path}\n -> {e}")
|
| 400 |
+
|
| 401 |
+
finally:
|
| 402 |
+
shutil.rmtree(tmp_dir, ignore_errors=True)
|
| 403 |
+
|
| 404 |
+
results_rows.append(row)
|
| 405 |
+
|
| 406 |
+
summary = {
|
| 407 |
+
"video_list": os.path.abspath(args.video_list),
|
| 408 |
+
"checkpoint": os.path.abspath(args.initial_checkpoint),
|
| 409 |
+
"dataset": args.dataset,
|
| 410 |
+
"model": args.model,
|
| 411 |
+
"threshold": float(args.threshold),
|
| 412 |
+
"num_videos": len(videos),
|
| 413 |
+
"num_ok": ok_count,
|
| 414 |
+
"results": results_rows,
|
| 415 |
+
}
|
| 416 |
+
|
| 417 |
+
print("\n=== SUMMARY ===")
|
| 418 |
+
print(f"ok: {ok_count}/{len(videos)}")
|
| 419 |
+
if ok_count > 0:
|
| 420 |
+
pf = [float(r["p_fake"]) for r in results_rows if r.get("status") == "ok"]
|
| 421 |
+
print(f"avg p_fake (ok only): {sum(pf)/len(pf):.4f}")
|
| 422 |
+
|
| 423 |
+
if args.output_json:
|
| 424 |
+
write_json(args.output_json, summary)
|
| 425 |
+
print(f"Saved JSON: {os.path.abspath(args.output_json)}")
|
| 426 |
+
|
| 427 |
+
if args.output_csv:
|
| 428 |
+
write_csv(args.output_csv, results_rows)
|
| 429 |
+
print(f"Saved CSV: {os.path.abspath(args.output_csv)}")
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
if __name__ == "__main__":
|
| 433 |
+
main()
|
keep_only_numbered.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, re
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
| 5 |
+
NUM_RE = re.compile(r"^\d+$") # stem must be only digits
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
ap = argparse.ArgumentParser()
|
| 9 |
+
ap.add_argument("--root", required=True)
|
| 10 |
+
ap.add_argument("--dry_run", action="store_true", help="Only print what would be deleted")
|
| 11 |
+
args = ap.parse_args()
|
| 12 |
+
|
| 13 |
+
root = Path(args.root)
|
| 14 |
+
to_delete = []
|
| 15 |
+
|
| 16 |
+
for cls in ["real", "fake"]:
|
| 17 |
+
cls_dir = root / cls
|
| 18 |
+
if not cls_dir.exists():
|
| 19 |
+
continue
|
| 20 |
+
for vid_dir in cls_dir.iterdir():
|
| 21 |
+
if not vid_dir.is_dir():
|
| 22 |
+
continue
|
| 23 |
+
for p in vid_dir.iterdir():
|
| 24 |
+
if not p.is_file():
|
| 25 |
+
continue
|
| 26 |
+
if p.suffix.lower() not in IMG_EXTS:
|
| 27 |
+
continue
|
| 28 |
+
if not NUM_RE.match(p.stem):
|
| 29 |
+
to_delete.append(p)
|
| 30 |
+
|
| 31 |
+
print(f"Found {len(to_delete)} non-numbered frames to delete.")
|
| 32 |
+
for p in to_delete[:30]:
|
| 33 |
+
print("DEL", p)
|
| 34 |
+
if len(to_delete) > 30:
|
| 35 |
+
print("...")
|
| 36 |
+
|
| 37 |
+
if args.dry_run:
|
| 38 |
+
print("Dry-run only. No files deleted.")
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
for p in to_delete:
|
| 42 |
+
try:
|
| 43 |
+
p.unlink()
|
| 44 |
+
except Exception as e:
|
| 45 |
+
print("FAILED", p, e)
|
| 46 |
+
|
| 47 |
+
print("Done.")
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import datetime
|
| 3 |
+
import numpy as np
|
| 4 |
+
import time
|
| 5 |
+
import torch
|
| 6 |
+
import torch.backends.cudnn as cudnn
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import warnings
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
from timm.data import Mixup
|
| 14 |
+
from timm.models import create_model
|
| 15 |
+
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
|
| 16 |
+
from timm.scheduler import create_scheduler
|
| 17 |
+
from timm.optim import create_optimizer
|
| 18 |
+
from timm.utils import NativeScaler, get_state_dict, ModelEma
|
| 19 |
+
|
| 20 |
+
#from datasets import build_dataset
|
| 21 |
+
from engine import train_one_epoch, evaluate
|
| 22 |
+
import models
|
| 23 |
+
import my_models
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
|
| 26 |
+
import utils
|
| 27 |
+
|
| 28 |
+
from video_dataset import VideoDataSet
|
| 29 |
+
from video_dataset_aug import get_augmentor, build_dataflow
|
| 30 |
+
from video_dataset_config import get_dataset_config, DATASET_CONFIG
|
| 31 |
+
|
| 32 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 33 |
+
|
| 34 |
+
def get_args_parser():
|
| 35 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
| 36 |
+
parser.add_argument('--model_name',default="TALL_SWIN")
|
| 37 |
+
parser.add_argument('--batch-size', default=2, type=int)
|
| 38 |
+
parser.add_argument('--epochs', default=30, type=int)
|
| 39 |
+
|
| 40 |
+
# Dataset parameters
|
| 41 |
+
parser.add_argument('--data_txt_dir', type=str,default='##path_for_dataset_txt##', help='path to text of dataset')
|
| 42 |
+
parser.add_argument('--data_dir', type=str,default="##path_for_dataset##", help='path to dataset')
|
| 43 |
+
parser.add_argument('--dataset', default='ffpp_train',
|
| 44 |
+
choices=list(DATASET_CONFIG.keys()), help='path to dataset file list')
|
| 45 |
+
parser.add_argument('--duration', default=8, type=int, help='number of frames')
|
| 46 |
+
parser.add_argument('--frames_per_group', default=1, type=int,
|
| 47 |
+
help='[uniform sampling] number of frames per group; '
|
| 48 |
+
'[dense sampling]: sampling frequency')
|
| 49 |
+
parser.add_argument('--threed_data', default=False, help='load data in the layout for 3D conv')
|
| 50 |
+
parser.add_argument('--input_size', default=224, type=int, metavar='N', help='input image size')
|
| 51 |
+
parser.add_argument('--disable_scaleup', action='store_true',
|
| 52 |
+
help='do not scale up and then crop a small region, directly crop the input_size')
|
| 53 |
+
parser.add_argument('--random_sampling', action='store_true',
|
| 54 |
+
help='perform determinstic sampling for data loader')
|
| 55 |
+
parser.add_argument('--dense_sampling', default=True,
|
| 56 |
+
help='perform dense sampling for data loader')
|
| 57 |
+
parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2'],
|
| 58 |
+
help='[v1] TSN data argmentation, [v2] resize the shorter side to `scale_range`')
|
| 59 |
+
parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+",
|
| 60 |
+
metavar='scale_range', help='scale range for augmentor v2')
|
| 61 |
+
parser.add_argument('--modality', default='rgb', type=str, help='rgb or flow')
|
| 62 |
+
parser.add_argument('--use_lmdb', default=False, help='use lmdb instead of jpeg.')
|
| 63 |
+
parser.add_argument('--use_pyav', default=False, help='use video directly.')
|
| 64 |
+
|
| 65 |
+
# temporal module
|
| 66 |
+
parser.add_argument('--pretrained', action='store_true', default=False,
|
| 67 |
+
help='Start with pretrained version of specified network (if avail)')
|
| 68 |
+
parser.add_argument('--temporal_module_name', default=None, type=str, metavar='TEM', choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA'],
|
| 69 |
+
help='temporal module applied. [TAM]')
|
| 70 |
+
parser.add_argument('--temporal_attention_only', action='store_true', default=False,
|
| 71 |
+
help='use attention only in temporal module]')
|
| 72 |
+
parser.add_argument('--no_token_mask', action='store_true', default=False, help='do not apply token mask')
|
| 73 |
+
parser.add_argument('--temporal_heads_scale', default=1.0, type=float, help='scale of the number of spatial heads')
|
| 74 |
+
parser.add_argument('--temporal_mlp_scale', default=1.0, type=float, help='scale of spatial mlp')
|
| 75 |
+
parser.add_argument('--rel_pos', action='store_true', default=False,
|
| 76 |
+
help='use relative positioning in temporal module]')
|
| 77 |
+
parser.add_argument('--temporal_pooling', type=str, default=None, choices=['avg', 'max', 'conv', 'depthconv'],
|
| 78 |
+
help='perform temporal pooling]')
|
| 79 |
+
parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw'],
|
| 80 |
+
help='use depth-wise bottleneck in temporal attention')
|
| 81 |
+
|
| 82 |
+
parser.add_argument('--window_size', default=14, type=int, help='number of frames')
|
| 83 |
+
parser.add_argument('--thumbnail_rows', default=4, type=int, help='number of frames per row')
|
| 84 |
+
|
| 85 |
+
parser.add_argument('--hpe_to_token', default=False, action='store_true',
|
| 86 |
+
help='add hub position embedding to image tokens')
|
| 87 |
+
# Model parameters
|
| 88 |
+
parser.add_argument('--model', default='TALL_SWIN', type=str, metavar='MODEL',
|
| 89 |
+
help='Name of model to train')
|
| 90 |
+
parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
| 91 |
+
|
| 92 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
| 93 |
+
help='Dropout rate (default: 0.)')
|
| 94 |
+
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
|
| 95 |
+
help='Drop path rate (default: 0.1)')
|
| 96 |
+
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
| 97 |
+
help='Drop block rate (default: None)')
|
| 98 |
+
|
| 99 |
+
parser.add_argument('--model-ema', action='store_true')
|
| 100 |
+
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
|
| 101 |
+
parser.set_defaults(model_ema=True)
|
| 102 |
+
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
|
| 103 |
+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
|
| 104 |
+
|
| 105 |
+
# Optimizer parameters
|
| 106 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
| 107 |
+
help='Optimizer (default: "adamw"')
|
| 108 |
+
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
| 109 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
| 110 |
+
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
| 111 |
+
help='Optimizer Betas (default: None, use opt default)')
|
| 112 |
+
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
| 113 |
+
help='Clip gradient norm (default: None, no clipping)')
|
| 114 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
| 115 |
+
help='SGD momentum (default: 0.9)')
|
| 116 |
+
parser.add_argument('--weight-decay', type=float, default=1e-5,
|
| 117 |
+
help='weight decay (default: 0.05)')
|
| 118 |
+
# Learning rate schedule parameters
|
| 119 |
+
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
| 120 |
+
help='LR scheduler (default: "cosine"')
|
| 121 |
+
parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
|
| 122 |
+
help='learning rate (default: 5e-4)')
|
| 123 |
+
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
| 124 |
+
help='learning rate noise on/off epoch percentages')
|
| 125 |
+
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
| 126 |
+
help='learning rate noise limit percent (default: 0.67)')
|
| 127 |
+
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
| 128 |
+
help='learning rate noise std-dev (default: 1.0)')
|
| 129 |
+
parser.add_argument('--warmup-lr', type=float, default=1e-8, metavar='LR',
|
| 130 |
+
help='warmup learning rate (default: 1e-6)')
|
| 131 |
+
parser.add_argument('--min-lr', type=float, default=1e-7, metavar='LR',
|
| 132 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
| 133 |
+
|
| 134 |
+
parser.add_argument('--decay-epochs', type=float, default=10, metavar='N',
|
| 135 |
+
help='epoch interval to decay LR')
|
| 136 |
+
parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
|
| 137 |
+
help='epochs to warmup LR, if scheduler supports')
|
| 138 |
+
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
| 139 |
+
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
| 140 |
+
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
| 141 |
+
help='patience epochs for Plateau LR scheduler (default: 10')
|
| 142 |
+
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
| 143 |
+
help='LR decay rate (default: 0.1)')
|
| 144 |
+
|
| 145 |
+
# Augmentation parameters
|
| 146 |
+
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
| 147 |
+
help='Color jitter factor (default: 0.4)')
|
| 148 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5', metavar='NAME',
|
| 149 |
+
help='Use AutoAugment policy. "v0" or "original". " + \
|
| 150 |
+
"(default: rand-m9-mstd0.5-inc1)'),
|
| 151 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
|
| 152 |
+
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
| 153 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
| 154 |
+
|
| 155 |
+
parser.add_argument('--repeated-aug', action='store_true')
|
| 156 |
+
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
| 157 |
+
parser.set_defaults(repeated_aug=False)
|
| 158 |
+
|
| 159 |
+
# * Random Erase params
|
| 160 |
+
parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT',
|
| 161 |
+
help='Random erase prob (default: 0.25)')
|
| 162 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
| 163 |
+
help='Random erase mode (default: "pixel")')
|
| 164 |
+
parser.add_argument('--recount', type=int, default=1,
|
| 165 |
+
help='Random erase count (default: 1)')
|
| 166 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
| 167 |
+
help='Do not random erase first (clean) augmentation split')
|
| 168 |
+
|
| 169 |
+
# * Mixup params
|
| 170 |
+
parser.add_argument('--cutout',default=True)
|
| 171 |
+
parser.add_argument('--mixup', type=float, default=0,
|
| 172 |
+
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
| 173 |
+
parser.add_argument('--cutmix', type=float, default=0,
|
| 174 |
+
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
| 175 |
+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
| 176 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
| 177 |
+
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
| 178 |
+
help='Probability of performing mixup or cutmix when either/both is enabled')
|
| 179 |
+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
| 180 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
| 181 |
+
parser.add_argument('--mixup-mode', type=str, default='batch',
|
| 182 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
| 183 |
+
|
| 184 |
+
# Dataset parameters
|
| 185 |
+
|
| 186 |
+
parser.add_argument('--output_dir', default="",
|
| 187 |
+
help='path where to save, empty for no saving')
|
| 188 |
+
parser.add_argument('--device', default='cuda',
|
| 189 |
+
help='device to use for training / testing')
|
| 190 |
+
parser.add_argument('--seed', default=42, type=int)
|
| 191 |
+
parser.add_argument('--resume', default="", help='resume from checkpoint')
|
| 192 |
+
parser.add_argument('--no-resume-loss-scaler', action='store_false', dest='resume_loss_scaler')
|
| 193 |
+
parser.add_argument('--no-amp', action='store_false', dest='amp', help='disable amp')
|
| 194 |
+
parser.add_argument('--use_checkpoint', default=False, help='use checkpoint to save memory')
|
| 195 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 196 |
+
help='start epoch')
|
| 197 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
| 198 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
| 199 |
+
parser.add_argument('--pin-mem', action='store_true',
|
| 200 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 201 |
+
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
| 202 |
+
help='')
|
| 203 |
+
parser.set_defaults(pin_mem=True)
|
| 204 |
+
|
| 205 |
+
# for testing and validation
|
| 206 |
+
parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10])
|
| 207 |
+
parser.add_argument('--num_clips', default=1, type=int)
|
| 208 |
+
|
| 209 |
+
# distributed training parameters
|
| 210 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 211 |
+
help='number of distributed processes')
|
| 212 |
+
parser.add_argument("--local_rank", type=int)
|
| 213 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
parser.add_argument('--auto-resume', default=True, help='auto resume')
|
| 217 |
+
# exp
|
| 218 |
+
# parser.add_argument('--simclr_w', type=float, default=0., help='weights for simclr loss')
|
| 219 |
+
parser.add_argument('--contrastive_nomixup', action='store_true', help='do not involve mixup in contrastive learning')
|
| 220 |
+
parser.add_argument('--finetune', default=False, help='finetune model')
|
| 221 |
+
parser.add_argument('--initial_checkpoint', type=str, default='', help='path to the pretrained model')
|
| 222 |
+
|
| 223 |
+
parser.add_argument('--hard_contrastive', action='store_true', help='use HEXA')
|
| 224 |
+
# parser.add_argument('--selfdis_w', type=float, default=0., help='enable self distillation')
|
| 225 |
+
|
| 226 |
+
return parser
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def main(args):
|
| 230 |
+
utils.init_distributed_mode(args)
|
| 231 |
+
print(args)
|
| 232 |
+
# Patch
|
| 233 |
+
if not hasattr(args, 'hard_contrastive'):
|
| 234 |
+
args.hard_contrastive = False
|
| 235 |
+
if not hasattr(args, 'selfdis_w'):
|
| 236 |
+
args.selfdis_w = 0.0
|
| 237 |
+
|
| 238 |
+
#is_imnet21k = args.data_set == 'IMNET21K'
|
| 239 |
+
|
| 240 |
+
device = torch.device(args.device)
|
| 241 |
+
|
| 242 |
+
# fix the seed for reproducibility
|
| 243 |
+
seed = args.seed + utils.get_rank()
|
| 244 |
+
torch.manual_seed(seed)
|
| 245 |
+
np.random.seed(seed)
|
| 246 |
+
# random.seed(seed)
|
| 247 |
+
|
| 248 |
+
cudnn.benchmark = True
|
| 249 |
+
|
| 250 |
+
num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(
|
| 251 |
+
args.dataset, args.use_lmdb)
|
| 252 |
+
|
| 253 |
+
args.num_classes = num_classes
|
| 254 |
+
if args.modality == 'rgb':
|
| 255 |
+
args.input_channels = 3
|
| 256 |
+
elif args.modality == 'flow':
|
| 257 |
+
args.input_channels = 2 * 5
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
mixup_fn = None
|
| 261 |
+
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
|
| 262 |
+
if mixup_active:
|
| 263 |
+
mixup_fn = Mixup(
|
| 264 |
+
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
|
| 265 |
+
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
|
| 266 |
+
label_smoothing=args.smoothing, num_classes=args.num_classes)
|
| 267 |
+
|
| 268 |
+
print(f"Creating model: {args.model}")
|
| 269 |
+
|
| 270 |
+
model =create_model(
|
| 271 |
+
args.model,
|
| 272 |
+
pretrained=args.pretrained,
|
| 273 |
+
duration=args.duration,
|
| 274 |
+
hpe_to_token = args.hpe_to_token,
|
| 275 |
+
rel_pos = args.rel_pos,
|
| 276 |
+
window_size=args.window_size,
|
| 277 |
+
thumbnail_rows = args.thumbnail_rows,
|
| 278 |
+
token_mask=not args.no_token_mask,
|
| 279 |
+
online_learning = False,
|
| 280 |
+
num_classes=args.num_classes,
|
| 281 |
+
drop_rate=args.drop,
|
| 282 |
+
drop_path_rate=args.drop_path,
|
| 283 |
+
drop_block_rate=args.drop_block,
|
| 284 |
+
use_checkpoint=args.use_checkpoint
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# TODO: finetuning
|
| 288 |
+
|
| 289 |
+
model.to(device)
|
| 290 |
+
|
| 291 |
+
model_ema = None
|
| 292 |
+
if args.model_ema:
|
| 293 |
+
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
| 294 |
+
model_ema = ModelEma(
|
| 295 |
+
model,
|
| 296 |
+
decay=args.model_ema_decay,
|
| 297 |
+
device='cpu' if args.model_ema_force_cpu else '',
|
| 298 |
+
resume=args.resume)
|
| 299 |
+
|
| 300 |
+
model_without_ddp = model
|
| 301 |
+
if args.distributed:
|
| 302 |
+
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 303 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
| 304 |
+
model_without_ddp = model.module
|
| 305 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 306 |
+
print('number of params:', n_parameters)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
optimizer = create_optimizer(args, model)
|
| 310 |
+
loss_scaler = NativeScaler()
|
| 311 |
+
#print(f"Scaled learning rate (batch size: {args.batch_size * utils.get_world_size()}): {linear_scaled_lr}")
|
| 312 |
+
lr_scheduler, _ = create_scheduler(args, optimizer)
|
| 313 |
+
|
| 314 |
+
criterion = LabelSmoothingCrossEntropy()
|
| 315 |
+
|
| 316 |
+
if args.mixup > 0.:
|
| 317 |
+
# smoothing is handled with mixup label transform
|
| 318 |
+
criterion = SoftTargetCrossEntropy()
|
| 319 |
+
elif args.smoothing:
|
| 320 |
+
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
|
| 321 |
+
else:
|
| 322 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 323 |
+
|
| 324 |
+
if args.distributed:
|
| 325 |
+
mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean']
|
| 326 |
+
std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std']
|
| 327 |
+
else:
|
| 328 |
+
mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean']
|
| 329 |
+
std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std']
|
| 330 |
+
# dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
| 331 |
+
# create data loaders w/ augmentation pipeiine
|
| 332 |
+
video_data_cls = VideoDataSet
|
| 333 |
+
train_list = os.path.join(args.data_txt_dir, train_list_name)
|
| 334 |
+
|
| 335 |
+
train_augmentor = get_augmentor(True, args.input_size, mean, std, threed_data=False,
|
| 336 |
+
version=args.augmentor_ver, scale_range=args.scale_range, cut_out = args.cutout,dataset=args.dataset)
|
| 337 |
+
dataset_train = video_data_cls(args.data_dir, train_list, args.duration, args.frames_per_group,
|
| 338 |
+
num_clips=args.num_clips,
|
| 339 |
+
modality=args.modality, image_tmpl=image_tmpl,
|
| 340 |
+
dense_sampling=args.dense_sampling,
|
| 341 |
+
transform=train_augmentor, is_train=True, test_mode=False,
|
| 342 |
+
seperator=filename_seperator, filter_video=filter_video)
|
| 343 |
+
|
| 344 |
+
num_tasks = utils.get_world_size()
|
| 345 |
+
data_loader_train = build_dataflow(dataset_train, is_train=True, batch_size=args.batch_size,
|
| 346 |
+
workers=args.num_workers, is_distributed=args.distributed)
|
| 347 |
+
|
| 348 |
+
val_list = os.path.join(args.data_txt_dir, val_list_name)
|
| 349 |
+
val_augmentor = get_augmentor(False, args.input_size, mean, std, args.disable_scaleup,
|
| 350 |
+
threed_data=args.threed_data, version=args.augmentor_ver,
|
| 351 |
+
scale_range=args.scale_range, num_clips=args.num_clips, num_crops=args.num_crops,cut_out = False, dataset=args.dataset)
|
| 352 |
+
dataset_val = video_data_cls(args.data_dir, val_list, args.duration, args.frames_per_group,
|
| 353 |
+
num_clips=args.num_clips,
|
| 354 |
+
modality=args.modality, image_tmpl=image_tmpl,
|
| 355 |
+
dense_sampling=args.dense_sampling,
|
| 356 |
+
transform=val_augmentor, is_train=False, test_mode=False,
|
| 357 |
+
seperator=filename_seperator, filter_video=filter_video)
|
| 358 |
+
|
| 359 |
+
data_loader_val = build_dataflow(dataset_val, is_train=False, batch_size=args.batch_size,
|
| 360 |
+
workers=args.num_workers, is_distributed=args.distributed)
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
max_accuracy = 0.0
|
| 364 |
+
output_dir = Path(args.output_dir)
|
| 365 |
+
|
| 366 |
+
if args.initial_checkpoint:
|
| 367 |
+
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
| 368 |
+
utils.load_checkpoint(model, checkpoint['model'])
|
| 369 |
+
|
| 370 |
+
if args.auto_resume:
|
| 371 |
+
if args.resume == '':
|
| 372 |
+
args.resume = str(output_dir / "checkpoint.pth")
|
| 373 |
+
if not os.path.exists(args.resume):
|
| 374 |
+
args.resume = ''
|
| 375 |
+
|
| 376 |
+
if args.resume:
|
| 377 |
+
if args.resume.startswith('https'):
|
| 378 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 379 |
+
args.resume, map_location='cpu', check_hash=True)
|
| 380 |
+
else:
|
| 381 |
+
checkpoint = torch.load(args.resume, map_location='cpu')
|
| 382 |
+
utils.load_checkpoint(model, checkpoint['model'])
|
| 383 |
+
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
|
| 384 |
+
optimizer.load_state_dict(checkpoint['optimizer'])
|
| 385 |
+
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
|
| 386 |
+
args.start_epoch = checkpoint['epoch'] + 1
|
| 387 |
+
if 'scaler' in checkpoint and args.resume_loss_scaler:
|
| 388 |
+
print("Resume with previous loss scaler state")
|
| 389 |
+
loss_scaler.load_state_dict(checkpoint['scaler'])
|
| 390 |
+
if args.model_ema:
|
| 391 |
+
utils._load_checkpoint_for_ema(model_ema, checkpoint['model_ema'])
|
| 392 |
+
max_accuracy = checkpoint['max_accuracy']
|
| 393 |
+
|
| 394 |
+
if args.eval:
|
| 395 |
+
test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp, num_crops=args.num_crops, num_clips=args.num_clips)
|
| 396 |
+
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
| 397 |
+
return
|
| 398 |
+
|
| 399 |
+
print(f"Start training, currnet max acc is {max_accuracy:.2f}")
|
| 400 |
+
start_time = time.time()
|
| 401 |
+
for epoch in range(args.start_epoch, args.epochs):
|
| 402 |
+
|
| 403 |
+
if args.distributed:
|
| 404 |
+
data_loader_train.sampler.set_epoch(epoch)
|
| 405 |
+
|
| 406 |
+
train_stats = train_one_epoch(
|
| 407 |
+
model, criterion, data_loader_train,args.num_clips,
|
| 408 |
+
optimizer, device, epoch, loss_scaler,
|
| 409 |
+
args.clip_grad, model_ema, mixup_fn, num_tasks, True,
|
| 410 |
+
amp=args.amp,
|
| 411 |
+
contrastive_nomixup=args.contrastive_nomixup,
|
| 412 |
+
hard_contrastive=args.hard_contrastive,
|
| 413 |
+
finetune=args.finetune
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
lr_scheduler.step(epoch)
|
| 417 |
+
|
| 418 |
+
test_stats = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp, num_crops=args.num_crops, num_clips=args.num_clips)
|
| 419 |
+
print(f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
|
| 420 |
+
|
| 421 |
+
max_accuracy = max(max_accuracy, test_stats["acc1"])
|
| 422 |
+
print(f'Max accuracy: {max_accuracy:.2f}%')
|
| 423 |
+
|
| 424 |
+
if args.output_dir:
|
| 425 |
+
checkpoint_paths = [output_dir / 'checkpoint{}.pth'.format(epoch)]
|
| 426 |
+
if test_stats["acc1"] == max_accuracy:
|
| 427 |
+
checkpoint_paths.append(output_dir / 'model_best.pth')
|
| 428 |
+
for checkpoint_path in checkpoint_paths:
|
| 429 |
+
state_dict = {
|
| 430 |
+
'model': model_without_ddp.state_dict(),
|
| 431 |
+
'optimizer': optimizer.state_dict(),
|
| 432 |
+
'lr_scheduler': lr_scheduler.state_dict(),
|
| 433 |
+
'epoch': epoch,
|
| 434 |
+
'args': args,
|
| 435 |
+
'scaler': loss_scaler.state_dict(),
|
| 436 |
+
'max_accuracy': max_accuracy
|
| 437 |
+
}
|
| 438 |
+
if args.model_ema:
|
| 439 |
+
state_dict['model_ema'] = get_state_dict(model_ema)
|
| 440 |
+
utils.save_on_master(state_dict, checkpoint_path)
|
| 441 |
+
|
| 442 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
| 443 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
| 444 |
+
'epoch': epoch,
|
| 445 |
+
'n_parameters': n_parameters}
|
| 446 |
+
|
| 447 |
+
if args.output_dir and utils.is_main_process():
|
| 448 |
+
with (output_dir / "log.txt").open("a") as f:
|
| 449 |
+
f.write(json.dumps(log_stats) + "\n")
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
total_time = time.time() - start_time
|
| 453 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 454 |
+
print('Training time {}'.format(total_time_str))
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
if __name__ == '__main__':
|
| 458 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', parents=[get_args_parser()])
|
| 459 |
+
args = parser.parse_args()
|
| 460 |
+
if args.output_dir:
|
| 461 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 462 |
+
main(args)
|
make_tall_txt.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# make_tall_txt.py
|
| 2 |
+
# Usage:
|
| 3 |
+
# python make_tall_txt.py --root frames_root --out lists --train_ratio 0.8 --seed 42
|
| 4 |
+
#
|
| 5 |
+
# Expected structure:
|
| 6 |
+
# root/
|
| 7 |
+
# real/<video_id>/*.jpg|png...
|
| 8 |
+
# fake/<video_id>/*.jpg|png...
|
| 9 |
+
#
|
| 10 |
+
# Output:
|
| 11 |
+
# lists/train.txt
|
| 12 |
+
# lists/test.txt
|
| 13 |
+
#
|
| 14 |
+
# Each line:
|
| 15 |
+
# relative_path start_frame end_frame label
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
import re
|
| 19 |
+
import argparse
|
| 20 |
+
import random
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
|
| 23 |
+
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
|
| 24 |
+
|
| 25 |
+
# Try to parse frame index from filenames like:
|
| 26 |
+
# frame_000012_f000036.jpg -> uses the "f000036" part
|
| 27 |
+
# 000036.png -> uses the numeric stem
|
| 28 |
+
RE_FPART = re.compile(r"_f(\d+)", re.IGNORECASE)
|
| 29 |
+
RE_NUM = re.compile(r"(\d+)$")
|
| 30 |
+
|
| 31 |
+
def list_frames(video_dir: Path):
|
| 32 |
+
files = [p for p in video_dir.iterdir() if p.is_file() and p.suffix.lower() in IMG_EXTS]
|
| 33 |
+
files.sort()
|
| 34 |
+
return files
|
| 35 |
+
|
| 36 |
+
def parse_frame_idx(p: Path):
|
| 37 |
+
s = p.stem
|
| 38 |
+
m = RE_FPART.search(s)
|
| 39 |
+
if m:
|
| 40 |
+
return int(m.group(1))
|
| 41 |
+
m = RE_NUM.search(s)
|
| 42 |
+
if m:
|
| 43 |
+
return int(m.group(1))
|
| 44 |
+
return None # unknown
|
| 45 |
+
|
| 46 |
+
def video_start_end(video_dir: Path):
|
| 47 |
+
frames = list_frames(video_dir)
|
| 48 |
+
if not frames:
|
| 49 |
+
return None, None, 0
|
| 50 |
+
|
| 51 |
+
idxs = [parse_frame_idx(p) for p in frames]
|
| 52 |
+
idxs = [i for i in idxs if i is not None]
|
| 53 |
+
|
| 54 |
+
# Fallback: if can't parse indices, use 1..N
|
| 55 |
+
if not idxs:
|
| 56 |
+
return 1, len(frames), len(frames)
|
| 57 |
+
|
| 58 |
+
return min(idxs), max(idxs), len(frames)
|
| 59 |
+
|
| 60 |
+
def collect_videos(root: Path, class_name: str):
|
| 61 |
+
class_dir = root / class_name
|
| 62 |
+
if not class_dir.exists():
|
| 63 |
+
return []
|
| 64 |
+
|
| 65 |
+
videos = []
|
| 66 |
+
for vd in class_dir.iterdir():
|
| 67 |
+
if vd.is_dir():
|
| 68 |
+
start, end, n = video_start_end(vd)
|
| 69 |
+
if n > 0:
|
| 70 |
+
videos.append((vd, start, end, n))
|
| 71 |
+
videos.sort(key=lambda x: x[0].name)
|
| 72 |
+
return videos
|
| 73 |
+
|
| 74 |
+
def write_list(items, out_path: Path, root: Path, label_map):
|
| 75 |
+
with open(out_path, "w", encoding="utf-8") as f:
|
| 76 |
+
for vd, start, end, _n, label in items:
|
| 77 |
+
rel = vd.relative_to(root).as_posix()
|
| 78 |
+
f.write(f"{rel} {start} {end} {label}\n")
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
ap = argparse.ArgumentParser()
|
| 82 |
+
ap.add_argument("--root", required=True, help="Root with real/ and fake/ video folders")
|
| 83 |
+
ap.add_argument("--out", default="lists", help="Output folder for txt files")
|
| 84 |
+
ap.add_argument("--train_ratio", type=float, default=0.8)
|
| 85 |
+
ap.add_argument("--seed", type=int, default=42)
|
| 86 |
+
ap.add_argument("--label_real", type=int, default=0, help="Label for real")
|
| 87 |
+
ap.add_argument("--label_fake", type=int, default=1, help="Label for fake")
|
| 88 |
+
args = ap.parse_args()
|
| 89 |
+
|
| 90 |
+
if not (0.0 < args.train_ratio < 1.0):
|
| 91 |
+
raise SystemExit("--train_ratio must be between 0 and 1.")
|
| 92 |
+
|
| 93 |
+
root = Path(args.root)
|
| 94 |
+
out = Path(args.out)
|
| 95 |
+
out.mkdir(parents=True, exist_ok=True)
|
| 96 |
+
|
| 97 |
+
# Collect
|
| 98 |
+
real_videos = collect_videos(root, "real")
|
| 99 |
+
fake_videos = collect_videos(root, "fake")
|
| 100 |
+
|
| 101 |
+
if not real_videos:
|
| 102 |
+
raise SystemExit(f"No videos found under: {root/'real'}")
|
| 103 |
+
if not fake_videos:
|
| 104 |
+
raise SystemExit(f"No videos found under: {root/'fake'}")
|
| 105 |
+
|
| 106 |
+
# Build items list (per video)
|
| 107 |
+
label_map = {"real": args.label_real, "fake": args.label_fake}
|
| 108 |
+
items = []
|
| 109 |
+
for vd, start, end, n in real_videos:
|
| 110 |
+
items.append((vd, start, end, n, label_map["real"]))
|
| 111 |
+
for vd, start, end, n in fake_videos:
|
| 112 |
+
items.append((vd, start, end, n, label_map["fake"]))
|
| 113 |
+
|
| 114 |
+
# Split by class (keeps balance more stable)
|
| 115 |
+
rng = random.Random(args.seed)
|
| 116 |
+
|
| 117 |
+
def split_class(videos, label):
|
| 118 |
+
vids = [(vd, s, e, n, label) for vd, s, e, n in videos]
|
| 119 |
+
rng.shuffle(vids)
|
| 120 |
+
k = int(round(len(vids) * args.train_ratio))
|
| 121 |
+
return vids[:k], vids[k:]
|
| 122 |
+
|
| 123 |
+
real_train, real_test = split_class(real_videos, label_map["real"])
|
| 124 |
+
fake_train, fake_test = split_class(fake_videos, label_map["fake"])
|
| 125 |
+
|
| 126 |
+
train_items = real_train + fake_train
|
| 127 |
+
test_items = real_test + fake_test
|
| 128 |
+
|
| 129 |
+
rng.shuffle(train_items)
|
| 130 |
+
rng.shuffle(test_items)
|
| 131 |
+
|
| 132 |
+
# Write
|
| 133 |
+
train_path = out / "cdf_train_fold.txt"
|
| 134 |
+
test_path = out / "cdf_test_fold.txt"
|
| 135 |
+
write_list(train_items, train_path, root, label_map)
|
| 136 |
+
write_list(test_items, test_path, root, label_map)
|
| 137 |
+
|
| 138 |
+
print("DONE")
|
| 139 |
+
print(f"Train videos: {len(train_items)} (real {len(real_train)}, fake {len(fake_train)})")
|
| 140 |
+
print(f"Test videos: {len(test_items)} (real {len(real_test)}, fake {len(fake_test)})")
|
| 141 |
+
print("Saved:")
|
| 142 |
+
print(" ", train_path.resolve())
|
| 143 |
+
print(" ", test_path.resolve())
|
| 144 |
+
|
| 145 |
+
if __name__ == "__main__":
|
| 146 |
+
main()
|
make_tall_txt_count.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse, random
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
| 5 |
+
|
| 6 |
+
def count_frames(video_dir: Path) -> int:
|
| 7 |
+
return sum(1 for p in video_dir.iterdir() if p.is_file() and p.suffix.lower() in IMG_EXTS)
|
| 8 |
+
|
| 9 |
+
def collect(root: Path, cls: str, label: int):
|
| 10 |
+
base = root / cls
|
| 11 |
+
out = []
|
| 12 |
+
for vd in sorted([p for p in base.iterdir() if p.is_dir()]):
|
| 13 |
+
n = count_frames(vd)
|
| 14 |
+
if n >= 4: # precisa >3 frames
|
| 15 |
+
rel = vd.relative_to(root).as_posix()
|
| 16 |
+
out.append((rel, 1, n, label))
|
| 17 |
+
return out
|
| 18 |
+
|
| 19 |
+
def split(items, train_ratio, seed):
|
| 20 |
+
rng = random.Random(seed)
|
| 21 |
+
rng.shuffle(items)
|
| 22 |
+
k = int(round(len(items)*train_ratio))
|
| 23 |
+
return items[:k], items[k:]
|
| 24 |
+
|
| 25 |
+
def write_list(path: Path, items):
|
| 26 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 27 |
+
for rel, s, e, lab in items:
|
| 28 |
+
f.write(f"{rel} {s} {e} {lab}\n")
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
ap = argparse.ArgumentParser()
|
| 32 |
+
ap.add_argument("--root", required=True)
|
| 33 |
+
ap.add_argument("--out", default="lists")
|
| 34 |
+
ap.add_argument("--train_ratio", type=float, default=0.8)
|
| 35 |
+
ap.add_argument("--seed", type=int, default=42)
|
| 36 |
+
ap.add_argument("--label_real", type=int, default=0)
|
| 37 |
+
ap.add_argument("--label_fake", type=int, default=1)
|
| 38 |
+
args = ap.parse_args()
|
| 39 |
+
|
| 40 |
+
root = Path(args.root)
|
| 41 |
+
out = Path(args.out); out.mkdir(parents=True, exist_ok=True)
|
| 42 |
+
|
| 43 |
+
real = collect(root, "real", args.label_real)
|
| 44 |
+
fake = collect(root, "fake", args.label_fake)
|
| 45 |
+
|
| 46 |
+
r_tr, r_te = split(real, args.train_ratio, args.seed)
|
| 47 |
+
f_tr, f_te = split(fake, args.train_ratio, args.seed)
|
| 48 |
+
|
| 49 |
+
train = r_tr + f_tr
|
| 50 |
+
test = r_te + f_te
|
| 51 |
+
|
| 52 |
+
rng = random.Random(args.seed)
|
| 53 |
+
rng.shuffle(train); rng.shuffle(test)
|
| 54 |
+
|
| 55 |
+
write_list(out/"train.txt", train)
|
| 56 |
+
write_list(out/"test.txt", test)
|
| 57 |
+
|
| 58 |
+
write_list(out/"cdf_train_fold.txt", train)
|
| 59 |
+
write_list(out/"cdf_test_fold.txt", test)
|
| 60 |
+
|
| 61 |
+
print("OK")
|
| 62 |
+
print(f"train videos: {len(train)} | test videos: {len(test)}")
|
| 63 |
+
|
| 64 |
+
if __name__ == "__main__":
|
| 65 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
from timm.models.vision_transformer import VisionTransformer, _cfg
|
| 12 |
+
from timm.models.registry import register_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@register_model
|
| 16 |
+
def deit_tiny_patch8_224(pretrained=False, **kwargs):
|
| 17 |
+
model = VisionTransformer(
|
| 18 |
+
patch_size=8, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
| 19 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 20 |
+
model.default_cfg = _cfg()
|
| 21 |
+
if pretrained:
|
| 22 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 23 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
| 24 |
+
map_location="cpu", check_hash=True
|
| 25 |
+
)
|
| 26 |
+
model.load_state_dict(checkpoint["model"])
|
| 27 |
+
return model
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@register_model
|
| 31 |
+
def deit_tiny_patch16_224(pretrained=False, **kwargs):
|
| 32 |
+
model = VisionTransformer(
|
| 33 |
+
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
| 34 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 35 |
+
model.default_cfg = _cfg()
|
| 36 |
+
if pretrained:
|
| 37 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 38 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
| 39 |
+
map_location="cpu", check_hash=True
|
| 40 |
+
)
|
| 41 |
+
model.load_state_dict(checkpoint["model"])
|
| 42 |
+
return model
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@register_model
|
| 46 |
+
def deit_tiny_patch16_d_6_224(pretrained=False, **kwargs):
|
| 47 |
+
model = VisionTransformer(
|
| 48 |
+
patch_size=16, embed_dim=192, depth=6, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
| 49 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 50 |
+
model.default_cfg = _cfg()
|
| 51 |
+
if pretrained:
|
| 52 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 53 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
| 54 |
+
map_location="cpu", check_hash=True
|
| 55 |
+
)
|
| 56 |
+
model.load_state_dict(checkpoint["model"])
|
| 57 |
+
return model
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@register_model
|
| 61 |
+
def deit_tiny_patch32_224(pretrained=False, **kwargs):
|
| 62 |
+
model = VisionTransformer(
|
| 63 |
+
patch_size=32, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
|
| 64 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 65 |
+
model.default_cfg = _cfg()
|
| 66 |
+
if pretrained:
|
| 67 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 68 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
| 69 |
+
map_location="cpu", check_hash=True
|
| 70 |
+
)
|
| 71 |
+
model.load_state_dict(checkpoint["model"])
|
| 72 |
+
return model
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
@register_model
|
| 76 |
+
def deit_small_patch8_224(pretrained=False, **kwargs):
|
| 77 |
+
model = VisionTransformer(
|
| 78 |
+
patch_size=8, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 79 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 80 |
+
model.default_cfg = _cfg()
|
| 81 |
+
if pretrained:
|
| 82 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 83 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
| 84 |
+
map_location="cpu", check_hash=True
|
| 85 |
+
)
|
| 86 |
+
model.load_state_dict(checkpoint["model"])
|
| 87 |
+
return model
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@register_model
|
| 91 |
+
def deit_small_patch16_224(pretrained=False, **kwargs):
|
| 92 |
+
model = VisionTransformer(
|
| 93 |
+
patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 94 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 95 |
+
model.default_cfg = _cfg()
|
| 96 |
+
if pretrained:
|
| 97 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 98 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
| 99 |
+
map_location="cpu", check_hash=True
|
| 100 |
+
)
|
| 101 |
+
model.load_state_dict(checkpoint["model"])
|
| 102 |
+
return model
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@register_model
|
| 106 |
+
def deit_small_patch16_d_6_224(pretrained=False, **kwargs):
|
| 107 |
+
model = VisionTransformer(
|
| 108 |
+
patch_size=16, embed_dim=384, depth=6, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 109 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 110 |
+
model.default_cfg = _cfg()
|
| 111 |
+
if pretrained:
|
| 112 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 113 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
| 114 |
+
map_location="cpu", check_hash=True
|
| 115 |
+
)
|
| 116 |
+
model.load_state_dict(checkpoint["model"])
|
| 117 |
+
return model
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@register_model
|
| 121 |
+
def deit_small_patch32_224(pretrained=False, **kwargs):
|
| 122 |
+
model = VisionTransformer(
|
| 123 |
+
patch_size=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
|
| 124 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 125 |
+
model.default_cfg = _cfg()
|
| 126 |
+
if pretrained:
|
| 127 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 128 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
| 129 |
+
map_location="cpu", check_hash=True
|
| 130 |
+
)
|
| 131 |
+
model.load_state_dict(checkpoint["model"])
|
| 132 |
+
return model
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@register_model
|
| 136 |
+
def deit_base_patch8_224(pretrained=False, **kwargs):
|
| 137 |
+
model = VisionTransformer(
|
| 138 |
+
patch_size=8, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 139 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 140 |
+
model.default_cfg = _cfg()
|
| 141 |
+
if pretrained:
|
| 142 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 143 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 144 |
+
map_location="cpu", check_hash=True
|
| 145 |
+
)
|
| 146 |
+
model.load_state_dict(checkpoint["model"])
|
| 147 |
+
return model
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
@register_model
|
| 151 |
+
def deit_base_patch16_224(pretrained=False, **kwargs):
|
| 152 |
+
model = VisionTransformer(
|
| 153 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 154 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 155 |
+
model.default_cfg = _cfg()
|
| 156 |
+
if pretrained:
|
| 157 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 158 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 159 |
+
map_location="cpu", check_hash=True
|
| 160 |
+
)
|
| 161 |
+
model.load_state_dict(checkpoint["model"])
|
| 162 |
+
return model
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@register_model
|
| 166 |
+
def deit_base_patch16_ft_224(pretrained=False, **kwargs):
|
| 167 |
+
model = VisionTransformer(
|
| 168 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 169 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 170 |
+
model.default_cfg = _cfg()
|
| 171 |
+
if pretrained:
|
| 172 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 173 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 174 |
+
map_location="cpu", check_hash=True
|
| 175 |
+
)
|
| 176 |
+
model.load_state_dict(checkpoint["model"])
|
| 177 |
+
|
| 178 |
+
for m in model.parameters():
|
| 179 |
+
m.requires_grad = False
|
| 180 |
+
|
| 181 |
+
for m in model.head.parameters():
|
| 182 |
+
m.requires_grad = True
|
| 183 |
+
|
| 184 |
+
return model
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
@register_model
|
| 189 |
+
def deit_base24_patch16_224(pretrained=False, **kwargs):
|
| 190 |
+
model = VisionTransformer(
|
| 191 |
+
patch_size=16, embed_dim=768, depth=24, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 192 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 193 |
+
model.default_cfg = _cfg()
|
| 194 |
+
if pretrained:
|
| 195 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 196 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 197 |
+
map_location="cpu", check_hash=True
|
| 198 |
+
)
|
| 199 |
+
model.load_state_dict(checkpoint["model"])
|
| 200 |
+
return model
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@register_model
|
| 204 |
+
def deit_base16_patch16_224(pretrained=False, **kwargs):
|
| 205 |
+
model = VisionTransformer(
|
| 206 |
+
patch_size=16, embed_dim=768, depth=16, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 207 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 208 |
+
model.default_cfg = _cfg()
|
| 209 |
+
if pretrained:
|
| 210 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 211 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 212 |
+
map_location="cpu", check_hash=True
|
| 213 |
+
)
|
| 214 |
+
model.load_state_dict(checkpoint["model"])
|
| 215 |
+
return model
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
@register_model
|
| 219 |
+
def deit_base_patch16_384(pretrained=False, **kwargs):
|
| 220 |
+
model = VisionTransformer(img_size=384,
|
| 221 |
+
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 222 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 223 |
+
model.default_cfg = _cfg()
|
| 224 |
+
if pretrained:
|
| 225 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 226 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 227 |
+
map_location="cpu", check_hash=True
|
| 228 |
+
)
|
| 229 |
+
model.load_state_dict(checkpoint["model"])
|
| 230 |
+
return model
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@register_model
|
| 234 |
+
def deit_base_patch32_224(pretrained=False, **kwargs):
|
| 235 |
+
model = VisionTransformer(
|
| 236 |
+
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
| 237 |
+
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
| 238 |
+
model.default_cfg = _cfg()
|
| 239 |
+
if pretrained:
|
| 240 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 241 |
+
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
| 242 |
+
map_location="cpu", check_hash=True
|
| 243 |
+
)
|
| 244 |
+
model.load_state_dict(checkpoint["model"])
|
| 245 |
+
return model
|
renumber_frames_for_tall.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# renumber_frames_for_tall.py
|
| 2 |
+
# Usage:
|
| 3 |
+
# python renumber_frames_for_tall.py --root "C:\...\TALL4Deepfake\data" --ext .jpg --digits 4 --copy
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import shutil
|
| 8 |
+
|
| 9 |
+
IMG_EXTS = {".jpg", ".jpeg", ".png", ".bmp", ".webp"}
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
ap = argparse.ArgumentParser()
|
| 13 |
+
ap.add_argument("--root", required=True, help="Root that contains real/ and fake/")
|
| 14 |
+
ap.add_argument("--digits", type=int, default=4, help="Digits for numbering (4 -> 0001.jpg)")
|
| 15 |
+
ap.add_argument("--ext", default=".jpg", help="Target extension for output filenames, e.g. .jpg")
|
| 16 |
+
ap.add_argument("--copy", action="store_true", help="Copy instead of rename (safer)")
|
| 17 |
+
args = ap.parse_args()
|
| 18 |
+
|
| 19 |
+
root = Path(args.root)
|
| 20 |
+
for cls in ["real", "fake"]:
|
| 21 |
+
cls_dir = root / cls
|
| 22 |
+
if not cls_dir.exists():
|
| 23 |
+
print(f"[skip] {cls_dir} not found")
|
| 24 |
+
continue
|
| 25 |
+
|
| 26 |
+
for vid_dir in sorted([p for p in cls_dir.iterdir() if p.is_dir()]):
|
| 27 |
+
frames = [p for p in vid_dir.iterdir() if p.is_file() and p.suffix.lower() in IMG_EXTS]
|
| 28 |
+
frames.sort(key=lambda p: p.name)
|
| 29 |
+
|
| 30 |
+
if not frames:
|
| 31 |
+
print(f"[empty] {vid_dir}")
|
| 32 |
+
continue
|
| 33 |
+
|
| 34 |
+
tmp_dir = vid_dir / "_tall_tmp"
|
| 35 |
+
tmp_dir.mkdir(exist_ok=True)
|
| 36 |
+
|
| 37 |
+
for i, src in enumerate(frames, start=1):
|
| 38 |
+
dst_name = f"{i:0{args.digits}d}{args.ext}"
|
| 39 |
+
dst = tmp_dir / dst_name
|
| 40 |
+
if args.copy:
|
| 41 |
+
shutil.copy2(src, dst)
|
| 42 |
+
else:
|
| 43 |
+
shutil.move(src, dst)
|
| 44 |
+
|
| 45 |
+
# move tmp content back
|
| 46 |
+
for f in tmp_dir.iterdir():
|
| 47 |
+
shutil.move(str(f), str(vid_dir / f.name))
|
| 48 |
+
tmp_dir.rmdir()
|
| 49 |
+
|
| 50 |
+
print(f"[ok] {vid_dir.name}: {len(frames)} frames -> renamed to 1..{len(frames)}")
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
main()
|
requirements-torch.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
--index-url https://download.pytorch.org/whl/cu117
|
| 2 |
+
torch==1.13.1+cu117
|
| 3 |
+
torchvision==0.14.1+cu117
|
| 4 |
+
opencv-python==4.13.0.92
|
requirements.txt
ADDED
|
Binary file (2.37 kB). View file
|
|
|
test.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.backends.cudnn as cudnn
|
| 5 |
+
import os
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from timm.models import create_model
|
| 11 |
+
from timm.utils import ModelEma
|
| 12 |
+
|
| 13 |
+
#from datasets import build_dataset
|
| 14 |
+
import my_models
|
| 15 |
+
from engine import evaluate
|
| 16 |
+
#import simclr
|
| 17 |
+
import utils
|
| 18 |
+
|
| 19 |
+
from video_dataset import VideoDataSet
|
| 20 |
+
from video_dataset_aug import get_augmentor, build_dataflow
|
| 21 |
+
from video_dataset_config import get_dataset_config, DATASET_CONFIG
|
| 22 |
+
|
| 23 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 24 |
+
#torch.multiprocessing.set_start_method('spawn', force=True)
|
| 25 |
+
|
| 26 |
+
def get_args_parser():
|
| 27 |
+
parser = argparse.ArgumentParser('DeiT training and evaluation script', add_help=False)
|
| 28 |
+
parser.add_argument('--model_name',default="TALL_SWIN")
|
| 29 |
+
parser.add_argument('--batch-size', default=2, type=int)
|
| 30 |
+
parser.add_argument('--epochs', default=30, type=int)
|
| 31 |
+
|
| 32 |
+
# Dataset parameters
|
| 33 |
+
parser.add_argument('--data_txt_dir', type=str,default='##path_for_dataset_txt##', help='path to text of dataset')
|
| 34 |
+
parser.add_argument('--data_dir', type=str,default="##path_for_dataset##", help='path to dataset')
|
| 35 |
+
parser.add_argument('--dataset', default='ffpp',
|
| 36 |
+
choices=list(DATASET_CONFIG.keys()), help='path to dataset file list')
|
| 37 |
+
parser.add_argument('--duration', default=1, type=int, help='number of frames')
|
| 38 |
+
parser.add_argument('--frames_per_group', default=1, type=int,
|
| 39 |
+
help='[uniform sampling] number of frames per group; '
|
| 40 |
+
'[dense sampling]: sampling frequency')
|
| 41 |
+
parser.add_argument('--threed_data', default=False, help='load data in the layout for 3D conv')
|
| 42 |
+
parser.add_argument('--input_size', default=224, type=int, metavar='N', help='input image size')
|
| 43 |
+
parser.add_argument('--disable_scaleup', action='store_true',
|
| 44 |
+
help='do not scale up and then crop a small region, directly crop the input_size')
|
| 45 |
+
parser.add_argument('--random_sampling', action='store_true',
|
| 46 |
+
help='perform determinstic sampling for data loader')
|
| 47 |
+
parser.add_argument('--dense_sampling', default=True,
|
| 48 |
+
help='perform dense sampling for data loader')
|
| 49 |
+
parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2'],
|
| 50 |
+
help='[v1] TSN data argmentation, [v2] resize the shorter side to `scale_range`')
|
| 51 |
+
parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+",
|
| 52 |
+
metavar='scale_range', help='scale range for augmentor v2')
|
| 53 |
+
parser.add_argument('--modality', default='rgb', type=str, help='rgb or flow')
|
| 54 |
+
parser.add_argument('--use_lmdb', default=False, help='use lmdb instead of jpeg.')
|
| 55 |
+
parser.add_argument('--use_pyav', default=False, help='use video directly.')
|
| 56 |
+
|
| 57 |
+
# temporal module
|
| 58 |
+
parser.add_argument('--pretrained', action='store_true', default=False,
|
| 59 |
+
help='Start with pretrained version of specified network (if avail)')
|
| 60 |
+
parser.add_argument('--temporal_module_name', default=None, type=str, metavar='TEM', choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA'],
|
| 61 |
+
help='temporal module applied. [TAM]')
|
| 62 |
+
parser.add_argument('--temporal_attention_only', action='store_true', default=False,
|
| 63 |
+
help='use attention only in temporal module]')
|
| 64 |
+
parser.add_argument('--no_token_mask', action='store_true', default=False, help='do not apply token mask')
|
| 65 |
+
parser.add_argument('--temporal_heads_scale', default=1.0, type=float, help='scale of the number of spatial heads')
|
| 66 |
+
parser.add_argument('--temporal_mlp_scale', default=1.0, type=float, help='scale of spatial mlp')
|
| 67 |
+
parser.add_argument('--rel_pos', action='store_true', default=False,
|
| 68 |
+
help='use relative positioning in temporal module]')
|
| 69 |
+
parser.add_argument('--temporal_pooling', type=str, default=None, choices=['avg', 'max', 'conv', 'depthconv'],
|
| 70 |
+
help='perform temporal pooling]')
|
| 71 |
+
parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw'],
|
| 72 |
+
help='use depth-wise bottleneck in temporal attention')
|
| 73 |
+
|
| 74 |
+
parser.add_argument('--window_size', default=7, type=int, help='number of frames')
|
| 75 |
+
parser.add_argument('--thumbnail_rows', default=3, type=int, help='number of frames per row')
|
| 76 |
+
|
| 77 |
+
parser.add_argument('--hpe_to_token', default=False, action='store_true',
|
| 78 |
+
help='add hub position embedding to image tokens')
|
| 79 |
+
# Model parameters
|
| 80 |
+
parser.add_argument('--model', default='TALL_SWIN', type=str, metavar='MODEL',
|
| 81 |
+
help='Name of model to train')
|
| 82 |
+
# parser.add_argument('--input-size', default=224, type=int, help='images input size')
|
| 83 |
+
|
| 84 |
+
parser.add_argument('--drop', type=float, default=0.0, metavar='PCT',
|
| 85 |
+
help='Dropout rate (default: 0.)')
|
| 86 |
+
parser.add_argument('--drop-path', type=float, default=0.1, metavar='PCT',
|
| 87 |
+
help='Drop path rate (default: 0.1)')
|
| 88 |
+
parser.add_argument('--drop-block', type=float, default=None, metavar='PCT',
|
| 89 |
+
help='Drop block rate (default: None)')
|
| 90 |
+
|
| 91 |
+
parser.add_argument('--model-ema', action='store_true')
|
| 92 |
+
parser.add_argument('--no-model-ema', action='store_false', dest='model_ema')
|
| 93 |
+
parser.set_defaults(model_ema=True)
|
| 94 |
+
parser.add_argument('--model-ema-decay', type=float, default=0.99996, help='')
|
| 95 |
+
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='')
|
| 96 |
+
|
| 97 |
+
# Optimizer parameters
|
| 98 |
+
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
|
| 99 |
+
help='Optimizer (default: "adamw"')
|
| 100 |
+
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
|
| 101 |
+
help='Optimizer Epsilon (default: 1e-8)')
|
| 102 |
+
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
|
| 103 |
+
help='Optimizer Betas (default: None, use opt default)')
|
| 104 |
+
parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
|
| 105 |
+
help='Clip gradient norm (default: None, no clipping)')
|
| 106 |
+
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
|
| 107 |
+
help='SGD momentum (default: 0.9)')
|
| 108 |
+
parser.add_argument('--weight-decay', type=float, default=1e-5,
|
| 109 |
+
help='weight decay (default: 0.05)')
|
| 110 |
+
# Learning rate schedule parameters
|
| 111 |
+
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
|
| 112 |
+
help='LR scheduler (default: "cosine"')
|
| 113 |
+
parser.add_argument('--lr', type=float, default=5e-5, metavar='LR',
|
| 114 |
+
help='learning rate (default: 5e-4)')
|
| 115 |
+
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
|
| 116 |
+
help='learning rate noise on/off epoch percentages')
|
| 117 |
+
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
|
| 118 |
+
help='learning rate noise limit percent (default: 0.67)')
|
| 119 |
+
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
|
| 120 |
+
help='learning rate noise std-dev (default: 1.0)')
|
| 121 |
+
parser.add_argument('--warmup-lr', type=float, default=1e-7, metavar='LR',
|
| 122 |
+
help='warmup learning rate (default: 1e-6)')
|
| 123 |
+
parser.add_argument('--min-lr', type=float, default=2e-6, metavar='LR',
|
| 124 |
+
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
|
| 125 |
+
|
| 126 |
+
parser.add_argument('--decay-epochs', type=float, default=10, metavar='N',
|
| 127 |
+
help='epoch interval to decay LR')
|
| 128 |
+
parser.add_argument('--warmup-epochs', type=int, default=10, metavar='N',
|
| 129 |
+
help='epochs to warmup LR, if scheduler supports')
|
| 130 |
+
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
|
| 131 |
+
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
|
| 132 |
+
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
|
| 133 |
+
help='patience epochs for Plateau LR scheduler (default: 10')
|
| 134 |
+
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
|
| 135 |
+
help='LR decay rate (default: 0.1)')
|
| 136 |
+
|
| 137 |
+
# Augmentation parameters
|
| 138 |
+
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
|
| 139 |
+
help='Color jitter factor (default: 0.4)')
|
| 140 |
+
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
|
| 141 |
+
help='Use AutoAugment policy. "v0" or "original". " + \
|
| 142 |
+
"(default: rand-m9-mstd0.5-inc1)'),
|
| 143 |
+
parser.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)')
|
| 144 |
+
parser.add_argument('--train-interpolation', type=str, default='bicubic',
|
| 145 |
+
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
|
| 146 |
+
|
| 147 |
+
parser.add_argument('--repeated-aug', action='store_true')
|
| 148 |
+
parser.add_argument('--no-repeated-aug', action='store_false', dest='repeated_aug')
|
| 149 |
+
parser.set_defaults(repeated_aug=False)
|
| 150 |
+
|
| 151 |
+
# * Random Erase params
|
| 152 |
+
parser.add_argument('--reprob', type=float, default=0.0, metavar='PCT',
|
| 153 |
+
help='Random erase prob (default: 0.25)')
|
| 154 |
+
parser.add_argument('--remode', type=str, default='pixel',
|
| 155 |
+
help='Random erase mode (default: "pixel")')
|
| 156 |
+
parser.add_argument('--recount', type=int, default=1,
|
| 157 |
+
help='Random erase count (default: 1)')
|
| 158 |
+
parser.add_argument('--resplit', action='store_true', default=False,
|
| 159 |
+
help='Do not random erase first (clean) augmentation split')
|
| 160 |
+
|
| 161 |
+
# * Mixup params
|
| 162 |
+
parser.add_argument('--mixup', type=float, default=0,
|
| 163 |
+
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
|
| 164 |
+
parser.add_argument('--cutmix', type=float, default=0,
|
| 165 |
+
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
|
| 166 |
+
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
|
| 167 |
+
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
|
| 168 |
+
parser.add_argument('--mixup-prob', type=float, default=1.0,
|
| 169 |
+
help='Probability of performing mixup or cutmix when either/both is enabled')
|
| 170 |
+
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
|
| 171 |
+
help='Probability of switching to cutmix when both mixup and cutmix enabled')
|
| 172 |
+
parser.add_argument('--mixup-mode', type=str, default='batch',
|
| 173 |
+
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
|
| 174 |
+
|
| 175 |
+
# Dataset parameters
|
| 176 |
+
|
| 177 |
+
parser.add_argument('--output_dir', default="./output",
|
| 178 |
+
help='path where to save, empty for no saving')
|
| 179 |
+
parser.add_argument('--device', default='cuda',
|
| 180 |
+
help='device to use for training / testing')
|
| 181 |
+
parser.add_argument('--seed', default=42, type=int)
|
| 182 |
+
parser.add_argument('--resume', default='', help='resume from checkpoint')
|
| 183 |
+
parser.add_argument('--no-resume-loss-scaler', action='store_false', dest='resume_loss_scaler')
|
| 184 |
+
parser.add_argument('--no-amp', action='store_false', dest='amp', help='disable amp')
|
| 185 |
+
parser.add_argument('--use_checkpoint', default=False, help='use checkpoint to save memory')
|
| 186 |
+
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
|
| 187 |
+
help='start epoch')
|
| 188 |
+
parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
|
| 189 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
| 190 |
+
parser.add_argument('--pin-mem', action='store_true',
|
| 191 |
+
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
| 192 |
+
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
|
| 193 |
+
help='')
|
| 194 |
+
parser.set_defaults(pin_mem=True)
|
| 195 |
+
|
| 196 |
+
# for testing and validation
|
| 197 |
+
parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10])
|
| 198 |
+
parser.add_argument('--num_clips', default=3, type=int)
|
| 199 |
+
|
| 200 |
+
# distributed training parameters
|
| 201 |
+
parser.add_argument('--world_size', default=1, type=int,
|
| 202 |
+
help='number of distributed processes')
|
| 203 |
+
parser.add_argument("--local_rank", type=int)
|
| 204 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
parser.add_argument('--auto-resume', default=True, help='auto resume')
|
| 208 |
+
# exp
|
| 209 |
+
# parser.add_argument('--simclr_w', type=float, default=0., help='weights for simclr loss')
|
| 210 |
+
parser.add_argument('--contrastive_nomixup', action='store_true', help='do not involve mixup in contrastive learning')
|
| 211 |
+
parser.add_argument('--finetune', default=False, help='finetune model')
|
| 212 |
+
parser.add_argument('--initial_checkpoint', type=str, default='', help='path to the pretrained model')
|
| 213 |
+
|
| 214 |
+
parser.add_argument('--hard_contrastive', action='store_true', help='use HEXA')
|
| 215 |
+
# parser.add_argument('--selfdis_w', type=float, default=0., help='enable self distillation')
|
| 216 |
+
|
| 217 |
+
return parser
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def main(args):
|
| 221 |
+
utils.init_distributed_mode(args)
|
| 222 |
+
print(args)
|
| 223 |
+
# Patch
|
| 224 |
+
if not hasattr(args, 'hard_contrastive'):
|
| 225 |
+
args.hard_contrastive = False
|
| 226 |
+
if not hasattr(args, 'selfdis_w'):
|
| 227 |
+
args.selfdis_w = 0.0
|
| 228 |
+
|
| 229 |
+
#is_imnet21k = args.data_set == 'IMNET21K'
|
| 230 |
+
|
| 231 |
+
device = torch.device(args.device)
|
| 232 |
+
|
| 233 |
+
# fix the seed for reproducibility
|
| 234 |
+
seed = args.seed + utils.get_rank()
|
| 235 |
+
torch.manual_seed(seed)
|
| 236 |
+
np.random.seed(seed)
|
| 237 |
+
# random.seed(seed)
|
| 238 |
+
|
| 239 |
+
cudnn.benchmark = True
|
| 240 |
+
|
| 241 |
+
num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = get_dataset_config(
|
| 242 |
+
args.dataset, args.use_lmdb)
|
| 243 |
+
|
| 244 |
+
args.num_classes = num_classes
|
| 245 |
+
if args.modality == 'rgb':
|
| 246 |
+
args.input_channels = 3
|
| 247 |
+
elif args.modality == 'flow':
|
| 248 |
+
args.input_channels = 2 * 5
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
print(f"Creating model: {args.model}")
|
| 252 |
+
|
| 253 |
+
model = create_model(
|
| 254 |
+
args.model,
|
| 255 |
+
pretrained=args.pretrained,
|
| 256 |
+
duration=args.duration,
|
| 257 |
+
hpe_to_token = args.hpe_to_token,
|
| 258 |
+
rel_pos = args.rel_pos,
|
| 259 |
+
window_size=args.window_size,
|
| 260 |
+
thumbnail_rows = args.thumbnail_rows,
|
| 261 |
+
token_mask=not args.no_token_mask,
|
| 262 |
+
online_learning = False,
|
| 263 |
+
num_classes=args.num_classes,
|
| 264 |
+
drop_rate=args.drop,
|
| 265 |
+
drop_path_rate=args.drop_path,
|
| 266 |
+
drop_block_rate=args.drop_block,
|
| 267 |
+
use_checkpoint=args.use_checkpoint
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
# TODO: finetuning
|
| 271 |
+
|
| 272 |
+
model.to(device)
|
| 273 |
+
|
| 274 |
+
model_ema = None
|
| 275 |
+
if args.model_ema:
|
| 276 |
+
# Important to create EMA model after cuda(), DP wrapper, and AMP but before SyncBN and DDP wrapper
|
| 277 |
+
model_ema = ModelEma(
|
| 278 |
+
model,
|
| 279 |
+
decay=args.model_ema_decay,
|
| 280 |
+
device='cpu' if args.model_ema_force_cpu else '',
|
| 281 |
+
resume=args.resume)
|
| 282 |
+
|
| 283 |
+
model_without_ddp = model
|
| 284 |
+
if args.distributed:
|
| 285 |
+
#model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
|
| 286 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
| 287 |
+
model_without_ddp = model.module
|
| 288 |
+
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 289 |
+
print('number of params:', n_parameters)
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
if args.distributed:
|
| 293 |
+
mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean']
|
| 294 |
+
std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std']
|
| 295 |
+
else:
|
| 296 |
+
mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean']
|
| 297 |
+
std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std']
|
| 298 |
+
# dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
|
| 299 |
+
# create data loaders w/ augmentation pipeiine
|
| 300 |
+
video_data_cls = VideoDataSet
|
| 301 |
+
|
| 302 |
+
num_tasks = utils.get_world_size()
|
| 303 |
+
|
| 304 |
+
val_list = os.path.join(args.data_txt_dir, val_list_name)
|
| 305 |
+
val_augmentor = get_augmentor(False, args.input_size, mean, std, args.disable_scaleup,
|
| 306 |
+
threed_data=args.threed_data, version=args.augmentor_ver,
|
| 307 |
+
scale_range=args.scale_range, num_clips=args.num_clips, num_crops=args.num_crops, dataset=args.dataset)
|
| 308 |
+
dataset_val = video_data_cls(args.data_dir, val_list, args.duration, args.frames_per_group,
|
| 309 |
+
num_clips=args.num_clips,
|
| 310 |
+
modality=args.modality,
|
| 311 |
+
dense_sampling=args.dense_sampling,
|
| 312 |
+
image_tmpl=image_tmpl,
|
| 313 |
+
transform=val_augmentor,
|
| 314 |
+
is_train=False, test_mode=False,
|
| 315 |
+
seperator=filename_seperator, filter_video=filter_video)
|
| 316 |
+
|
| 317 |
+
data_loader_val = build_dataflow(dataset_val, is_train=False, batch_size=args.batch_size,
|
| 318 |
+
workers=args.num_workers, is_distributed=args.distributed)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
if args.initial_checkpoint:
|
| 322 |
+
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
| 323 |
+
utils.load_checkpoint(model, checkpoint['model'])
|
| 324 |
+
|
| 325 |
+
state = evaluate(data_loader_val, model, device, num_tasks, distributed=args.distributed, amp=args.amp, num_crops=args.num_crops, num_clips=args.num_clips)
|
| 326 |
+
print(f"Accuracy of the network on the {len(dataset_val)} test images: {state['acc1']:.1f}%")
|
| 327 |
+
|
| 328 |
+
if __name__ == '__main__':
|
| 329 |
+
parser = argparse.ArgumentParser('DeiT evaluation script', parents=[get_args_parser()])
|
| 330 |
+
args = parser.parse_args()
|
| 331 |
+
if args.output_dir:
|
| 332 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 333 |
+
main(args)
|
test_new.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.backends.cudnn as cudnn
|
| 5 |
+
import os
|
| 6 |
+
import warnings
|
| 7 |
+
import json
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
|
| 10 |
+
from timm.models import create_model
|
| 11 |
+
|
| 12 |
+
import my_models # registra TALL_SWIN
|
| 13 |
+
import utils
|
| 14 |
+
|
| 15 |
+
from video_dataset import VideoDataSet
|
| 16 |
+
from video_dataset_aug import get_augmentor, build_dataflow
|
| 17 |
+
from video_dataset_config import get_dataset_config, DATASET_CONFIG
|
| 18 |
+
|
| 19 |
+
from sklearn.metrics import (
|
| 20 |
+
accuracy_score, balanced_accuracy_score,
|
| 21 |
+
precision_recall_fscore_support,
|
| 22 |
+
confusion_matrix, classification_report,
|
| 23 |
+
roc_auc_score, roc_curve,
|
| 24 |
+
average_precision_score, precision_recall_curve
|
| 25 |
+
)
|
| 26 |
+
import matplotlib.pyplot as plt
|
| 27 |
+
|
| 28 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_args_parser():
|
| 32 |
+
parser = argparse.ArgumentParser('DeiT evaluation script', add_help=False)
|
| 33 |
+
|
| 34 |
+
parser.add_argument('--model', default='TALL_SWIN', type=str)
|
| 35 |
+
parser.add_argument('--model_name', default="TALL_SWIN")
|
| 36 |
+
parser.add_argument('--batch-size', default=2, type=int)
|
| 37 |
+
|
| 38 |
+
# Dataset parameters
|
| 39 |
+
parser.add_argument('--data_txt_dir', type=str, default='##path_for_dataset_txt##')
|
| 40 |
+
parser.add_argument('--data_dir', type=str, default="##path_for_dataset##")
|
| 41 |
+
parser.add_argument('--dataset', default='ffpp', choices=list(DATASET_CONFIG.keys()))
|
| 42 |
+
parser.add_argument('--duration', default=1, type=int)
|
| 43 |
+
parser.add_argument('--frames_per_group', default=1, type=int)
|
| 44 |
+
parser.add_argument('--threed_data', default=False)
|
| 45 |
+
parser.add_argument('--input_size', default=224, type=int)
|
| 46 |
+
parser.add_argument('--disable_scaleup', action='store_true')
|
| 47 |
+
parser.add_argument('--random_sampling', action='store_true')
|
| 48 |
+
parser.add_argument('--dense_sampling', default=True)
|
| 49 |
+
parser.add_argument('--augmentor_ver', default='v1', type=str, choices=['v1', 'v2'])
|
| 50 |
+
parser.add_argument('--scale_range', default=[256, 320], type=int, nargs="+")
|
| 51 |
+
parser.add_argument('--modality', default='rgb', type=str)
|
| 52 |
+
parser.add_argument('--use_lmdb', default=False)
|
| 53 |
+
parser.add_argument('--use_pyav', default=False)
|
| 54 |
+
|
| 55 |
+
# temporal module / model params
|
| 56 |
+
parser.add_argument('--pretrained', action='store_true', default=False)
|
| 57 |
+
parser.add_argument('--temporal_module_name', default=None, type=str,
|
| 58 |
+
choices=['ResNet3d', 'TAM', 'TTAM', 'TSM', 'TTSM', 'MSA'])
|
| 59 |
+
parser.add_argument('--temporal_attention_only', action='store_true', default=False)
|
| 60 |
+
parser.add_argument('--no_token_mask', action='store_true', default=False)
|
| 61 |
+
parser.add_argument('--temporal_heads_scale', default=1.0, type=float)
|
| 62 |
+
parser.add_argument('--temporal_mlp_scale', default=1.0, type=float)
|
| 63 |
+
parser.add_argument('--rel_pos', action='store_true', default=False)
|
| 64 |
+
parser.add_argument('--temporal_pooling', type=str, default=None,
|
| 65 |
+
choices=['avg', 'max', 'conv', 'depthconv'])
|
| 66 |
+
parser.add_argument('--bottleneck', default=None, choices=['regular', 'dw'])
|
| 67 |
+
|
| 68 |
+
parser.add_argument('--window_size', default=7, type=int)
|
| 69 |
+
parser.add_argument('--thumbnail_rows', default=3, type=int)
|
| 70 |
+
parser.add_argument('--hpe_to_token', default=False, action='store_true')
|
| 71 |
+
|
| 72 |
+
parser.add_argument('--drop', type=float, default=0.0)
|
| 73 |
+
parser.add_argument('--drop-path', type=float, default=0.1)
|
| 74 |
+
parser.add_argument('--drop-block', type=float, default=None)
|
| 75 |
+
|
| 76 |
+
# runtime
|
| 77 |
+
parser.add_argument('--output_dir', default="./output")
|
| 78 |
+
parser.add_argument('--device', default='cuda')
|
| 79 |
+
parser.add_argument('--seed', default=42, type=int)
|
| 80 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
| 81 |
+
|
| 82 |
+
parser.add_argument('--num_crops', default=1, type=int, choices=[1, 3, 5, 10])
|
| 83 |
+
parser.add_argument('--num_clips', default=3, type=int)
|
| 84 |
+
|
| 85 |
+
parser.add_argument('--world_size', default=1, type=int)
|
| 86 |
+
parser.add_argument("--local_rank", type=int)
|
| 87 |
+
parser.add_argument('--dist_url', default='env://')
|
| 88 |
+
|
| 89 |
+
# checkpoint
|
| 90 |
+
parser.add_argument('--initial_checkpoint', type=str, default='',
|
| 91 |
+
help='path do .pth/.pth.tar com checkpoint (espera key "model")')
|
| 92 |
+
|
| 93 |
+
parser.add_argument('--threshold', type=float, default=0.5,
|
| 94 |
+
help='threshold para decidir classe 1 (fake) a partir de prob[:,1]')
|
| 95 |
+
parser.add_argument('--metrics_out', default='', type=str,
|
| 96 |
+
help='pasta pra salvar metrics.json e plots (default: output_dir)')
|
| 97 |
+
parser.add_argument('--save_plots', action='store_true',
|
| 98 |
+
help='salvar cm.png / roc.png / pr.png')
|
| 99 |
+
|
| 100 |
+
return parser
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@torch.no_grad()
|
| 104 |
+
def eval_with_outputs(data_loader, model, device, threshold: float = 0.5):
|
| 105 |
+
model.eval()
|
| 106 |
+
y_true, y_score, y_pred = [], [], []
|
| 107 |
+
|
| 108 |
+
thr = float(threshold)
|
| 109 |
+
|
| 110 |
+
for samples, targets in data_loader:
|
| 111 |
+
samples = samples.to(device, non_blocking=True)
|
| 112 |
+
targets = targets.to(device, non_blocking=True)
|
| 113 |
+
|
| 114 |
+
logits = model(samples) # [B,2] ou [B*K,2]
|
| 115 |
+
|
| 116 |
+
# se logits veio por-clip, agrega por vídeo
|
| 117 |
+
B = targets.shape[0]
|
| 118 |
+
if logits.shape[0] != B:
|
| 119 |
+
if logits.shape[0] % B != 0:
|
| 120 |
+
raise RuntimeError(
|
| 121 |
+
f"logits batch ({logits.shape[0]}) não é múltiplo do target batch ({B})."
|
| 122 |
+
)
|
| 123 |
+
K = logits.shape[0] // B
|
| 124 |
+
logits = logits.view(B, K, -1).mean(dim=1) # [B,2]
|
| 125 |
+
|
| 126 |
+
probs = torch.softmax(logits, dim=1) # [B,2]
|
| 127 |
+
p1 = probs[:, 1] # score da classe 1 (fake)
|
| 128 |
+
|
| 129 |
+
# >>> AQUI é o THRESHOLD <<<
|
| 130 |
+
hat = (p1 >= thr).long()
|
| 131 |
+
|
| 132 |
+
y_true.append(targets.detach().cpu().numpy())
|
| 133 |
+
y_score.append(p1.detach().cpu().numpy())
|
| 134 |
+
y_pred.append(hat.detach().cpu().numpy())
|
| 135 |
+
|
| 136 |
+
y_true = np.concatenate(y_true).astype(int)
|
| 137 |
+
y_score = np.concatenate(y_score).astype(float)
|
| 138 |
+
y_pred = np.concatenate(y_pred).astype(int)
|
| 139 |
+
return y_true, y_score, y_pred
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def plot_confusion(cm, out_path):
|
| 143 |
+
plt.figure(figsize=(6, 5))
|
| 144 |
+
plt.imshow(cm)
|
| 145 |
+
plt.title("Confusion Matrix")
|
| 146 |
+
plt.xlabel("Predicted")
|
| 147 |
+
plt.ylabel("True")
|
| 148 |
+
for (i, j), v in np.ndenumerate(cm):
|
| 149 |
+
plt.text(j, i, str(v), ha="center", va="center")
|
| 150 |
+
plt.tight_layout()
|
| 151 |
+
plt.savefig(out_path, dpi=200)
|
| 152 |
+
plt.close()
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def plot_roc(y, scores, out_path):
|
| 156 |
+
fpr, tpr, _ = roc_curve(y, scores)
|
| 157 |
+
auc = roc_auc_score(y, scores)
|
| 158 |
+
plt.figure(figsize=(7, 6))
|
| 159 |
+
plt.plot(fpr, tpr, label=f"AUC={auc:.4f}")
|
| 160 |
+
plt.plot([0, 1], [0, 1], "--", label="Chance")
|
| 161 |
+
plt.xlabel("FPR")
|
| 162 |
+
plt.ylabel("TPR")
|
| 163 |
+
plt.legend(loc="best")
|
| 164 |
+
plt.tight_layout()
|
| 165 |
+
plt.savefig(out_path, dpi=200)
|
| 166 |
+
plt.close()
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def plot_pr(y, scores, out_path):
|
| 170 |
+
p, r, _ = precision_recall_curve(y, scores)
|
| 171 |
+
ap = average_precision_score(y, scores)
|
| 172 |
+
plt.figure(figsize=(7, 6))
|
| 173 |
+
plt.plot(r, p, label=f"AP={ap:.4f}")
|
| 174 |
+
plt.xlabel("Recall")
|
| 175 |
+
plt.ylabel("Precision")
|
| 176 |
+
plt.legend(loc="best")
|
| 177 |
+
plt.tight_layout()
|
| 178 |
+
plt.savefig(out_path, dpi=200)
|
| 179 |
+
plt.close()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def main(args):
|
| 183 |
+
utils.init_distributed_mode(args)
|
| 184 |
+
print(args)
|
| 185 |
+
|
| 186 |
+
device = torch.device(args.device)
|
| 187 |
+
|
| 188 |
+
seed = args.seed + utils.get_rank()
|
| 189 |
+
torch.manual_seed(seed)
|
| 190 |
+
np.random.seed(seed)
|
| 191 |
+
cudnn.benchmark = True
|
| 192 |
+
|
| 193 |
+
num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, image_tmpl, filter_video, label_file = \
|
| 194 |
+
get_dataset_config(args.dataset, args.use_lmdb)
|
| 195 |
+
|
| 196 |
+
args.num_classes = num_classes
|
| 197 |
+
args.input_channels = 3 if args.modality == 'rgb' else 2 * 5
|
| 198 |
+
|
| 199 |
+
print(f"Creating model: {args.model}")
|
| 200 |
+
model = create_model(
|
| 201 |
+
args.model,
|
| 202 |
+
pretrained=args.pretrained,
|
| 203 |
+
duration=args.duration,
|
| 204 |
+
hpe_to_token=args.hpe_to_token,
|
| 205 |
+
rel_pos=args.rel_pos,
|
| 206 |
+
window_size=args.window_size,
|
| 207 |
+
thumbnail_rows=args.thumbnail_rows,
|
| 208 |
+
token_mask=not args.no_token_mask,
|
| 209 |
+
online_learning=False,
|
| 210 |
+
num_classes=args.num_classes,
|
| 211 |
+
drop_rate=args.drop,
|
| 212 |
+
drop_path_rate=args.drop_path,
|
| 213 |
+
drop_block_rate=args.drop_block,
|
| 214 |
+
use_checkpoint=False
|
| 215 |
+
)
|
| 216 |
+
model.to(device)
|
| 217 |
+
|
| 218 |
+
# mean/std
|
| 219 |
+
if args.distributed:
|
| 220 |
+
mean = (0.5, 0.5, 0.5) if 'mean' not in model.module.default_cfg else model.module.default_cfg['mean']
|
| 221 |
+
std = (0.5, 0.5, 0.5) if 'std' not in model.module.default_cfg else model.module.default_cfg['std']
|
| 222 |
+
else:
|
| 223 |
+
mean = (0.5, 0.5, 0.5) if 'mean' not in model.default_cfg else model.default_cfg['mean']
|
| 224 |
+
std = (0.5, 0.5, 0.5) if 'std' not in model.default_cfg else model.default_cfg['std']
|
| 225 |
+
|
| 226 |
+
# dataset (val list)
|
| 227 |
+
video_data_cls = VideoDataSet
|
| 228 |
+
val_list = os.path.join(args.data_txt_dir, val_list_name)
|
| 229 |
+
|
| 230 |
+
val_augmentor = get_augmentor(
|
| 231 |
+
False, args.input_size, mean, std, args.disable_scaleup,
|
| 232 |
+
threed_data=args.threed_data, version=args.augmentor_ver,
|
| 233 |
+
scale_range=args.scale_range, num_clips=args.num_clips,
|
| 234 |
+
num_crops=args.num_crops, dataset=args.dataset
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
dataset_val = video_data_cls(
|
| 238 |
+
args.data_dir, val_list,
|
| 239 |
+
args.duration, args.frames_per_group,
|
| 240 |
+
num_clips=args.num_clips,
|
| 241 |
+
modality=args.modality,
|
| 242 |
+
dense_sampling=args.dense_sampling,
|
| 243 |
+
image_tmpl=image_tmpl,
|
| 244 |
+
transform=val_augmentor,
|
| 245 |
+
is_train=False, test_mode=False,
|
| 246 |
+
seperator=filename_seperator, filter_video=filter_video
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
data_loader_val = build_dataflow(
|
| 250 |
+
dataset_val, is_train=False, batch_size=args.batch_size,
|
| 251 |
+
workers=args.num_workers, is_distributed=args.distributed
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if not args.initial_checkpoint:
|
| 255 |
+
raise RuntimeError("Passe --initial_checkpoint apontando pro checkpoint do modelo.")
|
| 256 |
+
|
| 257 |
+
checkpoint = torch.load(args.initial_checkpoint, map_location='cpu')
|
| 258 |
+
# muitos checkpoints vêm como {"model": state_dict, ...}
|
| 259 |
+
if isinstance(checkpoint, dict) and "model" in checkpoint:
|
| 260 |
+
utils.load_checkpoint(model, checkpoint["model"])
|
| 261 |
+
else:
|
| 262 |
+
# se for state_dict direto
|
| 263 |
+
model.load_state_dict(checkpoint, strict=False)
|
| 264 |
+
|
| 265 |
+
# eval
|
| 266 |
+
y_true, y_score, y_pred = eval_with_outputs(
|
| 267 |
+
data_loader_val, model, device, threshold=args.threshold
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
acc = accuracy_score(y_true, y_pred)
|
| 271 |
+
bacc = balanced_accuracy_score(y_true, y_pred)
|
| 272 |
+
prec, rec, f1, _ = precision_recall_fscore_support(
|
| 273 |
+
y_true, y_pred, average="binary", zero_division=0
|
| 274 |
+
)
|
| 275 |
+
cm = confusion_matrix(y_true, y_pred)
|
| 276 |
+
|
| 277 |
+
roc_auc = roc_auc_score(y_true, y_score)
|
| 278 |
+
pr_auc = average_precision_score(y_true, y_score)
|
| 279 |
+
|
| 280 |
+
print(f"\nN={len(y_true)} | thr={args.threshold:.3f}")
|
| 281 |
+
print(f"acc={acc:.4f} | bacc={bacc:.4f} | prec={prec:.4f} | rec={rec:.4f} | f1={f1:.4f} | roc_auc={roc_auc:.4f} | pr_auc={pr_auc:.4f}")
|
| 282 |
+
print(classification_report(y_true, y_pred, digits=4, zero_division=0))
|
| 283 |
+
|
| 284 |
+
outdir = args.metrics_out.strip() if args.metrics_out else args.output_dir
|
| 285 |
+
os.makedirs(outdir, exist_ok=True)
|
| 286 |
+
|
| 287 |
+
out_json = {
|
| 288 |
+
"threshold": float(args.threshold),
|
| 289 |
+
"acc": float(acc),
|
| 290 |
+
"balanced_acc": float(bacc),
|
| 291 |
+
"precision": float(prec),
|
| 292 |
+
"recall": float(rec),
|
| 293 |
+
"f1": float(f1),
|
| 294 |
+
"roc_auc": float(roc_auc),
|
| 295 |
+
"pr_auc": float(pr_auc),
|
| 296 |
+
"confusion_matrix": cm.tolist(),
|
| 297 |
+
"n": int(len(y_true)),
|
| 298 |
+
}
|
| 299 |
+
with open(os.path.join(outdir, "metrics.json"), "w", encoding="utf-8") as f:
|
| 300 |
+
json.dump(out_json, f, indent=2)
|
| 301 |
+
|
| 302 |
+
np.savez(os.path.join(outdir, "eval_outputs.npz"),
|
| 303 |
+
y_true=y_true, y_score=y_score, y_pred=y_pred)
|
| 304 |
+
|
| 305 |
+
if args.save_plots:
|
| 306 |
+
plot_confusion(cm, os.path.join(outdir, "cm.png"))
|
| 307 |
+
plot_roc(y_true, y_score, os.path.join(outdir, "roc.png"))
|
| 308 |
+
plot_pr(y_true, y_score, os.path.join(outdir, "pr.png"))
|
| 309 |
+
print(f"\n✔ Plots + metrics saved in: {os.path.abspath(outdir)}")
|
| 310 |
+
else:
|
| 311 |
+
print(f"\n✔ Metrics saved in: {os.path.abspath(os.path.join(outdir, 'metrics.json'))}")
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
if __name__ == '__main__':
|
| 315 |
+
parser = argparse.ArgumentParser('DeiT evaluation script', parents=[get_args_parser()])
|
| 316 |
+
args = parser.parse_args()
|
| 317 |
+
if args.output_dir:
|
| 318 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
| 319 |
+
main(args)
|
utils.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2015-present, Facebook, Inc.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the CC-by-NC license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
#
|
| 7 |
+
"""
|
| 8 |
+
Misc functions, including distributed helpers.
|
| 9 |
+
|
| 10 |
+
Mostly copy-paste from torchvision references.
|
| 11 |
+
"""
|
| 12 |
+
import io
|
| 13 |
+
import os
|
| 14 |
+
import time
|
| 15 |
+
from collections import defaultdict, deque
|
| 16 |
+
import datetime
|
| 17 |
+
import tempfile
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.distributed as dist
|
| 21 |
+
from fvcore.common.checkpoint import Checkpointer
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class SmoothedValue(object):
|
| 25 |
+
"""Track a series of values and provide access to smoothed values over a
|
| 26 |
+
window or the global series average.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(self, window_size=20, fmt=None):
|
| 30 |
+
if fmt is None:
|
| 31 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
| 32 |
+
self.deque = deque(maxlen=window_size)
|
| 33 |
+
self.total = 0.0
|
| 34 |
+
self.count = 0
|
| 35 |
+
self.fmt = fmt
|
| 36 |
+
|
| 37 |
+
def update(self, value, n=1):
|
| 38 |
+
self.deque.append(value)
|
| 39 |
+
self.count += n
|
| 40 |
+
self.total += value * n
|
| 41 |
+
|
| 42 |
+
def synchronize_between_processes(self):
|
| 43 |
+
"""
|
| 44 |
+
Warning: does not synchronize the deque!
|
| 45 |
+
"""
|
| 46 |
+
if not is_dist_avail_and_initialized():
|
| 47 |
+
return
|
| 48 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
| 49 |
+
dist.barrier()
|
| 50 |
+
dist.all_reduce(t)
|
| 51 |
+
t = t.tolist()
|
| 52 |
+
self.count = int(t[0])
|
| 53 |
+
self.total = t[1]
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def median(self):
|
| 57 |
+
d = torch.tensor(list(self.deque))
|
| 58 |
+
return d.median().item()
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def avg(self):
|
| 62 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
| 63 |
+
return d.mean().item()
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def global_avg(self):
|
| 67 |
+
return self.total / self.count
|
| 68 |
+
|
| 69 |
+
@property
|
| 70 |
+
def max(self):
|
| 71 |
+
return max(self.deque)
|
| 72 |
+
|
| 73 |
+
@property
|
| 74 |
+
def value(self):
|
| 75 |
+
return self.deque[-1]
|
| 76 |
+
|
| 77 |
+
def __str__(self):
|
| 78 |
+
return self.fmt.format(
|
| 79 |
+
median=self.median,
|
| 80 |
+
avg=self.avg,
|
| 81 |
+
global_avg=self.global_avg,
|
| 82 |
+
max=self.max,
|
| 83 |
+
value=self.value)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class MetricLogger(object):
|
| 87 |
+
def __init__(self, delimiter="\t"):
|
| 88 |
+
self.meters = defaultdict(SmoothedValue)
|
| 89 |
+
self.delimiter = delimiter
|
| 90 |
+
|
| 91 |
+
def update(self, **kwargs):
|
| 92 |
+
for k, v in kwargs.items():
|
| 93 |
+
if isinstance(v, torch.Tensor):
|
| 94 |
+
v = v.item()
|
| 95 |
+
assert isinstance(v, (float, int))
|
| 96 |
+
self.meters[k].update(v)
|
| 97 |
+
|
| 98 |
+
def __getattr__(self, attr):
|
| 99 |
+
if attr in self.meters:
|
| 100 |
+
return self.meters[attr]
|
| 101 |
+
if attr in self.__dict__:
|
| 102 |
+
return self.__dict__[attr]
|
| 103 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
| 104 |
+
type(self).__name__, attr))
|
| 105 |
+
|
| 106 |
+
def __str__(self):
|
| 107 |
+
loss_str = []
|
| 108 |
+
for name, meter in self.meters.items():
|
| 109 |
+
loss_str.append(
|
| 110 |
+
"{}: {}".format(name, str(meter))
|
| 111 |
+
)
|
| 112 |
+
return self.delimiter.join(loss_str)
|
| 113 |
+
|
| 114 |
+
def synchronize_between_processes(self):
|
| 115 |
+
for meter in self.meters.values():
|
| 116 |
+
meter.synchronize_between_processes()
|
| 117 |
+
|
| 118 |
+
def add_meter(self, name, meter):
|
| 119 |
+
self.meters[name] = meter
|
| 120 |
+
|
| 121 |
+
def log_every(self, iterable, print_freq, header=None):
|
| 122 |
+
i = 0
|
| 123 |
+
if not header:
|
| 124 |
+
header = ''
|
| 125 |
+
start_time = time.time()
|
| 126 |
+
end = time.time()
|
| 127 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
| 128 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
| 129 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
| 130 |
+
log_msg = [
|
| 131 |
+
header,
|
| 132 |
+
'[{0' + space_fmt + '}/{1}]',
|
| 133 |
+
'eta: {eta}',
|
| 134 |
+
'{meters}',
|
| 135 |
+
'time: {time}',
|
| 136 |
+
'data: {data}'
|
| 137 |
+
]
|
| 138 |
+
if torch.cuda.is_available():
|
| 139 |
+
log_msg.append('max mem: {memory:.0f}')
|
| 140 |
+
log_msg = self.delimiter.join(log_msg)
|
| 141 |
+
MB = 1024.0 * 1024.0
|
| 142 |
+
for obj in iterable:
|
| 143 |
+
data_time.update(time.time() - end)
|
| 144 |
+
yield obj
|
| 145 |
+
iter_time.update(time.time() - end)
|
| 146 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
| 147 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
| 148 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
| 149 |
+
if torch.cuda.is_available():
|
| 150 |
+
print(log_msg.format(
|
| 151 |
+
i, len(iterable), eta=eta_string,
|
| 152 |
+
meters=str(self),
|
| 153 |
+
time=str(iter_time), data=str(data_time),
|
| 154 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
| 155 |
+
else:
|
| 156 |
+
print(log_msg.format(
|
| 157 |
+
i, len(iterable), eta=eta_string,
|
| 158 |
+
meters=str(self),
|
| 159 |
+
time=str(iter_time), data=str(data_time)))
|
| 160 |
+
i += 1
|
| 161 |
+
end = time.time()
|
| 162 |
+
total_time = time.time() - start_time
|
| 163 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
| 164 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
| 165 |
+
header, total_time_str, total_time / len(iterable)))
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _load_checkpoint_for_ema(model_ema, checkpoint):
|
| 169 |
+
"""
|
| 170 |
+
Workaround for ModelEma._load_checkpoint to accept an already-loaded object
|
| 171 |
+
"""
|
| 172 |
+
mem_file = io.BytesIO()
|
| 173 |
+
torch.save(checkpoint, mem_file)
|
| 174 |
+
mem_file.seek(0)
|
| 175 |
+
model_ema._load_checkpoint(mem_file)
|
| 176 |
+
|
| 177 |
+
"""
|
| 178 |
+
def load_checkpoint(model, state_dict, mode=None):
|
| 179 |
+
|
| 180 |
+
# reuse Checkpointer in fvcore to support flexible loading
|
| 181 |
+
ckpt = Checkpointer(model, save_to_disk=False)
|
| 182 |
+
# since Checkpointer requires the weight to be put under `model` field, we need to save it to disk
|
| 183 |
+
tmp_path = tempfile.NamedTemporaryFile('w+b')
|
| 184 |
+
torch.save({'model': state_dict}, tmp_path.name)
|
| 185 |
+
ckpt.load(tmp_path.name)
|
| 186 |
+
"""
|
| 187 |
+
def load_checkpoint(model, state_dict):
|
| 188 |
+
# Load checkpoint directly (avoid writing temp files on Windows)
|
| 189 |
+
if isinstance(state_dict, dict) and 'state_dict' in state_dict:
|
| 190 |
+
state_dict = state_dict['state_dict']
|
| 191 |
+
|
| 192 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 193 |
+
|
| 194 |
+
if len(missing) > 0:
|
| 195 |
+
print(f"[load_checkpoint] Missing keys: {len(missing)}")
|
| 196 |
+
if len(unexpected) > 0:
|
| 197 |
+
print(f"[load_checkpoint] Unexpected keys: {len(unexpected)}")
|
| 198 |
+
|
| 199 |
+
def setup_for_distributed(is_master):
|
| 200 |
+
"""
|
| 201 |
+
This function disables printing when not in master process
|
| 202 |
+
"""
|
| 203 |
+
import builtins as __builtin__
|
| 204 |
+
builtin_print = __builtin__.print
|
| 205 |
+
|
| 206 |
+
def print(*args, **kwargs):
|
| 207 |
+
force = kwargs.pop('force', False)
|
| 208 |
+
if is_master or force:
|
| 209 |
+
builtin_print(*args, **kwargs)
|
| 210 |
+
|
| 211 |
+
__builtin__.print = print
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def is_dist_avail_and_initialized():
|
| 215 |
+
if not dist.is_available():
|
| 216 |
+
return False
|
| 217 |
+
if not dist.is_initialized():
|
| 218 |
+
return False
|
| 219 |
+
return True
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def get_world_size():
|
| 223 |
+
if not is_dist_avail_and_initialized():
|
| 224 |
+
return 1
|
| 225 |
+
return dist.get_world_size()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_rank():
|
| 229 |
+
if not is_dist_avail_and_initialized():
|
| 230 |
+
return 0
|
| 231 |
+
return dist.get_rank()
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def is_main_process():
|
| 235 |
+
return get_rank() == 0
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def save_on_master(*args, **kwargs):
|
| 239 |
+
if is_main_process():
|
| 240 |
+
torch.save(*args, **kwargs)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def init_distributed_mode(args):
|
| 244 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
| 245 |
+
args.rank = int(os.environ["RANK"])
|
| 246 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
| 247 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
| 248 |
+
elif 'SLURM_PROCID' in os.environ:
|
| 249 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
| 250 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
| 251 |
+
else:
|
| 252 |
+
print('Not using distributed mode')
|
| 253 |
+
args.distributed = False
|
| 254 |
+
return
|
| 255 |
+
|
| 256 |
+
args.distributed = True
|
| 257 |
+
|
| 258 |
+
torch.cuda.set_device(args.gpu)
|
| 259 |
+
args.dist_backend = 'nccl'
|
| 260 |
+
print('| distributed init (rank {}): {}'.format(
|
| 261 |
+
args.rank, args.dist_url), flush=True)
|
| 262 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
| 263 |
+
world_size=args.world_size, rank=args.rank)
|
| 264 |
+
torch.distributed.barrier()
|
| 265 |
+
setup_for_distributed(args.rank == 0)
|
video_dataset.py
ADDED
|
@@ -0,0 +1,1228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import six
|
| 3 |
+
from typing import Union
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
import torch.utils.data as data
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import lmdb
|
| 12 |
+
import pyarrow as pa
|
| 13 |
+
_HAS_LMDB = True
|
| 14 |
+
except ImportError as e:
|
| 15 |
+
_HAS_LMDB = False
|
| 16 |
+
_LMDB_ERROR_MSG = e
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import av
|
| 20 |
+
_HAS_PYAV = True
|
| 21 |
+
except ImportError as e:
|
| 22 |
+
_HAS_PYAV = False
|
| 23 |
+
_PYAV_ERROR_MSG = e
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def random_clip(video_frames, sampling_rate, frames_per_clip, fixed_offset=False, start_frame_idx=0, end_frame_idx=None):
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
video_frames (int): total frame number of a video
|
| 31 |
+
sampling_rate (int): sampling rate for clip, pick one every k frames
|
| 32 |
+
frames_per_clip (int): number of frames of a clip
|
| 33 |
+
fixed_offset (bool): used with sample offset to decide the offset value deterministically.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
list[int]: frame indices (started from zero)
|
| 37 |
+
"""
|
| 38 |
+
new_sampling_rate = sampling_rate
|
| 39 |
+
highest_idx = video_frames - new_sampling_rate * frames_per_clip if end_frame_idx is None else end_frame_idx
|
| 40 |
+
if highest_idx <= 0:
|
| 41 |
+
random_offset = 0
|
| 42 |
+
else:
|
| 43 |
+
if fixed_offset:
|
| 44 |
+
random_offset = (video_frames - new_sampling_rate * frames_per_clip) // 2
|
| 45 |
+
else:
|
| 46 |
+
random_offset = int(np.random.randint(start_frame_idx, highest_idx, 1))
|
| 47 |
+
# print(start_frame_idx, highest_idx, random_offset)
|
| 48 |
+
frame_idx = [int(random_offset + i * sampling_rate) % video_frames for i in range(frames_per_clip)]
|
| 49 |
+
return frame_idx
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def compute_img_diff(image_1, image_2, bound=255.0):
|
| 53 |
+
image_diff = np.asarray(image_1, dtype=np.float) - np.asarray(image_2, dtype=np.float)
|
| 54 |
+
image_diff += bound
|
| 55 |
+
image_diff *= (255.0 / float(2 * bound))
|
| 56 |
+
image_diff = image_diff.astype(np.uint8)
|
| 57 |
+
image_diff = Image.fromarray(image_diff)
|
| 58 |
+
return image_diff
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def load_image(root_path, directory, image_tmpl, idx, modality):
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
:param root_path:
|
| 65 |
+
:param directory:
|
| 66 |
+
:param image_tmpl:
|
| 67 |
+
:param idx: if it is a list, load a batch of images
|
| 68 |
+
:param modality:
|
| 69 |
+
:return:
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
def _safe_load_image(img_path):
|
| 73 |
+
img = None
|
| 74 |
+
num_try = 0
|
| 75 |
+
while num_try < 10:
|
| 76 |
+
try:
|
| 77 |
+
img_tmp = Image.open(img_path)
|
| 78 |
+
img = img_tmp.copy()
|
| 79 |
+
img_tmp.close()
|
| 80 |
+
break
|
| 81 |
+
except Exception as e:
|
| 82 |
+
print('[Will try load again] error loading image: {}, '
|
| 83 |
+
'error: {}'.format(img_path, str(e)))
|
| 84 |
+
num_try += 1
|
| 85 |
+
if img is None:
|
| 86 |
+
raise ValueError('[Fail 10 times] error loading image: {}'.format(img_path))
|
| 87 |
+
return img
|
| 88 |
+
|
| 89 |
+
if not isinstance(idx, list):
|
| 90 |
+
idx = [idx]
|
| 91 |
+
out = []
|
| 92 |
+
if modality == 'rgb':
|
| 93 |
+
for i in idx:
|
| 94 |
+
image_path_file = os.path.join(root_path, directory, image_tmpl.format(i))
|
| 95 |
+
out.append(_safe_load_image(image_path_file))
|
| 96 |
+
elif modality == 'rgbdiff':
|
| 97 |
+
tmp = {}
|
| 98 |
+
new_idx = np.unique(np.concatenate((np.asarray(idx), np.asarray(idx) + 1)))
|
| 99 |
+
for i in new_idx:
|
| 100 |
+
image_path_file = os.path.join(root_path, directory, image_tmpl.format(i))
|
| 101 |
+
tmp[i] = _safe_load_image(image_path_file)
|
| 102 |
+
for k in idx:
|
| 103 |
+
img_ = compute_img_diff(tmp[k + 1], tmp[k])
|
| 104 |
+
out.append(img_)
|
| 105 |
+
del tmp
|
| 106 |
+
elif modality == 'flow':
|
| 107 |
+
for i in idx:
|
| 108 |
+
flow_x_name = os.path.join(root_path, directory, "x_" + image_tmpl.format(i))
|
| 109 |
+
flow_y_name = os.path.join(root_path, directory, "y_" + image_tmpl.format(i))
|
| 110 |
+
out.extend([_safe_load_image(flow_x_name), _safe_load_image(flow_y_name)])
|
| 111 |
+
|
| 112 |
+
return out
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def load_sound(data_dir, record, idx, fps, audio_length, resampling_rate,
|
| 116 |
+
window_size=10, step_size=5, eps=1e-6):
|
| 117 |
+
import librosa
|
| 118 |
+
"""idx must be the center frame of a clip"""
|
| 119 |
+
centre_sec = (record.start_frame + idx) / fps
|
| 120 |
+
left_sec = centre_sec - (audio_length / 2.0)
|
| 121 |
+
right_sec = centre_sec + (audio_length / 2.0)
|
| 122 |
+
audio_fname = os.path.join(data_dir, record.path)
|
| 123 |
+
# TODO: generate 0s if the audio file does not exist.
|
| 124 |
+
if not os.path.exists(audio_fname):
|
| 125 |
+
return [Image.fromarray(np.zeros((256, 256 * int(audio_length / 1.28))))]
|
| 126 |
+
samples, sr = librosa.core.load(audio_fname, sr=None, mono=True)
|
| 127 |
+
duration = samples.shape[0] / float(resampling_rate)
|
| 128 |
+
|
| 129 |
+
left_sample = int(round(left_sec * resampling_rate))
|
| 130 |
+
right_sample = int(round(right_sec * resampling_rate))
|
| 131 |
+
|
| 132 |
+
required_samples = int(round(resampling_rate * audio_length))
|
| 133 |
+
|
| 134 |
+
if left_sec < 0:
|
| 135 |
+
samples = samples[:required_samples]
|
| 136 |
+
elif right_sec > duration:
|
| 137 |
+
samples = samples[-required_samples:]
|
| 138 |
+
else:
|
| 139 |
+
samples = samples[left_sample:right_sample]
|
| 140 |
+
|
| 141 |
+
# TODO: is the size of spec is fixed if number of samples are different?
|
| 142 |
+
# if the samples is not long enough, repeat the waveform
|
| 143 |
+
if len(samples) < required_samples:
|
| 144 |
+
multiplies = required_samples / len(samples)
|
| 145 |
+
samples = np.tile(samples, int(multiplies + 0.5) + 1)
|
| 146 |
+
samples = samples[:required_samples]
|
| 147 |
+
|
| 148 |
+
# log sepcgram
|
| 149 |
+
nperseg = int(round(window_size * resampling_rate / 1e3))
|
| 150 |
+
noverlap = int(round(step_size * resampling_rate / 1e3))
|
| 151 |
+
spec = librosa.stft(samples, n_fft=511, window='hann', hop_length=noverlap,
|
| 152 |
+
win_length=nperseg, pad_mode='constant')
|
| 153 |
+
spec = np.log(np.real(spec * np.conj(spec)) + eps)
|
| 154 |
+
img = Image.fromarray(spec)
|
| 155 |
+
return [img]
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def load_data_lmdb(videos, idx, modality):
|
| 159 |
+
def _convert_buffer_to_PIL(tmp_buf, is_flow=False):
|
| 160 |
+
data = six.BytesIO()
|
| 161 |
+
data.write(tmp_buf)
|
| 162 |
+
data.seek(0)
|
| 163 |
+
img_tmp = Image.open(data).convert('RGB' if not is_flow else 'L')
|
| 164 |
+
img_ = img_tmp.copy()
|
| 165 |
+
img_tmp.close()
|
| 166 |
+
return img_
|
| 167 |
+
|
| 168 |
+
img = []
|
| 169 |
+
if modality == 'rgb':
|
| 170 |
+
buf = [videos[i] for i in idx]
|
| 171 |
+
for x in buf:
|
| 172 |
+
img_ = _convert_buffer_to_PIL(x)
|
| 173 |
+
img.append(img_)
|
| 174 |
+
elif modality == 'flow':
|
| 175 |
+
new_idx = np.asarray(idx) * 2 - 1
|
| 176 |
+
buf = [[videos[i], videos[i + 1]] for i in new_idx]
|
| 177 |
+
for x in buf:
|
| 178 |
+
flow_x = _convert_buffer_to_PIL(x[0], True)
|
| 179 |
+
flow_y = _convert_buffer_to_PIL(x[1], True)
|
| 180 |
+
img.extend([flow_x, flow_y])
|
| 181 |
+
elif modality == 'rgbdiff':
|
| 182 |
+
tmp = {}
|
| 183 |
+
new_idx = np.unique(np.concatenate((np.asarray(idx), np.asarray(idx) + 1)))
|
| 184 |
+
for i in new_idx:
|
| 185 |
+
tmp[i] = _convert_buffer_to_PIL(videos[i])
|
| 186 |
+
for k in idx:
|
| 187 |
+
img_ = compute_img_diff(tmp[k + 1], tmp[k])
|
| 188 |
+
img.append(img_)
|
| 189 |
+
del tmp
|
| 190 |
+
return img
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def sample_train_clip(video_length, num_consecutive_frames, num_frames, sample_freq, dense_sampling, num_clips=1):
|
| 194 |
+
max_frame_idx = max(1, video_length - num_consecutive_frames + 1)
|
| 195 |
+
if dense_sampling:
|
| 196 |
+
frame_idx = np.zeros((num_clips, num_frames), dtype=int)
|
| 197 |
+
if num_clips == 1: # backward compatibility
|
| 198 |
+
frame_idx[0] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False))
|
| 199 |
+
else:
|
| 200 |
+
max_start_frame_idx = max_frame_idx - sample_freq * num_frames
|
| 201 |
+
frames_per_segment = max_start_frame_idx // num_clips
|
| 202 |
+
for i in range(num_clips):
|
| 203 |
+
if frames_per_segment <= 0:
|
| 204 |
+
frame_idx[i] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False))
|
| 205 |
+
#frame_idx[i] = [frame_idx[i][2],frame_idx[i][0],frame_idx[i][1],frame_idx[i][3]]
|
| 206 |
+
#frame_idx[i] = [frame_idx[i][3],frame_idx[i][2],frame_idx[i][1],frame_idx[i][0]]
|
| 207 |
+
else:
|
| 208 |
+
frame_idx[i] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames, False, i * frames_per_segment, (i + 1) * frames_per_segment))
|
| 209 |
+
#1423
|
| 210 |
+
#frame_idx[i] = [frame_idx[i][2],frame_idx[i][0],frame_idx[i][1],frame_idx[i][3]]
|
| 211 |
+
#frame_idx[i] = [frame_idx[i][3],frame_idx[i][2],frame_idx[i][1],frame_idx[i][0]]
|
| 212 |
+
frame_idx = frame_idx.flatten()
|
| 213 |
+
"""
|
| 214 |
+
def _check_interval_overlapped(int_1, int_2):
|
| 215 |
+
if int_1[0] < int_2[0]:
|
| 216 |
+
int_l, int_r = int_1, int_2
|
| 217 |
+
else:
|
| 218 |
+
int_l, int_r = int_2, int_1
|
| 219 |
+
|
| 220 |
+
return True if int_l[-1] > int_r[0] else False
|
| 221 |
+
|
| 222 |
+
clips = 0
|
| 223 |
+
num_tries = 0
|
| 224 |
+
#all_frame_idx = np.arange(max_frame_idx - sample_freq * num_frames)
|
| 225 |
+
while clips < num_clips and num_tries < 1000:
|
| 226 |
+
curr_clips = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames))
|
| 227 |
+
overlap = False
|
| 228 |
+
for i in range(clips):
|
| 229 |
+
overlap = _check_interval_overlapped((frame_idx[i][0], frame_idx[i][-1]), (curr_clips[0], curr_clips[-1]) )
|
| 230 |
+
if overlap:
|
| 231 |
+
break
|
| 232 |
+
if overlap:
|
| 233 |
+
num_tries += 1
|
| 234 |
+
continue
|
| 235 |
+
else:
|
| 236 |
+
frame_idx[clips] = curr_clips
|
| 237 |
+
clips += 1
|
| 238 |
+
for i in range(clips, num_clips):
|
| 239 |
+
frame_idx[i] = np.asarray(random_clip(max_frame_idx, sample_freq, num_frames))
|
| 240 |
+
|
| 241 |
+
# sort the intervals
|
| 242 |
+
frame_idx = frame_idx[np.argsort(frame_idx[:, 0]), ...]
|
| 243 |
+
frame_idx = frame_idx.flatten()
|
| 244 |
+
"""
|
| 245 |
+
|
| 246 |
+
else: # uniform sampling
|
| 247 |
+
# import pdb;pdb.set_trace()
|
| 248 |
+
total_frames = num_frames * sample_freq
|
| 249 |
+
ave_frames_per_group = max_frame_idx // num_frames
|
| 250 |
+
if ave_frames_per_group >= sample_freq:
|
| 251 |
+
# randomly sample f images per segement
|
| 252 |
+
frame_idx = np.arange(0, num_frames) * ave_frames_per_group
|
| 253 |
+
frame_idx = np.repeat(frame_idx, repeats=sample_freq)
|
| 254 |
+
offsets = np.random.choice(ave_frames_per_group, sample_freq, replace=False)
|
| 255 |
+
offsets = np.tile(offsets, num_frames)
|
| 256 |
+
frame_idx = frame_idx + offsets
|
| 257 |
+
elif max_frame_idx < total_frames:
|
| 258 |
+
# need to sample the same images
|
| 259 |
+
frame_idx = np.random.choice(max_frame_idx, total_frames)
|
| 260 |
+
else:
|
| 261 |
+
# sample cross all images
|
| 262 |
+
frame_idx = np.random.choice(max_frame_idx, total_frames, replace=False)
|
| 263 |
+
frame_idx = np.sort(frame_idx)
|
| 264 |
+
# print(frame_idx)
|
| 265 |
+
frame_idx = frame_idx + 1
|
| 266 |
+
# random.shuffle(frame_idx)
|
| 267 |
+
return frame_idx
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def sample_val_test_clip(video_length, num_consecutive_frames, num_frames, sample_freq, dense_sampling,
|
| 271 |
+
fixed_offset, num_clips, whole_video):
|
| 272 |
+
max_frame_idx = max(1, video_length - num_consecutive_frames + 1)
|
| 273 |
+
# import pdb;pdb.set_trace()
|
| 274 |
+
if whole_video:
|
| 275 |
+
return np.arange(1, max_frame_idx, step=sample_freq, dtype=int)
|
| 276 |
+
if dense_sampling:
|
| 277 |
+
if fixed_offset:
|
| 278 |
+
sample_pos = max(1, 1 + max_frame_idx - sample_freq * num_frames)
|
| 279 |
+
t_stride = sample_freq
|
| 280 |
+
start_list = np.linspace(0, sample_pos - 1, num=num_clips, dtype=int)
|
| 281 |
+
frame_idx = []
|
| 282 |
+
for start_idx in start_list.tolist():
|
| 283 |
+
frame_idx += [(idx * t_stride + start_idx) % max_frame_idx for idx in
|
| 284 |
+
range(num_frames)]
|
| 285 |
+
else:
|
| 286 |
+
frame_idx = []
|
| 287 |
+
for i in range(num_clips):
|
| 288 |
+
frame_idx.extend(random_clip(max_frame_idx, sample_freq, num_frames))
|
| 289 |
+
frame_idx = np.asarray(frame_idx) + 1
|
| 290 |
+
else: # uniform sampling
|
| 291 |
+
if fixed_offset:
|
| 292 |
+
frame_idices = []
|
| 293 |
+
sample_offsets = list(range(-num_clips // 2 + 1, num_clips // 2 + 1))
|
| 294 |
+
for sample_offset in sample_offsets:
|
| 295 |
+
if max_frame_idx > num_frames:
|
| 296 |
+
tick = max_frame_idx / float(num_frames)
|
| 297 |
+
curr_sample_offset = sample_offset
|
| 298 |
+
if curr_sample_offset >= tick / 2.0:
|
| 299 |
+
curr_sample_offset = tick / 2.0 - 1e-4
|
| 300 |
+
elif curr_sample_offset < -tick / 2.0:
|
| 301 |
+
curr_sample_offset = -tick / 2.0
|
| 302 |
+
frame_idx = np.array([int(tick / 2.0 + curr_sample_offset + tick * x) for x in
|
| 303 |
+
range(num_frames)])
|
| 304 |
+
else:
|
| 305 |
+
np.random.seed(sample_offset - (-num_clips // 2 + 1))
|
| 306 |
+
frame_idx = np.random.choice(max_frame_idx, num_frames)
|
| 307 |
+
frame_idx = np.sort(frame_idx)
|
| 308 |
+
frame_idices.extend(frame_idx.tolist())
|
| 309 |
+
else:
|
| 310 |
+
frame_idices = []
|
| 311 |
+
for i in range(num_clips):
|
| 312 |
+
total_frames = num_frames * sample_freq
|
| 313 |
+
ave_frames_per_group = max_frame_idx // num_frames
|
| 314 |
+
if ave_frames_per_group >= sample_freq:
|
| 315 |
+
# randomly sample f images per segment
|
| 316 |
+
frame_idx = np.arange(0, num_frames) * ave_frames_per_group
|
| 317 |
+
frame_idx = np.repeat(frame_idx, repeats=sample_freq)
|
| 318 |
+
offsets = np.random.choice(ave_frames_per_group, sample_freq,
|
| 319 |
+
replace=False)
|
| 320 |
+
offsets = np.tile(offsets, num_frames)
|
| 321 |
+
frame_idx = frame_idx + offsets
|
| 322 |
+
elif max_frame_idx < total_frames:
|
| 323 |
+
# need to sample the same images
|
| 324 |
+
np.random.seed(i)
|
| 325 |
+
frame_idx = np.random.choice(max_frame_idx, total_frames)
|
| 326 |
+
else:
|
| 327 |
+
# sample cross all images
|
| 328 |
+
np.random.seed(i)
|
| 329 |
+
frame_idx = np.random.choice(max_frame_idx, total_frames, replace=False)
|
| 330 |
+
frame_idx = np.sort(frame_idx)
|
| 331 |
+
frame_idices.extend(frame_idx.tolist())
|
| 332 |
+
frame_idx = np.asarray(frame_idices) + 1
|
| 333 |
+
return frame_idx
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
class VideoRecord(object):
|
| 337 |
+
def __init__(self, path, start_frame, end_frame, label, reverse=False):
|
| 338 |
+
self.path = path
|
| 339 |
+
self.video_id = os.path.basename(path)
|
| 340 |
+
self.start_frame = start_frame
|
| 341 |
+
self.end_frame = end_frame
|
| 342 |
+
self.label = label
|
| 343 |
+
self.reverse = reverse
|
| 344 |
+
|
| 345 |
+
@property
|
| 346 |
+
def num_frames(self):
|
| 347 |
+
return self.end_frame - self.start_frame + 1
|
| 348 |
+
|
| 349 |
+
def __str__(self):
|
| 350 |
+
return self.path
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
class VideoDataSet(data.Dataset):
|
| 354 |
+
|
| 355 |
+
def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1,
|
| 356 |
+
modality='rgb', dense_sampling=True, fixed_offset=True,
|
| 357 |
+
image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
|
| 358 |
+
filter_video=0, num_classes=None, whole_video=False,
|
| 359 |
+
fps=29.97, audio_length=1.28, resampling_rate=24000):
|
| 360 |
+
"""
|
| 361 |
+
|
| 362 |
+
Arguments have different meaning when dense_sampling is True:
|
| 363 |
+
- num_groups ==> number of frames
|
| 364 |
+
- frames_per_group ==> sample every K frame
|
| 365 |
+
- sample_offset ==> number of clips used in validation or test mode
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
root_path (str): the file path to the root of video folder
|
| 369 |
+
list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
|
| 370 |
+
num_groups (int): number of frames per data sample
|
| 371 |
+
frames_per_group (int): number of frames within one group
|
| 372 |
+
sample_offset (int): used in validation/test, the offset when sampling frames from a group
|
| 373 |
+
modality (str): rgb or flow
|
| 374 |
+
dense_sampling (bool): dense sampling in I3D
|
| 375 |
+
fixed_offset (bool): used for generating the same videos used in TSM
|
| 376 |
+
image_tmpl (str): template of image ids
|
| 377 |
+
transform: the transformer for preprocessing
|
| 378 |
+
is_train (bool): shuffle the video but keep the causality
|
| 379 |
+
test_mode (bool): testing mode, no label
|
| 380 |
+
whole_video (bool): take whole video
|
| 381 |
+
fps (float): frame rate per second, used to localize sound when frame idx is selected.
|
| 382 |
+
audio_length (float): the time window to extract audio feature.
|
| 383 |
+
resampling_rate (int): used to resampling audio extracted from wav
|
| 384 |
+
"""
|
| 385 |
+
if modality not in ['flow', 'rgb', 'rgbdiff', 'sound']:
|
| 386 |
+
raise ValueError("modality should be 'flow' or 'rgb' or 'rgbdiff' or 'sound'.")
|
| 387 |
+
|
| 388 |
+
self.root_path = root_path
|
| 389 |
+
self.list_file = os.path.join(root_path, list_file)
|
| 390 |
+
self.num_groups = num_groups
|
| 391 |
+
self.num_frames = num_groups
|
| 392 |
+
self.frames_per_group = frames_per_group
|
| 393 |
+
self.sample_freq = frames_per_group
|
| 394 |
+
self.num_clips = num_clips
|
| 395 |
+
self.sample_offset = sample_offset
|
| 396 |
+
self.fixed_offset = fixed_offset
|
| 397 |
+
self.dense_sampling = dense_sampling
|
| 398 |
+
self.modality = modality.lower()
|
| 399 |
+
self.image_tmpl = image_tmpl
|
| 400 |
+
self.transform = transform
|
| 401 |
+
self.is_train = is_train
|
| 402 |
+
self.test_mode = test_mode
|
| 403 |
+
self.separator = seperator
|
| 404 |
+
self.filter_video = filter_video
|
| 405 |
+
self.whole_video = whole_video
|
| 406 |
+
self.fps = fps
|
| 407 |
+
self.audio_length = audio_length
|
| 408 |
+
self.resampling_rate = resampling_rate
|
| 409 |
+
self.video_length = (self.num_frames * self.sample_freq) / self.fps
|
| 410 |
+
|
| 411 |
+
if self.modality in ['flow', 'rgbdiff']:
|
| 412 |
+
self.num_consecutive_frames = 5
|
| 413 |
+
else:
|
| 414 |
+
self.num_consecutive_frames = 1
|
| 415 |
+
|
| 416 |
+
self.video_list, self.multi_label = self._parse_list()
|
| 417 |
+
self.num_classes = num_classes
|
| 418 |
+
|
| 419 |
+
def _parse_list(self):
|
| 420 |
+
# usually it is [video_id, num_frames, class_idx]
|
| 421 |
+
# or [video_id, start_frame, end_frame, list of class_idx]
|
| 422 |
+
tmp = []
|
| 423 |
+
original_video_numbers = 0
|
| 424 |
+
for x in open(self.list_file):
|
| 425 |
+
elements = x.strip().split(self.separator)
|
| 426 |
+
start_frame = int(elements[1])
|
| 427 |
+
end_frame = int(elements[2])
|
| 428 |
+
total_frame = end_frame - start_frame + 1
|
| 429 |
+
original_video_numbers += 1
|
| 430 |
+
if self.test_mode:
|
| 431 |
+
tmp.append(elements)
|
| 432 |
+
else:
|
| 433 |
+
if total_frame >= self.filter_video:
|
| 434 |
+
tmp.append(elements)
|
| 435 |
+
|
| 436 |
+
num = len(tmp)
|
| 437 |
+
print("The number of videos is {} (with more than {} frames) "
|
| 438 |
+
"(original: {})".format(num, self.filter_video, original_video_numbers), flush=True)
|
| 439 |
+
assert (num > 0)
|
| 440 |
+
# TODO: a better way to check if multi-label or not
|
| 441 |
+
multi_label = np.mean(np.asarray([len(x) for x in tmp])) > 4.0
|
| 442 |
+
file_list = []
|
| 443 |
+
for item in tmp:
|
| 444 |
+
if self.test_mode:
|
| 445 |
+
file_list.append([item[0], int(item[1]), int(item[2]), -1])
|
| 446 |
+
else:
|
| 447 |
+
labels = []
|
| 448 |
+
for i in range(3, len(item)):
|
| 449 |
+
labels.append(float(item[i]))
|
| 450 |
+
if not multi_label:
|
| 451 |
+
labels = labels[0] if len(labels) == 1 else labels
|
| 452 |
+
file_list.append([item[0], int(item[1]), int(item[2]), labels])
|
| 453 |
+
|
| 454 |
+
video_list = [VideoRecord(item[0], item[1], item[2], item[3]) for item in file_list]
|
| 455 |
+
# flow model has one frame less
|
| 456 |
+
if self.modality in ['rgbdiff']:
|
| 457 |
+
for i in range(len(video_list)):
|
| 458 |
+
video_list[i].end_frame -= 1
|
| 459 |
+
|
| 460 |
+
#if self.is_train:
|
| 461 |
+
# video_list = video_list[:50000]
|
| 462 |
+
|
| 463 |
+
return video_list, multi_label
|
| 464 |
+
|
| 465 |
+
def remove_data(self, idx):
|
| 466 |
+
original_video_num = len(self.video_list)
|
| 467 |
+
self.video_list = [v for i, v in enumerate(self.video_list) if i not in idx]
|
| 468 |
+
print("Original videos: {}\t remove {} videos, remaining {} videos".format(original_video_num, len(idx), len(self.video_list)))
|
| 469 |
+
|
| 470 |
+
def _sample_indices(self, record):
|
| 471 |
+
return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 472 |
+
self.sample_freq, self.dense_sampling, self.num_clips)
|
| 473 |
+
|
| 474 |
+
def _get_val_indices(self, record):
|
| 475 |
+
return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 476 |
+
self.sample_freq, self.dense_sampling, self.fixed_offset,
|
| 477 |
+
self.num_clips, self.whole_video)
|
| 478 |
+
|
| 479 |
+
def __getitem__(self, index):
|
| 480 |
+
"""
|
| 481 |
+
Returns:
|
| 482 |
+
torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group.
|
| 483 |
+
torch.FloatTensor: the label
|
| 484 |
+
"""
|
| 485 |
+
record = self.video_list[index]
|
| 486 |
+
# check this is a legit video folder
|
| 487 |
+
indices = self._sample_indices(record) if self.is_train else self._get_val_indices(record)
|
| 488 |
+
images = self.get_data(record, indices)
|
| 489 |
+
images = self.transform(images)
|
| 490 |
+
label = self.get_label(record)
|
| 491 |
+
|
| 492 |
+
# re-order data to targeted format.
|
| 493 |
+
return images, label
|
| 494 |
+
|
| 495 |
+
def get_data(self, record, indices):
|
| 496 |
+
images = []
|
| 497 |
+
if self.whole_video:
|
| 498 |
+
tmp = len(indices) % self.num_frames
|
| 499 |
+
if tmp != 0:
|
| 500 |
+
indices = indices[:-tmp]
|
| 501 |
+
num_clips = len(indices) // self.num_frames
|
| 502 |
+
# print(tmp, indices, self.num_frames, num_clips)
|
| 503 |
+
else:
|
| 504 |
+
num_clips = self.num_clips
|
| 505 |
+
if self.modality == 'sound':
|
| 506 |
+
new_indices = [indices[i * self.num_frames: (i + 1) * self.num_frames]
|
| 507 |
+
for i in range(num_clips)]
|
| 508 |
+
for curr_indiecs in new_indices:
|
| 509 |
+
center_idx = (curr_indiecs[self.num_frames // 2 - 1] + curr_indiecs[self.num_frames // 2]) // 2 \
|
| 510 |
+
if self.num_frames % 2 == 0 else curr_indiecs[self.num_frames // 2]
|
| 511 |
+
center_idx = min(record.num_frames, center_idx)
|
| 512 |
+
seg_imgs = load_sound(self.root_path, record, center_idx,
|
| 513 |
+
self.fps, self.audio_length, self.resampling_rate)
|
| 514 |
+
images.extend(seg_imgs)
|
| 515 |
+
else:
|
| 516 |
+
images = []
|
| 517 |
+
for seg_ind in indices:
|
| 518 |
+
new_seg_ind = [min(seg_ind + record.start_frame - 1 + i, record.num_frames)
|
| 519 |
+
for i in range(self.num_consecutive_frames)]
|
| 520 |
+
seg_imgs = load_image(self.root_path, record.path, self.image_tmpl,
|
| 521 |
+
new_seg_ind, self.modality)
|
| 522 |
+
images.extend(seg_imgs)
|
| 523 |
+
return images
|
| 524 |
+
|
| 525 |
+
def get_label(self, record):
|
| 526 |
+
if self.test_mode:
|
| 527 |
+
# in test mode, return the video id as label
|
| 528 |
+
label = record.video_id
|
| 529 |
+
else:
|
| 530 |
+
if not self.multi_label:
|
| 531 |
+
label = int(record.label)
|
| 532 |
+
else:
|
| 533 |
+
# create a binary vector.
|
| 534 |
+
label = torch.zeros(self.num_classes, dtype=torch.float)
|
| 535 |
+
for x in record.label:
|
| 536 |
+
label[int(x)] = 1.0
|
| 537 |
+
return label
|
| 538 |
+
|
| 539 |
+
def __len__(self):
|
| 540 |
+
return len(self.video_list)
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
class VideoDataSetLMDB(data.Dataset):
|
| 544 |
+
# do not support sound
|
| 545 |
+
def __init__(self, datadir, db_name, num_groups=16, frames_per_group=1, sample_offset=0, num_clips=1,
|
| 546 |
+
modality='rgb', dense_sampling=False, fixed_offset=True,
|
| 547 |
+
image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False,
|
| 548 |
+
seperator=' ', filter_video=0, num_classes=None, whole_video=False,
|
| 549 |
+
fps=29.97, audio_length=1.28, resampling_rate=24000):
|
| 550 |
+
"""
|
| 551 |
+
|
| 552 |
+
Arguments have different meaning when dense_sampling is True:
|
| 553 |
+
- num_groups ==> number of frames
|
| 554 |
+
- frames_per_group ==> sample every K frame
|
| 555 |
+
- sample_offset ==> number of clips used in validation or test mode
|
| 556 |
+
|
| 557 |
+
Args:
|
| 558 |
+
db_path (str): the file path to the root of video folder
|
| 559 |
+
num_groups (int): number of frames per data sample
|
| 560 |
+
frames_per_group (int): number of frames within one group
|
| 561 |
+
sample_offset (int): used in validation/test, the offset when sampling frames from a group
|
| 562 |
+
modality (str): rgb or flow
|
| 563 |
+
dense_sampling (bool): dense sampling in I3D
|
| 564 |
+
fixed_offset (bool): used for generating the same videos used in TSM
|
| 565 |
+
image_tmpl (str): template of image ids
|
| 566 |
+
transform: the transformer for preprocessing
|
| 567 |
+
is_train (bool): shuffle the video but keep the causality
|
| 568 |
+
test_mode (bool): testing mode, no label
|
| 569 |
+
"""
|
| 570 |
+
# TODO: handle multi-label?
|
| 571 |
+
# TODO: flow data?
|
| 572 |
+
|
| 573 |
+
if not _HAS_LMDB:
|
| 574 |
+
raise ValueError(_LMDB_ERROR_MSG)
|
| 575 |
+
|
| 576 |
+
if modality not in ['flow', 'rgb', 'rgbdiff']:
|
| 577 |
+
raise ValueError("modality should be 'flow' or 'rgb'.")
|
| 578 |
+
|
| 579 |
+
self.db_path = os.path.join(datadir, db_name)
|
| 580 |
+
|
| 581 |
+
self.num_groups = num_groups
|
| 582 |
+
self.num_frames = num_groups
|
| 583 |
+
self.frames_per_group = frames_per_group
|
| 584 |
+
self.sample_freq = frames_per_group
|
| 585 |
+
self.num_clips = num_clips
|
| 586 |
+
self.sample_offset = sample_offset
|
| 587 |
+
self.fixed_offset = fixed_offset
|
| 588 |
+
self.dense_sampling = dense_sampling
|
| 589 |
+
self.modality = modality.lower()
|
| 590 |
+
self.image_tmpl = image_tmpl
|
| 591 |
+
self.transform = transform
|
| 592 |
+
self.is_train = is_train
|
| 593 |
+
self.test_mode = test_mode
|
| 594 |
+
self.seperator = seperator
|
| 595 |
+
self.filter_video = filter_video
|
| 596 |
+
self.whole_video = whole_video
|
| 597 |
+
self.fps = fps
|
| 598 |
+
self.audio_length = audio_length
|
| 599 |
+
self.resampling_rate = resampling_rate
|
| 600 |
+
self.video_length = (self.num_frames * self.sample_freq) / self.fps
|
| 601 |
+
|
| 602 |
+
if self.modality in ['flow', 'rgbdiff']:
|
| 603 |
+
self.num_consecutive_frames = 5
|
| 604 |
+
else:
|
| 605 |
+
self.num_consecutive_frames = 1
|
| 606 |
+
|
| 607 |
+
self.multi_label = None
|
| 608 |
+
|
| 609 |
+
self.db = None
|
| 610 |
+
db = lmdb.open(self.db_path, max_readers=1, subdir=os.path.isdir(self.db_path),
|
| 611 |
+
readonly=True, lock=False, readahead=False, meminit=False)
|
| 612 |
+
with db.begin(write=False) as txn:
|
| 613 |
+
self.length = pa.deserialize(txn.get(b'__len__'))
|
| 614 |
+
self.keys = pa.deserialize(txn.get(b'__keys__'))
|
| 615 |
+
db.close()
|
| 616 |
+
|
| 617 |
+
# TODO: a hack way to filter video
|
| 618 |
+
self.list_file = self.db_path.replace(".lmdb", ".txt")
|
| 619 |
+
|
| 620 |
+
valid_video_numbers = self.length
|
| 621 |
+
invalid_video_ids = []
|
| 622 |
+
if self.filter_video > 0:
|
| 623 |
+
valid_video_numbers = 0
|
| 624 |
+
invalid_video_ids = []
|
| 625 |
+
for x in open(self.list_file):
|
| 626 |
+
elements = x.strip().split(self.seperator)
|
| 627 |
+
start_frame = int(elements[1])
|
| 628 |
+
end_frame = int(elements[2])
|
| 629 |
+
total_frame = end_frame - start_frame + 1
|
| 630 |
+
if self.test_mode:
|
| 631 |
+
valid_video_numbers += 1
|
| 632 |
+
else:
|
| 633 |
+
if total_frame >= self.filter_video:
|
| 634 |
+
valid_video_numbers += 1
|
| 635 |
+
else:
|
| 636 |
+
name = u'{}'.format(elements[0].split("/")[-1]).encode('ascii')
|
| 637 |
+
invalid_video_ids.append(name)
|
| 638 |
+
|
| 639 |
+
print("The number of videos is {} (with more than {} frames) "
|
| 640 |
+
"(original: {})".format(valid_video_numbers, self.filter_video, self.length),
|
| 641 |
+
flush=True)
|
| 642 |
+
|
| 643 |
+
# remove keys and update length
|
| 644 |
+
self.length = valid_video_numbers
|
| 645 |
+
self.keys = [k for k in self.keys if k not in invalid_video_ids]
|
| 646 |
+
|
| 647 |
+
if self.length != len(self.keys):
|
| 648 |
+
raise ValueError("Do not filter video correctly.")
|
| 649 |
+
|
| 650 |
+
self.num_classes = num_classes
|
| 651 |
+
self.unpacked_video = None
|
| 652 |
+
|
| 653 |
+
def remove_data(self, idx):
|
| 654 |
+
original_video_num = self.length
|
| 655 |
+
self.keys = [v for i, v in enumerate(self.keys) if i not in idx]
|
| 656 |
+
self.length -= len(idx)
|
| 657 |
+
print("Original videos: {}\t remove {} videos, remaining {} videos".format(original_video_num, len(idx), self.length))
|
| 658 |
+
|
| 659 |
+
def _sample_indices(self, record):
|
| 660 |
+
return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 661 |
+
self.sample_freq, self.dense_sampling, self.num_clips)
|
| 662 |
+
|
| 663 |
+
def _get_val_indices(self, record):
|
| 664 |
+
return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 665 |
+
self.sample_freq, self.dense_sampling, self.fixed_offset,
|
| 666 |
+
self.num_clips, self.whole_video)
|
| 667 |
+
|
| 668 |
+
def __getitem__(self, index):
|
| 669 |
+
unpacked_video = self.maybe_open_and_get_buffer(index)
|
| 670 |
+
num_frames = unpacked_video[0] - 1 if self.modality == 'rgbdiff' else unpacked_video[0]
|
| 671 |
+
record = VideoRecord(self.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
|
| 672 |
+
indices = self._sample_indices(record) if self.is_train else self._get_val_indices(record)
|
| 673 |
+
images = self.get_data(record, indices, unpacked_video)
|
| 674 |
+
images = self.transform(images)
|
| 675 |
+
label = self.get_label(record)
|
| 676 |
+
self.unpacked_video = None
|
| 677 |
+
# re-order data to targeted format.
|
| 678 |
+
return images, label
|
| 679 |
+
|
| 680 |
+
def maybe_open_and_get_buffer(self, index):
|
| 681 |
+
if self.db is None:
|
| 682 |
+
self.db = lmdb.open(self.db_path, max_readers=1, subdir=os.path.isdir(self.db_path),
|
| 683 |
+
readonly=True, lock=False, readahead=False, meminit=False)
|
| 684 |
+
|
| 685 |
+
with self.db.begin(write=False) as txn:
|
| 686 |
+
byteflow = txn.get(self.keys[index])
|
| 687 |
+
try:
|
| 688 |
+
unpacked_video = pa.deserialize(byteflow)
|
| 689 |
+
except Exception as e:
|
| 690 |
+
with self.db.begin(write=False) as txn:
|
| 691 |
+
byteflow = txn.get(self.keys[0])
|
| 692 |
+
unpacked_video = pa.deserialize(byteflow)
|
| 693 |
+
print(self.keys[index], e, flush=True)
|
| 694 |
+
|
| 695 |
+
self.unpacked_video = unpacked_video
|
| 696 |
+
return unpacked_video
|
| 697 |
+
|
| 698 |
+
def get_data(self, record, indices, unpacked_video):
|
| 699 |
+
images = []
|
| 700 |
+
for seg_ind in indices:
|
| 701 |
+
new_seg_ind = [min(seg_ind + record.start_frame - 1 + i, record.num_frames)
|
| 702 |
+
for i in range(self.num_consecutive_frames)]
|
| 703 |
+
img = load_data_lmdb(unpacked_video, new_seg_ind, self.modality)
|
| 704 |
+
images.extend(img)
|
| 705 |
+
return images
|
| 706 |
+
|
| 707 |
+
def get_label(self, record):
|
| 708 |
+
if self.test_mode:
|
| 709 |
+
# in test mode, return the video id as label
|
| 710 |
+
label = record.video_id
|
| 711 |
+
else:
|
| 712 |
+
if not self.multi_label:
|
| 713 |
+
label = int(record.label)
|
| 714 |
+
else:
|
| 715 |
+
# create a binary vector.
|
| 716 |
+
label = torch.zeros(self.num_classes, dtype=torch.float)
|
| 717 |
+
for x in record.label:
|
| 718 |
+
label[int(x)] = 1.0
|
| 719 |
+
return label
|
| 720 |
+
|
| 721 |
+
def __len__(self):
|
| 722 |
+
return self.length
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
class MultiVideoDataSet(data.Dataset):
|
| 726 |
+
|
| 727 |
+
def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1,
|
| 728 |
+
modality='rgb', dense_sampling=False, fixed_offset=True,
|
| 729 |
+
image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
|
| 730 |
+
filter_video=0, num_classes=None, whole_video=False,
|
| 731 |
+
fps=29.97, audio_length=1.28, resampling_rate=24000):
|
| 732 |
+
"""
|
| 733 |
+
# root_path, modality and transform become list, each for one modality
|
| 734 |
+
|
| 735 |
+
Argments have different meaning when dense_sampling is True:
|
| 736 |
+
- num_groups ==> number of frames
|
| 737 |
+
- frames_per_group ==> sample every K frame
|
| 738 |
+
- sample_offset ==> number of clips used in validation or test mode
|
| 739 |
+
|
| 740 |
+
Args:
|
| 741 |
+
root_path (str): the file path to the root of video folder
|
| 742 |
+
list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
|
| 743 |
+
num_groups (int): number of frames per data sample
|
| 744 |
+
frames_per_group (int): number of frames within one group
|
| 745 |
+
sample_offset (int): used in validation/test, the offset when sampling frames from a group
|
| 746 |
+
modality (str): rgb or flow
|
| 747 |
+
dense_sampling (bool): dense sampling in I3D
|
| 748 |
+
fixed_offset (bool): used for generating the same videos used in TSM
|
| 749 |
+
image_tmpl (str): template of image ids
|
| 750 |
+
transform: the transformer for preprocessing
|
| 751 |
+
is_train (bool): shuffle the video but keep the causality
|
| 752 |
+
test_mode (bool): testing mode, no label
|
| 753 |
+
"""
|
| 754 |
+
|
| 755 |
+
video_datasets = []
|
| 756 |
+
for i in range(len(modality)):
|
| 757 |
+
tmp = VideoDataSet(root_path[i], os.path.join(root_path[i], list_file),
|
| 758 |
+
num_groups, frames_per_group, sample_offset,
|
| 759 |
+
num_clips, modality[i], dense_sampling, fixed_offset,
|
| 760 |
+
image_tmpl, transform[i], is_train, test_mode, seperator,
|
| 761 |
+
filter_video, num_classes, whole_video, fps, audio_length, resampling_rate)
|
| 762 |
+
video_datasets.append(tmp)
|
| 763 |
+
|
| 764 |
+
self.video_datasets = video_datasets
|
| 765 |
+
self.is_train = is_train
|
| 766 |
+
self.test_mode = test_mode
|
| 767 |
+
self.num_frames = num_groups
|
| 768 |
+
self.sample_freq = frames_per_group
|
| 769 |
+
self.dense_sampling = dense_sampling
|
| 770 |
+
self.num_clips = num_clips
|
| 771 |
+
self.fixed_offset = fixed_offset
|
| 772 |
+
self.modality = modality
|
| 773 |
+
self.num_classes = num_classes
|
| 774 |
+
self.whole_video = whole_video
|
| 775 |
+
|
| 776 |
+
self.video_list = video_datasets[0].video_list
|
| 777 |
+
self.num_consecutive_frames = max([x.num_consecutive_frames for x in self.video_datasets])
|
| 778 |
+
|
| 779 |
+
def _sample_indices(self, record):
|
| 780 |
+
return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 781 |
+
self.sample_freq, self.dense_sampling, self.num_clips)
|
| 782 |
+
|
| 783 |
+
def _get_val_indices(self, record):
|
| 784 |
+
return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 785 |
+
self.sample_freq, self.dense_sampling, self.fixed_offset,
|
| 786 |
+
self.num_clips, self.whole_video)
|
| 787 |
+
|
| 788 |
+
def remove_data(self, idx):
|
| 789 |
+
for i in range(len(self.video_datasets)):
|
| 790 |
+
self.video_datasets[i].remove_data(idx)
|
| 791 |
+
self.video_list = self.video_datasets[0].video_list
|
| 792 |
+
|
| 793 |
+
def __getitem__(self, index):
|
| 794 |
+
"""
|
| 795 |
+
Returns:
|
| 796 |
+
torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group.
|
| 797 |
+
torch.FloatTensor: the label
|
| 798 |
+
"""
|
| 799 |
+
|
| 800 |
+
record = self.video_list[index]
|
| 801 |
+
if self.is_train:
|
| 802 |
+
indices = self._sample_indices(record)
|
| 803 |
+
else:
|
| 804 |
+
indices = self._get_val_indices(record)
|
| 805 |
+
|
| 806 |
+
multi_modalities = []
|
| 807 |
+
for modality, video_dataset in zip(self.modality, self.video_datasets):
|
| 808 |
+
record = video_dataset.video_list[index]
|
| 809 |
+
images = video_dataset.get_data(record, indices)
|
| 810 |
+
images = video_dataset.transform(images)
|
| 811 |
+
label = video_dataset.get_label(record)
|
| 812 |
+
multi_modalities.append((images, label))
|
| 813 |
+
|
| 814 |
+
return [x for x, y in multi_modalities], multi_modalities[0][1]
|
| 815 |
+
|
| 816 |
+
def __len__(self):
|
| 817 |
+
return len(self.video_list)
|
| 818 |
+
|
| 819 |
+
|
| 820 |
+
class MultiVideoDataSetLMDB(data.Dataset):
|
| 821 |
+
|
| 822 |
+
def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1,
|
| 823 |
+
modality='rgb', dense_sampling=False, fixed_offset=True,
|
| 824 |
+
image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
|
| 825 |
+
filter_video=0, num_classes=None, whole_video=False,
|
| 826 |
+
fps=29.97, audio_length=1.28, resampling_rate=24000):
|
| 827 |
+
"""
|
| 828 |
+
# root_path, modality and transform become list, each for one modality
|
| 829 |
+
|
| 830 |
+
Argments have different meaning when dense_sampling is True:
|
| 831 |
+
- num_groups ==> number of frames
|
| 832 |
+
- frames_per_group ==> sample every K frame
|
| 833 |
+
- sample_offset ==> number of clips used in validation or test mode
|
| 834 |
+
|
| 835 |
+
Args:
|
| 836 |
+
root_path (str): the file path to the root of video folder
|
| 837 |
+
list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
|
| 838 |
+
num_groups (int): number of frames per data sample
|
| 839 |
+
frames_per_group (int): number of frames within one group
|
| 840 |
+
sample_offset (int): used in validation/test, the offset when sampling frames from a group
|
| 841 |
+
modality (str): rgb or flow
|
| 842 |
+
dense_sampling (bool): dense sampling in I3D
|
| 843 |
+
fixed_offset (bool): used for generating the same videos used in TSM
|
| 844 |
+
image_tmpl (str): template of image ids
|
| 845 |
+
transform: the transformer for preprocessing
|
| 846 |
+
is_train (bool): shuffle the video but keep the causality
|
| 847 |
+
test_mode (bool): testing mode, no label
|
| 848 |
+
"""
|
| 849 |
+
|
| 850 |
+
video_datasets = []
|
| 851 |
+
for i in range(len(modality)):
|
| 852 |
+
if modality[i] == 'sound':
|
| 853 |
+
list_file_ = list_file.replace(".lmdb", ".txt")
|
| 854 |
+
tmp = VideoDataSet(root_path[i], os.path.join(root_path[i], list_file_),
|
| 855 |
+
num_groups, frames_per_group, sample_offset,
|
| 856 |
+
num_clips, modality[i], dense_sampling, fixed_offset,
|
| 857 |
+
image_tmpl, transform[i], is_train, test_mode, seperator,
|
| 858 |
+
filter_video, num_classes, whole_video, fps, audio_length, resampling_rate)
|
| 859 |
+
else:
|
| 860 |
+
tmp = VideoDataSetLMDB(root_path[i], list_file, num_groups, frames_per_group,
|
| 861 |
+
sample_offset, num_clips, modality[i], dense_sampling,
|
| 862 |
+
fixed_offset, image_tmpl, transform[i], is_train, test_mode,
|
| 863 |
+
seperator, filter_video, num_classes, whole_video, fps, audio_length,
|
| 864 |
+
resampling_rate)
|
| 865 |
+
video_datasets.append(tmp)
|
| 866 |
+
|
| 867 |
+
self.video_datasets = video_datasets
|
| 868 |
+
self.is_train = is_train
|
| 869 |
+
self.test_mode = test_mode
|
| 870 |
+
self.num_frames = num_groups
|
| 871 |
+
self.sample_freq = frames_per_group
|
| 872 |
+
self.dense_sampling = dense_sampling
|
| 873 |
+
self.num_clips = num_clips
|
| 874 |
+
self.fixed_offset = fixed_offset
|
| 875 |
+
self.modality = modality
|
| 876 |
+
self.num_classes = num_classes
|
| 877 |
+
self.whole_video = whole_video
|
| 878 |
+
|
| 879 |
+
self.num_consecutive_frames = max([x.num_consecutive_frames for x in self.video_datasets])
|
| 880 |
+
|
| 881 |
+
def _sample_indices(self, record):
|
| 882 |
+
return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 883 |
+
self.sample_freq, self.dense_sampling, self.num_clips)
|
| 884 |
+
|
| 885 |
+
def _get_val_indices(self, record):
|
| 886 |
+
return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 887 |
+
self.sample_freq, self.dense_sampling, self.fixed_offset,
|
| 888 |
+
self.num_clips, self.whole_video)
|
| 889 |
+
|
| 890 |
+
def remove_data(self, idx):
|
| 891 |
+
for i in range(len(self.video_datasets)):
|
| 892 |
+
self.video_datasets[i].remove_data(idx)
|
| 893 |
+
|
| 894 |
+
def __getitem__(self, index):
|
| 895 |
+
"""
|
| 896 |
+
Returns:
|
| 897 |
+
torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group.
|
| 898 |
+
torch.FloatTensor: the label
|
| 899 |
+
"""
|
| 900 |
+
multi_modalities = []
|
| 901 |
+
indices = None
|
| 902 |
+
for modality, video_dataset in zip(self.modality, self.video_datasets):
|
| 903 |
+
if indices is None:
|
| 904 |
+
if modality == 'sound':
|
| 905 |
+
record = video_dataset.video_list[index]
|
| 906 |
+
else:
|
| 907 |
+
unpacked_video = video_dataset.maybe_open_and_get_buffer(index)
|
| 908 |
+
num_frames = unpacked_video[0] - 1 if modality == 'rgbdiff' else unpacked_video[0]
|
| 909 |
+
record = VideoRecord(video_dataset.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
|
| 910 |
+
indices = video_dataset._sample_indices(record) if video_dataset.is_train else video_dataset._get_val_indices(record)
|
| 911 |
+
|
| 912 |
+
if modality == 'sound':
|
| 913 |
+
record = video_dataset.video_list[index]
|
| 914 |
+
images = video_dataset.get_data(record, indices)
|
| 915 |
+
else:
|
| 916 |
+
if video_dataset.unpacked_video is None:
|
| 917 |
+
video_dataset.maybe_open_and_get_buffer(index)
|
| 918 |
+
unpacked_video = video_dataset.unpacked_video
|
| 919 |
+
num_frames = unpacked_video[0] - 1 if modality == 'rgbdiff' else unpacked_video[0]
|
| 920 |
+
record = VideoRecord(video_dataset.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
|
| 921 |
+
|
| 922 |
+
images = video_dataset.get_data(record, indices, video_dataset.unpacked_video)
|
| 923 |
+
video_dataset.unpacked_video = None
|
| 924 |
+
|
| 925 |
+
images = video_dataset.transform(images)
|
| 926 |
+
label = video_dataset.get_label(record)
|
| 927 |
+
multi_modalities.append((images, label))
|
| 928 |
+
|
| 929 |
+
return [x for x, y in multi_modalities], multi_modalities[0][1]
|
| 930 |
+
|
| 931 |
+
def __len__(self):
|
| 932 |
+
return len(self.video_datasets[0])
|
| 933 |
+
|
| 934 |
+
|
| 935 |
+
class VideoDataSetOnline(VideoDataSet):
|
| 936 |
+
|
| 937 |
+
def __init__(self, root_path, list_file, num_groups=8, frames_per_group=1, sample_offset=0,
|
| 938 |
+
num_clips=1, modality='rgb', dense_sampling=False, fixed_offset=True,
|
| 939 |
+
image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
|
| 940 |
+
filter_video=0, num_classes=None, whole_video=False,
|
| 941 |
+
fps=29.97, audio_length=1.28, resampling_rate=24000):
|
| 942 |
+
"""
|
| 943 |
+
|
| 944 |
+
Arguments have different meaning when dense_sampling is True:
|
| 945 |
+
- num_groups ==> number of frames
|
| 946 |
+
- frames_per_group ==> sample every K frame
|
| 947 |
+
- sample_offset ==> number of clips used in validation or test mode
|
| 948 |
+
|
| 949 |
+
Args:
|
| 950 |
+
root_path (str): the file path to the root of video folder
|
| 951 |
+
list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
|
| 952 |
+
num_groups (int): number of frames per data sample
|
| 953 |
+
frames_per_group (int): number of frames within one group
|
| 954 |
+
sample_offset (int): used in validation/test, the offset when sampling frames from a group
|
| 955 |
+
modality (str): rgb or flow
|
| 956 |
+
dense_sampling (bool): dense sampling in I3D
|
| 957 |
+
fixed_offset (bool): used for generating the same videos used in TSM
|
| 958 |
+
image_tmpl (str): template of image ids
|
| 959 |
+
transform: the transformer for preprocessing
|
| 960 |
+
is_train (bool): shuffle the video but keep the causality
|
| 961 |
+
test_mode (bool): testing mode, no label
|
| 962 |
+
fps (float): frame rate per second, used to localize sound when frame idx is selected.
|
| 963 |
+
audio_length (float): the time window to extract audio feature.
|
| 964 |
+
resampling_rate (int): used to resampling audio extracted from wav
|
| 965 |
+
"""
|
| 966 |
+
|
| 967 |
+
if not _HAS_PYAV:
|
| 968 |
+
raise ValueError(_PYAV_ERROR_MSG)
|
| 969 |
+
if modality not in ['rgb', 'rgbdiff']:
|
| 970 |
+
raise ValueError("modality should be 'rgb' or 'rgbdiff'.")
|
| 971 |
+
|
| 972 |
+
super().__init__(root_path, list_file, num_groups, frames_per_group, sample_offset,
|
| 973 |
+
num_clips, modality, dense_sampling, fixed_offset,
|
| 974 |
+
image_tmpl, transform, is_train, test_mode, seperator,
|
| 975 |
+
filter_video, num_classes, whole_video, fps, audio_length, resampling_rate)
|
| 976 |
+
|
| 977 |
+
def remove_data(self, idx):
|
| 978 |
+
original_video_num = len(self.video_list)
|
| 979 |
+
self.video_list = [v for i, v in enumerate(self.video_list) if i not in idx]
|
| 980 |
+
print("Original videos: {}\t remove {} videos, remaining {} videos".format(original_video_num, len(idx), len(self.video_list)))
|
| 981 |
+
|
| 982 |
+
def get_data(self, record, indices):
|
| 983 |
+
indices = indices - 1
|
| 984 |
+
container = av.open(os.path.join(self.root_path, record.path))
|
| 985 |
+
container.streams.video[0].thread_type = "AUTO"
|
| 986 |
+
frames_length = container.streams.video[0].frames
|
| 987 |
+
duration = container.streams.video[0].duration
|
| 988 |
+
if duration is None or frames_length == 0:
|
| 989 |
+
# If failed to fetch the decoding information, decode the entire video.
|
| 990 |
+
# video_start_pts, video_end_pts = 0, math.inf
|
| 991 |
+
decode_all = True
|
| 992 |
+
else:
|
| 993 |
+
# Perform selective decoding.
|
| 994 |
+
if frames_length != record.num_frames:
|
| 995 |
+
# remap the index
|
| 996 |
+
length_ratio = frames_length / record.num_frames
|
| 997 |
+
indices = np.around(indices * length_ratio).astype(int)
|
| 998 |
+
start_idx, end_idx = min(indices), max(indices)
|
| 999 |
+
# if self.modality == 'rgbdiff':
|
| 1000 |
+
# end_idx += (self.num_consecutive_frames + 1)
|
| 1001 |
+
timebase = duration / frames_length
|
| 1002 |
+
video_start_pts = int(start_idx * timebase)
|
| 1003 |
+
video_end_pts = int(end_idx * timebase)
|
| 1004 |
+
decode_all = False
|
| 1005 |
+
|
| 1006 |
+
def _selective_decoding(container, index, timebase):
|
| 1007 |
+
margin = 1024
|
| 1008 |
+
start_idx, end_idx = min(index), max(index)
|
| 1009 |
+
video_start_pts = int(start_idx * timebase)
|
| 1010 |
+
video_end_pts = int(end_idx * timebase)
|
| 1011 |
+
seek_offset = max(video_start_pts - margin, 0)
|
| 1012 |
+
container.seek(seek_offset, any_frame=False, backward=True,
|
| 1013 |
+
stream=container.streams.video[0])
|
| 1014 |
+
success = True
|
| 1015 |
+
video_frames = None
|
| 1016 |
+
try:
|
| 1017 |
+
frames = {}
|
| 1018 |
+
for frame in container.decode({'video': 0}):
|
| 1019 |
+
if frame.pts < video_start_pts:
|
| 1020 |
+
continue
|
| 1021 |
+
if frame.pts <= video_end_pts:
|
| 1022 |
+
frames[frame.pts] = frame
|
| 1023 |
+
else:
|
| 1024 |
+
break
|
| 1025 |
+
# the decoded frames is a whole region but we might subsample it
|
| 1026 |
+
video_frames = np.asarray([frames[pts].to_rgb().to_ndarray() for pts in sorted(frames)])
|
| 1027 |
+
index = np.linspace(0, max(0, len(video_frames) - 1), num=self.num_frames, dtype=int)
|
| 1028 |
+
if len(video_frames) == 0: # somehow decoding is wrong
|
| 1029 |
+
success = False
|
| 1030 |
+
else:
|
| 1031 |
+
video_frames = video_frames[index, ...]
|
| 1032 |
+
except Exception as e:
|
| 1033 |
+
success = False
|
| 1034 |
+
|
| 1035 |
+
return video_frames, success
|
| 1036 |
+
|
| 1037 |
+
# If video stream was found, fetch video frames from the video.
|
| 1038 |
+
# Seeking in the stream is imprecise. Thus, seek to an ealier PTS by a
|
| 1039 |
+
# margin pts.
|
| 1040 |
+
if not decode_all:
|
| 1041 |
+
timebase = duration / frames_length
|
| 1042 |
+
video_frames = None
|
| 1043 |
+
for i in range(self.num_clips):
|
| 1044 |
+
curr_index = indices[(i) * self.num_frames: (i + 1) * self.num_frames]
|
| 1045 |
+
curr_video_frames, success = _selective_decoding(container, curr_index, timebase)
|
| 1046 |
+
if not success:
|
| 1047 |
+
decode_all = True
|
| 1048 |
+
break
|
| 1049 |
+
if video_frames is not None:
|
| 1050 |
+
video_frames = np.concatenate((video_frames, curr_video_frames), axis=0)
|
| 1051 |
+
else:
|
| 1052 |
+
video_frames = curr_video_frames
|
| 1053 |
+
if decode_all:
|
| 1054 |
+
container.seek(0, any_frame=False, backward=True, stream=container.streams.video[0])
|
| 1055 |
+
frames = {}
|
| 1056 |
+
for frame in container.decode({'video': 0}):
|
| 1057 |
+
frames[frame.pts] = frame
|
| 1058 |
+
video_frames = np.asarray([frames[pts].to_rgb().to_ndarray() for pts in sorted(frames)])
|
| 1059 |
+
total_frames = len(video_frames)
|
| 1060 |
+
if total_frames != record.num_frames:
|
| 1061 |
+
# remap the index
|
| 1062 |
+
length_ratio = total_frames / record.num_frames
|
| 1063 |
+
indices = np.around(indices * length_ratio).astype(int)
|
| 1064 |
+
video_frames = video_frames[indices, ...]
|
| 1065 |
+
|
| 1066 |
+
"""
|
| 1067 |
+
if self.modality == 'rgbdiff':
|
| 1068 |
+
video_diff = np.asarray(video_frames[1:, ...].copy(), dtype=np.float) - np.asarray(video_frames[:-1, ...].copy(), dtype=np.float)
|
| 1069 |
+
video_diff += 255.0
|
| 1070 |
+
video_diff *= (255.0 / float(2 * 255.0))
|
| 1071 |
+
video_diff = video_diff.astype(np.uint8)
|
| 1072 |
+
for seg_ind in indices:
|
| 1073 |
+
new_seg_ind = [min(seg_ind + i, total_frames - 1)
|
| 1074 |
+
for i in range(self.num_consecutive_frames) ]
|
| 1075 |
+
|
| 1076 |
+
video_frames = video_diff
|
| 1077 |
+
else:
|
| 1078 |
+
"""
|
| 1079 |
+
images = [Image.fromarray(frame) for frame in video_frames]
|
| 1080 |
+
# TODO: support rgb diff, calculate end_pts differently.
|
| 1081 |
+
container.close()
|
| 1082 |
+
return images
|
| 1083 |
+
|
| 1084 |
+
|
| 1085 |
+
class MultiVideoDataSetOnline(data.Dataset):
|
| 1086 |
+
|
| 1087 |
+
def __init__(self, root_path, list_file, num_groups=64, frames_per_group=1, sample_offset=0, num_clips=1,
|
| 1088 |
+
modality='rgb', dense_sampling=False, fixed_offset=True,
|
| 1089 |
+
image_tmpl='{:05d}.jpg', transform=None, is_train=True, test_mode=False, seperator=' ',
|
| 1090 |
+
filter_video=0, num_classes=None, whole_video=False,
|
| 1091 |
+
fps=29.97, audio_length=1.28, resampling_rate=24000):
|
| 1092 |
+
"""
|
| 1093 |
+
# root_path, modality and transform become list, each for one modality
|
| 1094 |
+
|
| 1095 |
+
Argments have different meaning when dense_sampling is True:
|
| 1096 |
+
- num_groups ==> number of frames
|
| 1097 |
+
- frames_per_group ==> sample every K frame
|
| 1098 |
+
- sample_offset ==> number of clips used in validation or test mode
|
| 1099 |
+
|
| 1100 |
+
Args:
|
| 1101 |
+
root_path (str): the file path to the root of video folder
|
| 1102 |
+
list_file (str): the file list, each line with folder_path, start_frame, end_frame, label_id
|
| 1103 |
+
num_groups (int): number of frames per data sample
|
| 1104 |
+
frames_per_group (int): number of frames within one group
|
| 1105 |
+
sample_offset (int): used in validation/test, the offset when sampling frames from a group
|
| 1106 |
+
modality (str): rgb or flow
|
| 1107 |
+
dense_sampling (bool): dense sampling in I3D
|
| 1108 |
+
fixed_offset (bool): used for generating the same videos used in TSM
|
| 1109 |
+
image_tmpl (str): template of image ids
|
| 1110 |
+
transform: the transformer for preprocessing
|
| 1111 |
+
is_train (bool): shuffle the video but keep the causality
|
| 1112 |
+
test_mode (bool): testing mode, no label
|
| 1113 |
+
"""
|
| 1114 |
+
|
| 1115 |
+
# TODO: support mixed LMDB, pyAV, etc.
|
| 1116 |
+
|
| 1117 |
+
video_datasets = []
|
| 1118 |
+
for i in range(len(modality)):
|
| 1119 |
+
if modality[i] == 'rgb' or modality[i] == 'rgbdiff':
|
| 1120 |
+
video_dataset_cls = VideoDataSetOnline
|
| 1121 |
+
list_file_ = list_file
|
| 1122 |
+
elif modality[i] == 'sound':
|
| 1123 |
+
video_dataset_cls = VideoDataSet
|
| 1124 |
+
list_file_ = list_file
|
| 1125 |
+
elif modality[i] == 'flow':
|
| 1126 |
+
video_dataset_cls = VideoDataSetLMDB
|
| 1127 |
+
list_file_ = list_file.replace(".txt", ".lmdb")
|
| 1128 |
+
|
| 1129 |
+
tmp = video_dataset_cls(root_path[i], list_file_,
|
| 1130 |
+
num_groups, frames_per_group, sample_offset,
|
| 1131 |
+
num_clips, modality[i], dense_sampling, fixed_offset,
|
| 1132 |
+
image_tmpl, transform[i], is_train, test_mode, seperator,
|
| 1133 |
+
filter_video, num_classes, whole_video, fps, audio_length, resampling_rate)
|
| 1134 |
+
video_datasets.append(tmp)
|
| 1135 |
+
|
| 1136 |
+
self.video_datasets = video_datasets
|
| 1137 |
+
self.is_train = is_train
|
| 1138 |
+
self.test_mode = test_mode
|
| 1139 |
+
self.num_frames = num_groups
|
| 1140 |
+
self.sample_freq = frames_per_group
|
| 1141 |
+
self.dense_sampling = dense_sampling
|
| 1142 |
+
self.num_clips = num_clips
|
| 1143 |
+
self.fixed_offset = fixed_offset
|
| 1144 |
+
self.modality = modality
|
| 1145 |
+
self.num_classes = num_classes
|
| 1146 |
+
self.whole_video = whole_video
|
| 1147 |
+
|
| 1148 |
+
self.video_list = video_datasets[0].video_list
|
| 1149 |
+
self.num_consecutive_frames = max([x.num_consecutive_frames for x in self.video_datasets])
|
| 1150 |
+
|
| 1151 |
+
def _sample_indices(self, record):
|
| 1152 |
+
return sample_train_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 1153 |
+
self.sample_freq, self.dense_sampling, self.num_clips)
|
| 1154 |
+
|
| 1155 |
+
def _get_val_indices(self, record):
|
| 1156 |
+
return sample_val_test_clip(record.num_frames, self.num_consecutive_frames, self.num_frames,
|
| 1157 |
+
self.sample_freq, self.dense_sampling, self.fixed_offset,
|
| 1158 |
+
self.num_clips, self.whole_video)
|
| 1159 |
+
|
| 1160 |
+
def remove_data(self, idx):
|
| 1161 |
+
for i in range(len(self.video_datasets)):
|
| 1162 |
+
self.video_datasets[i].remove_data(idx)
|
| 1163 |
+
self.video_list = self.video_datasets[0].video_list
|
| 1164 |
+
|
| 1165 |
+
def __getitem__(self, index):
|
| 1166 |
+
"""
|
| 1167 |
+
Returns:
|
| 1168 |
+
torch.FloatTensor: (3xgxf)xHxW dimension, g is number of groups and f is the frames per group.
|
| 1169 |
+
torch.FloatTensor: the label
|
| 1170 |
+
"""
|
| 1171 |
+
multi_modalities = []
|
| 1172 |
+
indices = None
|
| 1173 |
+
for modality, video_dataset in zip(self.modality, self.video_datasets):
|
| 1174 |
+
if indices is None:
|
| 1175 |
+
if modality != 'flow':
|
| 1176 |
+
record = video_dataset.video_list[index]
|
| 1177 |
+
else:
|
| 1178 |
+
unpacked_video = video_dataset.maybe_open_and_get_buffer(index)
|
| 1179 |
+
num_frames = unpacked_video[0]
|
| 1180 |
+
record = VideoRecord(video_dataset.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
|
| 1181 |
+
indices = video_dataset._sample_indices(record) if video_dataset.is_train else video_dataset._get_val_indices(record)
|
| 1182 |
+
|
| 1183 |
+
if modality != 'flow':
|
| 1184 |
+
record = video_dataset.video_list[index]
|
| 1185 |
+
images = video_dataset.get_data(record, indices)
|
| 1186 |
+
else:
|
| 1187 |
+
if video_dataset.unpacked_video is None:
|
| 1188 |
+
video_dataset.maybe_open_and_get_buffer(index)
|
| 1189 |
+
unpacked_video = video_dataset.unpacked_video
|
| 1190 |
+
num_frames = unpacked_video[0]
|
| 1191 |
+
record = VideoRecord(video_dataset.keys[index].decode("utf-8"), 1, num_frames, unpacked_video[-1])
|
| 1192 |
+
|
| 1193 |
+
images = video_dataset.get_data(record, indices, video_dataset.unpacked_video)
|
| 1194 |
+
video_dataset.unpacked_video = None
|
| 1195 |
+
|
| 1196 |
+
images = video_dataset.transform(images)
|
| 1197 |
+
label = video_dataset.get_label(record)
|
| 1198 |
+
multi_modalities.append((images, label))
|
| 1199 |
+
|
| 1200 |
+
return [x for x, y in multi_modalities], multi_modalities[0][1]
|
| 1201 |
+
|
| 1202 |
+
def __len__(self):
|
| 1203 |
+
return len(self.video_datasets[0])
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
def get_dataloader(loader_type, *args, **kwargs) -> \
|
| 1207 |
+
Union[VideoDataSetLMDB, VideoDataSetOnline, VideoDataSet]:
|
| 1208 |
+
if loader_type == 'lmdb':
|
| 1209 |
+
return VideoDataSetLMDB(*args, **kwargs)
|
| 1210 |
+
elif loader_type == 'pyav':
|
| 1211 |
+
return VideoDataSetOnline(*args, **kwargs)
|
| 1212 |
+
elif loader_type == 'jpeg':
|
| 1213 |
+
return VideoDataSet(*args, **kwargs)
|
| 1214 |
+
else:
|
| 1215 |
+
raise ValueError(f'Unknown dataloader type: {loader_type}')
|
| 1216 |
+
|
| 1217 |
+
|
| 1218 |
+
def get_multimodality_dataloader(loader_type, *args, **kwargs) -> \
|
| 1219 |
+
Union[MultiVideoDataSetLMDB, MultiVideoDataSetOnline, MultiVideoDataSet]:
|
| 1220 |
+
if loader_type == 'lmdb':
|
| 1221 |
+
return MultiVideoDataSetLMDB(*args, **kwargs)
|
| 1222 |
+
elif loader_type == 'pyav':
|
| 1223 |
+
return MultiVideoDataSetOnline(*args, **kwargs)
|
| 1224 |
+
elif loader_type == 'jpeg':
|
| 1225 |
+
return MultiVideoDataSet(*args, **kwargs)
|
| 1226 |
+
else:
|
| 1227 |
+
raise ValueError(f'Unknown dataloader type: {loader_type}')
|
| 1228 |
+
|
video_dataset_aug.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing
|
| 2 |
+
from typing import List
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn.parallel
|
| 6 |
+
import torch.optim
|
| 7 |
+
import torch.utils.data
|
| 8 |
+
import torch.utils.data.distributed
|
| 9 |
+
import torchvision.transforms as transforms
|
| 10 |
+
from video_transforms import (GroupRandomHorizontalFlip, GroupOverSample,
|
| 11 |
+
GroupMultiScaleCrop, GroupScale, GroupCenterCrop, GroupRandomCrop,
|
| 12 |
+
GroupNormalize, Stack, ToTorchFormatTensor, GroupRandomScale,GroupCutout)
|
| 13 |
+
|
| 14 |
+
def get_augmentor(is_train: bool, image_size: int, mean: List[float] = None,
|
| 15 |
+
std: List[float] = None, disable_scaleup: bool = False,
|
| 16 |
+
threed_data: bool = False, version: str = 'v1', scale_range: [int] = None,
|
| 17 |
+
modality: str = 'rgb', num_clips: int = 1, num_crops: int = 1, cut_out=True,dataset: str = ''):
|
| 18 |
+
|
| 19 |
+
mean = [0.485, 0.456, 0.406] if mean is None else mean
|
| 20 |
+
std = [0.229, 0.224, 0.225] if std is None else std
|
| 21 |
+
scale_range = [256, 320] if scale_range is None else scale_range
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
augments = []
|
| 25 |
+
if is_train:
|
| 26 |
+
if version == 'v1':
|
| 27 |
+
augments += [
|
| 28 |
+
GroupMultiScaleCrop(image_size, [1, .875, .75, .66])
|
| 29 |
+
]
|
| 30 |
+
elif version == 'v2':
|
| 31 |
+
augments += [
|
| 32 |
+
GroupRandomScale(scale_range),
|
| 33 |
+
GroupRandomCrop(image_size),
|
| 34 |
+
]
|
| 35 |
+
if not (dataset.startswith('ststv') or 'jester' in dataset or 'mini_ststv' in dataset):
|
| 36 |
+
augments += [GroupRandomHorizontalFlip(is_flow=(modality == 'flow'))]
|
| 37 |
+
else:
|
| 38 |
+
scaled_size = image_size if disable_scaleup else int(image_size / 0.875 + 0.5)
|
| 39 |
+
if num_crops == 1:
|
| 40 |
+
augments += [
|
| 41 |
+
GroupScale(scaled_size),
|
| 42 |
+
GroupCenterCrop(image_size)
|
| 43 |
+
]
|
| 44 |
+
else:
|
| 45 |
+
flip = True if num_crops == 10 else False
|
| 46 |
+
augments += [
|
| 47 |
+
GroupOverSample(image_size, scaled_size, num_crops=num_crops, flip=flip),
|
| 48 |
+
]
|
| 49 |
+
augments += [
|
| 50 |
+
Stack(threed_data=threed_data),
|
| 51 |
+
ToTorchFormatTensor(num_clips_crops=num_clips * num_crops),
|
| 52 |
+
GroupNormalize(mean=mean, std=std, threed_data=threed_data)
|
| 53 |
+
]
|
| 54 |
+
if cut_out:
|
| 55 |
+
augments += [GroupCutout(n_holes=1,length=16)]
|
| 56 |
+
|
| 57 |
+
augmentor = transforms.Compose(augments)
|
| 58 |
+
return augmentor
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def build_dataflow(dataset, is_train, batch_size, workers=36, is_distributed=False):
|
| 62 |
+
workers = min(workers, multiprocessing.cpu_count())
|
| 63 |
+
shuffle = False
|
| 64 |
+
|
| 65 |
+
sampler = torch.utils.data.distributed.DistributedSampler(dataset) if is_distributed else None
|
| 66 |
+
if is_train:
|
| 67 |
+
shuffle = sampler is None
|
| 68 |
+
|
| 69 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
|
| 70 |
+
num_workers=workers, drop_last = True,pin_memory=True, sampler=sampler)
|
| 71 |
+
|
| 72 |
+
return data_loader
|
| 73 |
+
|
video_dataset_config.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
|
| 3 |
+
DATASET_CONFIG = {
|
| 4 |
+
'ffpp': {
|
| 5 |
+
'num_classes': 2,
|
| 6 |
+
'train_list_name': 'cdf_test_fold.txt',
|
| 7 |
+
'val_list_name': 'cdf_test_fold.txt',
|
| 8 |
+
'test_list_name': 'cdf_test_fold.txt',
|
| 9 |
+
'filename_seperator': " ",
|
| 10 |
+
'image_tmpl': '{:04d}.jpg',
|
| 11 |
+
'filter_video': 3,
|
| 12 |
+
}
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_dataset_config(dataset, use_lmdb=False):
|
| 17 |
+
ret = DATASET_CONFIG[dataset]
|
| 18 |
+
num_classes = ret['num_classes']
|
| 19 |
+
train_list_name = ret['train_list_name'].replace("txt", "lmdb") if use_lmdb \
|
| 20 |
+
else ret['train_list_name']
|
| 21 |
+
val_list_name = ret['val_list_name'].replace("txt", "lmdb") if use_lmdb \
|
| 22 |
+
else ret['val_list_name']
|
| 23 |
+
test_list_name = ret['test_list_name'].replace("txt", "lmdb") if use_lmdb \
|
| 24 |
+
else ret['test_list_name']
|
| 25 |
+
filename_seperator = ret['filename_seperator']
|
| 26 |
+
image_tmpl = ret['image_tmpl']
|
| 27 |
+
filter_video = ret.get('filter_video', 0)
|
| 28 |
+
label_file = ret.get('label_file', None)
|
| 29 |
+
|
| 30 |
+
return num_classes, train_list_name, val_list_name, test_list_name, filename_seperator, \
|
| 31 |
+
image_tmpl, filter_video, label_file
|
video_transforms.py
ADDED
|
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torchvision
|
| 2 |
+
import random
|
| 3 |
+
from PIL import Image, ImageOps
|
| 4 |
+
import numbers
|
| 5 |
+
import torch
|
| 6 |
+
import numpy as np
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
class GroupRandomCrop(object):
|
| 10 |
+
def __init__(self, size):
|
| 11 |
+
if isinstance(size, numbers.Number):
|
| 12 |
+
self.size = (int(size), int(size))
|
| 13 |
+
else:
|
| 14 |
+
self.size = size
|
| 15 |
+
|
| 16 |
+
def __call__(self, img_group):
|
| 17 |
+
|
| 18 |
+
w, h = img_group[0].size
|
| 19 |
+
th, tw = self.size
|
| 20 |
+
|
| 21 |
+
out_images = list()
|
| 22 |
+
|
| 23 |
+
x1 = random.randint(0, w - tw)
|
| 24 |
+
y1 = random.randint(0, h - th)
|
| 25 |
+
|
| 26 |
+
for img in img_group:
|
| 27 |
+
assert(img.size[0] == w and img.size[1] == h)
|
| 28 |
+
if w == tw and h == th:
|
| 29 |
+
out_images.append(img)
|
| 30 |
+
else:
|
| 31 |
+
out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
|
| 32 |
+
|
| 33 |
+
return out_images
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class GroupCenterCrop(object):
|
| 37 |
+
def __init__(self, size):
|
| 38 |
+
self.worker = torchvision.transforms.CenterCrop(size)
|
| 39 |
+
|
| 40 |
+
def __call__(self, img_group):
|
| 41 |
+
return [self.worker(img) for img in img_group]
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class GroupRandomHorizontalFlip(object):
|
| 45 |
+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
|
| 46 |
+
"""
|
| 47 |
+
def __init__(self, is_flow=False):
|
| 48 |
+
self.is_flow = is_flow
|
| 49 |
+
|
| 50 |
+
def __call__(self, img_group, is_flow=False):
|
| 51 |
+
v = random.random()
|
| 52 |
+
if v < 0.5:
|
| 53 |
+
ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
|
| 54 |
+
if self.is_flow:
|
| 55 |
+
for i in range(0, len(ret), 2):
|
| 56 |
+
ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping
|
| 57 |
+
return ret
|
| 58 |
+
else:
|
| 59 |
+
return img_group
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class GroupNormalize(object):
|
| 64 |
+
def __init__(self, mean, std, threed_data=False):
|
| 65 |
+
self.threed_data = threed_data
|
| 66 |
+
if self.threed_data:
|
| 67 |
+
# convert to the proper format
|
| 68 |
+
self.mean = torch.FloatTensor(mean).view(len(mean), 1, 1, 1)
|
| 69 |
+
self.std = torch.FloatTensor(std).view(len(std), 1, 1, 1)
|
| 70 |
+
else:
|
| 71 |
+
self.mean = mean
|
| 72 |
+
self.std = std
|
| 73 |
+
|
| 74 |
+
def __call__(self, tensor):
|
| 75 |
+
|
| 76 |
+
if self.threed_data:
|
| 77 |
+
tensor.sub_(self.mean).div_(self.std)
|
| 78 |
+
else:
|
| 79 |
+
rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
|
| 80 |
+
rep_std = self.std * (tensor.size()[0] // len(self.std))
|
| 81 |
+
|
| 82 |
+
# TODO: make efficient
|
| 83 |
+
for t, m, s in zip(tensor, rep_mean, rep_std):
|
| 84 |
+
t.sub_(m).div_(s)
|
| 85 |
+
|
| 86 |
+
return tensor
|
| 87 |
+
|
| 88 |
+
class GroupCutout(object):
|
| 89 |
+
"""Randomly mask out one or more patches from an image.
|
| 90 |
+
Args:
|
| 91 |
+
n_holes (int): Number of patches to cut out of each image.
|
| 92 |
+
length (int): The length (in pixels) of each square patch.
|
| 93 |
+
"""
|
| 94 |
+
def __init__(self, n_holes, length):
|
| 95 |
+
self.n_holes = n_holes
|
| 96 |
+
self.length = length
|
| 97 |
+
|
| 98 |
+
def __call__(self, imgs):
|
| 99 |
+
"""
|
| 100 |
+
Args:
|
| 101 |
+
img (Tensor): Tensor image of size (C, H, W).
|
| 102 |
+
Returns:
|
| 103 |
+
Tensor: Image with n_holes of dimension length x length cut out of it.
|
| 104 |
+
"""
|
| 105 |
+
new_imgs = []
|
| 106 |
+
# import pdb;pdb.set_trace()
|
| 107 |
+
C,W,H = imgs.shape #72,224,224
|
| 108 |
+
# print(C,W,H)
|
| 109 |
+
# imgs = imgs.reshape(-1,3,H,W)
|
| 110 |
+
y = np.random.randint(H)
|
| 111 |
+
x = np.random.randint(W)
|
| 112 |
+
for i in range(0,imgs.shape[0],3):
|
| 113 |
+
h = W
|
| 114 |
+
w = H
|
| 115 |
+
|
| 116 |
+
mask = np.ones((h, w), np.float32)
|
| 117 |
+
|
| 118 |
+
for n in range(self.n_holes):
|
| 119 |
+
|
| 120 |
+
y1 = np.clip(y - self.length // 2, 0, h)
|
| 121 |
+
y2 = np.clip(y + self.length // 2, 0, h)
|
| 122 |
+
x1 = np.clip(x - self.length // 2, 0, w)
|
| 123 |
+
x2 = np.clip(x + self.length // 2, 0, w)
|
| 124 |
+
|
| 125 |
+
mask[y1: y2, x1: x2] = 0.
|
| 126 |
+
|
| 127 |
+
mask = torch.from_numpy(mask)
|
| 128 |
+
mask = mask.expand_as(imgs[i:i+3])
|
| 129 |
+
img = imgs[i:i+3] * mask
|
| 130 |
+
new_imgs.append(img)
|
| 131 |
+
|
| 132 |
+
# import pdb;pdb.set_trace()
|
| 133 |
+
new_imgs = torch.stack(new_imgs,0).reshape(C,H,W)
|
| 134 |
+
# print(new_imgs.shape)
|
| 135 |
+
return new_imgs
|
| 136 |
+
|
| 137 |
+
class GroupScale(object):
|
| 138 |
+
""" Rescales the input PIL.Image to the given 'size'.
|
| 139 |
+
'size' will be the size of the smaller edge.
|
| 140 |
+
For example, if height > width, then image will be
|
| 141 |
+
rescaled to (size * height / width, size)
|
| 142 |
+
size: size of the smaller edge
|
| 143 |
+
interpolation: Default: PIL.Image.BILINEAR
|
| 144 |
+
"""
|
| 145 |
+
|
| 146 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
| 147 |
+
self.worker = torchvision.transforms.Resize(size, interpolation)
|
| 148 |
+
|
| 149 |
+
def __call__(self, img_group):
|
| 150 |
+
return [self.worker(img) for img in img_group]
|
| 151 |
+
|
| 152 |
+
class GroupRandomScale(object):
|
| 153 |
+
""" Rescales the input PIL.Image to the given 'size'.
|
| 154 |
+
'size' will be the size of the smaller edge.
|
| 155 |
+
For example, if height > width, then image will be
|
| 156 |
+
rescaled to (size * height / width, size)
|
| 157 |
+
size: size of the smaller edge
|
| 158 |
+
interpolation: Default: PIL.Image.BILINEAR
|
| 159 |
+
|
| 160 |
+
Randomly select the smaller edge from the range of 'size'.
|
| 161 |
+
"""
|
| 162 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
| 163 |
+
self.size = size
|
| 164 |
+
self.interpolation = interpolation
|
| 165 |
+
|
| 166 |
+
def __call__(self, img_group):
|
| 167 |
+
selected_size = np.random.randint(low=self.size[0], high=self.size[1] + 1, dtype=int)
|
| 168 |
+
scale = GroupScale(selected_size, interpolation=self.interpolation)
|
| 169 |
+
return scale(img_group)
|
| 170 |
+
|
| 171 |
+
class GroupOverSample(object):
|
| 172 |
+
def __init__(self, crop_size, scale_size=None, num_crops=5, flip=False):
|
| 173 |
+
self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size)
|
| 174 |
+
|
| 175 |
+
if scale_size is not None:
|
| 176 |
+
self.scale_worker = GroupScale(scale_size)
|
| 177 |
+
else:
|
| 178 |
+
self.scale_worker = None
|
| 179 |
+
|
| 180 |
+
if num_crops not in [1, 3, 5, 10]:
|
| 181 |
+
raise ValueError("num_crops should be in [1, 3, 5, 10] but ({})".format(num_crops))
|
| 182 |
+
self.num_crops = num_crops
|
| 183 |
+
|
| 184 |
+
self.flip = flip
|
| 185 |
+
|
| 186 |
+
def __call__(self, img_group):
|
| 187 |
+
|
| 188 |
+
if self.scale_worker is not None:
|
| 189 |
+
img_group = self.scale_worker(img_group)
|
| 190 |
+
|
| 191 |
+
image_w, image_h = img_group[0].size
|
| 192 |
+
crop_w, crop_h = self.crop_size
|
| 193 |
+
|
| 194 |
+
if self.num_crops == 3:
|
| 195 |
+
w_step = (image_w - crop_w) // 4
|
| 196 |
+
h_step = (image_h - crop_h) // 4
|
| 197 |
+
offsets = list()
|
| 198 |
+
if image_w != crop_w and image_h != crop_h:
|
| 199 |
+
offsets.append((0 * w_step, 0 * h_step)) # top
|
| 200 |
+
offsets.append((4 * w_step, 4 * h_step)) # bottom
|
| 201 |
+
offsets.append((2 * w_step, 2 * h_step)) # center
|
| 202 |
+
else:
|
| 203 |
+
if image_w < image_h:
|
| 204 |
+
offsets.append((2 * w_step, 0 * h_step)) # top
|
| 205 |
+
offsets.append((2 * w_step, 4 * h_step)) # bottom
|
| 206 |
+
offsets.append((2 * w_step, 2 * h_step)) # center
|
| 207 |
+
else:
|
| 208 |
+
offsets.append((0 * w_step, 2 * h_step)) # left
|
| 209 |
+
offsets.append((4 * w_step, 2 * h_step)) # right
|
| 210 |
+
offsets.append((2 * w_step, 2 * h_step)) # center
|
| 211 |
+
|
| 212 |
+
else:
|
| 213 |
+
offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h)
|
| 214 |
+
|
| 215 |
+
oversample_group = list()
|
| 216 |
+
for o_w, o_h in offsets:
|
| 217 |
+
normal_group = list()
|
| 218 |
+
flip_group = list()
|
| 219 |
+
for i, img in enumerate(img_group):
|
| 220 |
+
crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
|
| 221 |
+
normal_group.append(crop)
|
| 222 |
+
if self.flip:
|
| 223 |
+
flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
|
| 224 |
+
|
| 225 |
+
if img.mode == 'L' and i % 2 == 0:
|
| 226 |
+
flip_group.append(ImageOps.invert(flip_crop))
|
| 227 |
+
else:
|
| 228 |
+
flip_group.append(flip_crop)
|
| 229 |
+
|
| 230 |
+
oversample_group.extend(normal_group)
|
| 231 |
+
if self.flip:
|
| 232 |
+
oversample_group.extend(flip_group)
|
| 233 |
+
return oversample_group
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class GroupMultiScaleCrop(object):
|
| 237 |
+
|
| 238 |
+
def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True):
|
| 239 |
+
self.scales = scales if scales is not None else [1, 875, .75, .66]
|
| 240 |
+
self.max_distort = max_distort
|
| 241 |
+
self.fix_crop = fix_crop
|
| 242 |
+
self.more_fix_crop = more_fix_crop
|
| 243 |
+
self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size]
|
| 244 |
+
self.interpolation = Image.BILINEAR
|
| 245 |
+
|
| 246 |
+
def __call__(self, img_group):
|
| 247 |
+
|
| 248 |
+
im_size = img_group[0].size
|
| 249 |
+
|
| 250 |
+
crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
|
| 251 |
+
crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group]
|
| 252 |
+
ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
|
| 253 |
+
for img in crop_img_group]
|
| 254 |
+
return ret_img_group
|
| 255 |
+
|
| 256 |
+
def _sample_crop_size(self, im_size):
|
| 257 |
+
image_w, image_h = im_size[0], im_size[1]
|
| 258 |
+
|
| 259 |
+
# find a crop size
|
| 260 |
+
base_size = min(image_w, image_h)
|
| 261 |
+
crop_sizes = [int(base_size * x) for x in self.scales]
|
| 262 |
+
crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes]
|
| 263 |
+
crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes]
|
| 264 |
+
|
| 265 |
+
pairs = []
|
| 266 |
+
for i, h in enumerate(crop_h):
|
| 267 |
+
for j, w in enumerate(crop_w):
|
| 268 |
+
if abs(i - j) <= self.max_distort:
|
| 269 |
+
pairs.append((w, h))
|
| 270 |
+
|
| 271 |
+
crop_pair = random.choice(pairs)
|
| 272 |
+
if not self.fix_crop:
|
| 273 |
+
w_offset = random.randint(0, image_w - crop_pair[0])
|
| 274 |
+
h_offset = random.randint(0, image_h - crop_pair[1])
|
| 275 |
+
else:
|
| 276 |
+
w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1])
|
| 277 |
+
|
| 278 |
+
return crop_pair[0], crop_pair[1], w_offset, h_offset
|
| 279 |
+
|
| 280 |
+
def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
|
| 281 |
+
offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h)
|
| 282 |
+
return random.choice(offsets)
|
| 283 |
+
|
| 284 |
+
@staticmethod
|
| 285 |
+
def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
|
| 286 |
+
w_step = (image_w - crop_w) // 4
|
| 287 |
+
h_step = (image_h - crop_h) // 4
|
| 288 |
+
|
| 289 |
+
ret = list()
|
| 290 |
+
ret.append((0, 0)) # upper left
|
| 291 |
+
ret.append((4 * w_step, 0)) # upper right
|
| 292 |
+
ret.append((0, 4 * h_step)) # lower left
|
| 293 |
+
ret.append((4 * w_step, 4 * h_step)) # lower right
|
| 294 |
+
ret.append((2 * w_step, 2 * h_step)) # center
|
| 295 |
+
|
| 296 |
+
if more_fix_crop:
|
| 297 |
+
ret.append((0, 2 * h_step)) # center left
|
| 298 |
+
ret.append((4 * w_step, 2 * h_step)) # center right
|
| 299 |
+
ret.append((2 * w_step, 4 * h_step)) # lower center
|
| 300 |
+
ret.append((2 * w_step, 0 * h_step)) # upper center
|
| 301 |
+
|
| 302 |
+
ret.append((1 * w_step, 1 * h_step)) # upper left quarter
|
| 303 |
+
ret.append((3 * w_step, 1 * h_step)) # upper right quarter
|
| 304 |
+
ret.append((1 * w_step, 3 * h_step)) # lower left quarter
|
| 305 |
+
ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
|
| 306 |
+
|
| 307 |
+
return ret
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
class GroupRandomSizedCrop(object):
|
| 311 |
+
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
|
| 312 |
+
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
|
| 313 |
+
This is popularly used to train the Inception networks
|
| 314 |
+
size: size of the smaller edge
|
| 315 |
+
interpolation: Default: PIL.Image.BILINEAR
|
| 316 |
+
"""
|
| 317 |
+
def __init__(self, size, interpolation=Image.BILINEAR):
|
| 318 |
+
self.size = size
|
| 319 |
+
self.interpolation = interpolation
|
| 320 |
+
|
| 321 |
+
def __call__(self, img_group):
|
| 322 |
+
for attempt in range(10):
|
| 323 |
+
area = img_group[0].size[0] * img_group[0].size[1]
|
| 324 |
+
target_area = random.uniform(0.08, 1.0) * area
|
| 325 |
+
aspect_ratio = random.uniform(3. / 4, 4. / 3)
|
| 326 |
+
|
| 327 |
+
w = int(round(math.sqrt(target_area * aspect_ratio)))
|
| 328 |
+
h = int(round(math.sqrt(target_area / aspect_ratio)))
|
| 329 |
+
|
| 330 |
+
if random.random() < 0.5:
|
| 331 |
+
w, h = h, w
|
| 332 |
+
|
| 333 |
+
if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
|
| 334 |
+
x1 = random.randint(0, img_group[0].size[0] - w)
|
| 335 |
+
y1 = random.randint(0, img_group[0].size[1] - h)
|
| 336 |
+
found = True
|
| 337 |
+
break
|
| 338 |
+
else:
|
| 339 |
+
found = False
|
| 340 |
+
x1 = 0
|
| 341 |
+
y1 = 0
|
| 342 |
+
|
| 343 |
+
if found:
|
| 344 |
+
out_group = list()
|
| 345 |
+
for img in img_group:
|
| 346 |
+
img = img.crop((x1, y1, x1 + w, y1 + h))
|
| 347 |
+
assert(img.size == (w, h))
|
| 348 |
+
out_group.append(img.resize((self.size, self.size), self.interpolation))
|
| 349 |
+
return out_group
|
| 350 |
+
else:
|
| 351 |
+
# Fallback
|
| 352 |
+
scale = GroupScale(self.size, interpolation=self.interpolation)
|
| 353 |
+
crop = GroupRandomCrop(self.size)
|
| 354 |
+
return crop(scale(img_group))
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class Stack(object):
|
| 358 |
+
|
| 359 |
+
def __init__(self, roll=False, threed_data=False):
|
| 360 |
+
self.roll = roll
|
| 361 |
+
self.threed_data = threed_data
|
| 362 |
+
|
| 363 |
+
def __call__(self, img_group):
|
| 364 |
+
if img_group[0].mode == 'L':
|
| 365 |
+
return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2)
|
| 366 |
+
elif img_group[0].mode == 'RGB':
|
| 367 |
+
if self.threed_data:
|
| 368 |
+
return np.stack(img_group, axis=0)
|
| 369 |
+
else:
|
| 370 |
+
if self.roll:
|
| 371 |
+
return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
|
| 372 |
+
else:
|
| 373 |
+
return np.concatenate(img_group, axis=2)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class ToTorchFormatTensor(object):
|
| 377 |
+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
|
| 378 |
+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
|
| 379 |
+
def __init__(self, div=True, num_clips_crops=1):
|
| 380 |
+
self.div = div
|
| 381 |
+
self.num_clips_crops = num_clips_crops
|
| 382 |
+
|
| 383 |
+
def __call__(self, pic):
|
| 384 |
+
if isinstance(pic, np.ndarray):
|
| 385 |
+
# handle numpy array
|
| 386 |
+
if len(pic.shape) == 4:
|
| 387 |
+
# ((NF)xCxHxW) --> (Cx(NF)xHxW)
|
| 388 |
+
img = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous()
|
| 389 |
+
else: # data is HW(FC)
|
| 390 |
+
img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
|
| 391 |
+
else:
|
| 392 |
+
# handle PIL Image
|
| 393 |
+
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
|
| 394 |
+
img = img.view(pic.size[1], pic.size[0], len(pic.mode))
|
| 395 |
+
# put it from HWC to CHW format
|
| 396 |
+
# yikes, this transpose takes 80% of the loading time/CPU
|
| 397 |
+
img = img.transpose(0, 1).transpose(0, 2).contiguous()
|
| 398 |
+
return img.float().div(255) if self.div else img.float()
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class IdentityTransform(object):
|
| 402 |
+
|
| 403 |
+
def __call__(self, data):
|
| 404 |
+
return data
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
if __name__ == "__main__":
|
| 408 |
+
trans = torchvision.transforms.Compose([
|
| 409 |
+
GroupScale(256),
|
| 410 |
+
GroupRandomCrop(224),
|
| 411 |
+
GroupOverSample(224, 224, num_crops=3, flip=False),
|
| 412 |
+
Stack(),
|
| 413 |
+
ToTorchFormatTensor(num_clips_crops=9),
|
| 414 |
+
GroupNormalize(
|
| 415 |
+
mean=[.485, .456, .406],
|
| 416 |
+
std=[.229, .224, .225]
|
| 417 |
+
)]
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
|