Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- README.md +248 -0
- __pycache__/inference.cpython-313.pyc +0 -0
- configs/spai.yaml +54 -0
- inference.py +208 -0
- requirements.txt +37 -0
- spai/__init__.py +15 -0
- spai/__main__.py +208 -0
- spai/__pycache__/__init__.cpython-310.pyc +0 -0
- spai/__pycache__/__init__.cpython-313.pyc +0 -0
- spai/__pycache__/__main__.cpython-310.pyc +0 -0
- spai/__pycache__/__main__.cpython-313.pyc +0 -0
- spai/__pycache__/config.cpython-310.pyc +0 -0
- spai/__pycache__/config.cpython-313.pyc +0 -0
- spai/__pycache__/data_utils.cpython-310.pyc +0 -0
- spai/__pycache__/data_utils.cpython-313.pyc +0 -0
- spai/__pycache__/logger.cpython-310.pyc +0 -0
- spai/__pycache__/logger.cpython-313.pyc +0 -0
- spai/__pycache__/lr_scheduler.cpython-310.pyc +0 -0
- spai/__pycache__/lr_scheduler.cpython-313.pyc +0 -0
- spai/__pycache__/metrics.cpython-310.pyc +0 -0
- spai/__pycache__/metrics.cpython-313.pyc +0 -0
- spai/__pycache__/onnx.cpython-310.pyc +0 -0
- spai/__pycache__/onnx.cpython-313.pyc +0 -0
- spai/__pycache__/optimizer.cpython-310.pyc +0 -0
- spai/__pycache__/optimizer.cpython-313.pyc +0 -0
- spai/__pycache__/utils.cpython-310.pyc +0 -0
- spai/__pycache__/utils.cpython-313.pyc +0 -0
- spai/config.py +494 -0
- spai/data/__init__.py +26 -0
- spai/data/__pycache__/__init__.cpython-310.pyc +0 -0
- spai/data/__pycache__/__init__.cpython-313.pyc +0 -0
- spai/data/__pycache__/blur_kernels.cpython-310.pyc +0 -0
- spai/data/__pycache__/blur_kernels.cpython-313.pyc +0 -0
- spai/data/__pycache__/data_finetune.cpython-310.pyc +0 -0
- spai/data/__pycache__/data_finetune.cpython-313.pyc +0 -0
- spai/data/__pycache__/data_mfm.cpython-310.pyc +0 -0
- spai/data/__pycache__/data_mfm.cpython-313.pyc +0 -0
- spai/data/__pycache__/filestorage.cpython-310.pyc +0 -0
- spai/data/__pycache__/filestorage.cpython-313.pyc +0 -0
- spai/data/__pycache__/random_degradations.cpython-310.pyc +0 -0
- spai/data/__pycache__/random_degradations.cpython-313.pyc +0 -0
- spai/data/__pycache__/readers.cpython-310.pyc +0 -0
- spai/data/__pycache__/readers.cpython-313.pyc +0 -0
- spai/data/blur_kernels.py +539 -0
- spai/data/data_finetune.py +723 -0
- spai/data/data_mfm.py +131 -0
- spai/data/filestorage.py +387 -0
- spai/data/random_degradations.py +462 -0
- spai/data/readers.py +178 -0
- spai/data_utils.py +50 -0
README.md
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPAI: Spectral AI-Generated Image Detector
|
| 2 |
+
|
| 3 |
+
Official repository for the CVPR 2025 paper:
|
| 4 |
+
Any-Resolution AI-Generated Image Detection by Spectral Learning.
|
| 5 |
+
|
| 6 |
+
SPAI learns the spectral distribution of real images and detects AI-generated
|
| 7 |
+
images as out-of-distribution samples using spectral reconstruction similarity.
|
| 8 |
+
|
| 9 |
+
## Repository Status
|
| 10 |
+
|
| 11 |
+
This repository currently contains:
|
| 12 |
+
|
| 13 |
+
- Core SPAI package in `spai/`.
|
| 14 |
+
- Main config in `configs/spai.yaml`.
|
| 15 |
+
- A trained checkpoint in `spai/weights/spai.pth`.
|
| 16 |
+
- Unit tests in `tests/`.
|
| 17 |
+
- Utility scripts for data prep, crawling, Fourier analysis and reporting in `tools/` and `spai/tools/`.
|
| 18 |
+
- Hugging Face inference handler in `inference.py`.
|
| 19 |
+
|
| 20 |
+
## Project Structure
|
| 21 |
+
|
| 22 |
+
```text
|
| 23 |
+
.
|
| 24 |
+
├── configs/
|
| 25 |
+
│ └── spai.yaml
|
| 26 |
+
├── spai/
|
| 27 |
+
│ ├── data/ # datasets, readers, augmentations, filestorage (LMDB)
|
| 28 |
+
│ ├── models/ # backbones, SID, MFM, losses, filters
|
| 29 |
+
│ ├── tools/ # CSV generation and dataset utilities
|
| 30 |
+
│ ├── weights/
|
| 31 |
+
│ │ └── spai.pth # included checkpoint
|
| 32 |
+
│ ├── config.py # yacs configuration
|
| 33 |
+
│ ├── hf_utils.py # Hugging Face Hub upload/model card helpers
|
| 34 |
+
│ ├── main_mfm.py # MFM pretraining entrypoint
|
| 35 |
+
│ └── ...
|
| 36 |
+
├── tests/
|
| 37 |
+
│ ├── data/
|
| 38 |
+
│ └── models/
|
| 39 |
+
├── tools/ # analysis, crawling, preprocessing, HF execution logs
|
| 40 |
+
├── inference.py # HF EndpointHandler + local single-image inference
|
| 41 |
+
└── requirements.txt
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Installation
|
| 45 |
+
|
| 46 |
+
Recommended environment:
|
| 47 |
+
|
| 48 |
+
```bash
|
| 49 |
+
conda create -n spai python=3.11
|
| 50 |
+
conda activate spai
|
| 51 |
+
conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
|
| 52 |
+
pip install -r requirements.txt
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
Notes:
|
| 56 |
+
|
| 57 |
+
- Training code may require NVIDIA APEX.
|
| 58 |
+
- `requirements.txt` includes packages for training, inference, ONNX, crawling and Hugging Face utilities.
|
| 59 |
+
|
| 60 |
+
## Configuration and Weights
|
| 61 |
+
|
| 62 |
+
- Main config: `configs/spai.yaml`
|
| 63 |
+
- Default included checkpoint: `spai/weights/spai.pth`
|
| 64 |
+
|
| 65 |
+
The included config is set for SID finetuning/inference with arbitrary-resolution processing.
|
| 66 |
+
|
| 67 |
+
## Inference
|
| 68 |
+
|
| 69 |
+
### 1) Hugging Face Endpoint Handler (recommended)
|
| 70 |
+
|
| 71 |
+
File: `inference.py`
|
| 72 |
+
|
| 73 |
+
The `EndpointHandler` supports these input formats:
|
| 74 |
+
|
| 75 |
+
- image URL (`http://...` / `https://...`)
|
| 76 |
+
- local path
|
| 77 |
+
- base64 string
|
| 78 |
+
- raw bytes
|
| 79 |
+
- PIL image
|
| 80 |
+
- dict with one of keys: `url`, `path`, `b64`, `bytes`
|
| 81 |
+
|
| 82 |
+
Output format:
|
| 83 |
+
|
| 84 |
+
```json
|
| 85 |
+
{
|
| 86 |
+
"score": 0.8732,
|
| 87 |
+
"predicted_label": 1,
|
| 88 |
+
"predicted_label_name": "ai-generated",
|
| 89 |
+
"threshold": 0.5
|
| 90 |
+
}
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
Label convention used in repository tooling:
|
| 94 |
+
|
| 95 |
+
- `0` -> real
|
| 96 |
+
- `1` -> ai-generated
|
| 97 |
+
|
| 98 |
+
Run locally:
|
| 99 |
+
|
| 100 |
+
```bash
|
| 101 |
+
python inference.py --image "/path/to/image.jpg" --model-dir .
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
Environment overrides:
|
| 105 |
+
|
| 106 |
+
- `SPAI_THRESHOLD` (default `0.5`)
|
| 107 |
+
- `SPAI_CONFIG` (custom config path)
|
| 108 |
+
- `SPAI_CHECKPOINT` (custom checkpoint path)
|
| 109 |
+
- `SPAI_FORCE_CPU=1` (force CPU)
|
| 110 |
+
|
| 111 |
+
### 2) Python usage
|
| 112 |
+
|
| 113 |
+
```python
|
| 114 |
+
from inference import EndpointHandler
|
| 115 |
+
|
| 116 |
+
handler = EndpointHandler(path=".")
|
| 117 |
+
result = handler({"inputs": "https://example.com/image.jpg"})
|
| 118 |
+
print(result)
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
## Training
|
| 122 |
+
|
| 123 |
+
### MFM pretraining entrypoint
|
| 124 |
+
|
| 125 |
+
Use:
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
python spai/main_mfm.py --cfg configs/spai.yaml --data-path /path/to/data.csv --output output/mfm
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
`spai/main_mfm.py` also supports optional Hugging Face push flags:
|
| 132 |
+
|
| 133 |
+
- `--push-to-hub`
|
| 134 |
+
- `--hub-repo-id`
|
| 135 |
+
- `--hub-token`
|
| 136 |
+
- `--hub-create-model-card`
|
| 137 |
+
|
| 138 |
+
## Dataset CSV Format
|
| 139 |
+
|
| 140 |
+
Core dataset readers in `spai/data/data_finetune.py` expect CSVs with at least:
|
| 141 |
+
|
| 142 |
+
- `image`: image path
|
| 143 |
+
- `class`: class id
|
| 144 |
+
- `split`: one of `train`, `val`, `test`
|
| 145 |
+
|
| 146 |
+
Paths are resolved relatively to a configurable CSV root directory.
|
| 147 |
+
|
| 148 |
+
## LMDB Dataset File Storage
|
| 149 |
+
|
| 150 |
+
Module: `spai/data/filestorage.py`
|
| 151 |
+
|
| 152 |
+
Available commands:
|
| 153 |
+
|
| 154 |
+
```bash
|
| 155 |
+
python spai/data/filestorage.py add-csv --help
|
| 156 |
+
python spai/data/filestorage.py add-db --help
|
| 157 |
+
python spai/data/filestorage.py verify-csv --help
|
| 158 |
+
python spai/data/filestorage.py list-db --help
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
Use this workflow when you want to package many files into LMDB for faster or centralized IO.
|
| 162 |
+
|
| 163 |
+
## Utility Scripts
|
| 164 |
+
|
| 165 |
+
### Repository-level tools (`tools/`)
|
| 166 |
+
|
| 167 |
+
- `tools/simple_crawler.py`: crawl and download images with metadata.
|
| 168 |
+
- `tools/web_image_crawler.py`: crawl URLs/CSVs, download images, filter ad-like images.
|
| 169 |
+
- `tools/image_quality_processor.py`: quality filtering, deduplication and reports.
|
| 170 |
+
- `tools/preprocess_for_spai.py`: image preprocessing before SPAI.
|
| 171 |
+
- `tools/create_spai_metadata.py`: build metadata CSV from an image folder.
|
| 172 |
+
- `tools/extract_fourier_features.py`: compute Fourier-derived features.
|
| 173 |
+
- `tools/visualize_fourier.py`: Fourier spectrum visualizations.
|
| 174 |
+
- `tools/visualize_noise_decomposition.py`: advanced noise decomposition visualizations.
|
| 175 |
+
- `tools/analyze_spai_results.py`: plots/analysis for prediction results.
|
| 176 |
+
- `tools/analyze_normalization_impact.py`: study resize normalization impact.
|
| 177 |
+
- `tools/hf_log_execution.py`: generate execution artifacts and optionally upload to HF datasets.
|
| 178 |
+
|
| 179 |
+
Example:
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
python tools/hf_log_execution.py --results-csv output/preds.csv --output-dir output/hf_artifacts
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
### Package tools (`spai/tools/`)
|
| 186 |
+
|
| 187 |
+
- `spai.tools.create_dir_csv`: create train/val/test CSV from directories.
|
| 188 |
+
- `spai.tools.create_dmid_ldm_train_val_csv`: create DMID/LDM training CSV.
|
| 189 |
+
- `spai.tools.augment_dataset`: augment a dataset and export updated CSV.
|
| 190 |
+
- `spai.tools.reduce_csv_column`: conditional column reduction/aggregation.
|
| 191 |
+
- `spai/tools/create_synthbuster_csv.py`: Synthbuster CSV generation utility.
|
| 192 |
+
|
| 193 |
+
Examples:
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
python -m spai.tools.create_dir_csv --help
|
| 197 |
+
python -m spai.tools.create_dmid_ldm_train_val_csv --help
|
| 198 |
+
python -m spai.tools.augment_dataset --help
|
| 199 |
+
python -m spai.tools.reduce_csv_column --help
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
For `create_synthbuster_csv.py`, use a `PYTHONPATH` that includes `spai/` due its import style:
|
| 203 |
+
|
| 204 |
+
```bash
|
| 205 |
+
PYTHONPATH=spai python spai/tools/create_synthbuster_csv.py --help
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
## Tests
|
| 209 |
+
|
| 210 |
+
Run all tests:
|
| 211 |
+
|
| 212 |
+
```bash
|
| 213 |
+
pytest tests -q
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
Current test folders:
|
| 217 |
+
|
| 218 |
+
- `tests/data/`
|
| 219 |
+
- `tests/models/`
|
| 220 |
+
|
| 221 |
+
## Acknowledgments
|
| 222 |
+
|
| 223 |
+
This work was partly supported by Horizon Europe projects ELIAS and vera.ai,
|
| 224 |
+
and computational resources from GRNET.
|
| 225 |
+
|
| 226 |
+
Parts of the implementation build upon ideas/code from:
|
| 227 |
+
https://github.com/Jiahao000/MFM
|
| 228 |
+
|
| 229 |
+
## License
|
| 230 |
+
|
| 231 |
+
Source code is licensed under Apache 2.0.
|
| 232 |
+
Third-party datasets and dependencies keep their own licenses.
|
| 233 |
+
|
| 234 |
+
## Contact
|
| 235 |
+
|
| 236 |
+
For questions: d.karageorgiou@uva.nl
|
| 237 |
+
|
| 238 |
+
## Citation
|
| 239 |
+
|
| 240 |
+
```text
|
| 241 |
+
@inproceedings{karageorgiou2025any,
|
| 242 |
+
title={Any-resolution ai-generated image detection by spectral learning},
|
| 243 |
+
author={Karageorgiou, Dimitrios and Papadopoulos, Symeon and Kompatsiaris, Ioannis and Gavves, Efstratios},
|
| 244 |
+
booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
|
| 245 |
+
pages={18706--18717},
|
| 246 |
+
year={2025}
|
| 247 |
+
}
|
| 248 |
+
```
|
__pycache__/inference.cpython-313.pyc
ADDED
|
Binary file (12.7 kB). View file
|
|
|
configs/spai.yaml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MODEL:
|
| 2 |
+
SID_APPROACH: "freq_restoration"
|
| 3 |
+
TYPE: vit
|
| 4 |
+
NAME: finetune
|
| 5 |
+
DROP_PATH_RATE: 0.1
|
| 6 |
+
NUM_CLASSES: 2
|
| 7 |
+
REQUIRED_NORMALIZATION: "positive_0_1"
|
| 8 |
+
RESOLUTION_MODE: "arbitrary"
|
| 9 |
+
FEATURE_EXTRACTION_BATCH: 400
|
| 10 |
+
VIT:
|
| 11 |
+
EMBED_DIM: 768
|
| 12 |
+
DEPTH: 12
|
| 13 |
+
NUM_HEADS: 12
|
| 14 |
+
INIT_VALUES: None
|
| 15 |
+
USE_APE: True
|
| 16 |
+
USE_RPB: False
|
| 17 |
+
USE_SHARED_RPB: False
|
| 18 |
+
USE_MEAN_POOLING: True
|
| 19 |
+
USE_INTERMEDIATE_LAYERS: True
|
| 20 |
+
PROJECTION_DIM: 1024
|
| 21 |
+
PROJECTION_LAYERS: 2
|
| 22 |
+
PATCH_PROJECTION: True
|
| 23 |
+
PATCH_PROJECTION_PER_FEATURE: True
|
| 24 |
+
INTERMEDIATE_LAYERS: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
| 25 |
+
FRE:
|
| 26 |
+
MASKING_RADIUS: 16
|
| 27 |
+
PROJECTOR_LAST_LAYER_ACTIVATION_TYPE: None
|
| 28 |
+
ORIGINAL_IMAGE_FEATURES_BRANCH: True
|
| 29 |
+
CLS_HEAD:
|
| 30 |
+
MLP_RATIO: 3
|
| 31 |
+
PATCH_VIT:
|
| 32 |
+
MINIMUM_PATCHES: 4
|
| 33 |
+
DATA:
|
| 34 |
+
DATASET: csv_sid
|
| 35 |
+
IMG_SIZE: 224
|
| 36 |
+
NUM_WORKERS: 8
|
| 37 |
+
AUGMENTED_VIEWS: 4
|
| 38 |
+
TEST_PREFETCH_FACTOR: 1
|
| 39 |
+
AUG:
|
| 40 |
+
COLOR_JITTER: 0.
|
| 41 |
+
TRAIN:
|
| 42 |
+
EPOCHS: 35
|
| 43 |
+
WARMUP_EPOCHS: 5
|
| 44 |
+
BASE_LR: 5e-4
|
| 45 |
+
WARMUP_LR: 2.5e-7
|
| 46 |
+
MIN_LR: 2.5e-7
|
| 47 |
+
WEIGHT_DECAY: 0.05
|
| 48 |
+
LAYER_DECAY: 0.8
|
| 49 |
+
CLIP_GRAD: None
|
| 50 |
+
LOSS: "bce"
|
| 51 |
+
TEST:
|
| 52 |
+
ORIGINAL_RESOLUTION: True
|
| 53 |
+
PRINT_FREQ: 100
|
| 54 |
+
SAVE_FREQ: 10
|
inference.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import base64
|
| 5 |
+
import io
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import requests
|
| 12 |
+
import torch
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from spai.config import get_custom_config
|
| 16 |
+
from spai.data.data_finetune import build_transform
|
| 17 |
+
from spai.models import build_cls_model
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class EndpointHandler:
|
| 21 |
+
"""Hugging Face Inference Endpoint handler for SPAI."""
|
| 22 |
+
|
| 23 |
+
def __init__(self, path: str = "") -> None:
|
| 24 |
+
self.model_dir = Path(path) if path else Path(".")
|
| 25 |
+
self.threshold = float(os.getenv("SPAI_THRESHOLD", "0.5"))
|
| 26 |
+
|
| 27 |
+
cfg_path = self._resolve_config_path()
|
| 28 |
+
self.config = get_custom_config(str(cfg_path))
|
| 29 |
+
|
| 30 |
+
self.device = self._resolve_device()
|
| 31 |
+
self.model = build_cls_model(self.config)
|
| 32 |
+
checkpoint_path = self._resolve_checkpoint_path()
|
| 33 |
+
state_dict = self._load_state_dict(checkpoint_path)
|
| 34 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 35 |
+
self.model.to(self.device)
|
| 36 |
+
self.model.eval()
|
| 37 |
+
|
| 38 |
+
self.transform = build_transform(is_train=False, config=self.config)
|
| 39 |
+
|
| 40 |
+
def __call__(self, data: dict[str, Any]) -> dict[str, Any] | list[dict[str, Any]]:
|
| 41 |
+
inputs = data.get("inputs", data.get("image", data))
|
| 42 |
+
|
| 43 |
+
if isinstance(inputs, list):
|
| 44 |
+
return [self._predict_one(item) for item in inputs]
|
| 45 |
+
return self._predict_one(inputs)
|
| 46 |
+
|
| 47 |
+
def _predict_one(self, raw_input: Any) -> dict[str, Any]:
|
| 48 |
+
image = self._load_image(raw_input)
|
| 49 |
+
image_np = np.array(image)
|
| 50 |
+
image_tensor = self.transform(image=image_np)["image"]
|
| 51 |
+
|
| 52 |
+
if self.config.MODEL.RESOLUTION_MODE == "arbitrary":
|
| 53 |
+
model_input = [image_tensor.unsqueeze(0).to(self.device)]
|
| 54 |
+
feature_batch_size = self.config.MODEL.FEATURE_EXTRACTION_BATCH
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
logits = self.model(model_input, feature_batch_size)
|
| 57 |
+
else:
|
| 58 |
+
model_input = image_tensor.unsqueeze(0).to(self.device)
|
| 59 |
+
with torch.no_grad():
|
| 60 |
+
logits = self.model(model_input)
|
| 61 |
+
|
| 62 |
+
score = float(torch.sigmoid(logits).flatten()[0].item())
|
| 63 |
+
predicted_label = int(score >= self.threshold)
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
"score": score,
|
| 67 |
+
"predicted_label": predicted_label,
|
| 68 |
+
"predicted_label_name": "ai-generated" if predicted_label == 1 else "real",
|
| 69 |
+
"threshold": self.threshold,
|
| 70 |
+
}
|
| 71 |
+
|
| 72 |
+
def _resolve_config_path(self) -> Path:
|
| 73 |
+
env_cfg = os.getenv("SPAI_CONFIG")
|
| 74 |
+
if env_cfg:
|
| 75 |
+
cfg_path = Path(env_cfg)
|
| 76 |
+
if cfg_path.exists():
|
| 77 |
+
return cfg_path
|
| 78 |
+
raise FileNotFoundError(f"SPAI_CONFIG points to a missing file: {cfg_path}")
|
| 79 |
+
|
| 80 |
+
candidates = [
|
| 81 |
+
self.model_dir / "configs" / "spai.yaml",
|
| 82 |
+
self.model_dir / "spai.yaml",
|
| 83 |
+
self.model_dir / "config.yaml",
|
| 84 |
+
]
|
| 85 |
+
for candidate in candidates:
|
| 86 |
+
if candidate.exists():
|
| 87 |
+
return candidate
|
| 88 |
+
|
| 89 |
+
raise FileNotFoundError(
|
| 90 |
+
"Could not locate model config. Expected one of: "
|
| 91 |
+
"configs/spai.yaml, spai.yaml, config.yaml, or SPAI_CONFIG env var."
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def _resolve_checkpoint_path(self) -> Path:
|
| 95 |
+
env_ckpt = os.getenv("SPAI_CHECKPOINT")
|
| 96 |
+
if env_ckpt:
|
| 97 |
+
ckpt_path = Path(env_ckpt)
|
| 98 |
+
if ckpt_path.exists():
|
| 99 |
+
return ckpt_path
|
| 100 |
+
raise FileNotFoundError(f"SPAI_CHECKPOINT points to a missing file: {ckpt_path}")
|
| 101 |
+
|
| 102 |
+
candidates = [
|
| 103 |
+
self.model_dir / "spai.pth",
|
| 104 |
+
self.model_dir / "pytorch_model.bin",
|
| 105 |
+
self.model_dir / "weights" / "spai.pth",
|
| 106 |
+
self.model_dir / "spai" / "weights" / "spai.pth",
|
| 107 |
+
]
|
| 108 |
+
for candidate in candidates:
|
| 109 |
+
if candidate.exists():
|
| 110 |
+
return candidate
|
| 111 |
+
|
| 112 |
+
pth_files = sorted(self.model_dir.glob("*.pth"))
|
| 113 |
+
if pth_files:
|
| 114 |
+
return pth_files[0]
|
| 115 |
+
|
| 116 |
+
raise FileNotFoundError(
|
| 117 |
+
"Could not locate model checkpoint. Expected one of: "
|
| 118 |
+
"spai.pth, pytorch_model.bin, weights/spai.pth, spai/weights/spai.pth, "
|
| 119 |
+
"or SPAI_CHECKPOINT env var."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def _resolve_device() -> torch.device:
|
| 124 |
+
force_cpu = os.getenv("SPAI_FORCE_CPU", "0") == "1"
|
| 125 |
+
if (not force_cpu) and torch.cuda.is_available():
|
| 126 |
+
return torch.device("cuda")
|
| 127 |
+
return torch.device("cpu")
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def _load_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]:
|
| 131 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 132 |
+
if isinstance(checkpoint, dict) and "model" in checkpoint and isinstance(checkpoint["model"], dict):
|
| 133 |
+
return checkpoint["model"]
|
| 134 |
+
|
| 135 |
+
if isinstance(checkpoint, dict):
|
| 136 |
+
tensor_values = all(isinstance(v, torch.Tensor) for v in checkpoint.values())
|
| 137 |
+
if tensor_values:
|
| 138 |
+
return checkpoint
|
| 139 |
+
|
| 140 |
+
raise RuntimeError(
|
| 141 |
+
"Unsupported checkpoint format. Expected a dict with key 'model' or a raw state_dict."
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
def _load_image(self, raw_input: Any) -> Image.Image:
|
| 145 |
+
if isinstance(raw_input, Image.Image):
|
| 146 |
+
return raw_input.convert("RGB")
|
| 147 |
+
|
| 148 |
+
if isinstance(raw_input, bytes):
|
| 149 |
+
return Image.open(io.BytesIO(raw_input)).convert("RGB")
|
| 150 |
+
|
| 151 |
+
if isinstance(raw_input, dict):
|
| 152 |
+
if "bytes" in raw_input:
|
| 153 |
+
raw_bytes = raw_input["bytes"]
|
| 154 |
+
if isinstance(raw_bytes, str):
|
| 155 |
+
raw_bytes = base64.b64decode(raw_bytes)
|
| 156 |
+
return Image.open(io.BytesIO(raw_bytes)).convert("RGB")
|
| 157 |
+
if "b64" in raw_input:
|
| 158 |
+
return Image.open(io.BytesIO(base64.b64decode(raw_input["b64"]))).convert("RGB")
|
| 159 |
+
if "url" in raw_input:
|
| 160 |
+
return self._load_image_from_url(raw_input["url"])
|
| 161 |
+
if "path" in raw_input:
|
| 162 |
+
return Image.open(Path(raw_input["path"])).convert("RGB")
|
| 163 |
+
|
| 164 |
+
if isinstance(raw_input, str):
|
| 165 |
+
if raw_input.startswith("http://") or raw_input.startswith("https://"):
|
| 166 |
+
return self._load_image_from_url(raw_input)
|
| 167 |
+
|
| 168 |
+
if raw_input.startswith("data:image") and "," in raw_input:
|
| 169 |
+
_, encoded = raw_input.split(",", 1)
|
| 170 |
+
return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB")
|
| 171 |
+
|
| 172 |
+
maybe_path = Path(raw_input)
|
| 173 |
+
if maybe_path.exists():
|
| 174 |
+
return Image.open(maybe_path).convert("RGB")
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
decoded = base64.b64decode(raw_input, validate=True)
|
| 178 |
+
return Image.open(io.BytesIO(decoded)).convert("RGB")
|
| 179 |
+
except Exception as exc:
|
| 180 |
+
raise ValueError(
|
| 181 |
+
"String input is neither a valid URL, file path, nor base64 image payload."
|
| 182 |
+
) from exc
|
| 183 |
+
|
| 184 |
+
raise TypeError(
|
| 185 |
+
"Unsupported input type. Use a URL/path/base64 string, bytes, PIL.Image, "
|
| 186 |
+
"or dict with one of keys: bytes, b64, url, path."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
@staticmethod
|
| 190 |
+
def _load_image_from_url(url: str) -> Image.Image:
|
| 191 |
+
response = requests.get(url, timeout=15)
|
| 192 |
+
response.raise_for_status()
|
| 193 |
+
return Image.open(io.BytesIO(response.content)).convert("RGB")
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def _main() -> None:
|
| 197 |
+
parser = argparse.ArgumentParser(description="Run SPAI inference for a single image.")
|
| 198 |
+
parser.add_argument("--image", type=str, required=True, help="Image path/URL/base64 input")
|
| 199 |
+
parser.add_argument("--model-dir", type=str, default=".", help="Directory with config/checkpoint")
|
| 200 |
+
args = parser.parse_args()
|
| 201 |
+
|
| 202 |
+
handler = EndpointHandler(path=args.model_dir)
|
| 203 |
+
result = handler({"inputs": args.image})
|
| 204 |
+
print(result)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
if __name__ == "__main__":
|
| 208 |
+
_main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Install through conda
|
| 2 |
+
# pytorch
|
| 3 |
+
# torchvision~=0.18.1
|
| 4 |
+
# tsnecuda
|
| 5 |
+
# Compile from sources
|
| 6 |
+
# apex
|
| 7 |
+
# Install through pip
|
| 8 |
+
opencv-python~=4.10.0.84
|
| 9 |
+
pyyaml~=6.0.1
|
| 10 |
+
scipy~=1.14.0
|
| 11 |
+
tensorboard
|
| 12 |
+
termcolor~=2.4.0
|
| 13 |
+
timm==0.4.12
|
| 14 |
+
yacs~=0.1.8
|
| 15 |
+
numpy~=1.26.4
|
| 16 |
+
torchmetrics~=1.4.0.post0
|
| 17 |
+
tqdm~=4.66.4
|
| 18 |
+
pillow~=10.4.0
|
| 19 |
+
PyYAML
|
| 20 |
+
click~=8.1.7
|
| 21 |
+
neptune~=1.11.1
|
| 22 |
+
albumentations==1.4.14
|
| 23 |
+
albucore==0.0.16
|
| 24 |
+
lmdb~=1.5.1
|
| 25 |
+
networkx~=3.3
|
| 26 |
+
seaborn~=0.13.2
|
| 27 |
+
pandas~=2.2.2
|
| 28 |
+
neptune
|
| 29 |
+
einops~=0.8.0
|
| 30 |
+
git+https://github.com/openai/CLIP.git
|
| 31 |
+
onnx
|
| 32 |
+
onnxscript
|
| 33 |
+
huggingface_hub~=0.21.0
|
| 34 |
+
datasets~=2.19.0
|
| 35 |
+
requests~=2.32.3
|
| 36 |
+
beautifulsoup4~=4.12.3
|
| 37 |
+
imagehash~=4.3.1
|
spai/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
|
| 2 |
+
# and University of Amsterdam. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
spai/__main__.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
|
| 2 |
+
# and University of Amsterdam. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
import pathlib
|
| 20 |
+
import time
|
| 21 |
+
import datetime
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Optional
|
| 24 |
+
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
import neptune
|
| 28 |
+
import cv2
|
| 29 |
+
import click
|
| 30 |
+
import torch
|
| 31 |
+
import torch.backends.cudnn as cudnn
|
| 32 |
+
import torch.utils.data
|
| 33 |
+
import yacs
|
| 34 |
+
import filetype
|
| 35 |
+
from torch import nn
|
| 36 |
+
from torch.nn import TripletMarginLoss
|
| 37 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 38 |
+
from timm.utils import AverageMeter
|
| 39 |
+
from yacs.config import CfgNode
|
| 40 |
+
|
| 41 |
+
import spai.data.data_finetune
|
| 42 |
+
from spai.config import get_config
|
| 43 |
+
from spai.models import build_cls_model
|
| 44 |
+
from spai.data import build_loader, build_loader_test
|
| 45 |
+
from spai.lr_scheduler import build_scheduler
|
| 46 |
+
from spai.models.sid import AttentionMask
|
| 47 |
+
from spai.onnx import compare_pytorch_onnx_models
|
| 48 |
+
from spai.optimizer import build_optimizer
|
| 49 |
+
from spai.logger import create_logger
|
| 50 |
+
from spai.utils import (
|
| 51 |
+
load_pretrained,
|
| 52 |
+
save_checkpoint,
|
| 53 |
+
get_grad_norm,
|
| 54 |
+
find_pretrained_checkpoints,
|
| 55 |
+
inf_nan_to_num
|
| 56 |
+
)
|
| 57 |
+
from spai.models import losses
|
| 58 |
+
from spai import metrics
|
| 59 |
+
from spai import data_utils
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def _cuda_enabled() -> bool:
|
| 63 |
+
# Allow forcing CPU mode and avoid probing CUDA on incompatible drivers.
|
| 64 |
+
if os.environ.get("SPAI_FORCE_CPU", "0") == "1":
|
| 65 |
+
return False
|
| 66 |
+
if os.environ.get("CUDA_VISIBLE_DEVICES", "") == "":
|
| 67 |
+
return False
|
| 68 |
+
try:
|
| 69 |
+
return torch.cuda.is_available()
|
| 70 |
+
except Exception:
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
try:
|
| 74 |
+
# noinspection PyUnresolvedReferences
|
| 75 |
+
from apex import amp
|
| 76 |
+
except ImportError:
|
| 77 |
+
amp = None
|
| 78 |
+
|
| 79 |
+
cv2.setNumThreads(1)
|
| 80 |
+
logger: Optional[logging.Logger] = None
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@click.group()
|
| 84 |
+
def cli() -> None:
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@cli.command()
|
| 89 |
+
@click.option("--cfg", required=True,
|
| 90 |
+
type=click.Path(exists=True, dir_okay=False, path_type=Path))
|
| 91 |
+
@click.option("--batch-size", type=int,
|
| 92 |
+
help="Batch size for a single GPU.")
|
| 93 |
+
@click.option("--learning-rate", type=float)
|
| 94 |
+
@click.option("--data-path", required=True,
|
| 95 |
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
| 96 |
+
help="path to dataset")
|
| 97 |
+
@click.option("--csv-root-dir",
|
| 98 |
+
type=click.Path(exists=True, file_okay=False, path_type=Path))
|
| 99 |
+
@click.option("--lmdb", "lmdb_path",
|
| 100 |
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
| 101 |
+
help="Path to an LMDB file storage that contains the files defined in the "
|
| 102 |
+
"dataset's CSV file. If this option is not provided, the data will be "
|
| 103 |
+
"loaded from the filesystem.")
|
| 104 |
+
@click.option("--pretrained",
|
| 105 |
+
type=click.Path(exists=True, dir_okay=False),
|
| 106 |
+
help="path to pre-trained model")
|
| 107 |
+
@click.option("--resume", is_flag=True,
|
| 108 |
+
help="resume from checkpoint")
|
| 109 |
+
@click.option("--accumulation-steps", type=int, default=1,
|
| 110 |
+
help="Gradient accumulation steps.")
|
| 111 |
+
@click.option("--use-checkpoint", is_flag=True,
|
| 112 |
+
help="Whether to use gradient checkpointing to save memory.")
|
| 113 |
+
@click.option("--amp-opt-level", type=click.Choice(["O0", "O1", "O2"]), default="O1",
|
| 114 |
+
help="mixed precision opt level, if O0, no amp is used")
|
| 115 |
+
@click.option("--output", type=click.Path(file_okay=False, path_type=Path),
|
| 116 |
+
help="root of output folder, the full path is "
|
| 117 |
+
"<output>/<model_name>/<tag> (default: output)")
|
| 118 |
+
@click.option("--tag", type=str,
|
| 119 |
+
help="tag of experiment")
|
| 120 |
+
@click.option("--local_rank", type=int, default=0,
|
| 121 |
+
help="local_rank for distributed training")
|
| 122 |
+
@click.option("--test-csv", multiple=True,
|
| 123 |
+
type=click.Path(exists=True, dir_okay=False, path_type=Path),
|
| 124 |
+
help="Path to a CSV with test data. If this option is provided after the "
|
| 125 |
+
"validation of each epoch, a testing will also take place. This option "
|
| 126 |
+
"intends to facilitate understanding the progression of the generalization "
|
| 127 |
+
"ability of a model among the epochs and should not be used for selecting "
|
| 128 |
+
"the final model. This option can be repeated several times. For each provided "
|
| 129 |
+
"csv file, a separate testing run is going to take place.")
|
| 130 |
+
@click.option("--test-csv-root-dir", multiple=True,
|
| 131 |
+
type=click.Path(exists=True, file_okay=False, path_type=Path),
|
| 132 |
+
help="Root directory for the relative paths included into the test csv files. "
|
| 133 |
+
"If this option is omitted, the parent directory of each test csv file will "
|
| 134 |
+
"be used as the root dir for the paths it contains. If this option is provided "
|
| 135 |
+
"a single time, it will be used as the root dir for all the test csv files. If "
|
| 136 |
+
"it is provided multiple times, each value will be matched with a corresponding "
|
| 137 |
+
"test csv file. In that case, the number of provided test csv files and the "
|
| 138 |
+
"number of provided root directories should match. The order of the provided "
|
| 139 |
+
"arguments will be used for the matching.")
|
| 140 |
+
@click.option("--data-workers", type=int,
|
| 141 |
+
help="Number of worker processes to be used for data loading.")
|
| 142 |
+
@click.option("--disable-pin-memory", is_flag=True)
|
| 143 |
+
@click.option("--data-prefetch-factor", type=int)
|
| 144 |
+
@click.option("--save-all", is_flag=True)
|
| 145 |
+
@click.option("--opt", "extra_options", type=(str, str), multiple=True)
|
| 146 |
+
def train(
|
| 147 |
+
cfg: Path,
|
| 148 |
+
batch_size: Optional[int],
|
| 149 |
+
learning_rate: Optional[float],
|
| 150 |
+
data_path: Path,
|
| 151 |
+
csv_root_dir: Optional[Path],
|
| 152 |
+
lmdb_path: Optional[Path],
|
| 153 |
+
pretrained: Optional[Path],
|
| 154 |
+
resume: bool,
|
| 155 |
+
accumulation_steps: int,
|
| 156 |
+
use_checkpoint: bool,
|
| 157 |
+
amp_opt_level: str,
|
| 158 |
+
output: Path,
|
| 159 |
+
tag: str,
|
| 160 |
+
local_rank: int,
|
| 161 |
+
test_csv: list[Path],
|
| 162 |
+
test_csv_root_dir: list[Path],
|
| 163 |
+
data_workers: Optional[int],
|
| 164 |
+
disable_pin_memory: bool,
|
| 165 |
+
data_prefetch_factor: Optional[int],
|
| 166 |
+
save_all: bool,
|
| 167 |
+
extra_options: tuple[str, str]
|
| 168 |
+
) -> None:
|
| 169 |
+
if csv_root_dir is None:
|
| 170 |
+
csv_root_dir = data_path.parent
|
| 171 |
+
config = get_config({
|
| 172 |
+
"cfg": str(cfg),
|
| 173 |
+
"batch_size": batch_size,
|
| 174 |
+
"learning_rate": learning_rate,
|
| 175 |
+
"data_path": str(data_path),
|
| 176 |
+
"csv_root_dir": str(csv_root_dir),
|
| 177 |
+
"lmdb_path": str(lmdb_path),
|
| 178 |
+
"pretrained": str(pretrained) if pretrained is not None else None,
|
| 179 |
+
"resume": resume,
|
| 180 |
+
"accumulation_steps": accumulation_steps,
|
| 181 |
+
"use_checkpoint": use_checkpoint,
|
| 182 |
+
"amp_opt_level": amp_opt_level,
|
| 183 |
+
"output": str(output),
|
| 184 |
+
"tag": tag,
|
| 185 |
+
"local_rank": local_rank,
|
| 186 |
+
"test_csv": [str(p) for p in test_csv],
|
| 187 |
+
"test_csv_root": [str(p) for p in test_csv_root_dir],
|
| 188 |
+
"data_workers": data_workers,
|
| 189 |
+
"disable_pin_memory": disable_pin_memory,
|
| 190 |
+
"data_prefetch_factor": data_prefetch_factor,
|
| 191 |
+
"opts": extra_options
|
| 192 |
+
})
|
| 193 |
+
if 'LOCAL_RANK' not in os.environ:
|
| 194 |
+
os.environ['LOCAL_RANK'] = str(local_rank)
|
| 195 |
+
|
| 196 |
+
if config.AMP_OPT_LEVEL != "O0":
|
| 197 |
+
assert amp is not None, "amp not installed!"
|
| 198 |
+
|
| 199 |
+
# Set a fixed seed to all the random number generators.
|
| 200 |
+
seed = config.SEED
|
| 201 |
+
torch.manual_seed(seed)
|
| 202 |
+
np.random.seed(seed)
|
| 203 |
+
# random.seed(seed)
|
| 204 |
+
cudnn.benchmark = True
|
| 205 |
+
|
| 206 |
+
if config.TRAIN.SCALE_LR:
|
| 207 |
+
# Linear scale the learning rate according to total batch size - may not be optimal.
|
| 208 |
+
linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZ
|
spai/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (164 Bytes). View file
|
|
|
spai/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (144 Bytes). View file
|
|
|
spai/__pycache__/__main__.cpython-310.pyc
ADDED
|
Binary file (28.9 kB). View file
|
|
|
spai/__pycache__/__main__.cpython-313.pyc
ADDED
|
Binary file (9.31 kB). View file
|
|
|
spai/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (8.44 kB). View file
|
|
|
spai/__pycache__/config.cpython-313.pyc
ADDED
|
Binary file (20.1 kB). View file
|
|
|
spai/__pycache__/data_utils.cpython-310.pyc
ADDED
|
Binary file (1.47 kB). View file
|
|
|
spai/__pycache__/data_utils.cpython-313.pyc
ADDED
|
Binary file (2.15 kB). View file
|
|
|
spai/__pycache__/logger.cpython-310.pyc
ADDED
|
Binary file (1.11 kB). View file
|
|
|
spai/__pycache__/logger.cpython-313.pyc
ADDED
|
Binary file (1.93 kB). View file
|
|
|
spai/__pycache__/lr_scheduler.cpython-310.pyc
ADDED
|
Binary file (5.14 kB). View file
|
|
|
spai/__pycache__/lr_scheduler.cpython-313.pyc
ADDED
|
Binary file (7.59 kB). View file
|
|
|
spai/__pycache__/metrics.cpython-310.pyc
ADDED
|
Binary file (5.75 kB). View file
|
|
|
spai/__pycache__/metrics.cpython-313.pyc
ADDED
|
Binary file (10.9 kB). View file
|
|
|
spai/__pycache__/onnx.cpython-310.pyc
ADDED
|
Binary file (3.98 kB). View file
|
|
|
spai/__pycache__/onnx.cpython-313.pyc
ADDED
|
Binary file (7.25 kB). View file
|
|
|
spai/__pycache__/optimizer.cpython-310.pyc
ADDED
|
Binary file (4.6 kB). View file
|
|
|
spai/__pycache__/optimizer.cpython-313.pyc
ADDED
|
Binary file (8.91 kB). View file
|
|
|
spai/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (14.1 kB). View file
|
|
|
spai/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (25.4 kB). View file
|
|
|
spai/config.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
|
| 2 |
+
# and University of Amsterdam. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import Optional, Any
|
| 19 |
+
|
| 20 |
+
import yaml
|
| 21 |
+
from yacs.config import CfgNode as CN
|
| 22 |
+
|
| 23 |
+
_C = CN()
|
| 24 |
+
|
| 25 |
+
# Base config files
|
| 26 |
+
_C.BASE = ['']
|
| 27 |
+
|
| 28 |
+
# -----------------------------------------------------------------------------
|
| 29 |
+
# Data settings
|
| 30 |
+
# -----------------------------------------------------------------------------
|
| 31 |
+
_C.DATA = CN()
|
| 32 |
+
# Batch size for a single GPU, could be overwritten by command line argument
|
| 33 |
+
_C.DATA.BATCH_SIZE = 128
|
| 34 |
+
# Batch size for validation. If it is set to None, DATA.BATCH_SIZE will be used.
|
| 35 |
+
_C.DATA.VAL_BATCH_SIZE = None
|
| 36 |
+
# Batch size for test. If it is set to None, DATA.BATCH_SIZE will be used.
|
| 37 |
+
_C.DATA.TEST_BATCH_SIZE = None
|
| 38 |
+
# Path to dataset, could be overwritten by command line argument
|
| 39 |
+
_C.DATA.DATA_PATH = ''
|
| 40 |
+
# Root path for the relative paths included in a dataset csv file. Not-used when
|
| 41 |
+
# the DATA.DATA_PATH does not point to a csv file.
|
| 42 |
+
_C.DATA.CSV_ROOT = ''
|
| 43 |
+
# A list of paths to the test datasets. Can be overwritten by command line argument.
|
| 44 |
+
_C.DATA.TEST_DATA_PATH = []
|
| 45 |
+
# A list of paths that will be used as root directories for the paths in the test csv files.
|
| 46 |
+
_C.DATA.TEST_DATA_CSV_ROOT = []
|
| 47 |
+
# Path to an LMDB filestorage. When this option is not None, the dataset's file are loaded
|
| 48 |
+
# from this one, instead of the filesystem.
|
| 49 |
+
_C.DATA.LMDB_PATH = None
|
| 50 |
+
# Dataset name
|
| 51 |
+
_C.DATA.DATASET = 'imagenet'
|
| 52 |
+
# Input image size
|
| 53 |
+
_C.DATA.IMG_SIZE = 224
|
| 54 |
+
# Minimal crop scale
|
| 55 |
+
_C.DATA.MIN_CROP_SCALE = 0.2
|
| 56 |
+
# Interpolation to resize image (random, bilinear, bicubic)
|
| 57 |
+
_C.DATA.INTERPOLATION = 'bicubic'
|
| 58 |
+
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
|
| 59 |
+
_C.DATA.PIN_MEMORY = True
|
| 60 |
+
# Number of data loading threads
|
| 61 |
+
_C.DATA.NUM_WORKERS = 24
|
| 62 |
+
# Number of batches to be prefetched by each worker.
|
| 63 |
+
_C.DATA.PREFETCH_FACTOR = 2
|
| 64 |
+
# Prefetch factor for test data loaders.
|
| 65 |
+
_C.DATA.VAL_PREFETCH_FACTOR = None
|
| 66 |
+
# Prefetch factor for test data loaders.
|
| 67 |
+
_C.DATA.TEST_PREFETCH_FACTOR = None
|
| 68 |
+
|
| 69 |
+
# Filter type, support 'mfm', 'sr', 'deblur', 'denoise'
|
| 70 |
+
_C.DATA.FILTER_TYPE = 'mfm'
|
| 71 |
+
# [MFM] Sampling ratio for low-pass filters
|
| 72 |
+
_C.DATA.SAMPLE_RATIO = 0.5
|
| 73 |
+
# [MFM] First frequency mask radius
|
| 74 |
+
# should be smaller than half of the image size
|
| 75 |
+
_C.DATA.MASK_RADIUS1 = 16
|
| 76 |
+
# [MFM] Second frequency mask radius
|
| 77 |
+
# should be larger than the first radius
|
| 78 |
+
# only used when masking a frequency band
|
| 79 |
+
# setting a larger value than the image size, e.g., 999, will have no effect
|
| 80 |
+
_C.DATA.MASK_RADIUS2 = 999
|
| 81 |
+
# [SR] SR downsampling scale factor, only used when FILTER_TYPE == 'sr'
|
| 82 |
+
_C.DATA.SR_FACTOR = 8
|
| 83 |
+
# [Deblur] Deblur parameters, only used when FILTER_TYPE == 'deblur'
|
| 84 |
+
_C.DATA.BLUR = CN()
|
| 85 |
+
_C.DATA.BLUR.KERNEL_SIZE = [7, 9, 11, 13, 15, 17, 19, 21]
|
| 86 |
+
_C.DATA.BLUR.KERNEL_LIST = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso', 'sinc']
|
| 87 |
+
_C.DATA.BLUR.KERNEL_PROB = [0.405, 0.225, 0.108, 0.027, 0.108, 0.027, 0.1]
|
| 88 |
+
_C.DATA.BLUR.SIGMA_X = [0.2, 3]
|
| 89 |
+
_C.DATA.BLUR.SIGMA_Y = [0.2, 3]
|
| 90 |
+
_C.DATA.BLUR.ROTATE_ANGLE = [-3.1416, 3.1416]
|
| 91 |
+
_C.DATA.BLUR.BETA_GAUSSIAN = [0.5, 4]
|
| 92 |
+
_C.DATA.BLUR.BETA_PLATEAU = [1, 2]
|
| 93 |
+
# [Denoise] Denoise parameters, only used when FILTER_TYPE == 'denoise'
|
| 94 |
+
_C.DATA.NOISE = CN()
|
| 95 |
+
_C.DATA.NOISE.TYPE = ['gaussian', 'poisson']
|
| 96 |
+
_C.DATA.NOISE.PROB = [0.5, 0.5]
|
| 97 |
+
_C.DATA.NOISE.GAUSSIAN_SIGMA = [1, 30]
|
| 98 |
+
_C.DATA.NOISE.GAUSSIAN_GRAY_NOISE_PROB = 0.4
|
| 99 |
+
_C.DATA.NOISE.POISSON_SCALE = [0.05, 3]
|
| 100 |
+
_C.DATA.NOISE.POISSON_GRAY_NOISE_PROB = 0.4
|
| 101 |
+
# Number of augmented views for each batch. When SupCon loss is employed, this number
|
| 102 |
+
# should be at least 2.
|
| 103 |
+
_C.DATA.AUGMENTED_VIEWS = 1
|
| 104 |
+
|
| 105 |
+
# -----------------------------------------------------------------------------
|
| 106 |
+
# Model settings
|
| 107 |
+
# -----------------------------------------------------------------------------
|
| 108 |
+
_C.MODEL = CN()
|
| 109 |
+
# Model type
|
| 110 |
+
_C.MODEL.TYPE = 'vit'
|
| 111 |
+
# Type of weights that will be used to initialize the backbone. Supported "mfm", "clip", "dinov2".
|
| 112 |
+
_C.MODEL_WEIGHTS = "mfm"
|
| 113 |
+
# Model name
|
| 114 |
+
_C.MODEL.NAME = 'pretrain'
|
| 115 |
+
# Checkpoint to resume, could be overwritten by command line argument
|
| 116 |
+
_C.MODEL.RESUME = ''
|
| 117 |
+
# Number of classes, overwritten in data preparation
|
| 118 |
+
_C.MODEL.NUM_CLASSES = 1000
|
| 119 |
+
# Dropout rate for the backbone model.
|
| 120 |
+
_C.MODEL.DROP_RATE = 0.0
|
| 121 |
+
# Dropout rate for the trainable SID layers.
|
| 122 |
+
_C.MODEL.SID_DROPOUT = 0.5
|
| 123 |
+
# Drop path rate
|
| 124 |
+
_C.MODEL.DROP_PATH_RATE = 0.1
|
| 125 |
+
# Label Smoothing
|
| 126 |
+
_C.MODEL.LABEL_SMOOTHING = 0.1
|
| 127 |
+
# Required normalization to be applied to the image before provided to the model.
|
| 128 |
+
_C.MODEL.REQUIRED_NORMALIZATION = "imagenet"
|
| 129 |
+
# Approach used for the Synthetic Image Detection task. "single_extraction" and "freq_restoration"
|
| 130 |
+
# are currently supported.
|
| 131 |
+
_C.MODEL.SID_APPROACH = "single_extraction"
|
| 132 |
+
# Whether the model accepts a fixed resolution image to its input or an arbitrary resolution image.
|
| 133 |
+
# Supported values are "fixed" and "arbitrary"
|
| 134 |
+
_C.MODEL.RESOLUTION_MODE = "fixed"
|
| 135 |
+
# Batch size used internally by patched models for feature extraction. If not provided,
|
| 136 |
+
# it is determined by the batch size of the input.
|
| 137 |
+
_C.MODEL.FEATURE_EXTRACTION_BATCH = None
|
| 138 |
+
|
| 139 |
+
# Swin Transformer parameters
|
| 140 |
+
_C.MODEL.SWIN = CN()
|
| 141 |
+
_C.MODEL.SWIN.PATCH_SIZE = 4
|
| 142 |
+
_C.MODEL.SWIN.IN_CHANS = 3
|
| 143 |
+
_C.MODEL.SWIN.EMBED_DIM = 96
|
| 144 |
+
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
|
| 145 |
+
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
|
| 146 |
+
_C.MODEL.SWIN.WINDOW_SIZE = 7
|
| 147 |
+
_C.MODEL.SWIN.MLP_RATIO = 4.
|
| 148 |
+
_C.MODEL.SWIN.QKV_BIAS = True
|
| 149 |
+
_C.MODEL.SWIN.QK_SCALE = None
|
| 150 |
+
_C.MODEL.SWIN.APE = False
|
| 151 |
+
_C.MODEL.SWIN.PATCH_NORM = True
|
| 152 |
+
|
| 153 |
+
# Vision Transformer parameters
|
| 154 |
+
_C.MODEL.VIT = CN()
|
| 155 |
+
_C.MODEL.VIT.PATCH_SIZE = 16
|
| 156 |
+
_C.MODEL.VIT.IN_CHANS = 3
|
| 157 |
+
_C.MODEL.VIT.EMBED_DIM = 768
|
| 158 |
+
_C.MODEL.VIT.DEPTH = 12
|
| 159 |
+
_C.MODEL.VIT.NUM_HEADS = 12
|
| 160 |
+
_C.MODEL.VIT.MLP_RATIO = 4
|
| 161 |
+
_C.MODEL.VIT.QKV_BIAS = True
|
| 162 |
+
_C.MODEL.VIT.INIT_VALUES = 0.1
|
| 163 |
+
# learnable absolute positional embedding
|
| 164 |
+
_C.MODEL.VIT.USE_APE = True
|
| 165 |
+
# fixed sin-cos positional embedding
|
| 166 |
+
_C.MODEL.VIT.USE_FPE = False
|
| 167 |
+
# relative position bias
|
| 168 |
+
_C.MODEL.VIT.USE_RPB = False
|
| 169 |
+
_C.MODEL.VIT.USE_SHARED_RPB = False
|
| 170 |
+
_C.MODEL.VIT.USE_MEAN_POOLING = False
|
| 171 |
+
# Vision Transformer decoder parameters
|
| 172 |
+
_C.MODEL.VIT.DECODER = CN()
|
| 173 |
+
_C.MODEL.VIT.DECODER.EMBED_DIM = 512
|
| 174 |
+
_C.MODEL.VIT.DECODER.DEPTH = 0
|
| 175 |
+
_C.MODEL.VIT.DECODER.NUM_HEADS = 16
|
| 176 |
+
|
| 177 |
+
# Features processor parameter
|
| 178 |
+
# Supported features processors: "mean_norm", "norm_max", "rine"
|
| 179 |
+
_C.MODEL.VIT.FEATURES_PROCESSOR = "rine"
|
| 180 |
+
_C.MODEL.VIT.USE_INTERMEDIATE_LAYERS = False
|
| 181 |
+
_C.MODEL.VIT.INTERMEDIATE_LAYERS = [2, 5, 8, 11]
|
| 182 |
+
_C.MODEL.VIT.PROJECTION_DIM = 1024
|
| 183 |
+
_C.MODEL.VIT.PROJECTION_LAYERS = 2
|
| 184 |
+
_C.MODEL.VIT.PATCH_PROJECTION = False
|
| 185 |
+
_C.MODEL.VIT.PATCH_PROJECTION_PER_FEATURE = False
|
| 186 |
+
# Supported patch pooling: "mean", "l2_max"
|
| 187 |
+
_C.MODEL.VIT.PATCH_POOLING = "mean"
|
| 188 |
+
|
| 189 |
+
# Frequency Restoration Estimator parameters
|
| 190 |
+
_C.MODEL.FRE = CN()
|
| 191 |
+
_C.MODEL.FRE.MASKING_RADIUS = 16
|
| 192 |
+
_C.MODEL.FRE.PROJECTOR_LAST_LAYER_ACTIVATION_TYPE = "gelu"
|
| 193 |
+
_C.MODEL.FRE.ORIGINAL_IMAGE_FEATURES_BRANCH = False
|
| 194 |
+
_C.MODEL.FRE.DISABLE_RECONSTRUCTION_SIMILARITY = False
|
| 195 |
+
|
| 196 |
+
# PatchBasedMFViT related parameters
|
| 197 |
+
_C.MODEL.PATCH_VIT = CN()
|
| 198 |
+
_C.MODEL.PATCH_VIT.PATCH_STRIDE = 224
|
| 199 |
+
_C.MODEL.PATCH_VIT.NUM_HEADS = 12
|
| 200 |
+
_C.MODEL.PATCH_VIT.ATTN_EMBED_DIM = 1536
|
| 201 |
+
_C.MODEL.PATCH_VIT.MINIMUM_PATCHES = 1
|
| 202 |
+
|
| 203 |
+
# Classification head parameters
|
| 204 |
+
_C.MODEL.CLS_HEAD = CN()
|
| 205 |
+
_C.MODEL.CLS_HEAD.MLP_RATIO = 4
|
| 206 |
+
|
| 207 |
+
# ResNet parameters
|
| 208 |
+
_C.MODEL.RESNET = CN()
|
| 209 |
+
_C.MODEL.RESNET.LAYERS = [3, 4, 6, 3]
|
| 210 |
+
_C.MODEL.RESNET.IN_CHANS = 3
|
| 211 |
+
|
| 212 |
+
# [MFM] Reconstruction target type, support 'normal', 'masked'
|
| 213 |
+
_C.MODEL.RECOVER_TARGET_TYPE = 'normal'
|
| 214 |
+
# [MFM] Frequency loss parameters
|
| 215 |
+
_C.MODEL.FREQ_LOSS = CN()
|
| 216 |
+
_C.MODEL.FREQ_LOSS.LOSS_GAMMA = 1.
|
| 217 |
+
_C.MODEL.FREQ_LOSS.MATRIX_GAMMA = 1.
|
| 218 |
+
_C.MODEL.FREQ_LOSS.PATCH_FACTOR = 1
|
| 219 |
+
_C.MODEL.FREQ_LOSS.AVE_SPECTRUM = False
|
| 220 |
+
_C.MODEL.FREQ_LOSS.WITH_MATRIX = False
|
| 221 |
+
_C.MODEL.FREQ_LOSS.LOG_MATRIX = False
|
| 222 |
+
_C.MODEL.FREQ_LOSS.BATCH_MATRIX = False
|
| 223 |
+
|
| 224 |
+
# -----------------------------------------------------------------------------
|
| 225 |
+
# Training settings
|
| 226 |
+
# -----------------------------------------------------------------------------
|
| 227 |
+
_C.TRAIN = CN()
|
| 228 |
+
_C.TRAIN.START_EPOCH = 0
|
| 229 |
+
_C.TRAIN.EPOCHS = 300
|
| 230 |
+
_C.TRAIN.WARMUP_EPOCHS = 20
|
| 231 |
+
_C.TRAIN.WEIGHT_DECAY = 0.05
|
| 232 |
+
_C.TRAIN.BASE_LR = 3e-4
|
| 233 |
+
_C.TRAIN.WARMUP_LR = 2.5e-7
|
| 234 |
+
_C.TRAIN.MIN_LR = 2.5e-6
|
| 235 |
+
# Clip gradient norm
|
| 236 |
+
_C.TRAIN.CLIP_GRAD = 3.0
|
| 237 |
+
# Auto resume from latest checkpoint
|
| 238 |
+
_C.TRAIN.AUTO_RESUME = True
|
| 239 |
+
# Gradient accumulation steps
|
| 240 |
+
# could be overwritten by command line argument
|
| 241 |
+
_C.TRAIN.ACCUMULATION_STEPS = 1
|
| 242 |
+
# Whether to use gradient checkpointing to save memory
|
| 243 |
+
# could be overwritten by command line argument
|
| 244 |
+
_C.TRAIN.USE_CHECKPOINT = False
|
| 245 |
+
|
| 246 |
+
# LR scheduler
|
| 247 |
+
# Supported modes: "supervised", "contrastive"
|
| 248 |
+
_C.TRAIN.MODE = "supervised"
|
| 249 |
+
_C.TRAIN.LR_SCHEDULER = CN()
|
| 250 |
+
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
|
| 251 |
+
# Epoch interval to decay LR, used in StepLRScheduler
|
| 252 |
+
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
|
| 253 |
+
# LR decay rate, used in StepLRScheduler
|
| 254 |
+
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
|
| 255 |
+
# Gamma / Multi steps value, used in MultiStepLRScheduler
|
| 256 |
+
_C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
|
| 257 |
+
_C.TRAIN.LR_SCHEDULER.MULTISTEPS = []
|
| 258 |
+
# A flag that indicates whether to scale lr according to batch size and grad accumulation steps.
|
| 259 |
+
_C.TRAIN.SCALE_LR = False
|
| 260 |
+
# Optimizer
|
| 261 |
+
_C.TRAIN.OPTIMIZER = CN()
|
| 262 |
+
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
|
| 263 |
+
# Optimizer Epsilon
|
| 264 |
+
_C.TRAIN.OPTIMIZER.EPS = 1e-8
|
| 265 |
+
# Optimizer Betas
|
| 266 |
+
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
|
| 267 |
+
# SGD momentum
|
| 268 |
+
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
|
| 269 |
+
_C.TRAIN.LOSS = "bce_supcont"
|
| 270 |
+
_C.TRAIN.TRIPLET_LOSS_MARGIN = 0.5
|
| 271 |
+
# Layer decay for fine-tuning
|
| 272 |
+
_C.TRAIN.LAYER_DECAY = 1.0
|
| 273 |
+
|
| 274 |
+
# -----------------------------------------------------------------------------
|
| 275 |
+
# Augmentation settings
|
| 276 |
+
# -----------------------------------------------------------------------------
|
| 277 |
+
_C.AUG = CN()
|
| 278 |
+
# Crop augmentation
|
| 279 |
+
_C.AUG.MIN_CROP_AREA = 0.2
|
| 280 |
+
_C.AUG.MAX_CROP_AREA = 1.0
|
| 281 |
+
# Flip augmentation
|
| 282 |
+
_C.AUG.HORIZONTAL_FLIP_PROB = 0.5
|
| 283 |
+
_C.AUG.VERTICAL_FLIP_PROB = 0.5
|
| 284 |
+
# Rotation augmentation
|
| 285 |
+
_C.AUG.ROTATION_PROB = 0.5
|
| 286 |
+
_C.AUG.ROTATION_DEGREES = 90
|
| 287 |
+
# Gaussian blur augmentation
|
| 288 |
+
_C.AUG.GAUSSIAN_BLUR_PROB = 0.5
|
| 289 |
+
_C.AUG.GAUSSIAN_BLUR_LIMIT = (3, 9)
|
| 290 |
+
_C.AUG.GAUSSIAN_BLUR_SIGMA = (0.01, 0.5)
|
| 291 |
+
# Gaussian noise augmentation
|
| 292 |
+
_C.AUG.GAUSSIAN_NOISE_PROB = 0.5
|
| 293 |
+
# JPEG compression augmentation
|
| 294 |
+
_C.AUG.JPEG_COMPRESSION_PROB = 0.5
|
| 295 |
+
_C.AUG.JPEG_MIN_QUALITY = 50
|
| 296 |
+
_C.AUG.JPEG_MAX_QUALITY = 100
|
| 297 |
+
# WEBP compression augmentation
|
| 298 |
+
_C.AUG.WEBP_COMPRESSION_PROB = .0
|
| 299 |
+
_C.AUG.WEBP_MIN_QUALITY = 50
|
| 300 |
+
_C.AUG.WEBP_MAX_QUALITY = 100
|
| 301 |
+
# Color jitter augmentation
|
| 302 |
+
_C.AUG.COLOR_JITTER = .0
|
| 303 |
+
_C.AUG.COLOR_JITTER_BRIGHTNESS_RANGE = (0.8, 1.2)
|
| 304 |
+
_C.AUG.COLOR_JITTER_CONTRAST_RANGE = (0.8, 1.2)
|
| 305 |
+
_C.AUG.COLOR_JITTER_SATURATION_RANGE = (0.8, 1.2)
|
| 306 |
+
_C.AUG.COLOR_JITTER_HUE_RANGE = (-0.1, 0.1)
|
| 307 |
+
# Sharpen augmentation
|
| 308 |
+
_C.AUG.SHARPEN_PROB = .0
|
| 309 |
+
_C.AUG.SHARPEN_ALPHA_RANGE = (0.01, 0.4)
|
| 310 |
+
_C.AUG.SHARPEN_LIGHTNESS_RANGE = (0.95, 1)
|
| 311 |
+
# Use AutoAugment policy. "v0" or "original"
|
| 312 |
+
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
|
| 313 |
+
# Random erase prob
|
| 314 |
+
_C.AUG.REPROB = 0.25
|
| 315 |
+
# Random erase mode
|
| 316 |
+
_C.AUG.REMODE = 'pixel'
|
| 317 |
+
# Random erase count
|
| 318 |
+
_C.AUG.RECOUNT = 1
|
| 319 |
+
# Probability of applying blurring
|
| 320 |
+
_C.AUG.BLUR_PROB = 0.25
|
| 321 |
+
# Mixup alpha, mixup enabled if > 0
|
| 322 |
+
_C.AUG.MIXUP = 0.8
|
| 323 |
+
# Cutmix alpha, cutmix enabled if > 0
|
| 324 |
+
_C.AUG.CUTMIX = 1.0
|
| 325 |
+
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
|
| 326 |
+
_C.AUG.CUTMIX_MINMAX = None
|
| 327 |
+
# Probability of performing mixup or cutmix when either/both is enabled
|
| 328 |
+
_C.AUG.MIXUP_PROB = 1.0
|
| 329 |
+
# Probability of switching to cutmix when both mixup and cutmix enabled
|
| 330 |
+
_C.AUG.MIXUP_SWITCH_PROB = 0.5
|
| 331 |
+
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
|
| 332 |
+
_C.AUG.MIXUP_MODE = 'batch'
|
| 333 |
+
|
| 334 |
+
# -----------------------------------------------------------------------------
|
| 335 |
+
# Testing settings
|
| 336 |
+
# -----------------------------------------------------------------------------
|
| 337 |
+
_C.TEST = CN()
|
| 338 |
+
# Whether to use center crop when testing.
|
| 339 |
+
_C.TEST.CROP = True
|
| 340 |
+
# Size for resizing images during testing.
|
| 341 |
+
_C.TEST.MAX_SIZE: Optional[int] = None
|
| 342 |
+
# When this option is set to True, the original resolution image is provided to the model.
|
| 343 |
+
# Setting this option to True automatically sets the batch size for validation/testing to 1.
|
| 344 |
+
_C.TEST.ORIGINAL_RESOLUTION = False
|
| 345 |
+
# Approach that will be used for generating different views of an image during testing.
|
| 346 |
+
# Currently, "tencrop" and None are supported.
|
| 347 |
+
_C.TEST.VIEWS_GENERATION_APPROACH = None
|
| 348 |
+
# Approach that will be used to combine the scores predicted for multiple views of the same
|
| 349 |
+
# image. This value is meaningful only when a view generation approach is used.
|
| 350 |
+
# Currently, "mean" and "max" are supported.
|
| 351 |
+
_C.TEST.VIEWS_REDUCTION_APPROACH = "mean"
|
| 352 |
+
# A flag that when set to True exports the analysis of the Spectral Context Attention.
|
| 353 |
+
_C.TEST.EXPORT_IMAGE_PATCHES = False
|
| 354 |
+
# -----------------------------------------------------------------------------
|
| 355 |
+
# Setting for Test-Time Perturbations
|
| 356 |
+
# -----------------------------------------------------------------------------
|
| 357 |
+
# Gaussian blur perturbation.
|
| 358 |
+
_C.TEST.GAUSSIAN_BLUR = False
|
| 359 |
+
_C.TEST.GAUSSIAN_BLUR_KERNEL_SIZE = 3
|
| 360 |
+
# Gaussian noise perturbation.
|
| 361 |
+
_C.TEST.GAUSSIAN_NOISE = False
|
| 362 |
+
_C.TEST.GAUSSIAN_NOISE_SIGMA = 1.0
|
| 363 |
+
# JPEG compression perturbation.
|
| 364 |
+
_C.TEST.JPEG_COMPRESSION = False
|
| 365 |
+
_C.TEST.JPEG_QUALITY = 100
|
| 366 |
+
# WEBP compression augmentation.
|
| 367 |
+
_C.TEST.WEBP_COMPRESSION = False
|
| 368 |
+
_C.TEST.WEBP_QUALITY = 100
|
| 369 |
+
# Scale perturbation.
|
| 370 |
+
_C.TEST.SCALE = False
|
| 371 |
+
_C.TEST.SCALE_FACTOR = 1.0
|
| 372 |
+
|
| 373 |
+
# -----------------------------------------------------------------------------
|
| 374 |
+
# Misc
|
| 375 |
+
# -----------------------------------------------------------------------------
|
| 376 |
+
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
|
| 377 |
+
# overwritten by command line argument
|
| 378 |
+
_C.AMP_OPT_LEVEL = ''
|
| 379 |
+
# Path to output folder, overwritten by command line argument
|
| 380 |
+
_C.OUTPUT = ''
|
| 381 |
+
# Tag of experiment, overwritten by command line argument
|
| 382 |
+
_C.TAG = 'default'
|
| 383 |
+
# Frequency to save checkpoint
|
| 384 |
+
_C.SAVE_FREQ = 10
|
| 385 |
+
# Frequency to logging info
|
| 386 |
+
_C.PRINT_FREQ = 10
|
| 387 |
+
# Fixed random seed
|
| 388 |
+
_C.SEED = 0
|
| 389 |
+
# Perform evaluation only, overwritten by command line argument
|
| 390 |
+
_C.EVAL_MODE = False
|
| 391 |
+
# Test throughput only, overwritten by command line argument
|
| 392 |
+
_C.THROUGHPUT_MODE = False
|
| 393 |
+
# Local rank for DistributedDataParallel, given by command line argument
|
| 394 |
+
_C.LOCAL_RANK = 0
|
| 395 |
+
|
| 396 |
+
# Path to pre-trained model
|
| 397 |
+
_C.PRETRAINED = ''
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def _update_config_from_file(config, cfg_file):
|
| 401 |
+
config.defrost()
|
| 402 |
+
with open(cfg_file, 'r') as f:
|
| 403 |
+
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 404 |
+
|
| 405 |
+
for cfg in yaml_cfg.setdefault('BASE', ['']):
|
| 406 |
+
if cfg:
|
| 407 |
+
_update_config_from_file(
|
| 408 |
+
config, os.path.join(os.path.dirname(cfg_file), cfg)
|
| 409 |
+
)
|
| 410 |
+
print('=> merge config from {}'.format(cfg_file))
|
| 411 |
+
config.merge_from_file(cfg_file)
|
| 412 |
+
config.freeze()
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def update_config(config, args):
|
| 416 |
+
_update_config_from_file(config, args["cfg"])
|
| 417 |
+
|
| 418 |
+
config.defrost()
|
| 419 |
+
if "opts" in args:
|
| 420 |
+
options: list[Any] = []
|
| 421 |
+
for (k, v) in args["opts"]:
|
| 422 |
+
options.append(k)
|
| 423 |
+
options.append(eval(v))
|
| 424 |
+
config.merge_from_list(options)
|
| 425 |
+
|
| 426 |
+
def _check_args(name):
|
| 427 |
+
if name in args and args[name]:
|
| 428 |
+
return True
|
| 429 |
+
return False
|
| 430 |
+
|
| 431 |
+
# merge from specific arguments
|
| 432 |
+
if _check_args('batch_size'):
|
| 433 |
+
config.DATA.BATCH_SIZE = args["batch_size"]
|
| 434 |
+
if _check_args('data_path'):
|
| 435 |
+
config.DATA.DATA_PATH = args["data_path"]
|
| 436 |
+
if _check_args('csv_root_dir'):
|
| 437 |
+
config.DATA.CSV_ROOT = args["csv_root_dir"]
|
| 438 |
+
if _check_args("lmdb_path"):
|
| 439 |
+
config.DATA.LMDB_PATH = args["lmdb_path"]
|
| 440 |
+
if _check_args('resume'):
|
| 441 |
+
config.MODEL.RESUME = args["resume"]
|
| 442 |
+
if _check_args('pretrained'):
|
| 443 |
+
config.PRETRAINED = args["pretrained"]
|
| 444 |
+
if _check_args('accumulation_steps'):
|
| 445 |
+
config.TRAIN.ACCUMULATION_STEPS = args["accumulation_steps"]
|
| 446 |
+
if _check_args('use_checkpoint'):
|
| 447 |
+
config.TRAIN.USE_CHECKPOINT = True
|
| 448 |
+
if _check_args('amp_opt_level'):
|
| 449 |
+
config.AMP_OPT_LEVEL = args["amp_opt_level"]
|
| 450 |
+
if _check_args('output'):
|
| 451 |
+
config.OUTPUT = args["output"]
|
| 452 |
+
if _check_args('tag'):
|
| 453 |
+
config.TAG = args["tag"]
|
| 454 |
+
if _check_args('eval'):
|
| 455 |
+
config.EVAL_MODE = True
|
| 456 |
+
if _check_args('throughput'):
|
| 457 |
+
config.THROUGHPUT_MODE = True
|
| 458 |
+
if _check_args('test_csv'):
|
| 459 |
+
config.DATA.TEST_DATA_PATH = args["test_csv"]
|
| 460 |
+
if _check_args('test_csv_root'):
|
| 461 |
+
config.DATA.TEST_DATA_CSV_ROOT = args["test_csv_root"]
|
| 462 |
+
if _check_args('learning_rate'):
|
| 463 |
+
config.TRAIN.BASE_LR = args["learning_rate"]
|
| 464 |
+
if _check_args('resize_to'):
|
| 465 |
+
config.TEST.MAX_SIZE = args["resize_to"]
|
| 466 |
+
if _check_args("local_rank"):
|
| 467 |
+
# set local rank for distributed training
|
| 468 |
+
config.LOCAL_RANK = args["local_rank"]
|
| 469 |
+
if _check_args("data_workers"):
|
| 470 |
+
config.DATA.NUM_WORKERS = args["data_workers"]
|
| 471 |
+
if _check_args("disable_pin_memory"):
|
| 472 |
+
config.PIN_MEMORY = False
|
| 473 |
+
if _check_args("data_prefetch_factor"):
|
| 474 |
+
config.DATA.PREFETCH_FACTOR = args["data_prefetch_factor"]
|
| 475 |
+
# output folder
|
| 476 |
+
config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
|
| 477 |
+
|
| 478 |
+
config.freeze()
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def get_config(args):
|
| 482 |
+
"""Get a yacs CfgNode object with default values."""
|
| 483 |
+
# Return a clone so that the defaults will not be altered
|
| 484 |
+
# This is for the "local variable" use pattern
|
| 485 |
+
config = _C.clone()
|
| 486 |
+
update_config(config, args)
|
| 487 |
+
|
| 488 |
+
return config
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
def get_custom_config(cfg):
|
| 492 |
+
config = _C.clone()
|
| 493 |
+
_update_config_from_file(config, cfg)
|
| 494 |
+
return config
|
spai/data/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
|
| 2 |
+
# and University of Amsterdam. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from .data_mfm import build_loader_mfm
|
| 18 |
+
from .data_finetune import build_loader_finetune, build_loader_test
|
| 19 |
+
|
| 20 |
+
def build_loader(config, logger, is_pretrain, is_test):
|
| 21 |
+
if is_pretrain:
|
| 22 |
+
return build_loader_mfm(config, logger)
|
| 23 |
+
elif is_test:
|
| 24 |
+
return build_loader_test(config, logger)
|
| 25 |
+
else:
|
| 26 |
+
return build_loader_finetune(config, logger)
|
spai/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (501 Bytes). View file
|
|
|
spai/data/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (577 Bytes). View file
|
|
|
spai/data/__pycache__/blur_kernels.cpython-310.pyc
ADDED
|
Binary file (14.5 kB). View file
|
|
|
spai/data/__pycache__/blur_kernels.cpython-313.pyc
ADDED
|
Binary file (20.2 kB). View file
|
|
|
spai/data/__pycache__/data_finetune.cpython-310.pyc
ADDED
|
Binary file (19.4 kB). View file
|
|
|
spai/data/__pycache__/data_finetune.cpython-313.pyc
ADDED
|
Binary file (38.6 kB). View file
|
|
|
spai/data/__pycache__/data_mfm.cpython-310.pyc
ADDED
|
Binary file (4.89 kB). View file
|
|
|
spai/data/__pycache__/data_mfm.cpython-313.pyc
ADDED
|
Binary file (9.35 kB). View file
|
|
|
spai/data/__pycache__/filestorage.cpython-310.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
spai/data/__pycache__/filestorage.cpython-313.pyc
ADDED
|
Binary file (17.5 kB). View file
|
|
|
spai/data/__pycache__/random_degradations.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
spai/data/__pycache__/random_degradations.cpython-313.pyc
ADDED
|
Binary file (20.9 kB). View file
|
|
|
spai/data/__pycache__/readers.cpython-310.pyc
ADDED
|
Binary file (6.41 kB). View file
|
|
|
spai/data/__pycache__/readers.cpython-313.pyc
ADDED
|
Binary file (9.32 kB). View file
|
|
|
spai/data/blur_kernels.py
ADDED
|
@@ -0,0 +1,539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This code is referenced from BasicSR with modifications.
|
| 2 |
+
# Reference: https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py # noqa
|
| 3 |
+
# Original licence: Copyright (c) 2020 xinntao, under the Apache 2.0 license.
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from scipy import special
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def get_rotated_sigma_matrix(sig_x, sig_y, theta):
|
| 13 |
+
"""Calculate the rotated sigma matrix (two dimensional matrix).
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
sig_x (float): Standard deviation along the horizontal direction.
|
| 17 |
+
sig_y (float): Standard deviation along the vertical direction.
|
| 18 |
+
theta (float): Rotation in radian.
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
ndarray: Rotated sigma matrix.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
diag = np.array([[sig_x**2, 0], [0, sig_y**2]]).astype(np.float32)
|
| 25 |
+
rot = np.array([[np.cos(theta), -np.sin(theta)],
|
| 26 |
+
[np.sin(theta), np.cos(theta)]]).astype(np.float32)
|
| 27 |
+
|
| 28 |
+
return np.matmul(rot, np.matmul(diag, rot.T))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _mesh_grid(kernel_size):
|
| 32 |
+
"""Generate the mesh grid, centering at zero.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
kernel_size (int): The size of the kernel.
|
| 36 |
+
|
| 37 |
+
Returns:
|
| 38 |
+
x_grid (ndarray): x-coordinates with shape (kernel_size, kernel_size).
|
| 39 |
+
y_grid (ndarray): y-coordiantes with shape (kernel_size, kernel_size).
|
| 40 |
+
xy_grid (ndarray): stacked coordinates with shape
|
| 41 |
+
(kernel_size, kernel_size, 2).
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
range_ = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
|
| 45 |
+
x_grid, y_grid = np.meshgrid(range_, range_)
|
| 46 |
+
xy_grid = np.hstack((x_grid.reshape((kernel_size * kernel_size, 1)),
|
| 47 |
+
y_grid.reshape(kernel_size * kernel_size,
|
| 48 |
+
1))).reshape(kernel_size, kernel_size,
|
| 49 |
+
2)
|
| 50 |
+
|
| 51 |
+
return xy_grid, x_grid, y_grid
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def calculate_gaussian_pdf(sigma_matrix, grid):
|
| 55 |
+
"""Calculate PDF of the bivariate Gaussian distribution.
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
sigma_matrix (ndarray): The variance matrix with shape (2, 2).
|
| 59 |
+
grid (ndarray): Coordinates generated by :func:`_mesh_grid`,
|
| 60 |
+
with shape (K, K, 2), where K is the kernel size.
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
kernel (ndarrray): Un-normalized kernel.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 67 |
+
kernel = np.exp(-0.5 * np.sum(np.matmul(grid, inverse_sigma) * grid, 2))
|
| 68 |
+
|
| 69 |
+
return kernel
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def bivariate_gaussian(kernel_size,
|
| 73 |
+
sig_x,
|
| 74 |
+
sig_y=None,
|
| 75 |
+
theta=None,
|
| 76 |
+
grid=None,
|
| 77 |
+
is_isotropic=True):
|
| 78 |
+
"""Generate a bivariate isotropic or anisotropic Gaussian kernel.
|
| 79 |
+
|
| 80 |
+
In isotropic mode, only `sig_x` is used. `sig_y` and `theta` are
|
| 81 |
+
ignored.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
kernel_size (int): The size of the kernel
|
| 85 |
+
sig_x (float): Standard deviation along horizontal direction.
|
| 86 |
+
sig_y (float | None, optional): Standard deviation along the vertical
|
| 87 |
+
direction. If it is None, 'is_isotropic' must be set to True.
|
| 88 |
+
Default: None.
|
| 89 |
+
theta (float | None, optional): Rotation in radian. If it is None,
|
| 90 |
+
'is_isotropic' must be set to True. Default: None.
|
| 91 |
+
grid (ndarray, optional): Coordinates generated by :func:`_mesh_grid`,
|
| 92 |
+
with shape (K, K, 2), where K is the kernel size. Default: None
|
| 93 |
+
is_isotropic (bool, optional): Whether to use an isotropic kernel.
|
| 94 |
+
Default: True.
|
| 95 |
+
|
| 96 |
+
Returns:
|
| 97 |
+
kernel (ndarray): normalized kernel (i.e. sum to 1).
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
if grid is None:
|
| 101 |
+
grid, _, _ = _mesh_grid(kernel_size)
|
| 102 |
+
|
| 103 |
+
if is_isotropic:
|
| 104 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0,
|
| 105 |
+
sig_x**2]]).astype(np.float32)
|
| 106 |
+
else:
|
| 107 |
+
if sig_y is None:
|
| 108 |
+
raise ValueError('"sig_y" cannot be None if "is_isotropic" is '
|
| 109 |
+
'False.')
|
| 110 |
+
|
| 111 |
+
sigma_matrix = get_rotated_sigma_matrix(sig_x, sig_y, theta)
|
| 112 |
+
|
| 113 |
+
kernel = calculate_gaussian_pdf(sigma_matrix, grid)
|
| 114 |
+
kernel = kernel / np.sum(kernel)
|
| 115 |
+
|
| 116 |
+
return kernel
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def bivariate_generalized_gaussian(kernel_size,
|
| 120 |
+
sig_x,
|
| 121 |
+
sig_y=None,
|
| 122 |
+
theta=None,
|
| 123 |
+
beta=1,
|
| 124 |
+
grid=None,
|
| 125 |
+
is_isotropic=True):
|
| 126 |
+
"""Generate a bivariate generalized Gaussian kernel.
|
| 127 |
+
|
| 128 |
+
Described in `Parameter Estimation For Multivariate Generalized
|
| 129 |
+
Gaussian Distributions` by Pascal et. al (2013). In isotropic mode,
|
| 130 |
+
only `sig_x` is used. `sig_y` and `theta` is ignored.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
kernel_size (int): The size of the kernel
|
| 134 |
+
sig_x (float): Standard deviation along horizontal direction
|
| 135 |
+
sig_y (float | None, optional): Standard deviation along the vertical
|
| 136 |
+
direction. If it is None, 'is_isotropic' must be set to True.
|
| 137 |
+
Default: None.
|
| 138 |
+
theta (float | None, optional): Rotation in radian. If it is None,
|
| 139 |
+
'is_isotropic' must be set to True. Default: None.
|
| 140 |
+
beta (float, optional): Shape parameter, beta = 1 is the normal
|
| 141 |
+
distribution. Default: 1.
|
| 142 |
+
grid (ndarray, optional): Coordinates generated by :func:`_mesh_grid`,
|
| 143 |
+
with shape (K, K, 2), where K is the kernel size. Default: None
|
| 144 |
+
is_isotropic (bool, optional): Whether to use an isotropic kernel.
|
| 145 |
+
Default: True.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
kernel (ndarray): normalized kernel.
|
| 149 |
+
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
if grid is None:
|
| 153 |
+
grid, _, _ = _mesh_grid(kernel_size)
|
| 154 |
+
|
| 155 |
+
if is_isotropic:
|
| 156 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0,
|
| 157 |
+
sig_x**2]]).astype(np.float32)
|
| 158 |
+
else:
|
| 159 |
+
sigma_matrix = get_rotated_sigma_matrix(sig_x, sig_y, theta)
|
| 160 |
+
|
| 161 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 162 |
+
kernel = np.exp(
|
| 163 |
+
-0.5 *
|
| 164 |
+
np.power(np.sum(np.matmul(grid, inverse_sigma) * grid, 2), beta))
|
| 165 |
+
kernel = kernel / np.sum(kernel)
|
| 166 |
+
|
| 167 |
+
return kernel
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def bivariate_plateau(kernel_size,
|
| 171 |
+
sig_x,
|
| 172 |
+
sig_y,
|
| 173 |
+
theta,
|
| 174 |
+
beta,
|
| 175 |
+
grid=None,
|
| 176 |
+
is_isotropic=True):
|
| 177 |
+
"""Generate a plateau-like anisotropic kernel.
|
| 178 |
+
|
| 179 |
+
This kernel has a form of 1 / (1+x^(beta)).
|
| 180 |
+
Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution # noqa
|
| 181 |
+
In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
kernel_size (int): The size of the kernel
|
| 185 |
+
sig_x (float): Standard deviation along horizontal direction
|
| 186 |
+
sig_y (float): Standard deviation along the vertical direction.
|
| 187 |
+
theta (float): Rotation in radian.
|
| 188 |
+
beta (float): Shape parameter, beta = 1 is the normal distribution.
|
| 189 |
+
grid (ndarray, optional): Coordinates generated by :func:`_mesh_grid`,
|
| 190 |
+
with shape (K, K, 2), where K is the kernel size. Default: None
|
| 191 |
+
is_isotropic (bool, optional): Whether to use an isotropic kernel.
|
| 192 |
+
Default: True.
|
| 193 |
+
Returns:
|
| 194 |
+
kernel (ndarray): normalized kernel (i.e. sum to 1).
|
| 195 |
+
"""
|
| 196 |
+
if grid is None:
|
| 197 |
+
grid, _, _ = _mesh_grid(kernel_size)
|
| 198 |
+
|
| 199 |
+
if is_isotropic:
|
| 200 |
+
sigma_matrix = np.array([[sig_x**2, 0], [0,
|
| 201 |
+
sig_x**2]]).astype(np.float32)
|
| 202 |
+
else:
|
| 203 |
+
sigma_matrix = get_rotated_sigma_matrix(sig_x, sig_y, theta)
|
| 204 |
+
|
| 205 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 206 |
+
kernel = np.reciprocal(
|
| 207 |
+
np.power(np.sum(np.matmul(grid, inverse_sigma) * grid, 2), beta) + 1)
|
| 208 |
+
kernel = kernel / np.sum(kernel)
|
| 209 |
+
|
| 210 |
+
return kernel
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def random_bivariate_gaussian_kernel(kernel_size,
|
| 214 |
+
sigma_x_range,
|
| 215 |
+
sigma_y_range,
|
| 216 |
+
rotation_range,
|
| 217 |
+
noise_range=None,
|
| 218 |
+
is_isotropic=True):
|
| 219 |
+
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
|
| 220 |
+
|
| 221 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and
|
| 222 |
+
`rotation_range` is ignored.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
kernel_size (int): The size of the kernel.
|
| 226 |
+
sigma_x_range (tuple): The range of the standard deviation along the
|
| 227 |
+
horizontal direction. Default: [0.6, 5]
|
| 228 |
+
sigma_y_range (tuple): The range of the standard deviation along the
|
| 229 |
+
vertical direction. Default: [0.6, 5]
|
| 230 |
+
rotation_range (tuple): Range of rotation in radian.
|
| 231 |
+
noise_range (tuple, optional): Multiplicative kernel noise.
|
| 232 |
+
Default: None.
|
| 233 |
+
is_isotropic (bool, optional): Whether to use an isotropic kernel.
|
| 234 |
+
Default: True.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
kernel (ndarray): The kernel whose parameters are sampled from the
|
| 238 |
+
specified range.
|
| 239 |
+
"""
|
| 240 |
+
|
| 241 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 242 |
+
assert sigma_x_range[0] <= sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 243 |
+
|
| 244 |
+
sigma_x = random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 245 |
+
if is_isotropic is False:
|
| 246 |
+
assert sigma_y_range[0] <= sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 247 |
+
assert rotation_range[0] <= rotation_range[1], 'Wrong rotation_range.'
|
| 248 |
+
sigma_y = random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 249 |
+
rotation = random.uniform(rotation_range[0], rotation_range[1])
|
| 250 |
+
else:
|
| 251 |
+
sigma_y = sigma_x
|
| 252 |
+
rotation = 0
|
| 253 |
+
|
| 254 |
+
kernel = bivariate_gaussian(
|
| 255 |
+
kernel_size, sigma_x, sigma_y, rotation, is_isotropic=is_isotropic)
|
| 256 |
+
|
| 257 |
+
# add multiplicative noise
|
| 258 |
+
if noise_range is not None:
|
| 259 |
+
assert noise_range[0] <= noise_range[1], 'Wrong noise range.'
|
| 260 |
+
noise = torch.FloatTensor(
|
| 261 |
+
*(kernel.shape)).uniform_(noise_range[0], noise_range[1]).numpy()
|
| 262 |
+
kernel = kernel * noise
|
| 263 |
+
kernel = kernel / np.sum(kernel)
|
| 264 |
+
|
| 265 |
+
return kernel
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def random_bivariate_generalized_gaussian_kernel(kernel_size,
|
| 269 |
+
sigma_x_range,
|
| 270 |
+
sigma_y_range,
|
| 271 |
+
rotation_range,
|
| 272 |
+
beta_range,
|
| 273 |
+
noise_range=None,
|
| 274 |
+
is_isotropic=True):
|
| 275 |
+
"""Randomly generate bivariate generalized Gaussian kernels.
|
| 276 |
+
|
| 277 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and
|
| 278 |
+
`rotation_range` is ignored.
|
| 279 |
+
|
| 280 |
+
Args:
|
| 281 |
+
kernel_size (int): The size of the kernel.
|
| 282 |
+
sigma_x_range (tuple): The range of the standard deviation along the
|
| 283 |
+
horizontal direction. Default: [0.6, 5]
|
| 284 |
+
sigma_y_range (tuple): The range of the standard deviation along the
|
| 285 |
+
vertical direction. Default: [0.6, 5]
|
| 286 |
+
rotation_range (tuple): Range of rotation in radian.
|
| 287 |
+
beta_range (float): The range of the shape parameter, beta = 1 is the
|
| 288 |
+
normal distribution.
|
| 289 |
+
noise_range (tuple, optional): Multiplicative kernel noise.
|
| 290 |
+
Default: None.
|
| 291 |
+
is_isotropic (bool, optional): Whether to use an isotropic kernel.
|
| 292 |
+
Default: True.
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
kernel (ndarray):
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 299 |
+
assert sigma_x_range[0] <= sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 300 |
+
|
| 301 |
+
sigma_x = random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 302 |
+
if is_isotropic is False:
|
| 303 |
+
assert sigma_y_range[0] <= sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 304 |
+
assert rotation_range[0] <= rotation_range[1], 'Wrong rotation_range.'
|
| 305 |
+
sigma_y = random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 306 |
+
rotation = random.uniform(rotation_range[0], rotation_range[1])
|
| 307 |
+
else:
|
| 308 |
+
sigma_y = sigma_x
|
| 309 |
+
rotation = 0
|
| 310 |
+
|
| 311 |
+
# assume beta_range[0] <= 1 <= beta_range[1]
|
| 312 |
+
if random.random() <= 0.5:
|
| 313 |
+
beta = random.uniform(beta_range[0], 1)
|
| 314 |
+
else:
|
| 315 |
+
beta = random.uniform(1, beta_range[1])
|
| 316 |
+
|
| 317 |
+
kernel = bivariate_generalized_gaussian(
|
| 318 |
+
kernel_size,
|
| 319 |
+
sigma_x,
|
| 320 |
+
sigma_y,
|
| 321 |
+
rotation,
|
| 322 |
+
beta,
|
| 323 |
+
is_isotropic=is_isotropic)
|
| 324 |
+
|
| 325 |
+
# add multiplicative noise
|
| 326 |
+
if noise_range is not None:
|
| 327 |
+
assert noise_range[0] <= noise_range[1], 'Wrong noise range.'
|
| 328 |
+
noise = torch.FloatTensor(
|
| 329 |
+
*(kernel.shape)).uniform_(noise_range[0], noise_range[1]).numpy()
|
| 330 |
+
kernel = kernel * noise
|
| 331 |
+
kernel = kernel / np.sum(kernel)
|
| 332 |
+
|
| 333 |
+
return kernel
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def random_bivariate_plateau_kernel(kernel_size,
|
| 337 |
+
sigma_x_range,
|
| 338 |
+
sigma_y_range,
|
| 339 |
+
rotation_range,
|
| 340 |
+
beta_range,
|
| 341 |
+
noise_range=None,
|
| 342 |
+
is_isotropic=True):
|
| 343 |
+
"""Randomly generate bivariate plateau kernels.
|
| 344 |
+
|
| 345 |
+
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and
|
| 346 |
+
`rotation_range` is ignored.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
kernel_size (int): The size of the kernel.
|
| 350 |
+
sigma_x_range (tuple): The range of the standard deviation along the
|
| 351 |
+
horizontal direction. Default: [0.6, 5]
|
| 352 |
+
sigma_y_range (tuple): The range of the standard deviation along the
|
| 353 |
+
vertical direction. Default: [0.6, 5]
|
| 354 |
+
rotation_range (tuple): Range of rotation in radian.
|
| 355 |
+
beta_range (float): The range of the shape parameter, beta = 1 is the
|
| 356 |
+
normal distribution.
|
| 357 |
+
noise_range (tuple, optional): Multiplicative kernel noise.
|
| 358 |
+
Default: None.
|
| 359 |
+
is_isotropic (bool, optional): Whether to use an isotropic kernel.
|
| 360 |
+
Default: True.
|
| 361 |
+
|
| 362 |
+
Returns:
|
| 363 |
+
kernel (ndarray):
|
| 364 |
+
"""
|
| 365 |
+
|
| 366 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 367 |
+
assert sigma_x_range[0] <= sigma_x_range[1], 'Wrong sigma_x_range.'
|
| 368 |
+
sigma_x = random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 369 |
+
|
| 370 |
+
if is_isotropic is False:
|
| 371 |
+
assert sigma_y_range[0] <= sigma_y_range[1], 'Wrong sigma_y_range.'
|
| 372 |
+
assert rotation_range[0] <= rotation_range[1], 'Wrong rotation_range.'
|
| 373 |
+
sigma_y = random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 374 |
+
rotation = random.uniform(rotation_range[0], rotation_range[1])
|
| 375 |
+
else:
|
| 376 |
+
sigma_y = sigma_x
|
| 377 |
+
rotation = 0
|
| 378 |
+
|
| 379 |
+
# TODO: this may be not proper
|
| 380 |
+
if random.random() <= 0.5:
|
| 381 |
+
beta = random.uniform(beta_range[0], 1)
|
| 382 |
+
else:
|
| 383 |
+
beta = random.uniform(1, beta_range[1])
|
| 384 |
+
|
| 385 |
+
kernel = bivariate_plateau(
|
| 386 |
+
kernel_size,
|
| 387 |
+
sigma_x,
|
| 388 |
+
sigma_y,
|
| 389 |
+
rotation,
|
| 390 |
+
beta,
|
| 391 |
+
is_isotropic=is_isotropic)
|
| 392 |
+
|
| 393 |
+
# add multiplicative noise
|
| 394 |
+
if noise_range is not None:
|
| 395 |
+
assert noise_range[0] <= noise_range[1], 'Wrong noise range.'
|
| 396 |
+
noise = torch.FloatTensor(
|
| 397 |
+
*(kernel.shape)).uniform_(noise_range[0], noise_range[1]).numpy()
|
| 398 |
+
kernel = kernel * noise
|
| 399 |
+
kernel = kernel / np.sum(kernel)
|
| 400 |
+
|
| 401 |
+
return kernel
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def random_circular_lowpass_kernel(omega_range, kernel_size, pad_to=0):
|
| 405 |
+
""" Generate a 2D Sinc filter
|
| 406 |
+
|
| 407 |
+
Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter # noqa
|
| 408 |
+
|
| 409 |
+
Args:
|
| 410 |
+
omega_range (tuple): The cutoff frequency in radian (pi is max).
|
| 411 |
+
kernel_size (int): The size of the kernel. It must be an odd number.
|
| 412 |
+
pad_to (int, optional): The size of the padded kernel. It must be odd
|
| 413 |
+
or zero. Default: 0.
|
| 414 |
+
|
| 415 |
+
Returns:
|
| 416 |
+
ndarray: The Sinc kernel with specified parameters.
|
| 417 |
+
"""
|
| 418 |
+
err = np.geterr()
|
| 419 |
+
np.seterr(divide='ignore', invalid='ignore')
|
| 420 |
+
|
| 421 |
+
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
|
| 422 |
+
omega = random.uniform(omega_range[0], omega_range[-1])
|
| 423 |
+
|
| 424 |
+
kernel = np.fromfunction(
|
| 425 |
+
lambda x, y: omega * special.j1(omega * np.sqrt(
|
| 426 |
+
(x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) /
|
| 427 |
+
(2 * np.pi * np.sqrt((x - (kernel_size - 1) / 2)**2 +
|
| 428 |
+
(y - (kernel_size - 1) / 2)**2)),
|
| 429 |
+
[kernel_size, kernel_size])
|
| 430 |
+
kernel[(kernel_size - 1) // 2,
|
| 431 |
+
(kernel_size - 1) // 2] = omega**2 / (4 * np.pi)
|
| 432 |
+
kernel = kernel / np.sum(kernel)
|
| 433 |
+
|
| 434 |
+
if pad_to > kernel_size:
|
| 435 |
+
pad_size = (pad_to - kernel_size) // 2
|
| 436 |
+
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
|
| 437 |
+
|
| 438 |
+
np.seterr(**err)
|
| 439 |
+
|
| 440 |
+
return kernel
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def random_mixed_kernels(kernel_list,
|
| 444 |
+
kernel_prob,
|
| 445 |
+
kernel_size,
|
| 446 |
+
sigma_x_range=[0.6, 5],
|
| 447 |
+
sigma_y_range=[0.6, 5],
|
| 448 |
+
rotation_range=[-np.pi, np.pi],
|
| 449 |
+
beta_gaussian_range=[0.5, 8],
|
| 450 |
+
beta_plateau_range=[1, 2],
|
| 451 |
+
omega_range=[0, np.pi],
|
| 452 |
+
noise_range=None):
|
| 453 |
+
"""Randomly generate a kernel.
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
kernel_list (list): A list of kernel types. Choices are
|
| 458 |
+
'iso', 'aniso', 'skew', 'generalized_iso', 'generalized_aniso',
|
| 459 |
+
'plateau_iso', 'plateau_aniso', 'sinc'.
|
| 460 |
+
kernel_prob (list): The probability of choosing of the corresponding
|
| 461 |
+
kernel.
|
| 462 |
+
kernel_size (int): The size of the kernel.
|
| 463 |
+
sigma_x_range (list, optional): The range of the standard deviation
|
| 464 |
+
along the horizontal direction. Default: (0.6, 5).
|
| 465 |
+
sigma_y_range (list, optional): The range of the standard deviation
|
| 466 |
+
along the vertical direction. Default: (0.6, 5).
|
| 467 |
+
rotation_range (list, optional): Range of rotation in radian.
|
| 468 |
+
Default: (-np.pi, np.pi).
|
| 469 |
+
beta_gaussian_range (list, optional): The range of the shape parameter
|
| 470 |
+
for generalized Gaussian. Default: (0.5, 8).
|
| 471 |
+
beta_plateau_range (list, optional): The range of the shape parameter
|
| 472 |
+
for plateau kernel. Default: (1, 2).
|
| 473 |
+
omega_range (list, optional): The range of omega used in Sinc kernel.
|
| 474 |
+
Default: (0, np.pi).
|
| 475 |
+
noise_range (list, optional): Multiplicative kernel noise.
|
| 476 |
+
Default: None.
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
kernel (ndarray): The kernel whose parameters are sampled from the
|
| 480 |
+
specified range.
|
| 481 |
+
"""
|
| 482 |
+
|
| 483 |
+
kernel_type = random.choices(kernel_list, weights=kernel_prob)[0]
|
| 484 |
+
if kernel_type == 'iso':
|
| 485 |
+
kernel = random_bivariate_gaussian_kernel(
|
| 486 |
+
kernel_size,
|
| 487 |
+
sigma_x_range,
|
| 488 |
+
sigma_y_range,
|
| 489 |
+
rotation_range,
|
| 490 |
+
noise_range=noise_range,
|
| 491 |
+
is_isotropic=True)
|
| 492 |
+
elif kernel_type == 'aniso':
|
| 493 |
+
kernel = random_bivariate_gaussian_kernel(
|
| 494 |
+
kernel_size,
|
| 495 |
+
sigma_x_range,
|
| 496 |
+
sigma_y_range,
|
| 497 |
+
rotation_range,
|
| 498 |
+
noise_range=noise_range,
|
| 499 |
+
is_isotropic=False)
|
| 500 |
+
elif kernel_type == 'generalized_iso':
|
| 501 |
+
kernel = random_bivariate_generalized_gaussian_kernel(
|
| 502 |
+
kernel_size,
|
| 503 |
+
sigma_x_range,
|
| 504 |
+
sigma_y_range,
|
| 505 |
+
rotation_range,
|
| 506 |
+
beta_gaussian_range,
|
| 507 |
+
noise_range=noise_range,
|
| 508 |
+
is_isotropic=True)
|
| 509 |
+
elif kernel_type == 'generalized_aniso':
|
| 510 |
+
kernel = random_bivariate_generalized_gaussian_kernel(
|
| 511 |
+
kernel_size,
|
| 512 |
+
sigma_x_range,
|
| 513 |
+
sigma_y_range,
|
| 514 |
+
rotation_range,
|
| 515 |
+
beta_gaussian_range,
|
| 516 |
+
noise_range=noise_range,
|
| 517 |
+
is_isotropic=False)
|
| 518 |
+
elif kernel_type == 'plateau_iso':
|
| 519 |
+
kernel = random_bivariate_plateau_kernel(
|
| 520 |
+
kernel_size,
|
| 521 |
+
sigma_x_range,
|
| 522 |
+
sigma_y_range,
|
| 523 |
+
rotation_range,
|
| 524 |
+
beta_plateau_range,
|
| 525 |
+
noise_range=None,
|
| 526 |
+
is_isotropic=True)
|
| 527 |
+
elif kernel_type == 'plateau_aniso':
|
| 528 |
+
kernel = random_bivariate_plateau_kernel(
|
| 529 |
+
kernel_size,
|
| 530 |
+
sigma_x_range,
|
| 531 |
+
sigma_y_range,
|
| 532 |
+
rotation_range,
|
| 533 |
+
beta_plateau_range,
|
| 534 |
+
noise_range=None,
|
| 535 |
+
is_isotropic=False)
|
| 536 |
+
elif kernel_type == 'sinc':
|
| 537 |
+
kernel = random_circular_lowpass_kernel(omega_range, kernel_size)
|
| 538 |
+
|
| 539 |
+
return kernel
|
spai/data/data_finetune.py
ADDED
|
@@ -0,0 +1,723 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
|
| 2 |
+
# and University of Amsterdam. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import collections
|
| 18 |
+
import os
|
| 19 |
+
import pathlib
|
| 20 |
+
import random
|
| 21 |
+
from functools import partial
|
| 22 |
+
from typing import Any, Union, Optional, Iterable
|
| 23 |
+
from collections.abc import Callable
|
| 24 |
+
|
| 25 |
+
import albumentations as A
|
| 26 |
+
import torchvision.transforms.functional
|
| 27 |
+
from albumentations.augmentations.transforms import ImageCompressionType
|
| 28 |
+
from albumentations.pytorch import ToTensorV2
|
| 29 |
+
import numpy as np
|
| 30 |
+
import torch
|
| 31 |
+
from torch.utils.data import DataLoader
|
| 32 |
+
from PIL import Image
|
| 33 |
+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
| 34 |
+
from timm.data import Mixup
|
| 35 |
+
import cv2
|
| 36 |
+
from torchvision.transforms.v2.functional import ten_crop, pad
|
| 37 |
+
import filetype
|
| 38 |
+
|
| 39 |
+
from spai.data import readers
|
| 40 |
+
from spai.data import filestorage
|
| 41 |
+
from spai import data_utils
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CSVDataset(torch.utils.data.Dataset):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
csv_path: pathlib.Path,
|
| 48 |
+
csv_root_path: pathlib.Path,
|
| 49 |
+
split: str,
|
| 50 |
+
transform,
|
| 51 |
+
path_column: str = "image",
|
| 52 |
+
split_column: str = "split",
|
| 53 |
+
class_column: str = "class",
|
| 54 |
+
views: int = 1,
|
| 55 |
+
concatenate_views_horizontally: bool = False,
|
| 56 |
+
lmdb_storage: Optional[pathlib.Path] = None,
|
| 57 |
+
views_generator: Optional[Callable[[Image.Image], tuple[Image.Image, ...]]] = None
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
self.csv_path: pathlib.Path = csv_path
|
| 61 |
+
self.csv_root_path: pathlib.Path = csv_root_path
|
| 62 |
+
self.split: str = split
|
| 63 |
+
self.path_column: str = path_column
|
| 64 |
+
self.split_column: str = split_column
|
| 65 |
+
self.class_column: str = class_column
|
| 66 |
+
self.transform = transform
|
| 67 |
+
self.views: int = views
|
| 68 |
+
self.views_generator: Optional[
|
| 69 |
+
Callable[[Image.Image], tuple[Image.Image, ...]]] = views_generator
|
| 70 |
+
self.concatenate_views_horizontally: bool = concatenate_views_horizontally
|
| 71 |
+
self.lmdb_storage: Optional[pathlib.Path] = lmdb_storage
|
| 72 |
+
|
| 73 |
+
# Reader to be used for data loading. Its creation is deferred
|
| 74 |
+
self.data_reader: Optional[readers.DataReader] = None
|
| 75 |
+
|
| 76 |
+
if split not in ["train", "val", "test"]:
|
| 77 |
+
raise RuntimeError(f"Unsupported split: {split}")
|
| 78 |
+
|
| 79 |
+
# Path of the CSV file is expected to be absolute.
|
| 80 |
+
reader = readers.FileSystemReader(pathlib.Path("/"))
|
| 81 |
+
self.entries: list[dict[str, Any]] = reader.read_csv_file(str(self.csv_path))
|
| 82 |
+
self.entries = [e for e in self.entries if e[self.split_column] == self.split]
|
| 83 |
+
|
| 84 |
+
self.num_classes: int = len(
|
| 85 |
+
collections.Counter([e[self.class_column] for e in self.entries]).keys()
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
def __len__(self):
|
| 89 |
+
return len(self.entries)
|
| 90 |
+
|
| 91 |
+
def __getitem__(self, idx: int) -> tuple[torch.Tensor, np.ndarray, int]:
|
| 92 |
+
"""Returns the requested image sample from the dataset.
|
| 93 |
+
|
| 94 |
+
:returns: A tuple containing the image tensor, the labels numpy array and the
|
| 95 |
+
index in dataset.
|
| 96 |
+
Image tensor: (V x 3 x H x W) where V is the number of augmented views.
|
| 97 |
+
Label array: (1, )
|
| 98 |
+
Index
|
| 99 |
+
"""
|
| 100 |
+
# Defer the creation of the data reader until the first read operation in order to
|
| 101 |
+
# properly handle the spawning of multiple processes by DataLoader, where each one
|
| 102 |
+
# should contain a separate reader object.
|
| 103 |
+
if self.data_reader is None:
|
| 104 |
+
self._create_data_reader()
|
| 105 |
+
|
| 106 |
+
# Load sample.
|
| 107 |
+
img_obj: Image.Image = self.data_reader.load_image(
|
| 108 |
+
self.entries[idx][self.path_column], channels=3
|
| 109 |
+
)
|
| 110 |
+
label: int = int(self.entries[idx][self.class_column])
|
| 111 |
+
|
| 112 |
+
# Generate multiple views of an image either through a provided views generation
|
| 113 |
+
# function or through multiple augmentations of the image.
|
| 114 |
+
if self.views_generator is not None:
|
| 115 |
+
augmented_views: tuple[Image.Image, ...] = self.views_generator(img_obj)
|
| 116 |
+
augmented_views: list[np.ndarray] = [np.array(v) for v in augmented_views]
|
| 117 |
+
augmented_views: list[torch.Tensor] = [
|
| 118 |
+
self.transform(image=v)["image"] for v in augmented_views
|
| 119 |
+
]
|
| 120 |
+
else:
|
| 121 |
+
img: np.ndarray = np.array(img_obj)
|
| 122 |
+
augmented_views: list[torch.Tensor] = []
|
| 123 |
+
for _ in range(self.views):
|
| 124 |
+
augmented_views.append(self.transform(image=img)["image"])
|
| 125 |
+
|
| 126 |
+
# Either concatenate the views in a single big image, or provide them stacked
|
| 127 |
+
# into a new tensor dimension.
|
| 128 |
+
if self.concatenate_views_horizontally:
|
| 129 |
+
augmented_img: torch.Tensor = torch.cat(augmented_views, dim=-1)
|
| 130 |
+
augmented_img = augmented_img.unsqueeze(dim=0)
|
| 131 |
+
else:
|
| 132 |
+
augmented_img: torch.Tensor = torch.stack(augmented_views, dim=0)
|
| 133 |
+
|
| 134 |
+
# Cleanup resources.
|
| 135 |
+
img_obj.close()
|
| 136 |
+
|
| 137 |
+
return augmented_img, np.array(label, dtype=float), idx
|
| 138 |
+
|
| 139 |
+
def get_classes_num(self) -> int:
|
| 140 |
+
return self.num_classes
|
| 141 |
+
|
| 142 |
+
def get_dataset_root_path(self) -> pathlib.Path:
|
| 143 |
+
if self.lmdb_storage is not None:
|
| 144 |
+
return self.lmdb_storage
|
| 145 |
+
else:
|
| 146 |
+
return self.csv_root_path
|
| 147 |
+
|
| 148 |
+
def update_dataset_csv(
|
| 149 |
+
self,
|
| 150 |
+
column_name: str,
|
| 151 |
+
values: dict[int, Any],
|
| 152 |
+
export_dir: Optional[pathlib.Path] = None
|
| 153 |
+
) -> None:
|
| 154 |
+
for idx, v in values.items():
|
| 155 |
+
self.entries[idx][column_name] = v
|
| 156 |
+
|
| 157 |
+
# Make sure that a valid value for the updated column exists for all entries.
|
| 158 |
+
for e in self.entries:
|
| 159 |
+
if column_name not in e:
|
| 160 |
+
e[column_name] = ""
|
| 161 |
+
|
| 162 |
+
if export_dir:
|
| 163 |
+
export_path: pathlib.Path = export_dir / self.csv_path.name
|
| 164 |
+
data_utils.write_csv_file(self.entries, export_path, delimiter=",")
|
| 165 |
+
|
| 166 |
+
def _create_data_reader(self) -> None:
|
| 167 |
+
# Limit the number of OpenCV threads to 2 to utilize multiple processes. Otherwise,
|
| 168 |
+
# each process spawns a number of threads equal to the number of logical cores and
|
| 169 |
+
# the overall performance gets worse due to threads congestion.
|
| 170 |
+
cv2.setNumThreads(1)
|
| 171 |
+
|
| 172 |
+
if self.lmdb_storage is None:
|
| 173 |
+
self.data_reader: readers.FileSystemReader = readers.FileSystemReader(
|
| 174 |
+
pathlib.Path(self.csv_root_path)
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
self.data_reader: readers.LMDBFileStorageReader = readers.LMDBFileStorageReader(
|
| 178 |
+
filestorage.LMDBFileStorage(self.lmdb_storage, read_only=True)
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class CSVDatasetTriplet(torch.utils.data.Dataset):
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
csv_path: pathlib.Path,
|
| 186 |
+
csv_root_path: pathlib.Path,
|
| 187 |
+
split: str,
|
| 188 |
+
transform,
|
| 189 |
+
path_column: str = "image",
|
| 190 |
+
split_column: str = "split",
|
| 191 |
+
class_column: str = "class",
|
| 192 |
+
lmdb_storage: Optional[pathlib.Path] = None
|
| 193 |
+
):
|
| 194 |
+
super().__init__()
|
| 195 |
+
self.csv_path: pathlib.Path = csv_path
|
| 196 |
+
self.csv_root_path: pathlib.Path = csv_root_path
|
| 197 |
+
self.split: str = split
|
| 198 |
+
self.path_column: str = path_column
|
| 199 |
+
self.split_column: str = split_column
|
| 200 |
+
self.class_column: str = class_column
|
| 201 |
+
self.transform = transform
|
| 202 |
+
self.lmdb_storage: Optional[pathlib.Path] = lmdb_storage
|
| 203 |
+
|
| 204 |
+
# Reader to be used for data loading. Its creation is deferred
|
| 205 |
+
self.data_reader: Optional[readers.DataReader] = None
|
| 206 |
+
|
| 207 |
+
if split not in ["train", "val", "test"]:
|
| 208 |
+
raise RuntimeError(f"Unsupported split: {split}")
|
| 209 |
+
|
| 210 |
+
# Path of the CSV file is expected to be absolute.
|
| 211 |
+
reader = readers.FileSystemReader(pathlib.Path("/"))
|
| 212 |
+
self.entries: list[dict[str, Any]] = reader.read_csv_file(str(self.csv_path))
|
| 213 |
+
self.entries = [e for e in self.entries if e[self.split_column] == self.split]
|
| 214 |
+
|
| 215 |
+
self.num_classes: int = len(
|
| 216 |
+
collections.Counter([e[self.class_column] for e in self.entries]).keys()
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
# Save paths that will be accessed by different dataloaders as numpy arrays in
|
| 220 |
+
# order to avoid copy-on-read of python objects, and thus child processes to
|
| 221 |
+
# take huge amounts of memory.
|
| 222 |
+
self.anchor_v: Optional[np.ndarray] = None
|
| 223 |
+
self.anchor_o: Optional[np.ndarray] = None
|
| 224 |
+
self.positive_v: Optional[np.ndarray] = None
|
| 225 |
+
self.positive_o: Optional[np.ndarray] = None
|
| 226 |
+
self.negative_v: Optional[np.ndarray] = None
|
| 227 |
+
self.negative_o: Optional[np.ndarray] = None
|
| 228 |
+
self.triplets_num: Optional[int] = None
|
| 229 |
+
self.generate_triplets()
|
| 230 |
+
|
| 231 |
+
def __len__(self) -> int:
|
| 232 |
+
return self.triplets_num
|
| 233 |
+
|
| 234 |
+
def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 235 |
+
"""Returns the triplet with the specified index.
|
| 236 |
+
|
| 237 |
+
:returns: A tuple in the form of (anchor_img, positive_img, negative_img)
|
| 238 |
+
"""
|
| 239 |
+
# Defer the creation of the data reader until the first read operation in order to
|
| 240 |
+
# properly handle the spawning of multiple processes by DataLoader, where each one
|
| 241 |
+
# should contain a separate reader object.
|
| 242 |
+
if self.data_reader is None:
|
| 243 |
+
self._create_data_reader()
|
| 244 |
+
|
| 245 |
+
anchor_path: str = sequence_to_string(unpack_sequence(self.anchor_v, self.anchor_o, idx))
|
| 246 |
+
positive_path: str = sequence_to_string(
|
| 247 |
+
unpack_sequence(self.positive_v, self.positive_o, idx))
|
| 248 |
+
negative_path: str = sequence_to_string(
|
| 249 |
+
unpack_sequence(self.negative_v, self.negative_o, idx))
|
| 250 |
+
|
| 251 |
+
anchor_img_obj: Image.Image = self.data_reader.load_image(anchor_path, channels=3)
|
| 252 |
+
positive_img_obj: Image.Image = self.data_reader.load_image(positive_path, channels=3)
|
| 253 |
+
negative_img_obj: Image.Image = self.data_reader.load_image(negative_path, channels=3)
|
| 254 |
+
|
| 255 |
+
anchor_img: np.ndarray = np.array(anchor_img_obj)
|
| 256 |
+
positive_img: np.ndarray = np.array(positive_img_obj)
|
| 257 |
+
negative_img: np.ndarray = np.array(negative_img_obj)
|
| 258 |
+
|
| 259 |
+
anchor_img_obj.close()
|
| 260 |
+
positive_img_obj.close()
|
| 261 |
+
negative_img_obj.close()
|
| 262 |
+
|
| 263 |
+
return (self.transform(image=anchor_img)["image"],
|
| 264 |
+
self.transform(image=positive_img)["image"],
|
| 265 |
+
self.transform(image=negative_img)["image"])
|
| 266 |
+
|
| 267 |
+
def get_classes_num(self) -> int:
|
| 268 |
+
return self.num_classes
|
| 269 |
+
|
| 270 |
+
def get_dataset_root_path(self) -> pathlib.Path:
|
| 271 |
+
if self.lmdb_storage is not None:
|
| 272 |
+
return self.lmdb_storage
|
| 273 |
+
else:
|
| 274 |
+
return self.csv_root_path
|
| 275 |
+
|
| 276 |
+
def generate_triplets(self) -> None:
|
| 277 |
+
# Separate the entries into groups of each class.
|
| 278 |
+
entries_per_class: dict[int, list[dict[str, Any]]] = {
|
| 279 |
+
i: [] for i in range(self.num_classes)
|
| 280 |
+
}
|
| 281 |
+
for e in self.entries:
|
| 282 |
+
entries_per_class[int(e[self.class_column])].append(e)
|
| 283 |
+
|
| 284 |
+
triplets: list[tuple[dict[str,Any], dict[str, Any], dict[str, Any]]] = []
|
| 285 |
+
for class_id, class_group in entries_per_class.items():
|
| 286 |
+
class_group: list[dict[str, Any]] = list(class_group)
|
| 287 |
+
rest_groups: list[list[dict[str, Any]]] = list(entries_per_class.values())
|
| 288 |
+
del rest_groups[class_id]
|
| 289 |
+
|
| 290 |
+
for i, e in enumerate(class_group):
|
| 291 |
+
negative_sample: dict[str, Any] = random.choice(random.choice(rest_groups))
|
| 292 |
+
|
| 293 |
+
positive_sample: dict[str, Any] = random.choice(class_group)
|
| 294 |
+
while e == positive_sample:
|
| 295 |
+
positive_sample: dict[str, Any] = random.choice(class_group)
|
| 296 |
+
|
| 297 |
+
triplets.append((e, positive_sample, negative_sample))
|
| 298 |
+
|
| 299 |
+
self.anchor_v, self.anchor_o = pack_sequences(
|
| 300 |
+
[string_to_sequence(t[0][self.path_column]) for t in triplets]
|
| 301 |
+
)
|
| 302 |
+
self.positive_v, self.positive_o = pack_sequences(
|
| 303 |
+
[string_to_sequence(t[1][self.path_column]) for t in triplets]
|
| 304 |
+
)
|
| 305 |
+
self.negative_v, self.negative_o = pack_sequences(
|
| 306 |
+
[string_to_sequence(t[2][self.path_column]) for t in triplets]
|
| 307 |
+
)
|
| 308 |
+
self.anchor_labels: np.ndarray = np.array([int(t[0][self.class_column]) for t in triplets])
|
| 309 |
+
self.triplets_num = len(triplets)
|
| 310 |
+
|
| 311 |
+
def _create_data_reader(self) -> None:
|
| 312 |
+
# Limit the number of OpenCV threads to 2 to utilize multiple processes. Otherwise,
|
| 313 |
+
# each process spawns a number of threads equal to the number of logical cores and
|
| 314 |
+
# the overall performance gets worse due to threads congestion.
|
| 315 |
+
cv2.setNumThreads(1)
|
| 316 |
+
|
| 317 |
+
if self.lmdb_storage is None:
|
| 318 |
+
self.data_reader: readers.FileSystemReader = readers.FileSystemReader(
|
| 319 |
+
pathlib.Path(self.csv_root_path)
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
self.data_reader: readers.LMDBFileStorageReader = readers.LMDBFileStorageReader(
|
| 323 |
+
filestorage.LMDBFileStorage(self.lmdb_storage, read_only=True)
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def build_loader_finetune(config, logger):
|
| 328 |
+
config.defrost()
|
| 329 |
+
dataset_train, config.MODEL.NUM_CLASSES = build_dataset(
|
| 330 |
+
config.DATA.DATA_PATH,
|
| 331 |
+
config.DATA.CSV_ROOT,
|
| 332 |
+
config=config,
|
| 333 |
+
split_name="train",
|
| 334 |
+
logger=logger
|
| 335 |
+
)
|
| 336 |
+
config.freeze()
|
| 337 |
+
dataset_val, _ = build_dataset(
|
| 338 |
+
config.DATA.DATA_PATH,
|
| 339 |
+
config.DATA.CSV_ROOT,
|
| 340 |
+
config=config,
|
| 341 |
+
split_name="val",
|
| 342 |
+
logger=logger
|
| 343 |
+
)
|
| 344 |
+
logger.info(f"Train images: {len(dataset_train)} | Validation images: {len(dataset_val)}")
|
| 345 |
+
logger.info(f"Train Images Source: {dataset_train.get_dataset_root_path()}")
|
| 346 |
+
logger.info(f"Validation Images Source: {dataset_val.get_dataset_root_path()}")
|
| 347 |
+
|
| 348 |
+
data_loader_train = DataLoader(
|
| 349 |
+
dataset_train,
|
| 350 |
+
batch_size=config.DATA.BATCH_SIZE,
|
| 351 |
+
num_workers=config.DATA.NUM_WORKERS,
|
| 352 |
+
pin_memory=config.DATA.PIN_MEMORY,
|
| 353 |
+
drop_last=True,
|
| 354 |
+
shuffle=True,
|
| 355 |
+
prefetch_factor=config.DATA.PREFETCH_FACTOR
|
| 356 |
+
)
|
| 357 |
+
data_loader_val = DataLoader(
|
| 358 |
+
dataset_val,
|
| 359 |
+
batch_size=config.DATA.VAL_BATCH_SIZE or config.DATA.BATCH_SIZE,
|
| 360 |
+
num_workers=config.DATA.NUM_WORKERS,
|
| 361 |
+
pin_memory=config.DATA.PIN_MEMORY,
|
| 362 |
+
drop_last=False,
|
| 363 |
+
prefetch_factor=config.DATA.VAL_PREFETCH_FACTOR or config.DATA.PREFETCH_FACTOR,
|
| 364 |
+
collate_fn=(torch.utils.data.default_collate
|
| 365 |
+
if not config.MODEL.RESOLUTION_MODE == "arbitrary"
|
| 366 |
+
else image_enlisting_collate_fn)
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# Setup mixup / cutmix
|
| 370 |
+
mixup_fn = None
|
| 371 |
+
mixup_active: bool = (config.AUG.MIXUP > 0
|
| 372 |
+
or config.AUG.CUTMIX > 0.
|
| 373 |
+
or config.AUG.CUTMIX_MINMAX is not None)
|
| 374 |
+
if mixup_active:
|
| 375 |
+
mixup_fn = Mixup(
|
| 376 |
+
mixup_alpha=config.AUG.MIXUP,
|
| 377 |
+
cutmix_alpha=config.AUG.CUTMIX,
|
| 378 |
+
cutmix_minmax=config.AUG.CUTMIX_MINMAX,
|
| 379 |
+
prob=config.AUG.MIXUP_PROB,
|
| 380 |
+
switch_prob=config.AUG.MIXUP_SWITCH_PROB,
|
| 381 |
+
mode=config.AUG.MIXUP_MODE,
|
| 382 |
+
label_smoothing=config.MODEL.LABEL_SMOOTHING,
|
| 383 |
+
num_classes=config.MODEL.NUM_CLASSES
|
| 384 |
+
)
|
| 385 |
+
|
| 386 |
+
return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def build_loader_test(
|
| 390 |
+
config,
|
| 391 |
+
logger,
|
| 392 |
+
split: str = "test",
|
| 393 |
+
dummy_csv_dir: Optional[pathlib.Path] = None,
|
| 394 |
+
) -> tuple[list[str], list[torch.utils.data.Dataset], list[torch.utils.data.DataLoader]]:
|
| 395 |
+
# Obtain the root directory for each test input (either a CSV file or a directory).
|
| 396 |
+
input_root_paths: list[pathlib.Path]
|
| 397 |
+
if len(config.DATA.TEST_DATA_CSV_ROOT) > 1:
|
| 398 |
+
input_root_paths = [pathlib.Path(p) for p in config.DATA.TEST_DATA_CSV_ROOT]
|
| 399 |
+
elif len(config.DATA.TEST_DATA_CSV_ROOT) == 1:
|
| 400 |
+
input_root_paths = [pathlib.Path(config.DATA.TEST_DATA_CSV_ROOT[0])
|
| 401 |
+
for _ in config.DATA.TEST_DATA_PATH]
|
| 402 |
+
else:
|
| 403 |
+
input_root_paths = [pathlib.Path(input_path).parent
|
| 404 |
+
for input_path in config.DATA.TEST_DATA_PATH]
|
| 405 |
+
|
| 406 |
+
# If some input is a directory, create a dummy csv file for it.
|
| 407 |
+
csv_paths: list[pathlib.Path] = []
|
| 408 |
+
csv_root_paths: list[pathlib.Path] = []
|
| 409 |
+
for input_path, input_root_path in zip(config.DATA.TEST_DATA_PATH, input_root_paths):
|
| 410 |
+
input_path: pathlib.Path = pathlib.Path(input_path)
|
| 411 |
+
if input_path.is_dir():
|
| 412 |
+
# Create a dummy csv and point directories
|
| 413 |
+
if dummy_csv_dir is None:
|
| 414 |
+
dummy_csv_dir = pathlib.Path("./outputs")
|
| 415 |
+
entries: list[dict[str, str]] = [
|
| 416 |
+
{
|
| 417 |
+
"image": str(file_path.name),
|
| 418 |
+
"split": split,
|
| 419 |
+
"class": "1" # TODO: Remove csv requirement for dummy ground-truth.
|
| 420 |
+
}
|
| 421 |
+
for file_path in input_path.iterdir() if filetype.is_image(file_path)
|
| 422 |
+
]
|
| 423 |
+
dummy_csv_path: pathlib.Path = dummy_csv_dir / f"{input_path.stem}.csv"
|
| 424 |
+
data_utils.write_csv_file(entries, dummy_csv_path, delimiter=",")
|
| 425 |
+
csv_paths.append(dummy_csv_path.absolute())
|
| 426 |
+
csv_root_paths.append(input_path.absolute()) # Paths in CSV are relative to input dir.
|
| 427 |
+
else:
|
| 428 |
+
csv_paths.append(input_path.absolute())
|
| 429 |
+
csv_root_paths.append(input_root_path.absolute())
|
| 430 |
+
|
| 431 |
+
# Obtain the separate testing sets and their names.
|
| 432 |
+
test_datasets: list[CSVDataset] = []
|
| 433 |
+
test_datasets_names: list[str] = []
|
| 434 |
+
num_classes_per_dataset: list[int] = []
|
| 435 |
+
for csv_path, csv_root_path in zip(csv_paths, csv_root_paths):
|
| 436 |
+
csv_path: pathlib.Path = pathlib.Path(csv_path)
|
| 437 |
+
dataset: CSVDataset
|
| 438 |
+
dataset, num_classes = build_dataset(csv_path, csv_root_path, config, split, logger)
|
| 439 |
+
test_datasets.append(dataset)
|
| 440 |
+
test_datasets_names.append(csv_path.stem)
|
| 441 |
+
num_classes_per_dataset.append(num_classes)
|
| 442 |
+
# Check that the number of classes match among all test sets.
|
| 443 |
+
unique_number_of_classes: list[int] = list(collections.Counter(num_classes_per_dataset).keys())
|
| 444 |
+
if len(unique_number_of_classes) > 1:
|
| 445 |
+
raise RuntimeError(
|
| 446 |
+
f"Encountered different number of classes among test sets: {unique_number_of_classes}"
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
for dataset, dataset_name in zip(test_datasets, test_datasets_names):
|
| 450 |
+
logger.info(f"Dataset \'{dataset_name}\' | Split: {split} | Total images: {len(dataset)} | "
|
| 451 |
+
f"Source: {dataset.get_dataset_root_path()}")
|
| 452 |
+
|
| 453 |
+
# Create the corresponding data loaders.
|
| 454 |
+
force_cpu: bool = os.environ.get("SPAI_FORCE_CPU", "0") == "1"
|
| 455 |
+
cuda_visible: str = os.environ.get("CUDA_VISIBLE_DEVICES", "")
|
| 456 |
+
use_cuda: bool = (not force_cpu) and (cuda_visible != "")
|
| 457 |
+
test_num_workers: int = config.DATA.NUM_WORKERS if use_cuda else min(config.DATA.NUM_WORKERS, 2)
|
| 458 |
+
test_pin_memory: bool = config.DATA.PIN_MEMORY if use_cuda else False
|
| 459 |
+
|
| 460 |
+
test_data_loaders: list[torch.utils.data.DataLoader] = [
|
| 461 |
+
DataLoader(
|
| 462 |
+
dataset,
|
| 463 |
+
batch_size=config.DATA.TEST_BATCH_SIZE or config.DATA.BATCH_SIZE,
|
| 464 |
+
num_workers=test_num_workers,
|
| 465 |
+
pin_memory=test_pin_memory,
|
| 466 |
+
drop_last=False,
|
| 467 |
+
prefetch_factor=config.DATA.TEST_PREFETCH_FACTOR or config.DATA.PREFETCH_FACTOR,
|
| 468 |
+
collate_fn=(torch.utils.data.default_collate
|
| 469 |
+
if not config.MODEL.RESOLUTION_MODE == "arbitrary"
|
| 470 |
+
else image_enlisting_collate_fn)
|
| 471 |
+
)
|
| 472 |
+
for dataset in test_datasets
|
| 473 |
+
]
|
| 474 |
+
|
| 475 |
+
return test_datasets_names, test_datasets, test_data_loaders
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def build_dataset(
|
| 479 |
+
csv_path: pathlib.Path,
|
| 480 |
+
csv_root_dir: pathlib.Path,
|
| 481 |
+
config,
|
| 482 |
+
split_name: str,
|
| 483 |
+
logger,
|
| 484 |
+
) -> tuple[Union[CSVDataset, CSVDatasetTriplet], int]:
|
| 485 |
+
if split_name not in ["train", "val", "test"]:
|
| 486 |
+
raise RuntimeError(f"Unsupported split: {split_name}")
|
| 487 |
+
|
| 488 |
+
transform = build_transform(split_name == "train", config)
|
| 489 |
+
logger.info(f"Data transform | mode: {config.TRAIN.MODE} | split: {split_name}:\n{transform}")
|
| 490 |
+
|
| 491 |
+
if split_name == "train" and config.TRAIN.LOSS == "triplet":
|
| 492 |
+
dataset = CSVDatasetTriplet(
|
| 493 |
+
csv_path,
|
| 494 |
+
csv_root_dir,
|
| 495 |
+
split=split_name,
|
| 496 |
+
transform=transform,
|
| 497 |
+
lmdb_storage=pathlib.Path(config.DATA.LMDB_PATH) if config.DATA.LMDB_PATH else None
|
| 498 |
+
)
|
| 499 |
+
elif split_name == "train" and config.TRAIN.LOSS == "supcont":
|
| 500 |
+
assert config.DATA.AUGMENTED_VIEWS > 1, "SupCon loss requires at least 2 views."
|
| 501 |
+
dataset = CSVDataset(
|
| 502 |
+
csv_path,
|
| 503 |
+
csv_root_dir,
|
| 504 |
+
split=split_name,
|
| 505 |
+
transform=transform,
|
| 506 |
+
views=config.DATA.AUGMENTED_VIEWS,
|
| 507 |
+
lmdb_storage=pathlib.Path(config.DATA.LMDB_PATH) if config.DATA.LMDB_PATH else None
|
| 508 |
+
)
|
| 509 |
+
elif split_name == "train" and config.MODEL.RESOLUTION_MODE == "arbitrary":
|
| 510 |
+
dataset = CSVDataset(
|
| 511 |
+
csv_path,
|
| 512 |
+
csv_root_dir,
|
| 513 |
+
split=split_name,
|
| 514 |
+
transform=transform,
|
| 515 |
+
views=config.DATA.AUGMENTED_VIEWS,
|
| 516 |
+
concatenate_views_horizontally=True,
|
| 517 |
+
lmdb_storage=pathlib.Path(config.DATA.LMDB_PATH) if config.DATA.LMDB_PATH else None
|
| 518 |
+
)
|
| 519 |
+
else:
|
| 520 |
+
views_generator: Optional[Callable[[Image.Image], tuple[Image.Image, ...]]]
|
| 521 |
+
if config.TEST.VIEWS_GENERATION_APPROACH == "tencrop":
|
| 522 |
+
def safe_ten_crop(img: Image.Image) -> tuple[Image.Image, ...]:
|
| 523 |
+
width = img.width
|
| 524 |
+
height = img.height
|
| 525 |
+
left_padding: int = max((config.DATA.IMG_SIZE - width) // 2, 0)
|
| 526 |
+
right_padding: int = max(
|
| 527 |
+
(config.DATA.IMG_SIZE - width) // 2
|
| 528 |
+
+ (((config.DATA.IMG_SIZE - width) % 2) if config.DATA.IMG_SIZE > width else 0),
|
| 529 |
+
0
|
| 530 |
+
)
|
| 531 |
+
top_padding: int = max((config.DATA.IMG_SIZE - height) // 2, 0)
|
| 532 |
+
bottom_padding: int = max(
|
| 533 |
+
(config.DATA.IMG_SIZE - height) // 2
|
| 534 |
+
+ (((config.DATA.IMG_SIZE - height) % 2) if config.DATA.IMG_SIZE > height else 0),
|
| 535 |
+
0
|
| 536 |
+
)
|
| 537 |
+
img = pad(img, [left_padding, top_padding, right_padding, bottom_padding])
|
| 538 |
+
return ten_crop(img, size=config.DATA.IMG_SIZE)
|
| 539 |
+
|
| 540 |
+
views_generator = safe_ten_crop
|
| 541 |
+
elif config.TEST.VIEWS_GENERATION_APPROACH is None:
|
| 542 |
+
views_generator = None
|
| 543 |
+
else:
|
| 544 |
+
raise TypeError(f"{config.TEST.VIEW_GENERATION_APPROACH} is not a supported "
|
| 545 |
+
f"view generation approach.")
|
| 546 |
+
|
| 547 |
+
dataset = CSVDataset(
|
| 548 |
+
csv_path,
|
| 549 |
+
csv_root_dir,
|
| 550 |
+
split=split_name,
|
| 551 |
+
transform=transform,
|
| 552 |
+
lmdb_storage=pathlib.Path(config.DATA.LMDB_PATH) if config.DATA.LMDB_PATH else None,
|
| 553 |
+
views_generator=views_generator
|
| 554 |
+
)
|
| 555 |
+
num_classes: int = dataset.get_classes_num()
|
| 556 |
+
|
| 557 |
+
return dataset, num_classes
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def build_transform(is_train, config) -> Callable[[np.ndarray], np.ndarray]:
|
| 561 |
+
# resize_im: bool = config.DATA.IMG_SIZE > 32
|
| 562 |
+
# # this should always dispatch to transforms_imagenet_train
|
| 563 |
+
# transform = create_transform(
|
| 564 |
+
# input_size=config.DATA.IMG_SIZE,
|
| 565 |
+
# is_training=True,
|
| 566 |
+
# color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
|
| 567 |
+
# auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
|
| 568 |
+
# re_prob=config.AUG.REPROB,
|
| 569 |
+
# re_mode=config.AUG.REMODE,
|
| 570 |
+
# re_count=config.AUG.RECOUNT,
|
| 571 |
+
# interpolation=config.DATA.INTERPOLATION,
|
| 572 |
+
# )
|
| 573 |
+
# if not resize_im:
|
| 574 |
+
# # replace RandomResizedCropAndInterpolation with
|
| 575 |
+
# # RandomCrop
|
| 576 |
+
# transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
|
| 577 |
+
# transform.transforms.insert(0, torchvision.transforms.v2.JPEG((50, 100)))
|
| 578 |
+
# transform.transforms.insert(4, torchvision.transforms.GaussianBlur(kernel_size=(3, 9), sigma=(0.01, 0.5)))
|
| 579 |
+
|
| 580 |
+
if is_train: # Training augmentations
|
| 581 |
+
transforms_list = []
|
| 582 |
+
|
| 583 |
+
if config.AUG.MIN_CROP_AREA == config.AUG.MAX_CROP_AREA:
|
| 584 |
+
transforms_list.append(
|
| 585 |
+
A.PadIfNeeded(min_height=config.DATA.IMG_SIZE, min_width=config.DATA.IMG_SIZE)
|
| 586 |
+
)
|
| 587 |
+
transforms_list.append(
|
| 588 |
+
A.RandomCrop(height=config.DATA.IMG_SIZE, width=config.DATA.IMG_SIZE)
|
| 589 |
+
)
|
| 590 |
+
else:
|
| 591 |
+
transforms_list.append(
|
| 592 |
+
A.RandomResizedCrop(size=(config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
|
| 593 |
+
scale=(config.AUG.MIN_CROP_AREA, config.AUG.MAX_CROP_AREA))
|
| 594 |
+
)
|
| 595 |
+
transforms_list.extend([
|
| 596 |
+
A.HorizontalFlip(p=config.AUG.HORIZONTAL_FLIP_PROB),
|
| 597 |
+
A.VerticalFlip(p=config.AUG.VERTICAL_FLIP_PROB),
|
| 598 |
+
A.Rotate(limit=config.AUG.ROTATION_DEGREES,
|
| 599 |
+
crop_border=True,
|
| 600 |
+
p=config.AUG.ROTATION_PROB)
|
| 601 |
+
])
|
| 602 |
+
if config.AUG.ROTATION_PROB > .0:
|
| 603 |
+
# Rotation with crop_border set to True leads to images smaller than the target
|
| 604 |
+
# size. So, restore the target size.
|
| 605 |
+
transforms_list.append(
|
| 606 |
+
A.Resize(height=config.DATA.IMG_SIZE, width=config.DATA.IMG_SIZE)
|
| 607 |
+
)
|
| 608 |
+
transforms_list.extend([
|
| 609 |
+
A.GaussianBlur(blur_limit=(3, 9),
|
| 610 |
+
sigma_limit=(0.01, 0.5),
|
| 611 |
+
p=config.AUG.GAUSSIAN_BLUR_PROB),
|
| 612 |
+
A.GaussNoise(p=config.AUG.GAUSSIAN_NOISE_PROB),
|
| 613 |
+
A.ColorJitter(
|
| 614 |
+
p=config.AUG.COLOR_JITTER,
|
| 615 |
+
brightness=config.AUG.COLOR_JITTER_BRIGHTNESS_RANGE,
|
| 616 |
+
contrast=config.AUG.COLOR_JITTER_CONTRAST_RANGE,
|
| 617 |
+
saturation=config.AUG.COLOR_JITTER_SATURATION_RANGE,
|
| 618 |
+
hue=config.AUG.COLOR_JITTER_HUE_RANGE,
|
| 619 |
+
),
|
| 620 |
+
A.Sharpen(p=config.AUG.SHARPEN_PROB,
|
| 621 |
+
alpha=config.AUG.SHARPEN_ALPHA_RANGE,
|
| 622 |
+
lightness=config.AUG.SHARPEN_LIGHTNESS_RANGE),
|
| 623 |
+
A.ImageCompression(quality_lower=config.AUG.JPEG_MIN_QUALITY,
|
| 624 |
+
quality_upper=config.AUG.JPEG_MAX_QUALITY,
|
| 625 |
+
compression_type=ImageCompressionType.JPEG,
|
| 626 |
+
p=config.AUG.JPEG_COMPRESSION_PROB),
|
| 627 |
+
A.ImageCompression(quality_lower=config.AUG.WEBP_MIN_QUALITY,
|
| 628 |
+
quality_upper=config.AUG.WEBP_MAX_QUALITY,
|
| 629 |
+
compression_type=ImageCompressionType.WEBP,
|
| 630 |
+
p=config.AUG.WEBP_COMPRESSION_PROB),
|
| 631 |
+
])
|
| 632 |
+
if config.MODEL.REQUIRED_NORMALIZATION == "imagenet":
|
| 633 |
+
transforms_list.append(
|
| 634 |
+
A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
|
| 635 |
+
)
|
| 636 |
+
elif config.MODEL.REQUIRED_NORMALIZATION == "positive_0_1":
|
| 637 |
+
transforms_list.append(
|
| 638 |
+
A.Normalize(mean=0., std=1.)
|
| 639 |
+
)
|
| 640 |
+
else:
|
| 641 |
+
raise RuntimeError(f"Unsupported Normalization: {config.MODEL.REQUIRED_NORMALIZATION}")
|
| 642 |
+
transforms_list.append(ToTensorV2())
|
| 643 |
+
transform = A.Compose(transforms_list)
|
| 644 |
+
|
| 645 |
+
else: # Inference augmentations
|
| 646 |
+
transforms_list = [
|
| 647 |
+
A.ImageCompression(quality_lower=config.TEST.JPEG_QUALITY,
|
| 648 |
+
quality_upper=config.TEST.JPEG_QUALITY,
|
| 649 |
+
compression_type=ImageCompressionType.JPEG,
|
| 650 |
+
p=1.0 if config.TEST.JPEG_COMPRESSION else .0),
|
| 651 |
+
A.ImageCompression(quality_lower=config.TEST.WEBP_QUALITY,
|
| 652 |
+
quality_upper=config.TEST.WEBP_QUALITY,
|
| 653 |
+
compression_type=ImageCompressionType.WEBP,
|
| 654 |
+
p=1.0 if config.TEST.WEBP_COMPRESSION else .0),
|
| 655 |
+
A.GaussianBlur(blur_limit=(config.TEST.GAUSSIAN_BLUR_KERNEL_SIZE,
|
| 656 |
+
config.TEST.GAUSSIAN_BLUR_KERNEL_SIZE),
|
| 657 |
+
sigma_limit=0,
|
| 658 |
+
p=1.0 if config.TEST.GAUSSIAN_BLUR else .0),
|
| 659 |
+
A.GaussNoise(var_limit=(config.TEST.GAUSSIAN_NOISE_SIGMA**2,
|
| 660 |
+
config.TEST.GAUSSIAN_NOISE_SIGMA**2),
|
| 661 |
+
p=1.0 if config.TEST.GAUSSIAN_NOISE else .0),
|
| 662 |
+
A.RandomScale(scale_limit=(config.TEST.SCALE_FACTOR-1, config.TEST.SCALE_FACTOR-1),
|
| 663 |
+
p=1.0 if config.TEST.SCALE else .0)
|
| 664 |
+
]
|
| 665 |
+
if config.TEST.MAX_SIZE is not None:
|
| 666 |
+
transforms_list.append(A.SmallestMaxSize(max_size=config.TEST.MAX_SIZE))
|
| 667 |
+
|
| 668 |
+
if config.TEST.ORIGINAL_RESOLUTION:
|
| 669 |
+
transforms_list.append(A.PadIfNeeded(min_height=config.DATA.IMG_SIZE,
|
| 670 |
+
min_width=config.DATA.IMG_SIZE))
|
| 671 |
+
elif config.TEST.CROP:
|
| 672 |
+
transforms_list.append(A.PadIfNeeded(min_height=config.DATA.IMG_SIZE,
|
| 673 |
+
min_width=config.DATA.IMG_SIZE))
|
| 674 |
+
transforms_list.append(A.CenterCrop(height=config.DATA.IMG_SIZE,
|
| 675 |
+
width=config.DATA.IMG_SIZE))
|
| 676 |
+
else:
|
| 677 |
+
transforms_list.append(A.Resize(config.DATA.IMG_SIZE, config.DATA.IMG_SIZE))
|
| 678 |
+
if config.MODEL.REQUIRED_NORMALIZATION == "imagenet":
|
| 679 |
+
transforms_list.append(A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD))
|
| 680 |
+
elif config.MODEL.REQUIRED_NORMALIZATION == "positive_0_1":
|
| 681 |
+
transforms_list.append(A.Normalize(mean=0., std=1.))
|
| 682 |
+
else:
|
| 683 |
+
raise RuntimeError(f"Unsupported Normalization: {config.MODEL.REQUIRED_NORMALIZATION}")
|
| 684 |
+
transforms_list.append(ToTensorV2())
|
| 685 |
+
transform = A.Compose(transforms_list)
|
| 686 |
+
|
| 687 |
+
return transform
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def string_to_sequence(s: str, dtype=np.int32) -> np.ndarray:
|
| 691 |
+
return np.array([ord(c) for c in s], dtype=dtype)
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
def sequence_to_string(seq: np.ndarray) -> str:
|
| 695 |
+
return ''.join([chr(c) for c in seq])
|
| 696 |
+
|
| 697 |
+
|
| 698 |
+
def pack_sequences(seqs: Union[np.ndarray, list]) -> (np.ndarray, np.ndarray):
|
| 699 |
+
values = np.concatenate(seqs, axis=0)
|
| 700 |
+
offsets = np.cumsum([len(s) for s in seqs])
|
| 701 |
+
return values, offsets
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def unpack_sequence(values: np.ndarray, offsets: np.ndarray, index: int) -> np.ndarray:
|
| 705 |
+
off1 = offsets[index]
|
| 706 |
+
if index > 0:
|
| 707 |
+
off0 = offsets[index - 1]
|
| 708 |
+
elif index == 0:
|
| 709 |
+
off0 = 0
|
| 710 |
+
else:
|
| 711 |
+
raise ValueError(index)
|
| 712 |
+
return values[off0:off1]
|
| 713 |
+
|
| 714 |
+
|
| 715 |
+
def image_enlisting_collate_fn(
|
| 716 |
+
batch: Iterable[tuple[torch.Tensor, np.ndarray, int]]
|
| 717 |
+
) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]:
|
| 718 |
+
"""Collate function that enlists its entries."""
|
| 719 |
+
return (
|
| 720 |
+
[torch.utils.data.default_collate([s[0]]) for s in batch],
|
| 721 |
+
torch.utils.data.default_collate([s[1] for s in batch]),
|
| 722 |
+
torch.utils.data.default_collate([s[2] for s in batch]),
|
| 723 |
+
)
|
spai/data/data_mfm.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.distributed as dist
|
| 4 |
+
import torchvision.transforms as T
|
| 5 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 6 |
+
from torch.utils.data._utils.collate import default_collate
|
| 7 |
+
from torchvision.datasets import ImageFolder
|
| 8 |
+
from timm.data.transforms import _pil_interp
|
| 9 |
+
|
| 10 |
+
from .random_degradations import RandomBlur, RandomNoise
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FreqMaskGenerator:
|
| 14 |
+
def __init__(self,
|
| 15 |
+
input_size=224,
|
| 16 |
+
mask_radius1=16,
|
| 17 |
+
mask_radius2=999,
|
| 18 |
+
sample_ratio=0.5):
|
| 19 |
+
self.input_size = input_size
|
| 20 |
+
self.mask_radius1 = mask_radius1
|
| 21 |
+
self.mask_radius2 = mask_radius2
|
| 22 |
+
self.sample_ratio = sample_ratio
|
| 23 |
+
self.mask = np.ones((self.input_size, self.input_size), dtype=int)
|
| 24 |
+
for y in range(self.input_size):
|
| 25 |
+
for x in range(self.input_size):
|
| 26 |
+
if ((x - self.input_size // 2) ** 2 + (y - self.input_size // 2) ** 2) >= self.mask_radius1 ** 2 \
|
| 27 |
+
and ((x - self.input_size // 2) ** 2 + (y - self.input_size // 2) ** 2) < self.mask_radius2 ** 2:
|
| 28 |
+
self.mask[y, x] = 0
|
| 29 |
+
|
| 30 |
+
def __call__(self):
|
| 31 |
+
rnd = torch.bernoulli(torch.tensor(self.sample_ratio, dtype=torch.float)).item()
|
| 32 |
+
if rnd == 0: # high-pass
|
| 33 |
+
return 1 - self.mask
|
| 34 |
+
elif rnd == 1: # low-pass
|
| 35 |
+
return self.mask
|
| 36 |
+
else:
|
| 37 |
+
raise ValueError
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class MFMTransform:
|
| 41 |
+
def __init__(self, config):
|
| 42 |
+
self.transform_img = T.Compose([
|
| 43 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
| 44 |
+
T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(config.DATA.MIN_CROP_SCALE, 1.), interpolation=_pil_interp(config.DATA.INTERPOLATION)),
|
| 45 |
+
T.RandomHorizontalFlip(),
|
| 46 |
+
])
|
| 47 |
+
|
| 48 |
+
self.filter_type = config.DATA.FILTER_TYPE
|
| 49 |
+
|
| 50 |
+
if config.MODEL.TYPE == 'swin':
|
| 51 |
+
model_patch_size = config.MODEL.SWIN.PATCH_SIZE
|
| 52 |
+
elif config.MODEL.TYPE == 'vit':
|
| 53 |
+
model_patch_size = config.MODEL.VIT.PATCH_SIZE
|
| 54 |
+
elif config.MODEL.TYPE == 'resnet':
|
| 55 |
+
model_patch_size = 1
|
| 56 |
+
else:
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
if config.DATA.FILTER_TYPE == 'deblur':
|
| 60 |
+
self.degrade_transform = RandomBlur(
|
| 61 |
+
params=dict(
|
| 62 |
+
kernel_size=config.DATA.BLUR.KERNEL_SIZE,
|
| 63 |
+
kernel_list=config.DATA.BLUR.KERNEL_LIST,
|
| 64 |
+
kernel_prob=config.DATA.BLUR.KERNEL_PROB,
|
| 65 |
+
sigma_x=config.DATA.BLUR.SIGMA_X,
|
| 66 |
+
sigma_y=config.DATA.BLUR.SIGMA_Y,
|
| 67 |
+
rotate_angle=config.DATA.BLUR.ROTATE_ANGLE,
|
| 68 |
+
beta_gaussian=config.DATA.BLUR.BETA_GAUSSIAN,
|
| 69 |
+
beta_plateau=config.DATA.BLUR.BETA_PLATEAU),
|
| 70 |
+
)
|
| 71 |
+
elif config.DATA.FILTER_TYPE == 'denoise':
|
| 72 |
+
self.degrade_transform = RandomNoise(
|
| 73 |
+
params=dict(
|
| 74 |
+
noise_type=config.DATA.NOISE.TYPE,
|
| 75 |
+
noise_prob=config.DATA.NOISE.PROB,
|
| 76 |
+
gaussian_sigma=config.DATA.NOISE.GAUSSIAN_SIGMA,
|
| 77 |
+
gaussian_gray_noise_prob=config.DATA.NOISE.GAUSSIAN_GRAY_NOISE_PROB,
|
| 78 |
+
poisson_scale=config.DATA.NOISE.POISSON_SCALE,
|
| 79 |
+
poisson_gray_noise_prob=config.DATA.NOISE.POISSON_GRAY_NOISE_PROB),
|
| 80 |
+
)
|
| 81 |
+
elif config.DATA.FILTER_TYPE == 'mfm':
|
| 82 |
+
self.freq_mask_generator = FreqMaskGenerator(
|
| 83 |
+
input_size=config.DATA.IMG_SIZE,
|
| 84 |
+
mask_radius1=config.DATA.MASK_RADIUS1,
|
| 85 |
+
mask_radius2=config.DATA.MASK_RADIUS2,
|
| 86 |
+
sample_ratio=config.DATA.SAMPLE_RATIO
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def __call__(self, img):
|
| 90 |
+
img = self.transform_img(img) # PIL Image (HxWxC, 0-255), no normalization
|
| 91 |
+
if self.filter_type in ['deblur', 'denoise']:
|
| 92 |
+
img_lq = np.array(img).astype(np.float32) / 255.
|
| 93 |
+
img_lq = self.degrade_transform(img_lq)
|
| 94 |
+
img_lq = torch.from_numpy(img_lq.transpose(2, 0, 1))
|
| 95 |
+
else:
|
| 96 |
+
img_lq = None
|
| 97 |
+
img = T.ToTensor()(img) # Tensor (CxHxW, 0-1)
|
| 98 |
+
if self.filter_type == 'mfm':
|
| 99 |
+
mask = self.freq_mask_generator()
|
| 100 |
+
else:
|
| 101 |
+
mask = None
|
| 102 |
+
|
| 103 |
+
return img, img_lq, mask
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def collate_fn(batch):
|
| 107 |
+
if not isinstance(batch[0][0], tuple):
|
| 108 |
+
return default_collate(batch)
|
| 109 |
+
else:
|
| 110 |
+
batch_num = len(batch)
|
| 111 |
+
ret = []
|
| 112 |
+
for item_idx in range(len(batch[0][0])):
|
| 113 |
+
if batch[0][0][item_idx] is None:
|
| 114 |
+
ret.append(None)
|
| 115 |
+
else:
|
| 116 |
+
ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
|
| 117 |
+
ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
|
| 118 |
+
return ret
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def build_loader_mfm(config, logger):
|
| 122 |
+
transform = MFMTransform(config)
|
| 123 |
+
logger.info(f'Pre-train data transform:\n{transform}')
|
| 124 |
+
|
| 125 |
+
dataset = ImageFolder(config.DATA.DATA_PATH, transform)
|
| 126 |
+
logger.info(f'Build dataset: train images = {len(dataset)}')
|
| 127 |
+
|
| 128 |
+
sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
|
| 129 |
+
dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
|
| 130 |
+
|
| 131 |
+
return dataloader
|
spai/data/filestorage.py
ADDED
|
@@ -0,0 +1,387 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
|
| 2 |
+
# and University of Amsterdam. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import csv
|
| 18 |
+
import hashlib
|
| 19 |
+
import io
|
| 20 |
+
import logging
|
| 21 |
+
import pathlib
|
| 22 |
+
from collections import Counter
|
| 23 |
+
from typing import Union, Optional
|
| 24 |
+
|
| 25 |
+
import click
|
| 26 |
+
import lmdb
|
| 27 |
+
import tqdm
|
| 28 |
+
import networkx as nx
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
__version__: str = "0.1.0-alpha"
|
| 32 |
+
__revision__: int = 2
|
| 33 |
+
__author__: str = "Dimitrios Karageorgiou"
|
| 34 |
+
__email__: str = "dkarageo@iti.gr"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class LMDBFileStorage:
|
| 38 |
+
"""A file storage for handling large datasets based on LMDB."""
|
| 39 |
+
def __init__(self,
|
| 40 |
+
db_path: pathlib.Path,
|
| 41 |
+
map_size: int = 1024*1024*1024*1024, # 1TB
|
| 42 |
+
read_only: bool = False,
|
| 43 |
+
max_readers: int = 128):
|
| 44 |
+
self.db: lmdb.Environment = lmdb.open(
|
| 45 |
+
str(db_path),
|
| 46 |
+
map_size=map_size,
|
| 47 |
+
subdir=False,
|
| 48 |
+
readonly=read_only,
|
| 49 |
+
max_readers=max_readers,
|
| 50 |
+
lock=False,
|
| 51 |
+
sync=False
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def open_file(self, file_id: str, mode: str = "r") -> Union[io.TextIOWrapper, io.BytesIO]:
|
| 55 |
+
"""Returns a file-like stream of a file in the database."""
|
| 56 |
+
# with self.db.begin() as trans:
|
| 57 |
+
# data: bytes = trans.get(file_id.encode("ascii"))
|
| 58 |
+
with self.db.begin(buffers=True) as trans:
|
| 59 |
+
data = trans.get(file_id.encode("utf-8"))
|
| 60 |
+
stream: io.BytesIO = io.BytesIO(data)
|
| 61 |
+
|
| 62 |
+
if mode == "r":
|
| 63 |
+
reader: io.TextIOWrapper = io.TextIOWrapper(stream)
|
| 64 |
+
elif mode == "b":
|
| 65 |
+
reader: io.BytesIO = stream
|
| 66 |
+
else:
|
| 67 |
+
raise RuntimeError(f"Unsupported file mode: '{mode}'. Only 'r' and 'b' are supported.")
|
| 68 |
+
|
| 69 |
+
return reader
|
| 70 |
+
|
| 71 |
+
def write_file(self, file_id: str, file_data: bytes) -> None:
|
| 72 |
+
with self.db.begin(write=True) as trans:
|
| 73 |
+
trans.put(file_id.encode("utf-8"), file_data)
|
| 74 |
+
|
| 75 |
+
def get_all_ids(self) -> list[str]:
|
| 76 |
+
with self.db.begin() as trans:
|
| 77 |
+
cursor = trans.cursor()
|
| 78 |
+
ids: list[str] = [k for k, _ in cursor]
|
| 79 |
+
return ids
|
| 80 |
+
|
| 81 |
+
def close(self) -> None:
|
| 82 |
+
self.db.close()
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@click.group()
|
| 86 |
+
def cli() -> None:
|
| 87 |
+
pass
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@cli.command()
|
| 91 |
+
@click.option("-c", "--csv_file", required=True,
|
| 92 |
+
type=click.Path(dir_okay=False, exists=True, path_type=pathlib.Path),
|
| 93 |
+
help="Path to a CSV file containing relative paths to the dataset files.")
|
| 94 |
+
@click.option("-b", "--base_dir",
|
| 95 |
+
type=click.Path(file_okay=False, exists=True, path_type=pathlib.Path),
|
| 96 |
+
help="Base directory of the dataset. Paths inside the CSV should be relative "
|
| 97 |
+
"to that path. When not provided, the directory of the CSV file is "
|
| 98 |
+
"considered as the base directory.")
|
| 99 |
+
@click.option("-o", "--output_file", required=True,
|
| 100 |
+
type=click.Path(dir_okay=False, path_type=pathlib.Path),
|
| 101 |
+
help="Path to the database. If the file does not "
|
| 102 |
+
"exist, a new database is generated. Otherwise, it should point to a "
|
| 103 |
+
"previous instance of the LMDB, where data will be added.")
|
| 104 |
+
def add_csv(
|
| 105 |
+
csv_file: pathlib.Path,
|
| 106 |
+
base_dir: Optional[pathlib.Path],
|
| 107 |
+
output_file: pathlib.Path
|
| 108 |
+
) -> None:
|
| 109 |
+
if base_dir is None:
|
| 110 |
+
base_dir = csv_file.parent
|
| 111 |
+
db: LMDBFileStorage = LMDBFileStorage(output_file)
|
| 112 |
+
add_csv_to_db(csv_file, db, base_dir)
|
| 113 |
+
db.close()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@cli.command()
|
| 117 |
+
@click.option("-s", "--src", required=True,
|
| 118 |
+
type=click.Path(dir_okay=False, path_type=pathlib.Path, exists=True),
|
| 119 |
+
help="Database whose files will be added to the destination database.")
|
| 120 |
+
@click.option("-d", "--dest", required=True,
|
| 121 |
+
type=click.Path(dir_okay=False, path_type=pathlib.Path),
|
| 122 |
+
help="Database where file from source database will be added.")
|
| 123 |
+
def add_db(
|
| 124 |
+
src: pathlib.Path,
|
| 125 |
+
dest: pathlib.Path
|
| 126 |
+
) -> None:
|
| 127 |
+
"""Adds all the contents of a database to another."""
|
| 128 |
+
src_db: LMDBFileStorage = LMDBFileStorage(src, read_only=True)
|
| 129 |
+
dest_db: LMDBFileStorage = LMDBFileStorage(dest)
|
| 130 |
+
|
| 131 |
+
for k in tqdm.tqdm(src_db.get_all_ids(), desc="Copying files", unit="file"):
|
| 132 |
+
k = str(k, 'UTF-8')
|
| 133 |
+
src_file: io.BytesIO = src_db.open_file(k, mode="b")
|
| 134 |
+
dest_db.write_file(k, src_file.read())
|
| 135 |
+
|
| 136 |
+
src_db.close()
|
| 137 |
+
dest_db.close()
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
@cli.command()
|
| 141 |
+
@click.option("-c", "--csv_file", required=True,
|
| 142 |
+
type=click.Path(dir_okay=False, exists=True, path_type=pathlib.Path),
|
| 143 |
+
help="Path to a CSV file containing relative paths to the dataset files.")
|
| 144 |
+
@click.option("-b", "--base_dir",
|
| 145 |
+
type=click.Path(file_okay=False, exists=True, path_type=pathlib.Path),
|
| 146 |
+
help="Base directory of the dataset. Paths inside the CSV should be relative "
|
| 147 |
+
"to that path. When not provided, the directory of the CSV file is "
|
| 148 |
+
"considered as the base directory.")
|
| 149 |
+
@click.option("-o", "--output_file", required=True,
|
| 150 |
+
type=click.Path(dir_okay=False, path_type=pathlib.Path, exists=True),
|
| 151 |
+
help="Path to the database to verify.")
|
| 152 |
+
def verify_csv(
|
| 153 |
+
csv_file: pathlib.Path,
|
| 154 |
+
base_dir: Optional[pathlib.Path],
|
| 155 |
+
output_file: pathlib.Path
|
| 156 |
+
) -> None:
|
| 157 |
+
if base_dir is None:
|
| 158 |
+
base_dir = csv_file.parent
|
| 159 |
+
db: LMDBFileStorage = LMDBFileStorage(output_file, read_only=True)
|
| 160 |
+
verify_csv_in_db(csv_file, db, base_dir)
|
| 161 |
+
db.close()
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
@cli.command()
|
| 165 |
+
@click.option("-d", "--database", required=True,
|
| 166 |
+
type=click.Path(dir_okay=False, path_type=pathlib.Path, exists=True),
|
| 167 |
+
help="Database whose keys will be printed.")
|
| 168 |
+
@click.option("-h", "--hierarchical", is_flag=True,
|
| 169 |
+
help="List files in DB according to directories hierarchy.")
|
| 170 |
+
def list_db(
|
| 171 |
+
database: pathlib.Path,
|
| 172 |
+
hierarchical: bool
|
| 173 |
+
) -> None:
|
| 174 |
+
"""Lists the contents of a file storage."""
|
| 175 |
+
db: LMDBFileStorage = LMDBFileStorage(database, read_only=True)
|
| 176 |
+
|
| 177 |
+
if not hierarchical:
|
| 178 |
+
for k in db.get_all_ids():
|
| 179 |
+
print(k)
|
| 180 |
+
else:
|
| 181 |
+
# The db contains filenames as keys, so their parents will always be the dir names.
|
| 182 |
+
ids: list[str] = [str(pathlib.Path(str(k, 'UTF-8')).parent) for k in db.get_all_ids()]
|
| 183 |
+
counts: Counter = Counter(ids)
|
| 184 |
+
|
| 185 |
+
dir_graph: nx.DiGraph = nx.DiGraph()
|
| 186 |
+
for k in counts.keys():
|
| 187 |
+
dir_graph.add_edge(str(pathlib.Path(k).parent), k)
|
| 188 |
+
dir_graph.nodes[k]["items_num"] = counts[k]
|
| 189 |
+
|
| 190 |
+
top_level_nodes: list[str] = [n for n in dir_graph.nodes if dir_graph.in_degree(n) == 0]
|
| 191 |
+
top_level_nodes = sorted(top_level_nodes)
|
| 192 |
+
|
| 193 |
+
for n in top_level_nodes:
|
| 194 |
+
print_dirs_from_graph(dir_graph, n)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def add_csv_to_db(
|
| 198 |
+
csv_file: pathlib.Path,
|
| 199 |
+
db: LMDBFileStorage,
|
| 200 |
+
base_dir: pathlib.Path,
|
| 201 |
+
key_base_dir: Optional[pathlib.Path] = None,
|
| 202 |
+
verbose: bool = True
|
| 203 |
+
) -> int:
|
| 204 |
+
"""Adds the contents of the file paths included in a CSV file into an LMDB File Storage.
|
| 205 |
+
|
| 206 |
+
Paths of the files, relative to the base dir, are utilized as keys into the storage.
|
| 207 |
+
Thus, the maximum allowed path length is 511 bytes.
|
| 208 |
+
|
| 209 |
+
The contents of nested CSV files are recursively added into the LMDB File Storage.
|
| 210 |
+
In that case, keys represent the file structure relative to the base dir.
|
| 211 |
+
|
| 212 |
+
:param csv_file: Path to a CSV file describing a dataset.
|
| 213 |
+
:param db: An instance of LMDB File Storage, where files will be added.
|
| 214 |
+
:param base_dir: Directory where paths included into the CSV file are relative to.
|
| 215 |
+
:param key_base_dir: Directory where paths encoded into the keys of the LMDB File
|
| 216 |
+
Storage will be relative to. It should be either the same or an upper directory
|
| 217 |
+
compared to base dir. When this argument is omitted, the value of base dir
|
| 218 |
+
is used.
|
| 219 |
+
:param verbose: When set to False, progress messages will not be printed.
|
| 220 |
+
"""
|
| 221 |
+
entries: list[dict[str, str]] = read_csv_file(csv_file, verbose=verbose)
|
| 222 |
+
|
| 223 |
+
if key_base_dir is None:
|
| 224 |
+
key_base_dir = base_dir
|
| 225 |
+
|
| 226 |
+
if verbose:
|
| 227 |
+
pbar = tqdm.tqdm(entries, desc="Writing CSV data to database", unit="file")
|
| 228 |
+
else:
|
| 229 |
+
pbar = entries
|
| 230 |
+
|
| 231 |
+
files_written: int = 0
|
| 232 |
+
for e in pbar:
|
| 233 |
+
# Generate key-path pairs for each path in the CSV.
|
| 234 |
+
files_to_write: dict[str, pathlib.Path] = find_files(
|
| 235 |
+
list(e.values()),
|
| 236 |
+
base_dir,
|
| 237 |
+
key_base_dir
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
files_written += write_files_to_db(files_to_write, db)
|
| 241 |
+
|
| 242 |
+
# Recursively add the contents of the encountered CSV files.
|
| 243 |
+
for p in files_to_write.values():
|
| 244 |
+
if p.suffix == ".csv":
|
| 245 |
+
files_written += add_csv_to_db(
|
| 246 |
+
p, db, p.parent, key_base_dir=key_base_dir, verbose=False
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if verbose:
|
| 250 |
+
pbar.set_postfix({"Files Written": files_written})
|
| 251 |
+
|
| 252 |
+
return files_written
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def verify_csv_in_db(
|
| 256 |
+
csv_file: pathlib.Path,
|
| 257 |
+
db: LMDBFileStorage,
|
| 258 |
+
base_dir: pathlib.Path,
|
| 259 |
+
key_base_dir: Optional[pathlib.Path] = None,
|
| 260 |
+
verbose: bool = True
|
| 261 |
+
) -> int:
|
| 262 |
+
entries: list[dict[str, str]] = read_csv_file(csv_file, verbose=verbose)
|
| 263 |
+
|
| 264 |
+
if key_base_dir is None:
|
| 265 |
+
key_base_dir = base_dir
|
| 266 |
+
|
| 267 |
+
if verbose:
|
| 268 |
+
pbar = tqdm.tqdm(entries, desc="Verifying CSV data in database", unit="file")
|
| 269 |
+
else:
|
| 270 |
+
pbar = entries
|
| 271 |
+
|
| 272 |
+
files_verified: int = 0
|
| 273 |
+
for e in pbar:
|
| 274 |
+
# Generate key-path pairs for each path in the CSV.
|
| 275 |
+
files: dict[str, pathlib.Path] = find_files(
|
| 276 |
+
list(e.values()),
|
| 277 |
+
base_dir,
|
| 278 |
+
key_base_dir
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
files_verified += verify_files_in_db(files, db)
|
| 282 |
+
|
| 283 |
+
# Recursively verify the contents of the encountered CSV files.
|
| 284 |
+
for p in files.values():
|
| 285 |
+
if p.suffix == ".csv":
|
| 286 |
+
files_verified += verify_csv_in_db(
|
| 287 |
+
p, db, p.parent, key_base_dir=key_base_dir, verbose=False
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
if verbose:
|
| 291 |
+
pbar.set_postfix({"Files Verified": files_verified})
|
| 292 |
+
|
| 293 |
+
return files_verified
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def find_files(
|
| 297 |
+
candidates: list[str],
|
| 298 |
+
base_dir: pathlib.Path,
|
| 299 |
+
key_base_dir: pathlib.Path
|
| 300 |
+
) -> dict[str, pathlib.Path]:
|
| 301 |
+
files: dict[str, pathlib.Path] = {}
|
| 302 |
+
for c in candidates:
|
| 303 |
+
p: pathlib.Path = base_dir / c
|
| 304 |
+
key: str = str(p.relative_to(key_base_dir))
|
| 305 |
+
if p.exists() and p.is_file():
|
| 306 |
+
files[key] = p
|
| 307 |
+
return files
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def write_files_to_db(files: dict[str, pathlib.Path], db: LMDBFileStorage) -> int:
|
| 311 |
+
for k, p in files.items():
|
| 312 |
+
data: bytes = read_raw_file(p)
|
| 313 |
+
db.write_file(k, data)
|
| 314 |
+
return len(files)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def verify_files_in_db(files: dict[str, pathlib.Path], db: LMDBFileStorage) -> int:
|
| 318 |
+
verified: int = 0
|
| 319 |
+
for k, p in files.items():
|
| 320 |
+
# Calculate md5 hash of the file in csv.
|
| 321 |
+
with p.open("rb") as f:
|
| 322 |
+
csv_file_hash: str = md5(f)
|
| 323 |
+
# Calculate md5 hash of the file in db.
|
| 324 |
+
db_file: io.BytesIO = db.open_file(k, mode="b")
|
| 325 |
+
db_file_hash: str = md5(db_file)
|
| 326 |
+
if csv_file_hash == db_file_hash:
|
| 327 |
+
verified += 1
|
| 328 |
+
else:
|
| 329 |
+
logging.error(f"File in DB not matching file in CSV: {str(p)}")
|
| 330 |
+
return verified
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
def read_csv_file(csv_file: pathlib.Path, verbose: bool = True) -> list[dict[str, str]]:
|
| 334 |
+
# Read the whole csv file.
|
| 335 |
+
if verbose:
|
| 336 |
+
logging.info(f"READING CSV: {str(csv_file)}")
|
| 337 |
+
|
| 338 |
+
entries: list[dict[str, str]] = []
|
| 339 |
+
with csv_file.open() as f:
|
| 340 |
+
reader: csv.DictReader = csv.DictReader(f, delimiter=",")
|
| 341 |
+
if verbose:
|
| 342 |
+
pbar = tqdm.tqdm(reader, desc="Reading CSV entries", unit="entry")
|
| 343 |
+
else:
|
| 344 |
+
pbar = reader
|
| 345 |
+
for row in pbar:
|
| 346 |
+
entries.append(row)
|
| 347 |
+
|
| 348 |
+
if verbose:
|
| 349 |
+
logging.info(f"TOTAL ENTRIES: {len(entries)}")
|
| 350 |
+
|
| 351 |
+
return entries
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
def read_raw_file(p: pathlib.Path) -> bytes:
|
| 355 |
+
with p.open("rb") as f:
|
| 356 |
+
data: bytes = f.read()
|
| 357 |
+
return data
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def md5(stream) -> str:
|
| 361 |
+
"""Calculates md5 hash of a file-like stream."""
|
| 362 |
+
hash_md5 = hashlib.md5()
|
| 363 |
+
for chunk in iter(lambda: stream.read(4096), b""):
|
| 364 |
+
hash_md5.update(chunk)
|
| 365 |
+
return hash_md5.hexdigest()
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def print_dirs_from_graph(g: nx.DiGraph, n: str, depth: int = 0) -> None:
|
| 369 |
+
if depth > 0:
|
| 370 |
+
init_text: str = " " * (depth - 1) + "|.. "
|
| 371 |
+
else:
|
| 372 |
+
init_text: str = ""
|
| 373 |
+
text: str = f"{init_text}{n}"
|
| 374 |
+
print(text)
|
| 375 |
+
|
| 376 |
+
for s in sorted(g.successors(n)):
|
| 377 |
+
print_dirs_from_graph(g, s, depth+1)
|
| 378 |
+
|
| 379 |
+
if "items_num" in g.nodes[n]:
|
| 380 |
+
init_text = " " * (depth+1)
|
| 381 |
+
text = f"{init_text}({g.nodes[n]['items_num']} files)"
|
| 382 |
+
print(text)
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if __name__ == "__main__":
|
| 386 |
+
logging.getLogger().setLevel(logging.INFO)
|
| 387 |
+
cli()
|
spai/data/random_degradations.py
ADDED
|
@@ -0,0 +1,462 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
from . import blur_kernels as blur_kernels
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class RandomBlur:
|
| 12 |
+
"""Apply random blur to the input.
|
| 13 |
+
|
| 14 |
+
Modified keys are the attributed specified in "keys".
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
params (dict): A dictionary specifying the degradation settings.
|
| 18 |
+
keys (list[str]): A list specifying the keys whose values are
|
| 19 |
+
modified.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, params):
|
| 23 |
+
self.params = params
|
| 24 |
+
|
| 25 |
+
def get_kernel(self, num_kernels):
|
| 26 |
+
kernel_type = random.choices(
|
| 27 |
+
self.params['kernel_list'], weights=self.params['kernel_prob'])[0]
|
| 28 |
+
kernel_size = random.choice(self.params['kernel_size'])
|
| 29 |
+
|
| 30 |
+
sigma_x_range = self.params.get('sigma_x', [0, 0])
|
| 31 |
+
sigma_x = random.uniform(sigma_x_range[0], sigma_x_range[1])
|
| 32 |
+
sigma_x_step = self.params.get('sigma_x_step', 0)
|
| 33 |
+
|
| 34 |
+
sigma_y_range = self.params.get('sigma_y', [0, 0])
|
| 35 |
+
sigma_y = random.uniform(sigma_y_range[0], sigma_y_range[1])
|
| 36 |
+
sigma_y_step = self.params.get('sigma_y_step', 0)
|
| 37 |
+
|
| 38 |
+
rotate_angle_range = self.params.get('rotate_angle', [-np.pi, np.pi])
|
| 39 |
+
rotate_angle = random.uniform(rotate_angle_range[0],
|
| 40 |
+
rotate_angle_range[1])
|
| 41 |
+
rotate_angle_step = self.params.get('rotate_angle_step', 0)
|
| 42 |
+
|
| 43 |
+
beta_gau_range = self.params.get('beta_gaussian', [0.5, 4])
|
| 44 |
+
beta_gau = random.uniform(beta_gau_range[0], beta_gau_range[1])
|
| 45 |
+
beta_gau_step = self.params.get('beta_gaussian_step', 0)
|
| 46 |
+
|
| 47 |
+
beta_pla_range = self.params.get('beta_plateau', [1, 2])
|
| 48 |
+
beta_pla = random.uniform(beta_pla_range[0], beta_pla_range[1])
|
| 49 |
+
beta_pla_step = self.params.get('beta_plateau_step', 0)
|
| 50 |
+
|
| 51 |
+
omega_range = self.params.get('omega', None)
|
| 52 |
+
omega_step = self.params.get('omega_step', 0)
|
| 53 |
+
if omega_range is None: # follow Real-ESRGAN settings if not specified
|
| 54 |
+
if kernel_size < 13:
|
| 55 |
+
omega_range = [np.pi / 3., np.pi]
|
| 56 |
+
else:
|
| 57 |
+
omega_range = [np.pi / 5., np.pi]
|
| 58 |
+
omega = random.uniform(omega_range[0], omega_range[1])
|
| 59 |
+
|
| 60 |
+
# determine blurring kernel
|
| 61 |
+
kernels = []
|
| 62 |
+
for _ in range(0, num_kernels):
|
| 63 |
+
kernel = blur_kernels.random_mixed_kernels(
|
| 64 |
+
[kernel_type],
|
| 65 |
+
[1],
|
| 66 |
+
kernel_size,
|
| 67 |
+
[sigma_x, sigma_x],
|
| 68 |
+
[sigma_y, sigma_y],
|
| 69 |
+
[rotate_angle, rotate_angle],
|
| 70 |
+
[beta_gau, beta_gau],
|
| 71 |
+
[beta_pla, beta_pla],
|
| 72 |
+
[omega, omega],
|
| 73 |
+
None,
|
| 74 |
+
)
|
| 75 |
+
kernels.append(kernel)
|
| 76 |
+
|
| 77 |
+
# update kernel parameters
|
| 78 |
+
sigma_x += random.uniform(-sigma_x_step, sigma_x_step)
|
| 79 |
+
sigma_y += random.uniform(-sigma_y_step, sigma_y_step)
|
| 80 |
+
rotate_angle += random.uniform(-rotate_angle_step,
|
| 81 |
+
rotate_angle_step)
|
| 82 |
+
beta_gau += random.uniform(-beta_gau_step, beta_gau_step)
|
| 83 |
+
beta_pla += random.uniform(-beta_pla_step, beta_pla_step)
|
| 84 |
+
omega += random.uniform(-omega_step, omega_step)
|
| 85 |
+
|
| 86 |
+
sigma_x = np.clip(sigma_x, sigma_x_range[0], sigma_x_range[1])
|
| 87 |
+
sigma_y = np.clip(sigma_y, sigma_y_range[0], sigma_y_range[1])
|
| 88 |
+
rotate_angle = np.clip(rotate_angle, rotate_angle_range[0],
|
| 89 |
+
rotate_angle_range[1])
|
| 90 |
+
beta_gau = np.clip(beta_gau, beta_gau_range[0], beta_gau_range[1])
|
| 91 |
+
beta_pla = np.clip(beta_pla, beta_pla_range[0], beta_pla_range[1])
|
| 92 |
+
omega = np.clip(omega, omega_range[0], omega_range[1])
|
| 93 |
+
|
| 94 |
+
return kernels
|
| 95 |
+
|
| 96 |
+
def _apply_random_blur(self, imgs):
|
| 97 |
+
is_single_image = False
|
| 98 |
+
if isinstance(imgs, np.ndarray):
|
| 99 |
+
is_single_image = True
|
| 100 |
+
imgs = [imgs]
|
| 101 |
+
|
| 102 |
+
# get kernel and blur the input
|
| 103 |
+
kernels = self.get_kernel(num_kernels=len(imgs))
|
| 104 |
+
imgs = [
|
| 105 |
+
cv2.filter2D(img, -1, kernel)
|
| 106 |
+
for img, kernel in zip(imgs, kernels)
|
| 107 |
+
]
|
| 108 |
+
|
| 109 |
+
if is_single_image:
|
| 110 |
+
imgs = imgs[0]
|
| 111 |
+
|
| 112 |
+
return imgs
|
| 113 |
+
|
| 114 |
+
def __call__(self, results):
|
| 115 |
+
if random.random() > self.params.get('prob', 1):
|
| 116 |
+
return results
|
| 117 |
+
|
| 118 |
+
results = self._apply_random_blur(results)
|
| 119 |
+
|
| 120 |
+
return results
|
| 121 |
+
|
| 122 |
+
def __repr__(self):
|
| 123 |
+
repr_str = self.__class__.__name__
|
| 124 |
+
repr_str += (f'(params={self.params})')
|
| 125 |
+
return repr_str
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class RandomResize:
|
| 129 |
+
"""Randomly resize the input.
|
| 130 |
+
|
| 131 |
+
Modified keys are the attributed specified in "keys".
|
| 132 |
+
|
| 133 |
+
Args:
|
| 134 |
+
params (dict): A dictionary specifying the degradation settings.
|
| 135 |
+
keys (list[str]): A list specifying the keys whose values are
|
| 136 |
+
modified.
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
def __init__(self, params):
|
| 140 |
+
self.params = params
|
| 141 |
+
|
| 142 |
+
self.resize_dict = dict(
|
| 143 |
+
bilinear=cv2.INTER_LINEAR,
|
| 144 |
+
bicubic=cv2.INTER_CUBIC,
|
| 145 |
+
area=cv2.INTER_AREA,
|
| 146 |
+
lanczos=cv2.INTER_LANCZOS4)
|
| 147 |
+
|
| 148 |
+
def _random_resize(self, imgs):
|
| 149 |
+
is_single_image = False
|
| 150 |
+
if isinstance(imgs, np.ndarray):
|
| 151 |
+
is_single_image = True
|
| 152 |
+
imgs = [imgs]
|
| 153 |
+
|
| 154 |
+
h, w = imgs[0].shape[:2]
|
| 155 |
+
|
| 156 |
+
resize_opt = self.params['resize_opt']
|
| 157 |
+
resize_prob = self.params['resize_prob']
|
| 158 |
+
resize_opt = random.choices(resize_opt, weights=resize_prob)[0].lower()
|
| 159 |
+
if resize_opt not in self.resize_dict:
|
| 160 |
+
raise NotImplementedError(f'resize_opt [{resize_opt}] is not '
|
| 161 |
+
'implemented')
|
| 162 |
+
resize_opt = self.resize_dict[resize_opt]
|
| 163 |
+
|
| 164 |
+
resize_step = self.params.get('resize_step', 0)
|
| 165 |
+
|
| 166 |
+
# determine the target size, if not provided
|
| 167 |
+
target_size = self.params.get('target_size', None)
|
| 168 |
+
if target_size is None:
|
| 169 |
+
resize_mode = random.choices(['up', 'down', 'keep'],
|
| 170 |
+
weights=self.params['resize_mode_prob'])[0]
|
| 171 |
+
resize_scale = self.params['resize_scale']
|
| 172 |
+
if resize_mode == 'up':
|
| 173 |
+
scale_factor = random.uniform(1, resize_scale[1])
|
| 174 |
+
elif resize_mode == 'down':
|
| 175 |
+
scale_factor = random.uniform(resize_scale[0], 1)
|
| 176 |
+
else:
|
| 177 |
+
scale_factor = 1
|
| 178 |
+
|
| 179 |
+
# determine output size
|
| 180 |
+
h_out, w_out = h * scale_factor, w * scale_factor
|
| 181 |
+
if self.params.get('is_size_even', False):
|
| 182 |
+
h_out, w_out = 2 * (h_out // 2), 2 * (w_out // 2)
|
| 183 |
+
target_size = (int(h_out), int(w_out))
|
| 184 |
+
else:
|
| 185 |
+
resize_step = 0
|
| 186 |
+
|
| 187 |
+
# resize the input
|
| 188 |
+
if resize_step == 0: # same target_size for all input images
|
| 189 |
+
outputs = [
|
| 190 |
+
cv2.resize(img, target_size[::-1], interpolation=resize_opt)
|
| 191 |
+
for img in imgs
|
| 192 |
+
]
|
| 193 |
+
else: # different target_size for each input image
|
| 194 |
+
outputs = []
|
| 195 |
+
for img in imgs:
|
| 196 |
+
img = cv2.resize(
|
| 197 |
+
img, target_size[::-1], interpolation=resize_opt)
|
| 198 |
+
outputs.append(img)
|
| 199 |
+
|
| 200 |
+
# update scale
|
| 201 |
+
scale_factor += random.uniform(-resize_step, resize_step)
|
| 202 |
+
scale_factor = np.clip(scale_factor, resize_scale[0],
|
| 203 |
+
resize_scale[1])
|
| 204 |
+
|
| 205 |
+
# determine output size
|
| 206 |
+
h_out, w_out = h * scale_factor, w * scale_factor
|
| 207 |
+
if self.params.get('is_size_even', False):
|
| 208 |
+
h_out, w_out = 2 * (h_out // 2), 2 * (w_out // 2)
|
| 209 |
+
target_size = (int(h_out), int(w_out))
|
| 210 |
+
|
| 211 |
+
if is_single_image:
|
| 212 |
+
outputs = outputs[0]
|
| 213 |
+
|
| 214 |
+
return outputs
|
| 215 |
+
|
| 216 |
+
def __call__(self, results):
|
| 217 |
+
if random.random() > self.params.get('prob', 1):
|
| 218 |
+
return results
|
| 219 |
+
|
| 220 |
+
results = self._random_resize(results)
|
| 221 |
+
|
| 222 |
+
return results
|
| 223 |
+
|
| 224 |
+
def __repr__(self):
|
| 225 |
+
repr_str = self.__class__.__name__
|
| 226 |
+
repr_str += (f'(params={self.params})')
|
| 227 |
+
return repr_str
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class RandomNoise:
|
| 231 |
+
"""Apply random noise to the input.
|
| 232 |
+
|
| 233 |
+
Currently support Gaussian noise and Poisson noise.
|
| 234 |
+
|
| 235 |
+
Modified keys are the attributed specified in "keys".
|
| 236 |
+
|
| 237 |
+
Args:
|
| 238 |
+
params (dict): A dictionary specifying the degradation settings.
|
| 239 |
+
keys (list[str]): A list specifying the keys whose values are
|
| 240 |
+
modified.
|
| 241 |
+
"""
|
| 242 |
+
|
| 243 |
+
def __init__(self, params):
|
| 244 |
+
self.params = params
|
| 245 |
+
|
| 246 |
+
def _apply_gaussian_noise(self, imgs):
|
| 247 |
+
sigma_range = self.params['gaussian_sigma']
|
| 248 |
+
sigma = random.uniform(sigma_range[0], sigma_range[1]) / 255.
|
| 249 |
+
|
| 250 |
+
sigma_step = self.params.get('gaussian_sigma_step', 0)
|
| 251 |
+
|
| 252 |
+
gray_noise_prob = self.params['gaussian_gray_noise_prob']
|
| 253 |
+
is_gray_noise = random.random() < gray_noise_prob
|
| 254 |
+
|
| 255 |
+
outputs = []
|
| 256 |
+
for img in imgs:
|
| 257 |
+
noise = torch.randn(*(img.shape)).numpy() * sigma
|
| 258 |
+
if is_gray_noise:
|
| 259 |
+
noise = noise[:, :, :1]
|
| 260 |
+
outputs.append(img + noise)
|
| 261 |
+
|
| 262 |
+
# update noise level
|
| 263 |
+
sigma += random.uniform(-sigma_step, sigma_step) / 255.
|
| 264 |
+
sigma = np.clip(sigma, sigma_range[0] / 255.,
|
| 265 |
+
sigma_range[1] / 255.)
|
| 266 |
+
|
| 267 |
+
return outputs
|
| 268 |
+
|
| 269 |
+
def _apply_poisson_noise(self, imgs):
|
| 270 |
+
scale_range = self.params['poisson_scale']
|
| 271 |
+
scale = random.uniform(scale_range[0], scale_range[1])
|
| 272 |
+
|
| 273 |
+
scale_step = self.params.get('poisson_scale_step', 0)
|
| 274 |
+
|
| 275 |
+
gray_noise_prob = self.params['poisson_gray_noise_prob']
|
| 276 |
+
is_gray_noise = random.random() < gray_noise_prob
|
| 277 |
+
|
| 278 |
+
outputs = []
|
| 279 |
+
for img in imgs:
|
| 280 |
+
noise = img.copy()
|
| 281 |
+
if is_gray_noise:
|
| 282 |
+
noise = cv2.cvtColor(noise[..., [2, 1, 0]], cv2.COLOR_BGR2GRAY)
|
| 283 |
+
noise = noise[..., np.newaxis]
|
| 284 |
+
noise = np.clip((noise * 255.0).round(), 0, 255) / 255.
|
| 285 |
+
unique_val = 2**np.ceil(np.log2(len(np.unique(noise))))
|
| 286 |
+
noise = torch.poisson(torch.from_numpy(noise * unique_val)).numpy() / unique_val - noise
|
| 287 |
+
|
| 288 |
+
outputs.append(img + noise * scale)
|
| 289 |
+
|
| 290 |
+
# update noise level
|
| 291 |
+
scale += random.uniform(-scale_step, scale_step)
|
| 292 |
+
scale = np.clip(scale, scale_range[0], scale_range[1])
|
| 293 |
+
|
| 294 |
+
return outputs
|
| 295 |
+
|
| 296 |
+
def _apply_random_noise(self, imgs):
|
| 297 |
+
noise_type = random.choices(
|
| 298 |
+
self.params['noise_type'], weights=self.params['noise_prob'])[0]
|
| 299 |
+
|
| 300 |
+
is_single_image = False
|
| 301 |
+
if isinstance(imgs, np.ndarray):
|
| 302 |
+
is_single_image = True
|
| 303 |
+
imgs = [imgs]
|
| 304 |
+
|
| 305 |
+
if noise_type.lower() == 'gaussian':
|
| 306 |
+
imgs = self._apply_gaussian_noise(imgs)
|
| 307 |
+
elif noise_type.lower() == 'poisson':
|
| 308 |
+
imgs = self._apply_poisson_noise(imgs)
|
| 309 |
+
else:
|
| 310 |
+
raise NotImplementedError(f'"noise_type" [{noise_type}] is '
|
| 311 |
+
'not implemented.')
|
| 312 |
+
|
| 313 |
+
if is_single_image:
|
| 314 |
+
imgs = imgs[0]
|
| 315 |
+
|
| 316 |
+
return imgs
|
| 317 |
+
|
| 318 |
+
def __call__(self, results):
|
| 319 |
+
if random.random() > self.params.get('prob', 1):
|
| 320 |
+
return results
|
| 321 |
+
|
| 322 |
+
results = self._apply_random_noise(results)
|
| 323 |
+
|
| 324 |
+
return results
|
| 325 |
+
|
| 326 |
+
def __repr__(self):
|
| 327 |
+
repr_str = self.__class__.__name__
|
| 328 |
+
repr_str += (f'(params={self.params})')
|
| 329 |
+
return repr_str
|
| 330 |
+
|
| 331 |
+
|
| 332 |
+
class RandomJPEGCompression:
|
| 333 |
+
"""Apply random JPEG compression to the input.
|
| 334 |
+
|
| 335 |
+
Modified keys are the attributed specified in "keys".
|
| 336 |
+
|
| 337 |
+
Args:
|
| 338 |
+
params (dict): A dictionary specifying the degradation settings.
|
| 339 |
+
keys (list[str]): A list specifying the keys whose values are
|
| 340 |
+
modified.
|
| 341 |
+
"""
|
| 342 |
+
|
| 343 |
+
def __init__(self, params):
|
| 344 |
+
self.params = params
|
| 345 |
+
|
| 346 |
+
def _apply_random_compression(self, imgs):
|
| 347 |
+
is_single_image = False
|
| 348 |
+
if isinstance(imgs, np.ndarray):
|
| 349 |
+
is_single_image = True
|
| 350 |
+
imgs = [imgs]
|
| 351 |
+
|
| 352 |
+
# determine initial compression level and the step size
|
| 353 |
+
quality = self.params['quality']
|
| 354 |
+
quality_step = self.params.get('quality_step', 0)
|
| 355 |
+
jpeg_param = round(random.uniform(quality[0], quality[1]))
|
| 356 |
+
|
| 357 |
+
# apply jpeg compression
|
| 358 |
+
outputs = []
|
| 359 |
+
for img in imgs:
|
| 360 |
+
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_param]
|
| 361 |
+
_, img_encoded = cv2.imencode('.jpg', img * 255., encode_param)
|
| 362 |
+
outputs.append(np.float32(cv2.imdecode(img_encoded, 1)) / 255.)
|
| 363 |
+
|
| 364 |
+
# update compression level
|
| 365 |
+
jpeg_param += random.uniform(-quality_step, quality_step)
|
| 366 |
+
jpeg_param = round(np.clip(jpeg_param, quality[0], quality[1]))
|
| 367 |
+
|
| 368 |
+
if is_single_image:
|
| 369 |
+
outputs = outputs[0]
|
| 370 |
+
|
| 371 |
+
return outputs
|
| 372 |
+
|
| 373 |
+
def __call__(self, results):
|
| 374 |
+
if random.random() > self.params.get('prob', 1):
|
| 375 |
+
return results
|
| 376 |
+
|
| 377 |
+
results = self._apply_random_compression(results)
|
| 378 |
+
|
| 379 |
+
return results
|
| 380 |
+
|
| 381 |
+
def __repr__(self):
|
| 382 |
+
repr_str = self.__class__.__name__
|
| 383 |
+
repr_str += (f'(params={self.params})')
|
| 384 |
+
return repr_str
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
allowed_degradations = {
|
| 388 |
+
'RandomBlur': RandomBlur,
|
| 389 |
+
'RandomResize': RandomResize,
|
| 390 |
+
'RandomNoise': RandomNoise,
|
| 391 |
+
'RandomJPEGCompression': RandomJPEGCompression,
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
class DegradationsWithShuffle:
|
| 396 |
+
"""Apply random degradations to input, with degradations being shuffled.
|
| 397 |
+
|
| 398 |
+
Degradation groups are supported. The order of degradations within the same
|
| 399 |
+
group is preserved. For example, if we have degradations = [a, b, [c, d]]
|
| 400 |
+
and shuffle_idx = None, then the possible orders are
|
| 401 |
+
|
| 402 |
+
::
|
| 403 |
+
|
| 404 |
+
[a, b, [c, d]]
|
| 405 |
+
[a, [c, d], b]
|
| 406 |
+
[b, a, [c, d]]
|
| 407 |
+
[b, [c, d], a]
|
| 408 |
+
[[c, d], a, b]
|
| 409 |
+
[[c, d], b, a]
|
| 410 |
+
|
| 411 |
+
Modified keys are the attributed specified in "keys".
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
degradations (list[dict]): The list of degradations.
|
| 415 |
+
keys (list[str]): A list specifying the keys whose values are
|
| 416 |
+
modified.
|
| 417 |
+
shuffle_idx (list | None, optional): The degradations corresponding to
|
| 418 |
+
these indices are shuffled. If None, all degradations are shuffled.
|
| 419 |
+
"""
|
| 420 |
+
|
| 421 |
+
def __init__(self, degradations, shuffle_idx=None):
|
| 422 |
+
|
| 423 |
+
self.degradations = self._build_degradations(degradations)
|
| 424 |
+
|
| 425 |
+
if shuffle_idx is None:
|
| 426 |
+
self.shuffle_idx = list(range(0, len(degradations)))
|
| 427 |
+
else:
|
| 428 |
+
self.shuffle_idx = shuffle_idx
|
| 429 |
+
|
| 430 |
+
def _build_degradations(self, degradations):
|
| 431 |
+
for i, degradation in enumerate(degradations):
|
| 432 |
+
if isinstance(degradation, (list, tuple)):
|
| 433 |
+
degradations[i] = self._build_degradations(degradation)
|
| 434 |
+
else:
|
| 435 |
+
degradation_ = allowed_degradations[degradation['type']]
|
| 436 |
+
degradations[i] = degradation_(degradation['params'])
|
| 437 |
+
|
| 438 |
+
return degradations
|
| 439 |
+
|
| 440 |
+
def __call__(self, results):
|
| 441 |
+
# shuffle degradations
|
| 442 |
+
if len(self.shuffle_idx) > 0:
|
| 443 |
+
shuffle_list = [self.degradations[i] for i in self.shuffle_idx]
|
| 444 |
+
random.shuffle(shuffle_list)
|
| 445 |
+
for i, idx in enumerate(self.shuffle_idx):
|
| 446 |
+
self.degradations[idx] = shuffle_list[i]
|
| 447 |
+
|
| 448 |
+
# apply degradations to input
|
| 449 |
+
for degradation in self.degradations:
|
| 450 |
+
if isinstance(degradation, (tuple, list)):
|
| 451 |
+
for subdegrdation in degradation:
|
| 452 |
+
results = subdegrdation(results)
|
| 453 |
+
else:
|
| 454 |
+
results = degradation(results)
|
| 455 |
+
|
| 456 |
+
return results
|
| 457 |
+
|
| 458 |
+
def __repr__(self):
|
| 459 |
+
repr_str = self.__class__.__name__
|
| 460 |
+
repr_str += (f'(degradations={self.degradations}, '
|
| 461 |
+
f'shuffle_idx={self.shuffle_idx})')
|
| 462 |
+
return repr_str
|
spai/data/readers.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
|
| 2 |
+
# and University of Amsterdam. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import csv
|
| 18 |
+
import io
|
| 19 |
+
import pathlib
|
| 20 |
+
from typing import Any, Union, Optional
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from torchvision.io import read_image
|
| 26 |
+
|
| 27 |
+
from spai.data import filestorage
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DataReader:
|
| 31 |
+
|
| 32 |
+
def read_csv_file(self, path: str) -> list[dict[str, Any]]:
|
| 33 |
+
raise NotImplementedError
|
| 34 |
+
|
| 35 |
+
def load_image(self, path: str, channels: int) -> Image.Image:
|
| 36 |
+
raise NotImplementedError
|
| 37 |
+
|
| 38 |
+
def get_image_size(self, path: str) -> tuple[int, int]:
|
| 39 |
+
"""Returns the size of an image as a (width, height) tuple."""
|
| 40 |
+
raise NotImplementedError
|
| 41 |
+
|
| 42 |
+
def load_signals_from_csv(
|
| 43 |
+
self,
|
| 44 |
+
csv_path: str,
|
| 45 |
+
column_name: str = "seg_map",
|
| 46 |
+
channels: int = 1,
|
| 47 |
+
data_specifier: Optional[dict[str, str]] = None
|
| 48 |
+
) -> list[np.ndarray]:
|
| 49 |
+
"""Loads all the signals specified in a column of a CSV file.
|
| 50 |
+
|
| 51 |
+
The default values for the column name and the number of channels have
|
| 52 |
+
been specified for the CSV containing the instance segmentation maps of
|
| 53 |
+
an image."""
|
| 54 |
+
raise NotImplementedError
|
| 55 |
+
|
| 56 |
+
def load_file_path_or_stream(self, path: str) -> Union[pathlib.Path, io.FileIO, io.BytesIO]:
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class FileSystemReader(DataReader):
|
| 61 |
+
"""Reader that maps relative paths to absolute paths of the filesystem."""
|
| 62 |
+
def __init__(self, root_path: pathlib.Path):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.root_path: pathlib.Path = root_path
|
| 65 |
+
|
| 66 |
+
def read_csv_file(self, path: str) -> list[dict[str, Any]]:
|
| 67 |
+
with (self.root_path/path).open("r") as f:
|
| 68 |
+
reader = csv.DictReader(f, delimiter=",")
|
| 69 |
+
contents: list[dict[str, Any]] = [row for row in reader]
|
| 70 |
+
return contents
|
| 71 |
+
|
| 72 |
+
def get_image_size(self, path: str) -> tuple[int, int]:
|
| 73 |
+
with Image.open(self.root_path/path) as image:
|
| 74 |
+
image_size: tuple[int, int] = image.size
|
| 75 |
+
return image_size
|
| 76 |
+
|
| 77 |
+
def load_image(self, path: str, channels: int) -> Image.Image:
|
| 78 |
+
try:
|
| 79 |
+
image = Image.open(self.root_path/path)
|
| 80 |
+
if channels == 1:
|
| 81 |
+
image = image.convert("L")
|
| 82 |
+
else:
|
| 83 |
+
image = image.convert("RGB")
|
| 84 |
+
except Exception as e:
|
| 85 |
+
print(f"Failed to read: {path}")
|
| 86 |
+
raise e
|
| 87 |
+
# image = np.array(image)
|
| 88 |
+
#
|
| 89 |
+
# if len(image.shape) == 2:
|
| 90 |
+
# image = np.expand_dims(image, axis=2)
|
| 91 |
+
return image
|
| 92 |
+
|
| 93 |
+
def load_signals_from_csv(
|
| 94 |
+
self,
|
| 95 |
+
csv_path: str,
|
| 96 |
+
column_name: str = "seg_map",
|
| 97 |
+
channels: int = 1,
|
| 98 |
+
data_specifier: Optional[dict[str, str]] = None
|
| 99 |
+
) -> list[np.ndarray]:
|
| 100 |
+
csv_data: list[dict[str, Any]] = self.read_csv_file(csv_path)
|
| 101 |
+
|
| 102 |
+
signals: list[np.ndarray] = []
|
| 103 |
+
for row in csv_data:
|
| 104 |
+
# Ignore entries that do not match with the given data specifier.
|
| 105 |
+
if data_specifier is not None and not data_specifier_matches_entry(row, data_specifier):
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
signal_path: pathlib.Path = (self.root_path / csv_path).parent / row[column_name]
|
| 109 |
+
signal: np.ndarray = self.load_image(str(signal_path.relative_to(self.root_path)),
|
| 110 |
+
channels=channels)
|
| 111 |
+
signals.append(signal)
|
| 112 |
+
|
| 113 |
+
return signals
|
| 114 |
+
|
| 115 |
+
def load_file_path_or_stream(self, path: str) -> Union[pathlib.Path, io.FileIO, io.BytesIO]:
|
| 116 |
+
return self.root_path / path
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class LMDBFileStorageReader(DataReader):
|
| 120 |
+
"""Reader that maps relative paths into an LMDBFileStorage."""
|
| 121 |
+
def __init__(self, storage: filestorage.LMDBFileStorage):
|
| 122 |
+
super().__init__()
|
| 123 |
+
self.storage: filestorage.LMDBFileStorage = storage
|
| 124 |
+
|
| 125 |
+
def read_csv_file(self, path: str) -> list[dict[str, Any]]:
|
| 126 |
+
stream = self.storage.open_file(path)
|
| 127 |
+
reader = csv.DictReader(stream, delimiter=",")
|
| 128 |
+
contents: list[dict[str, Any]] = [row for row in reader]
|
| 129 |
+
return contents
|
| 130 |
+
|
| 131 |
+
def get_image_size(self, path: str) -> tuple[int, int]:
|
| 132 |
+
stream = self.storage.open_file(path, mode="b")
|
| 133 |
+
with Image.open(stream) as image:
|
| 134 |
+
image_size: tuple[int, int] = image.size
|
| 135 |
+
return image_size
|
| 136 |
+
|
| 137 |
+
def load_image(self, path: str, channels: int) -> Image.Image:
|
| 138 |
+
stream = self.storage.open_file(path, mode="b")
|
| 139 |
+
with Image.open(stream) as image:
|
| 140 |
+
if channels == 1:
|
| 141 |
+
image = image.convert("L")
|
| 142 |
+
else:
|
| 143 |
+
image = image.convert("RGB")
|
| 144 |
+
# image = np.array(image)
|
| 145 |
+
stream.close()
|
| 146 |
+
|
| 147 |
+
# if len(image.shape) == 2:
|
| 148 |
+
# image = np.expand_dims(image, axis=2)
|
| 149 |
+
return image
|
| 150 |
+
|
| 151 |
+
def load_signals_from_csv(
|
| 152 |
+
self,
|
| 153 |
+
csv_path: str,
|
| 154 |
+
column_name: str = "seg_map",
|
| 155 |
+
channels: int = 1,
|
| 156 |
+
data_specifier: Optional[dict[str, str]] = None
|
| 157 |
+
) -> list[np.ndarray]:
|
| 158 |
+
csv_data: list[dict[str, Any]] = self.read_csv_file(csv_path)
|
| 159 |
+
|
| 160 |
+
signals: list[np.ndarray] = []
|
| 161 |
+
for row in csv_data:
|
| 162 |
+
# Ignore entries that do not match with the given data specifier.
|
| 163 |
+
if data_specifier is not None and not data_specifier_matches_entry(row, data_specifier):
|
| 164 |
+
continue
|
| 165 |
+
|
| 166 |
+
signal_path: str = str(pathlib.Path(csv_path).parent / row[column_name])
|
| 167 |
+
signal: np.ndarray = self.load_image(signal_path, channels=channels)
|
| 168 |
+
signals.append(signal)
|
| 169 |
+
|
| 170 |
+
return signals
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def data_specifier_matches_entry(entry: dict[str, str], specifier: dict[str, str]) -> bool:
|
| 174 |
+
"""Checks whether a CSV entry matches a data specifier."""
|
| 175 |
+
for k, v in specifier.items():
|
| 176 |
+
if k not in entry or entry[k] != v:
|
| 177 |
+
return False
|
| 178 |
+
return True
|
spai/data_utils.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
|
| 2 |
+
# and University of Amsterdam. All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import pathlib
|
| 18 |
+
import csv
|
| 19 |
+
import hashlib
|
| 20 |
+
from typing import Any, Optional
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def read_csv_file(
|
| 24 |
+
path: pathlib.Path,
|
| 25 |
+
delimiter: str = ","
|
| 26 |
+
) -> list[dict[str, Any]]:
|
| 27 |
+
with path.open("r") as f:
|
| 28 |
+
reader = csv.DictReader(f, delimiter=delimiter)
|
| 29 |
+
contents: list[dict[str, Any]] = [row for row in reader]
|
| 30 |
+
return contents
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def write_csv_file(
|
| 34 |
+
data: list[dict[str, Any]],
|
| 35 |
+
output_file: pathlib.Path,
|
| 36 |
+
fieldnames: Optional[list[str]] = None,
|
| 37 |
+
delimiter: str = "|"
|
| 38 |
+
) -> None:
|
| 39 |
+
if fieldnames is None:
|
| 40 |
+
fieldnames = list(data[0].keys())
|
| 41 |
+
with output_file.open("w", newline="") as f:
|
| 42 |
+
writer: csv.DictWriter = csv.DictWriter(f, fieldnames=fieldnames, delimiter=delimiter)
|
| 43 |
+
writer.writeheader()
|
| 44 |
+
for r in data:
|
| 45 |
+
writer.writerow(r)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def compute_file_md5(path: pathlib.Path) -> str:
|
| 49 |
+
with path.open("rb") as f:
|
| 50 |
+
return hashlib.md5(f.read()).hexdigest()
|