aminasifar1 commited on
Commit
04f866d
·
verified ·
1 Parent(s): 599969f

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +248 -0
  2. __pycache__/inference.cpython-313.pyc +0 -0
  3. configs/spai.yaml +54 -0
  4. inference.py +208 -0
  5. requirements.txt +37 -0
  6. spai/__init__.py +15 -0
  7. spai/__main__.py +208 -0
  8. spai/__pycache__/__init__.cpython-310.pyc +0 -0
  9. spai/__pycache__/__init__.cpython-313.pyc +0 -0
  10. spai/__pycache__/__main__.cpython-310.pyc +0 -0
  11. spai/__pycache__/__main__.cpython-313.pyc +0 -0
  12. spai/__pycache__/config.cpython-310.pyc +0 -0
  13. spai/__pycache__/config.cpython-313.pyc +0 -0
  14. spai/__pycache__/data_utils.cpython-310.pyc +0 -0
  15. spai/__pycache__/data_utils.cpython-313.pyc +0 -0
  16. spai/__pycache__/logger.cpython-310.pyc +0 -0
  17. spai/__pycache__/logger.cpython-313.pyc +0 -0
  18. spai/__pycache__/lr_scheduler.cpython-310.pyc +0 -0
  19. spai/__pycache__/lr_scheduler.cpython-313.pyc +0 -0
  20. spai/__pycache__/metrics.cpython-310.pyc +0 -0
  21. spai/__pycache__/metrics.cpython-313.pyc +0 -0
  22. spai/__pycache__/onnx.cpython-310.pyc +0 -0
  23. spai/__pycache__/onnx.cpython-313.pyc +0 -0
  24. spai/__pycache__/optimizer.cpython-310.pyc +0 -0
  25. spai/__pycache__/optimizer.cpython-313.pyc +0 -0
  26. spai/__pycache__/utils.cpython-310.pyc +0 -0
  27. spai/__pycache__/utils.cpython-313.pyc +0 -0
  28. spai/config.py +494 -0
  29. spai/data/__init__.py +26 -0
  30. spai/data/__pycache__/__init__.cpython-310.pyc +0 -0
  31. spai/data/__pycache__/__init__.cpython-313.pyc +0 -0
  32. spai/data/__pycache__/blur_kernels.cpython-310.pyc +0 -0
  33. spai/data/__pycache__/blur_kernels.cpython-313.pyc +0 -0
  34. spai/data/__pycache__/data_finetune.cpython-310.pyc +0 -0
  35. spai/data/__pycache__/data_finetune.cpython-313.pyc +0 -0
  36. spai/data/__pycache__/data_mfm.cpython-310.pyc +0 -0
  37. spai/data/__pycache__/data_mfm.cpython-313.pyc +0 -0
  38. spai/data/__pycache__/filestorage.cpython-310.pyc +0 -0
  39. spai/data/__pycache__/filestorage.cpython-313.pyc +0 -0
  40. spai/data/__pycache__/random_degradations.cpython-310.pyc +0 -0
  41. spai/data/__pycache__/random_degradations.cpython-313.pyc +0 -0
  42. spai/data/__pycache__/readers.cpython-310.pyc +0 -0
  43. spai/data/__pycache__/readers.cpython-313.pyc +0 -0
  44. spai/data/blur_kernels.py +539 -0
  45. spai/data/data_finetune.py +723 -0
  46. spai/data/data_mfm.py +131 -0
  47. spai/data/filestorage.py +387 -0
  48. spai/data/random_degradations.py +462 -0
  49. spai/data/readers.py +178 -0
  50. spai/data_utils.py +50 -0
README.md ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPAI: Spectral AI-Generated Image Detector
2
+
3
+ Official repository for the CVPR 2025 paper:
4
+ Any-Resolution AI-Generated Image Detection by Spectral Learning.
5
+
6
+ SPAI learns the spectral distribution of real images and detects AI-generated
7
+ images as out-of-distribution samples using spectral reconstruction similarity.
8
+
9
+ ## Repository Status
10
+
11
+ This repository currently contains:
12
+
13
+ - Core SPAI package in `spai/`.
14
+ - Main config in `configs/spai.yaml`.
15
+ - A trained checkpoint in `spai/weights/spai.pth`.
16
+ - Unit tests in `tests/`.
17
+ - Utility scripts for data prep, crawling, Fourier analysis and reporting in `tools/` and `spai/tools/`.
18
+ - Hugging Face inference handler in `inference.py`.
19
+
20
+ ## Project Structure
21
+
22
+ ```text
23
+ .
24
+ ├── configs/
25
+ │ └── spai.yaml
26
+ ├── spai/
27
+ │ ├── data/ # datasets, readers, augmentations, filestorage (LMDB)
28
+ │ ├── models/ # backbones, SID, MFM, losses, filters
29
+ │ ├── tools/ # CSV generation and dataset utilities
30
+ │ ├── weights/
31
+ │ │ └── spai.pth # included checkpoint
32
+ │ ├── config.py # yacs configuration
33
+ │ ├── hf_utils.py # Hugging Face Hub upload/model card helpers
34
+ │ ├── main_mfm.py # MFM pretraining entrypoint
35
+ │ └── ...
36
+ ├── tests/
37
+ │ ├── data/
38
+ │ └── models/
39
+ ├── tools/ # analysis, crawling, preprocessing, HF execution logs
40
+ ├── inference.py # HF EndpointHandler + local single-image inference
41
+ └── requirements.txt
42
+ ```
43
+
44
+ ## Installation
45
+
46
+ Recommended environment:
47
+
48
+ ```bash
49
+ conda create -n spai python=3.11
50
+ conda activate spai
51
+ conda install pytorch torchvision torchaudio pytorch-cuda=12.4 -c pytorch -c nvidia
52
+ pip install -r requirements.txt
53
+ ```
54
+
55
+ Notes:
56
+
57
+ - Training code may require NVIDIA APEX.
58
+ - `requirements.txt` includes packages for training, inference, ONNX, crawling and Hugging Face utilities.
59
+
60
+ ## Configuration and Weights
61
+
62
+ - Main config: `configs/spai.yaml`
63
+ - Default included checkpoint: `spai/weights/spai.pth`
64
+
65
+ The included config is set for SID finetuning/inference with arbitrary-resolution processing.
66
+
67
+ ## Inference
68
+
69
+ ### 1) Hugging Face Endpoint Handler (recommended)
70
+
71
+ File: `inference.py`
72
+
73
+ The `EndpointHandler` supports these input formats:
74
+
75
+ - image URL (`http://...` / `https://...`)
76
+ - local path
77
+ - base64 string
78
+ - raw bytes
79
+ - PIL image
80
+ - dict with one of keys: `url`, `path`, `b64`, `bytes`
81
+
82
+ Output format:
83
+
84
+ ```json
85
+ {
86
+ "score": 0.8732,
87
+ "predicted_label": 1,
88
+ "predicted_label_name": "ai-generated",
89
+ "threshold": 0.5
90
+ }
91
+ ```
92
+
93
+ Label convention used in repository tooling:
94
+
95
+ - `0` -> real
96
+ - `1` -> ai-generated
97
+
98
+ Run locally:
99
+
100
+ ```bash
101
+ python inference.py --image "/path/to/image.jpg" --model-dir .
102
+ ```
103
+
104
+ Environment overrides:
105
+
106
+ - `SPAI_THRESHOLD` (default `0.5`)
107
+ - `SPAI_CONFIG` (custom config path)
108
+ - `SPAI_CHECKPOINT` (custom checkpoint path)
109
+ - `SPAI_FORCE_CPU=1` (force CPU)
110
+
111
+ ### 2) Python usage
112
+
113
+ ```python
114
+ from inference import EndpointHandler
115
+
116
+ handler = EndpointHandler(path=".")
117
+ result = handler({"inputs": "https://example.com/image.jpg"})
118
+ print(result)
119
+ ```
120
+
121
+ ## Training
122
+
123
+ ### MFM pretraining entrypoint
124
+
125
+ Use:
126
+
127
+ ```bash
128
+ python spai/main_mfm.py --cfg configs/spai.yaml --data-path /path/to/data.csv --output output/mfm
129
+ ```
130
+
131
+ `spai/main_mfm.py` also supports optional Hugging Face push flags:
132
+
133
+ - `--push-to-hub`
134
+ - `--hub-repo-id`
135
+ - `--hub-token`
136
+ - `--hub-create-model-card`
137
+
138
+ ## Dataset CSV Format
139
+
140
+ Core dataset readers in `spai/data/data_finetune.py` expect CSVs with at least:
141
+
142
+ - `image`: image path
143
+ - `class`: class id
144
+ - `split`: one of `train`, `val`, `test`
145
+
146
+ Paths are resolved relatively to a configurable CSV root directory.
147
+
148
+ ## LMDB Dataset File Storage
149
+
150
+ Module: `spai/data/filestorage.py`
151
+
152
+ Available commands:
153
+
154
+ ```bash
155
+ python spai/data/filestorage.py add-csv --help
156
+ python spai/data/filestorage.py add-db --help
157
+ python spai/data/filestorage.py verify-csv --help
158
+ python spai/data/filestorage.py list-db --help
159
+ ```
160
+
161
+ Use this workflow when you want to package many files into LMDB for faster or centralized IO.
162
+
163
+ ## Utility Scripts
164
+
165
+ ### Repository-level tools (`tools/`)
166
+
167
+ - `tools/simple_crawler.py`: crawl and download images with metadata.
168
+ - `tools/web_image_crawler.py`: crawl URLs/CSVs, download images, filter ad-like images.
169
+ - `tools/image_quality_processor.py`: quality filtering, deduplication and reports.
170
+ - `tools/preprocess_for_spai.py`: image preprocessing before SPAI.
171
+ - `tools/create_spai_metadata.py`: build metadata CSV from an image folder.
172
+ - `tools/extract_fourier_features.py`: compute Fourier-derived features.
173
+ - `tools/visualize_fourier.py`: Fourier spectrum visualizations.
174
+ - `tools/visualize_noise_decomposition.py`: advanced noise decomposition visualizations.
175
+ - `tools/analyze_spai_results.py`: plots/analysis for prediction results.
176
+ - `tools/analyze_normalization_impact.py`: study resize normalization impact.
177
+ - `tools/hf_log_execution.py`: generate execution artifacts and optionally upload to HF datasets.
178
+
179
+ Example:
180
+
181
+ ```bash
182
+ python tools/hf_log_execution.py --results-csv output/preds.csv --output-dir output/hf_artifacts
183
+ ```
184
+
185
+ ### Package tools (`spai/tools/`)
186
+
187
+ - `spai.tools.create_dir_csv`: create train/val/test CSV from directories.
188
+ - `spai.tools.create_dmid_ldm_train_val_csv`: create DMID/LDM training CSV.
189
+ - `spai.tools.augment_dataset`: augment a dataset and export updated CSV.
190
+ - `spai.tools.reduce_csv_column`: conditional column reduction/aggregation.
191
+ - `spai/tools/create_synthbuster_csv.py`: Synthbuster CSV generation utility.
192
+
193
+ Examples:
194
+
195
+ ```bash
196
+ python -m spai.tools.create_dir_csv --help
197
+ python -m spai.tools.create_dmid_ldm_train_val_csv --help
198
+ python -m spai.tools.augment_dataset --help
199
+ python -m spai.tools.reduce_csv_column --help
200
+ ```
201
+
202
+ For `create_synthbuster_csv.py`, use a `PYTHONPATH` that includes `spai/` due its import style:
203
+
204
+ ```bash
205
+ PYTHONPATH=spai python spai/tools/create_synthbuster_csv.py --help
206
+ ```
207
+
208
+ ## Tests
209
+
210
+ Run all tests:
211
+
212
+ ```bash
213
+ pytest tests -q
214
+ ```
215
+
216
+ Current test folders:
217
+
218
+ - `tests/data/`
219
+ - `tests/models/`
220
+
221
+ ## Acknowledgments
222
+
223
+ This work was partly supported by Horizon Europe projects ELIAS and vera.ai,
224
+ and computational resources from GRNET.
225
+
226
+ Parts of the implementation build upon ideas/code from:
227
+ https://github.com/Jiahao000/MFM
228
+
229
+ ## License
230
+
231
+ Source code is licensed under Apache 2.0.
232
+ Third-party datasets and dependencies keep their own licenses.
233
+
234
+ ## Contact
235
+
236
+ For questions: d.karageorgiou@uva.nl
237
+
238
+ ## Citation
239
+
240
+ ```text
241
+ @inproceedings{karageorgiou2025any,
242
+ title={Any-resolution ai-generated image detection by spectral learning},
243
+ author={Karageorgiou, Dimitrios and Papadopoulos, Symeon and Kompatsiaris, Ioannis and Gavves, Efstratios},
244
+ booktitle={Proceedings of the Computer Vision and Pattern Recognition Conference},
245
+ pages={18706--18717},
246
+ year={2025}
247
+ }
248
+ ```
__pycache__/inference.cpython-313.pyc ADDED
Binary file (12.7 kB). View file
 
configs/spai.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MODEL:
2
+ SID_APPROACH: "freq_restoration"
3
+ TYPE: vit
4
+ NAME: finetune
5
+ DROP_PATH_RATE: 0.1
6
+ NUM_CLASSES: 2
7
+ REQUIRED_NORMALIZATION: "positive_0_1"
8
+ RESOLUTION_MODE: "arbitrary"
9
+ FEATURE_EXTRACTION_BATCH: 400
10
+ VIT:
11
+ EMBED_DIM: 768
12
+ DEPTH: 12
13
+ NUM_HEADS: 12
14
+ INIT_VALUES: None
15
+ USE_APE: True
16
+ USE_RPB: False
17
+ USE_SHARED_RPB: False
18
+ USE_MEAN_POOLING: True
19
+ USE_INTERMEDIATE_LAYERS: True
20
+ PROJECTION_DIM: 1024
21
+ PROJECTION_LAYERS: 2
22
+ PATCH_PROJECTION: True
23
+ PATCH_PROJECTION_PER_FEATURE: True
24
+ INTERMEDIATE_LAYERS: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
25
+ FRE:
26
+ MASKING_RADIUS: 16
27
+ PROJECTOR_LAST_LAYER_ACTIVATION_TYPE: None
28
+ ORIGINAL_IMAGE_FEATURES_BRANCH: True
29
+ CLS_HEAD:
30
+ MLP_RATIO: 3
31
+ PATCH_VIT:
32
+ MINIMUM_PATCHES: 4
33
+ DATA:
34
+ DATASET: csv_sid
35
+ IMG_SIZE: 224
36
+ NUM_WORKERS: 8
37
+ AUGMENTED_VIEWS: 4
38
+ TEST_PREFETCH_FACTOR: 1
39
+ AUG:
40
+ COLOR_JITTER: 0.
41
+ TRAIN:
42
+ EPOCHS: 35
43
+ WARMUP_EPOCHS: 5
44
+ BASE_LR: 5e-4
45
+ WARMUP_LR: 2.5e-7
46
+ MIN_LR: 2.5e-7
47
+ WEIGHT_DECAY: 0.05
48
+ LAYER_DECAY: 0.8
49
+ CLIP_GRAD: None
50
+ LOSS: "bce"
51
+ TEST:
52
+ ORIGINAL_RESOLUTION: True
53
+ PRINT_FREQ: 100
54
+ SAVE_FREQ: 10
inference.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import base64
5
+ import io
6
+ import os
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import requests
12
+ import torch
13
+ from PIL import Image
14
+
15
+ from spai.config import get_custom_config
16
+ from spai.data.data_finetune import build_transform
17
+ from spai.models import build_cls_model
18
+
19
+
20
+ class EndpointHandler:
21
+ """Hugging Face Inference Endpoint handler for SPAI."""
22
+
23
+ def __init__(self, path: str = "") -> None:
24
+ self.model_dir = Path(path) if path else Path(".")
25
+ self.threshold = float(os.getenv("SPAI_THRESHOLD", "0.5"))
26
+
27
+ cfg_path = self._resolve_config_path()
28
+ self.config = get_custom_config(str(cfg_path))
29
+
30
+ self.device = self._resolve_device()
31
+ self.model = build_cls_model(self.config)
32
+ checkpoint_path = self._resolve_checkpoint_path()
33
+ state_dict = self._load_state_dict(checkpoint_path)
34
+ self.model.load_state_dict(state_dict, strict=False)
35
+ self.model.to(self.device)
36
+ self.model.eval()
37
+
38
+ self.transform = build_transform(is_train=False, config=self.config)
39
+
40
+ def __call__(self, data: dict[str, Any]) -> dict[str, Any] | list[dict[str, Any]]:
41
+ inputs = data.get("inputs", data.get("image", data))
42
+
43
+ if isinstance(inputs, list):
44
+ return [self._predict_one(item) for item in inputs]
45
+ return self._predict_one(inputs)
46
+
47
+ def _predict_one(self, raw_input: Any) -> dict[str, Any]:
48
+ image = self._load_image(raw_input)
49
+ image_np = np.array(image)
50
+ image_tensor = self.transform(image=image_np)["image"]
51
+
52
+ if self.config.MODEL.RESOLUTION_MODE == "arbitrary":
53
+ model_input = [image_tensor.unsqueeze(0).to(self.device)]
54
+ feature_batch_size = self.config.MODEL.FEATURE_EXTRACTION_BATCH
55
+ with torch.no_grad():
56
+ logits = self.model(model_input, feature_batch_size)
57
+ else:
58
+ model_input = image_tensor.unsqueeze(0).to(self.device)
59
+ with torch.no_grad():
60
+ logits = self.model(model_input)
61
+
62
+ score = float(torch.sigmoid(logits).flatten()[0].item())
63
+ predicted_label = int(score >= self.threshold)
64
+
65
+ return {
66
+ "score": score,
67
+ "predicted_label": predicted_label,
68
+ "predicted_label_name": "ai-generated" if predicted_label == 1 else "real",
69
+ "threshold": self.threshold,
70
+ }
71
+
72
+ def _resolve_config_path(self) -> Path:
73
+ env_cfg = os.getenv("SPAI_CONFIG")
74
+ if env_cfg:
75
+ cfg_path = Path(env_cfg)
76
+ if cfg_path.exists():
77
+ return cfg_path
78
+ raise FileNotFoundError(f"SPAI_CONFIG points to a missing file: {cfg_path}")
79
+
80
+ candidates = [
81
+ self.model_dir / "configs" / "spai.yaml",
82
+ self.model_dir / "spai.yaml",
83
+ self.model_dir / "config.yaml",
84
+ ]
85
+ for candidate in candidates:
86
+ if candidate.exists():
87
+ return candidate
88
+
89
+ raise FileNotFoundError(
90
+ "Could not locate model config. Expected one of: "
91
+ "configs/spai.yaml, spai.yaml, config.yaml, or SPAI_CONFIG env var."
92
+ )
93
+
94
+ def _resolve_checkpoint_path(self) -> Path:
95
+ env_ckpt = os.getenv("SPAI_CHECKPOINT")
96
+ if env_ckpt:
97
+ ckpt_path = Path(env_ckpt)
98
+ if ckpt_path.exists():
99
+ return ckpt_path
100
+ raise FileNotFoundError(f"SPAI_CHECKPOINT points to a missing file: {ckpt_path}")
101
+
102
+ candidates = [
103
+ self.model_dir / "spai.pth",
104
+ self.model_dir / "pytorch_model.bin",
105
+ self.model_dir / "weights" / "spai.pth",
106
+ self.model_dir / "spai" / "weights" / "spai.pth",
107
+ ]
108
+ for candidate in candidates:
109
+ if candidate.exists():
110
+ return candidate
111
+
112
+ pth_files = sorted(self.model_dir.glob("*.pth"))
113
+ if pth_files:
114
+ return pth_files[0]
115
+
116
+ raise FileNotFoundError(
117
+ "Could not locate model checkpoint. Expected one of: "
118
+ "spai.pth, pytorch_model.bin, weights/spai.pth, spai/weights/spai.pth, "
119
+ "or SPAI_CHECKPOINT env var."
120
+ )
121
+
122
+ @staticmethod
123
+ def _resolve_device() -> torch.device:
124
+ force_cpu = os.getenv("SPAI_FORCE_CPU", "0") == "1"
125
+ if (not force_cpu) and torch.cuda.is_available():
126
+ return torch.device("cuda")
127
+ return torch.device("cpu")
128
+
129
+ @staticmethod
130
+ def _load_state_dict(checkpoint_path: Path) -> dict[str, torch.Tensor]:
131
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
132
+ if isinstance(checkpoint, dict) and "model" in checkpoint and isinstance(checkpoint["model"], dict):
133
+ return checkpoint["model"]
134
+
135
+ if isinstance(checkpoint, dict):
136
+ tensor_values = all(isinstance(v, torch.Tensor) for v in checkpoint.values())
137
+ if tensor_values:
138
+ return checkpoint
139
+
140
+ raise RuntimeError(
141
+ "Unsupported checkpoint format. Expected a dict with key 'model' or a raw state_dict."
142
+ )
143
+
144
+ def _load_image(self, raw_input: Any) -> Image.Image:
145
+ if isinstance(raw_input, Image.Image):
146
+ return raw_input.convert("RGB")
147
+
148
+ if isinstance(raw_input, bytes):
149
+ return Image.open(io.BytesIO(raw_input)).convert("RGB")
150
+
151
+ if isinstance(raw_input, dict):
152
+ if "bytes" in raw_input:
153
+ raw_bytes = raw_input["bytes"]
154
+ if isinstance(raw_bytes, str):
155
+ raw_bytes = base64.b64decode(raw_bytes)
156
+ return Image.open(io.BytesIO(raw_bytes)).convert("RGB")
157
+ if "b64" in raw_input:
158
+ return Image.open(io.BytesIO(base64.b64decode(raw_input["b64"]))).convert("RGB")
159
+ if "url" in raw_input:
160
+ return self._load_image_from_url(raw_input["url"])
161
+ if "path" in raw_input:
162
+ return Image.open(Path(raw_input["path"])).convert("RGB")
163
+
164
+ if isinstance(raw_input, str):
165
+ if raw_input.startswith("http://") or raw_input.startswith("https://"):
166
+ return self._load_image_from_url(raw_input)
167
+
168
+ if raw_input.startswith("data:image") and "," in raw_input:
169
+ _, encoded = raw_input.split(",", 1)
170
+ return Image.open(io.BytesIO(base64.b64decode(encoded))).convert("RGB")
171
+
172
+ maybe_path = Path(raw_input)
173
+ if maybe_path.exists():
174
+ return Image.open(maybe_path).convert("RGB")
175
+
176
+ try:
177
+ decoded = base64.b64decode(raw_input, validate=True)
178
+ return Image.open(io.BytesIO(decoded)).convert("RGB")
179
+ except Exception as exc:
180
+ raise ValueError(
181
+ "String input is neither a valid URL, file path, nor base64 image payload."
182
+ ) from exc
183
+
184
+ raise TypeError(
185
+ "Unsupported input type. Use a URL/path/base64 string, bytes, PIL.Image, "
186
+ "or dict with one of keys: bytes, b64, url, path."
187
+ )
188
+
189
+ @staticmethod
190
+ def _load_image_from_url(url: str) -> Image.Image:
191
+ response = requests.get(url, timeout=15)
192
+ response.raise_for_status()
193
+ return Image.open(io.BytesIO(response.content)).convert("RGB")
194
+
195
+
196
+ def _main() -> None:
197
+ parser = argparse.ArgumentParser(description="Run SPAI inference for a single image.")
198
+ parser.add_argument("--image", type=str, required=True, help="Image path/URL/base64 input")
199
+ parser.add_argument("--model-dir", type=str, default=".", help="Directory with config/checkpoint")
200
+ args = parser.parse_args()
201
+
202
+ handler = EndpointHandler(path=args.model_dir)
203
+ result = handler({"inputs": args.image})
204
+ print(result)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ _main()
requirements.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install through conda
2
+ # pytorch
3
+ # torchvision~=0.18.1
4
+ # tsnecuda
5
+ # Compile from sources
6
+ # apex
7
+ # Install through pip
8
+ opencv-python~=4.10.0.84
9
+ pyyaml~=6.0.1
10
+ scipy~=1.14.0
11
+ tensorboard
12
+ termcolor~=2.4.0
13
+ timm==0.4.12
14
+ yacs~=0.1.8
15
+ numpy~=1.26.4
16
+ torchmetrics~=1.4.0.post0
17
+ tqdm~=4.66.4
18
+ pillow~=10.4.0
19
+ PyYAML
20
+ click~=8.1.7
21
+ neptune~=1.11.1
22
+ albumentations==1.4.14
23
+ albucore==0.0.16
24
+ lmdb~=1.5.1
25
+ networkx~=3.3
26
+ seaborn~=0.13.2
27
+ pandas~=2.2.2
28
+ neptune
29
+ einops~=0.8.0
30
+ git+https://github.com/openai/CLIP.git
31
+ onnx
32
+ onnxscript
33
+ huggingface_hub~=0.21.0
34
+ datasets~=2.19.0
35
+ requests~=2.32.3
36
+ beautifulsoup4~=4.12.3
37
+ imagehash~=4.3.1
spai/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
2
+ # and University of Amsterdam. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
spai/__main__.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
2
+ # and University of Amsterdam. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ import os
19
+ import pathlib
20
+ import time
21
+ import datetime
22
+ from pathlib import Path
23
+ from typing import Optional
24
+
25
+ import numpy as np
26
+
27
+ import neptune
28
+ import cv2
29
+ import click
30
+ import torch
31
+ import torch.backends.cudnn as cudnn
32
+ import torch.utils.data
33
+ import yacs
34
+ import filetype
35
+ from torch import nn
36
+ from torch.nn import TripletMarginLoss
37
+ from torch.utils.tensorboard import SummaryWriter
38
+ from timm.utils import AverageMeter
39
+ from yacs.config import CfgNode
40
+
41
+ import spai.data.data_finetune
42
+ from spai.config import get_config
43
+ from spai.models import build_cls_model
44
+ from spai.data import build_loader, build_loader_test
45
+ from spai.lr_scheduler import build_scheduler
46
+ from spai.models.sid import AttentionMask
47
+ from spai.onnx import compare_pytorch_onnx_models
48
+ from spai.optimizer import build_optimizer
49
+ from spai.logger import create_logger
50
+ from spai.utils import (
51
+ load_pretrained,
52
+ save_checkpoint,
53
+ get_grad_norm,
54
+ find_pretrained_checkpoints,
55
+ inf_nan_to_num
56
+ )
57
+ from spai.models import losses
58
+ from spai import metrics
59
+ from spai import data_utils
60
+
61
+
62
+ def _cuda_enabled() -> bool:
63
+ # Allow forcing CPU mode and avoid probing CUDA on incompatible drivers.
64
+ if os.environ.get("SPAI_FORCE_CPU", "0") == "1":
65
+ return False
66
+ if os.environ.get("CUDA_VISIBLE_DEVICES", "") == "":
67
+ return False
68
+ try:
69
+ return torch.cuda.is_available()
70
+ except Exception:
71
+ return False
72
+
73
+ try:
74
+ # noinspection PyUnresolvedReferences
75
+ from apex import amp
76
+ except ImportError:
77
+ amp = None
78
+
79
+ cv2.setNumThreads(1)
80
+ logger: Optional[logging.Logger] = None
81
+
82
+
83
+ @click.group()
84
+ def cli() -> None:
85
+ pass
86
+
87
+
88
+ @cli.command()
89
+ @click.option("--cfg", required=True,
90
+ type=click.Path(exists=True, dir_okay=False, path_type=Path))
91
+ @click.option("--batch-size", type=int,
92
+ help="Batch size for a single GPU.")
93
+ @click.option("--learning-rate", type=float)
94
+ @click.option("--data-path", required=True,
95
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
96
+ help="path to dataset")
97
+ @click.option("--csv-root-dir",
98
+ type=click.Path(exists=True, file_okay=False, path_type=Path))
99
+ @click.option("--lmdb", "lmdb_path",
100
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
101
+ help="Path to an LMDB file storage that contains the files defined in the "
102
+ "dataset's CSV file. If this option is not provided, the data will be "
103
+ "loaded from the filesystem.")
104
+ @click.option("--pretrained",
105
+ type=click.Path(exists=True, dir_okay=False),
106
+ help="path to pre-trained model")
107
+ @click.option("--resume", is_flag=True,
108
+ help="resume from checkpoint")
109
+ @click.option("--accumulation-steps", type=int, default=1,
110
+ help="Gradient accumulation steps.")
111
+ @click.option("--use-checkpoint", is_flag=True,
112
+ help="Whether to use gradient checkpointing to save memory.")
113
+ @click.option("--amp-opt-level", type=click.Choice(["O0", "O1", "O2"]), default="O1",
114
+ help="mixed precision opt level, if O0, no amp is used")
115
+ @click.option("--output", type=click.Path(file_okay=False, path_type=Path),
116
+ help="root of output folder, the full path is "
117
+ "<output>/<model_name>/<tag> (default: output)")
118
+ @click.option("--tag", type=str,
119
+ help="tag of experiment")
120
+ @click.option("--local_rank", type=int, default=0,
121
+ help="local_rank for distributed training")
122
+ @click.option("--test-csv", multiple=True,
123
+ type=click.Path(exists=True, dir_okay=False, path_type=Path),
124
+ help="Path to a CSV with test data. If this option is provided after the "
125
+ "validation of each epoch, a testing will also take place. This option "
126
+ "intends to facilitate understanding the progression of the generalization "
127
+ "ability of a model among the epochs and should not be used for selecting "
128
+ "the final model. This option can be repeated several times. For each provided "
129
+ "csv file, a separate testing run is going to take place.")
130
+ @click.option("--test-csv-root-dir", multiple=True,
131
+ type=click.Path(exists=True, file_okay=False, path_type=Path),
132
+ help="Root directory for the relative paths included into the test csv files. "
133
+ "If this option is omitted, the parent directory of each test csv file will "
134
+ "be used as the root dir for the paths it contains. If this option is provided "
135
+ "a single time, it will be used as the root dir for all the test csv files. If "
136
+ "it is provided multiple times, each value will be matched with a corresponding "
137
+ "test csv file. In that case, the number of provided test csv files and the "
138
+ "number of provided root directories should match. The order of the provided "
139
+ "arguments will be used for the matching.")
140
+ @click.option("--data-workers", type=int,
141
+ help="Number of worker processes to be used for data loading.")
142
+ @click.option("--disable-pin-memory", is_flag=True)
143
+ @click.option("--data-prefetch-factor", type=int)
144
+ @click.option("--save-all", is_flag=True)
145
+ @click.option("--opt", "extra_options", type=(str, str), multiple=True)
146
+ def train(
147
+ cfg: Path,
148
+ batch_size: Optional[int],
149
+ learning_rate: Optional[float],
150
+ data_path: Path,
151
+ csv_root_dir: Optional[Path],
152
+ lmdb_path: Optional[Path],
153
+ pretrained: Optional[Path],
154
+ resume: bool,
155
+ accumulation_steps: int,
156
+ use_checkpoint: bool,
157
+ amp_opt_level: str,
158
+ output: Path,
159
+ tag: str,
160
+ local_rank: int,
161
+ test_csv: list[Path],
162
+ test_csv_root_dir: list[Path],
163
+ data_workers: Optional[int],
164
+ disable_pin_memory: bool,
165
+ data_prefetch_factor: Optional[int],
166
+ save_all: bool,
167
+ extra_options: tuple[str, str]
168
+ ) -> None:
169
+ if csv_root_dir is None:
170
+ csv_root_dir = data_path.parent
171
+ config = get_config({
172
+ "cfg": str(cfg),
173
+ "batch_size": batch_size,
174
+ "learning_rate": learning_rate,
175
+ "data_path": str(data_path),
176
+ "csv_root_dir": str(csv_root_dir),
177
+ "lmdb_path": str(lmdb_path),
178
+ "pretrained": str(pretrained) if pretrained is not None else None,
179
+ "resume": resume,
180
+ "accumulation_steps": accumulation_steps,
181
+ "use_checkpoint": use_checkpoint,
182
+ "amp_opt_level": amp_opt_level,
183
+ "output": str(output),
184
+ "tag": tag,
185
+ "local_rank": local_rank,
186
+ "test_csv": [str(p) for p in test_csv],
187
+ "test_csv_root": [str(p) for p in test_csv_root_dir],
188
+ "data_workers": data_workers,
189
+ "disable_pin_memory": disable_pin_memory,
190
+ "data_prefetch_factor": data_prefetch_factor,
191
+ "opts": extra_options
192
+ })
193
+ if 'LOCAL_RANK' not in os.environ:
194
+ os.environ['LOCAL_RANK'] = str(local_rank)
195
+
196
+ if config.AMP_OPT_LEVEL != "O0":
197
+ assert amp is not None, "amp not installed!"
198
+
199
+ # Set a fixed seed to all the random number generators.
200
+ seed = config.SEED
201
+ torch.manual_seed(seed)
202
+ np.random.seed(seed)
203
+ # random.seed(seed)
204
+ cudnn.benchmark = True
205
+
206
+ if config.TRAIN.SCALE_LR:
207
+ # Linear scale the learning rate according to total batch size - may not be optimal.
208
+ linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZ
spai/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (164 Bytes). View file
 
spai/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (144 Bytes). View file
 
spai/__pycache__/__main__.cpython-310.pyc ADDED
Binary file (28.9 kB). View file
 
spai/__pycache__/__main__.cpython-313.pyc ADDED
Binary file (9.31 kB). View file
 
spai/__pycache__/config.cpython-310.pyc ADDED
Binary file (8.44 kB). View file
 
spai/__pycache__/config.cpython-313.pyc ADDED
Binary file (20.1 kB). View file
 
spai/__pycache__/data_utils.cpython-310.pyc ADDED
Binary file (1.47 kB). View file
 
spai/__pycache__/data_utils.cpython-313.pyc ADDED
Binary file (2.15 kB). View file
 
spai/__pycache__/logger.cpython-310.pyc ADDED
Binary file (1.11 kB). View file
 
spai/__pycache__/logger.cpython-313.pyc ADDED
Binary file (1.93 kB). View file
 
spai/__pycache__/lr_scheduler.cpython-310.pyc ADDED
Binary file (5.14 kB). View file
 
spai/__pycache__/lr_scheduler.cpython-313.pyc ADDED
Binary file (7.59 kB). View file
 
spai/__pycache__/metrics.cpython-310.pyc ADDED
Binary file (5.75 kB). View file
 
spai/__pycache__/metrics.cpython-313.pyc ADDED
Binary file (10.9 kB). View file
 
spai/__pycache__/onnx.cpython-310.pyc ADDED
Binary file (3.98 kB). View file
 
spai/__pycache__/onnx.cpython-313.pyc ADDED
Binary file (7.25 kB). View file
 
spai/__pycache__/optimizer.cpython-310.pyc ADDED
Binary file (4.6 kB). View file
 
spai/__pycache__/optimizer.cpython-313.pyc ADDED
Binary file (8.91 kB). View file
 
spai/__pycache__/utils.cpython-310.pyc ADDED
Binary file (14.1 kB). View file
 
spai/__pycache__/utils.cpython-313.pyc ADDED
Binary file (25.4 kB). View file
 
spai/config.py ADDED
@@ -0,0 +1,494 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
2
+ # and University of Amsterdam. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ from typing import Optional, Any
19
+
20
+ import yaml
21
+ from yacs.config import CfgNode as CN
22
+
23
+ _C = CN()
24
+
25
+ # Base config files
26
+ _C.BASE = ['']
27
+
28
+ # -----------------------------------------------------------------------------
29
+ # Data settings
30
+ # -----------------------------------------------------------------------------
31
+ _C.DATA = CN()
32
+ # Batch size for a single GPU, could be overwritten by command line argument
33
+ _C.DATA.BATCH_SIZE = 128
34
+ # Batch size for validation. If it is set to None, DATA.BATCH_SIZE will be used.
35
+ _C.DATA.VAL_BATCH_SIZE = None
36
+ # Batch size for test. If it is set to None, DATA.BATCH_SIZE will be used.
37
+ _C.DATA.TEST_BATCH_SIZE = None
38
+ # Path to dataset, could be overwritten by command line argument
39
+ _C.DATA.DATA_PATH = ''
40
+ # Root path for the relative paths included in a dataset csv file. Not-used when
41
+ # the DATA.DATA_PATH does not point to a csv file.
42
+ _C.DATA.CSV_ROOT = ''
43
+ # A list of paths to the test datasets. Can be overwritten by command line argument.
44
+ _C.DATA.TEST_DATA_PATH = []
45
+ # A list of paths that will be used as root directories for the paths in the test csv files.
46
+ _C.DATA.TEST_DATA_CSV_ROOT = []
47
+ # Path to an LMDB filestorage. When this option is not None, the dataset's file are loaded
48
+ # from this one, instead of the filesystem.
49
+ _C.DATA.LMDB_PATH = None
50
+ # Dataset name
51
+ _C.DATA.DATASET = 'imagenet'
52
+ # Input image size
53
+ _C.DATA.IMG_SIZE = 224
54
+ # Minimal crop scale
55
+ _C.DATA.MIN_CROP_SCALE = 0.2
56
+ # Interpolation to resize image (random, bilinear, bicubic)
57
+ _C.DATA.INTERPOLATION = 'bicubic'
58
+ # Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
59
+ _C.DATA.PIN_MEMORY = True
60
+ # Number of data loading threads
61
+ _C.DATA.NUM_WORKERS = 24
62
+ # Number of batches to be prefetched by each worker.
63
+ _C.DATA.PREFETCH_FACTOR = 2
64
+ # Prefetch factor for test data loaders.
65
+ _C.DATA.VAL_PREFETCH_FACTOR = None
66
+ # Prefetch factor for test data loaders.
67
+ _C.DATA.TEST_PREFETCH_FACTOR = None
68
+
69
+ # Filter type, support 'mfm', 'sr', 'deblur', 'denoise'
70
+ _C.DATA.FILTER_TYPE = 'mfm'
71
+ # [MFM] Sampling ratio for low-pass filters
72
+ _C.DATA.SAMPLE_RATIO = 0.5
73
+ # [MFM] First frequency mask radius
74
+ # should be smaller than half of the image size
75
+ _C.DATA.MASK_RADIUS1 = 16
76
+ # [MFM] Second frequency mask radius
77
+ # should be larger than the first radius
78
+ # only used when masking a frequency band
79
+ # setting a larger value than the image size, e.g., 999, will have no effect
80
+ _C.DATA.MASK_RADIUS2 = 999
81
+ # [SR] SR downsampling scale factor, only used when FILTER_TYPE == 'sr'
82
+ _C.DATA.SR_FACTOR = 8
83
+ # [Deblur] Deblur parameters, only used when FILTER_TYPE == 'deblur'
84
+ _C.DATA.BLUR = CN()
85
+ _C.DATA.BLUR.KERNEL_SIZE = [7, 9, 11, 13, 15, 17, 19, 21]
86
+ _C.DATA.BLUR.KERNEL_LIST = ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso', 'sinc']
87
+ _C.DATA.BLUR.KERNEL_PROB = [0.405, 0.225, 0.108, 0.027, 0.108, 0.027, 0.1]
88
+ _C.DATA.BLUR.SIGMA_X = [0.2, 3]
89
+ _C.DATA.BLUR.SIGMA_Y = [0.2, 3]
90
+ _C.DATA.BLUR.ROTATE_ANGLE = [-3.1416, 3.1416]
91
+ _C.DATA.BLUR.BETA_GAUSSIAN = [0.5, 4]
92
+ _C.DATA.BLUR.BETA_PLATEAU = [1, 2]
93
+ # [Denoise] Denoise parameters, only used when FILTER_TYPE == 'denoise'
94
+ _C.DATA.NOISE = CN()
95
+ _C.DATA.NOISE.TYPE = ['gaussian', 'poisson']
96
+ _C.DATA.NOISE.PROB = [0.5, 0.5]
97
+ _C.DATA.NOISE.GAUSSIAN_SIGMA = [1, 30]
98
+ _C.DATA.NOISE.GAUSSIAN_GRAY_NOISE_PROB = 0.4
99
+ _C.DATA.NOISE.POISSON_SCALE = [0.05, 3]
100
+ _C.DATA.NOISE.POISSON_GRAY_NOISE_PROB = 0.4
101
+ # Number of augmented views for each batch. When SupCon loss is employed, this number
102
+ # should be at least 2.
103
+ _C.DATA.AUGMENTED_VIEWS = 1
104
+
105
+ # -----------------------------------------------------------------------------
106
+ # Model settings
107
+ # -----------------------------------------------------------------------------
108
+ _C.MODEL = CN()
109
+ # Model type
110
+ _C.MODEL.TYPE = 'vit'
111
+ # Type of weights that will be used to initialize the backbone. Supported "mfm", "clip", "dinov2".
112
+ _C.MODEL_WEIGHTS = "mfm"
113
+ # Model name
114
+ _C.MODEL.NAME = 'pretrain'
115
+ # Checkpoint to resume, could be overwritten by command line argument
116
+ _C.MODEL.RESUME = ''
117
+ # Number of classes, overwritten in data preparation
118
+ _C.MODEL.NUM_CLASSES = 1000
119
+ # Dropout rate for the backbone model.
120
+ _C.MODEL.DROP_RATE = 0.0
121
+ # Dropout rate for the trainable SID layers.
122
+ _C.MODEL.SID_DROPOUT = 0.5
123
+ # Drop path rate
124
+ _C.MODEL.DROP_PATH_RATE = 0.1
125
+ # Label Smoothing
126
+ _C.MODEL.LABEL_SMOOTHING = 0.1
127
+ # Required normalization to be applied to the image before provided to the model.
128
+ _C.MODEL.REQUIRED_NORMALIZATION = "imagenet"
129
+ # Approach used for the Synthetic Image Detection task. "single_extraction" and "freq_restoration"
130
+ # are currently supported.
131
+ _C.MODEL.SID_APPROACH = "single_extraction"
132
+ # Whether the model accepts a fixed resolution image to its input or an arbitrary resolution image.
133
+ # Supported values are "fixed" and "arbitrary"
134
+ _C.MODEL.RESOLUTION_MODE = "fixed"
135
+ # Batch size used internally by patched models for feature extraction. If not provided,
136
+ # it is determined by the batch size of the input.
137
+ _C.MODEL.FEATURE_EXTRACTION_BATCH = None
138
+
139
+ # Swin Transformer parameters
140
+ _C.MODEL.SWIN = CN()
141
+ _C.MODEL.SWIN.PATCH_SIZE = 4
142
+ _C.MODEL.SWIN.IN_CHANS = 3
143
+ _C.MODEL.SWIN.EMBED_DIM = 96
144
+ _C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
145
+ _C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
146
+ _C.MODEL.SWIN.WINDOW_SIZE = 7
147
+ _C.MODEL.SWIN.MLP_RATIO = 4.
148
+ _C.MODEL.SWIN.QKV_BIAS = True
149
+ _C.MODEL.SWIN.QK_SCALE = None
150
+ _C.MODEL.SWIN.APE = False
151
+ _C.MODEL.SWIN.PATCH_NORM = True
152
+
153
+ # Vision Transformer parameters
154
+ _C.MODEL.VIT = CN()
155
+ _C.MODEL.VIT.PATCH_SIZE = 16
156
+ _C.MODEL.VIT.IN_CHANS = 3
157
+ _C.MODEL.VIT.EMBED_DIM = 768
158
+ _C.MODEL.VIT.DEPTH = 12
159
+ _C.MODEL.VIT.NUM_HEADS = 12
160
+ _C.MODEL.VIT.MLP_RATIO = 4
161
+ _C.MODEL.VIT.QKV_BIAS = True
162
+ _C.MODEL.VIT.INIT_VALUES = 0.1
163
+ # learnable absolute positional embedding
164
+ _C.MODEL.VIT.USE_APE = True
165
+ # fixed sin-cos positional embedding
166
+ _C.MODEL.VIT.USE_FPE = False
167
+ # relative position bias
168
+ _C.MODEL.VIT.USE_RPB = False
169
+ _C.MODEL.VIT.USE_SHARED_RPB = False
170
+ _C.MODEL.VIT.USE_MEAN_POOLING = False
171
+ # Vision Transformer decoder parameters
172
+ _C.MODEL.VIT.DECODER = CN()
173
+ _C.MODEL.VIT.DECODER.EMBED_DIM = 512
174
+ _C.MODEL.VIT.DECODER.DEPTH = 0
175
+ _C.MODEL.VIT.DECODER.NUM_HEADS = 16
176
+
177
+ # Features processor parameter
178
+ # Supported features processors: "mean_norm", "norm_max", "rine"
179
+ _C.MODEL.VIT.FEATURES_PROCESSOR = "rine"
180
+ _C.MODEL.VIT.USE_INTERMEDIATE_LAYERS = False
181
+ _C.MODEL.VIT.INTERMEDIATE_LAYERS = [2, 5, 8, 11]
182
+ _C.MODEL.VIT.PROJECTION_DIM = 1024
183
+ _C.MODEL.VIT.PROJECTION_LAYERS = 2
184
+ _C.MODEL.VIT.PATCH_PROJECTION = False
185
+ _C.MODEL.VIT.PATCH_PROJECTION_PER_FEATURE = False
186
+ # Supported patch pooling: "mean", "l2_max"
187
+ _C.MODEL.VIT.PATCH_POOLING = "mean"
188
+
189
+ # Frequency Restoration Estimator parameters
190
+ _C.MODEL.FRE = CN()
191
+ _C.MODEL.FRE.MASKING_RADIUS = 16
192
+ _C.MODEL.FRE.PROJECTOR_LAST_LAYER_ACTIVATION_TYPE = "gelu"
193
+ _C.MODEL.FRE.ORIGINAL_IMAGE_FEATURES_BRANCH = False
194
+ _C.MODEL.FRE.DISABLE_RECONSTRUCTION_SIMILARITY = False
195
+
196
+ # PatchBasedMFViT related parameters
197
+ _C.MODEL.PATCH_VIT = CN()
198
+ _C.MODEL.PATCH_VIT.PATCH_STRIDE = 224
199
+ _C.MODEL.PATCH_VIT.NUM_HEADS = 12
200
+ _C.MODEL.PATCH_VIT.ATTN_EMBED_DIM = 1536
201
+ _C.MODEL.PATCH_VIT.MINIMUM_PATCHES = 1
202
+
203
+ # Classification head parameters
204
+ _C.MODEL.CLS_HEAD = CN()
205
+ _C.MODEL.CLS_HEAD.MLP_RATIO = 4
206
+
207
+ # ResNet parameters
208
+ _C.MODEL.RESNET = CN()
209
+ _C.MODEL.RESNET.LAYERS = [3, 4, 6, 3]
210
+ _C.MODEL.RESNET.IN_CHANS = 3
211
+
212
+ # [MFM] Reconstruction target type, support 'normal', 'masked'
213
+ _C.MODEL.RECOVER_TARGET_TYPE = 'normal'
214
+ # [MFM] Frequency loss parameters
215
+ _C.MODEL.FREQ_LOSS = CN()
216
+ _C.MODEL.FREQ_LOSS.LOSS_GAMMA = 1.
217
+ _C.MODEL.FREQ_LOSS.MATRIX_GAMMA = 1.
218
+ _C.MODEL.FREQ_LOSS.PATCH_FACTOR = 1
219
+ _C.MODEL.FREQ_LOSS.AVE_SPECTRUM = False
220
+ _C.MODEL.FREQ_LOSS.WITH_MATRIX = False
221
+ _C.MODEL.FREQ_LOSS.LOG_MATRIX = False
222
+ _C.MODEL.FREQ_LOSS.BATCH_MATRIX = False
223
+
224
+ # -----------------------------------------------------------------------------
225
+ # Training settings
226
+ # -----------------------------------------------------------------------------
227
+ _C.TRAIN = CN()
228
+ _C.TRAIN.START_EPOCH = 0
229
+ _C.TRAIN.EPOCHS = 300
230
+ _C.TRAIN.WARMUP_EPOCHS = 20
231
+ _C.TRAIN.WEIGHT_DECAY = 0.05
232
+ _C.TRAIN.BASE_LR = 3e-4
233
+ _C.TRAIN.WARMUP_LR = 2.5e-7
234
+ _C.TRAIN.MIN_LR = 2.5e-6
235
+ # Clip gradient norm
236
+ _C.TRAIN.CLIP_GRAD = 3.0
237
+ # Auto resume from latest checkpoint
238
+ _C.TRAIN.AUTO_RESUME = True
239
+ # Gradient accumulation steps
240
+ # could be overwritten by command line argument
241
+ _C.TRAIN.ACCUMULATION_STEPS = 1
242
+ # Whether to use gradient checkpointing to save memory
243
+ # could be overwritten by command line argument
244
+ _C.TRAIN.USE_CHECKPOINT = False
245
+
246
+ # LR scheduler
247
+ # Supported modes: "supervised", "contrastive"
248
+ _C.TRAIN.MODE = "supervised"
249
+ _C.TRAIN.LR_SCHEDULER = CN()
250
+ _C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
251
+ # Epoch interval to decay LR, used in StepLRScheduler
252
+ _C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
253
+ # LR decay rate, used in StepLRScheduler
254
+ _C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
255
+ # Gamma / Multi steps value, used in MultiStepLRScheduler
256
+ _C.TRAIN.LR_SCHEDULER.GAMMA = 0.1
257
+ _C.TRAIN.LR_SCHEDULER.MULTISTEPS = []
258
+ # A flag that indicates whether to scale lr according to batch size and grad accumulation steps.
259
+ _C.TRAIN.SCALE_LR = False
260
+ # Optimizer
261
+ _C.TRAIN.OPTIMIZER = CN()
262
+ _C.TRAIN.OPTIMIZER.NAME = 'adamw'
263
+ # Optimizer Epsilon
264
+ _C.TRAIN.OPTIMIZER.EPS = 1e-8
265
+ # Optimizer Betas
266
+ _C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
267
+ # SGD momentum
268
+ _C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
269
+ _C.TRAIN.LOSS = "bce_supcont"
270
+ _C.TRAIN.TRIPLET_LOSS_MARGIN = 0.5
271
+ # Layer decay for fine-tuning
272
+ _C.TRAIN.LAYER_DECAY = 1.0
273
+
274
+ # -----------------------------------------------------------------------------
275
+ # Augmentation settings
276
+ # -----------------------------------------------------------------------------
277
+ _C.AUG = CN()
278
+ # Crop augmentation
279
+ _C.AUG.MIN_CROP_AREA = 0.2
280
+ _C.AUG.MAX_CROP_AREA = 1.0
281
+ # Flip augmentation
282
+ _C.AUG.HORIZONTAL_FLIP_PROB = 0.5
283
+ _C.AUG.VERTICAL_FLIP_PROB = 0.5
284
+ # Rotation augmentation
285
+ _C.AUG.ROTATION_PROB = 0.5
286
+ _C.AUG.ROTATION_DEGREES = 90
287
+ # Gaussian blur augmentation
288
+ _C.AUG.GAUSSIAN_BLUR_PROB = 0.5
289
+ _C.AUG.GAUSSIAN_BLUR_LIMIT = (3, 9)
290
+ _C.AUG.GAUSSIAN_BLUR_SIGMA = (0.01, 0.5)
291
+ # Gaussian noise augmentation
292
+ _C.AUG.GAUSSIAN_NOISE_PROB = 0.5
293
+ # JPEG compression augmentation
294
+ _C.AUG.JPEG_COMPRESSION_PROB = 0.5
295
+ _C.AUG.JPEG_MIN_QUALITY = 50
296
+ _C.AUG.JPEG_MAX_QUALITY = 100
297
+ # WEBP compression augmentation
298
+ _C.AUG.WEBP_COMPRESSION_PROB = .0
299
+ _C.AUG.WEBP_MIN_QUALITY = 50
300
+ _C.AUG.WEBP_MAX_QUALITY = 100
301
+ # Color jitter augmentation
302
+ _C.AUG.COLOR_JITTER = .0
303
+ _C.AUG.COLOR_JITTER_BRIGHTNESS_RANGE = (0.8, 1.2)
304
+ _C.AUG.COLOR_JITTER_CONTRAST_RANGE = (0.8, 1.2)
305
+ _C.AUG.COLOR_JITTER_SATURATION_RANGE = (0.8, 1.2)
306
+ _C.AUG.COLOR_JITTER_HUE_RANGE = (-0.1, 0.1)
307
+ # Sharpen augmentation
308
+ _C.AUG.SHARPEN_PROB = .0
309
+ _C.AUG.SHARPEN_ALPHA_RANGE = (0.01, 0.4)
310
+ _C.AUG.SHARPEN_LIGHTNESS_RANGE = (0.95, 1)
311
+ # Use AutoAugment policy. "v0" or "original"
312
+ _C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
313
+ # Random erase prob
314
+ _C.AUG.REPROB = 0.25
315
+ # Random erase mode
316
+ _C.AUG.REMODE = 'pixel'
317
+ # Random erase count
318
+ _C.AUG.RECOUNT = 1
319
+ # Probability of applying blurring
320
+ _C.AUG.BLUR_PROB = 0.25
321
+ # Mixup alpha, mixup enabled if > 0
322
+ _C.AUG.MIXUP = 0.8
323
+ # Cutmix alpha, cutmix enabled if > 0
324
+ _C.AUG.CUTMIX = 1.0
325
+ # Cutmix min/max ratio, overrides alpha and enables cutmix if set
326
+ _C.AUG.CUTMIX_MINMAX = None
327
+ # Probability of performing mixup or cutmix when either/both is enabled
328
+ _C.AUG.MIXUP_PROB = 1.0
329
+ # Probability of switching to cutmix when both mixup and cutmix enabled
330
+ _C.AUG.MIXUP_SWITCH_PROB = 0.5
331
+ # How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
332
+ _C.AUG.MIXUP_MODE = 'batch'
333
+
334
+ # -----------------------------------------------------------------------------
335
+ # Testing settings
336
+ # -----------------------------------------------------------------------------
337
+ _C.TEST = CN()
338
+ # Whether to use center crop when testing.
339
+ _C.TEST.CROP = True
340
+ # Size for resizing images during testing.
341
+ _C.TEST.MAX_SIZE: Optional[int] = None
342
+ # When this option is set to True, the original resolution image is provided to the model.
343
+ # Setting this option to True automatically sets the batch size for validation/testing to 1.
344
+ _C.TEST.ORIGINAL_RESOLUTION = False
345
+ # Approach that will be used for generating different views of an image during testing.
346
+ # Currently, "tencrop" and None are supported.
347
+ _C.TEST.VIEWS_GENERATION_APPROACH = None
348
+ # Approach that will be used to combine the scores predicted for multiple views of the same
349
+ # image. This value is meaningful only when a view generation approach is used.
350
+ # Currently, "mean" and "max" are supported.
351
+ _C.TEST.VIEWS_REDUCTION_APPROACH = "mean"
352
+ # A flag that when set to True exports the analysis of the Spectral Context Attention.
353
+ _C.TEST.EXPORT_IMAGE_PATCHES = False
354
+ # -----------------------------------------------------------------------------
355
+ # Setting for Test-Time Perturbations
356
+ # -----------------------------------------------------------------------------
357
+ # Gaussian blur perturbation.
358
+ _C.TEST.GAUSSIAN_BLUR = False
359
+ _C.TEST.GAUSSIAN_BLUR_KERNEL_SIZE = 3
360
+ # Gaussian noise perturbation.
361
+ _C.TEST.GAUSSIAN_NOISE = False
362
+ _C.TEST.GAUSSIAN_NOISE_SIGMA = 1.0
363
+ # JPEG compression perturbation.
364
+ _C.TEST.JPEG_COMPRESSION = False
365
+ _C.TEST.JPEG_QUALITY = 100
366
+ # WEBP compression augmentation.
367
+ _C.TEST.WEBP_COMPRESSION = False
368
+ _C.TEST.WEBP_QUALITY = 100
369
+ # Scale perturbation.
370
+ _C.TEST.SCALE = False
371
+ _C.TEST.SCALE_FACTOR = 1.0
372
+
373
+ # -----------------------------------------------------------------------------
374
+ # Misc
375
+ # -----------------------------------------------------------------------------
376
+ # Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
377
+ # overwritten by command line argument
378
+ _C.AMP_OPT_LEVEL = ''
379
+ # Path to output folder, overwritten by command line argument
380
+ _C.OUTPUT = ''
381
+ # Tag of experiment, overwritten by command line argument
382
+ _C.TAG = 'default'
383
+ # Frequency to save checkpoint
384
+ _C.SAVE_FREQ = 10
385
+ # Frequency to logging info
386
+ _C.PRINT_FREQ = 10
387
+ # Fixed random seed
388
+ _C.SEED = 0
389
+ # Perform evaluation only, overwritten by command line argument
390
+ _C.EVAL_MODE = False
391
+ # Test throughput only, overwritten by command line argument
392
+ _C.THROUGHPUT_MODE = False
393
+ # Local rank for DistributedDataParallel, given by command line argument
394
+ _C.LOCAL_RANK = 0
395
+
396
+ # Path to pre-trained model
397
+ _C.PRETRAINED = ''
398
+
399
+
400
+ def _update_config_from_file(config, cfg_file):
401
+ config.defrost()
402
+ with open(cfg_file, 'r') as f:
403
+ yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
404
+
405
+ for cfg in yaml_cfg.setdefault('BASE', ['']):
406
+ if cfg:
407
+ _update_config_from_file(
408
+ config, os.path.join(os.path.dirname(cfg_file), cfg)
409
+ )
410
+ print('=> merge config from {}'.format(cfg_file))
411
+ config.merge_from_file(cfg_file)
412
+ config.freeze()
413
+
414
+
415
+ def update_config(config, args):
416
+ _update_config_from_file(config, args["cfg"])
417
+
418
+ config.defrost()
419
+ if "opts" in args:
420
+ options: list[Any] = []
421
+ for (k, v) in args["opts"]:
422
+ options.append(k)
423
+ options.append(eval(v))
424
+ config.merge_from_list(options)
425
+
426
+ def _check_args(name):
427
+ if name in args and args[name]:
428
+ return True
429
+ return False
430
+
431
+ # merge from specific arguments
432
+ if _check_args('batch_size'):
433
+ config.DATA.BATCH_SIZE = args["batch_size"]
434
+ if _check_args('data_path'):
435
+ config.DATA.DATA_PATH = args["data_path"]
436
+ if _check_args('csv_root_dir'):
437
+ config.DATA.CSV_ROOT = args["csv_root_dir"]
438
+ if _check_args("lmdb_path"):
439
+ config.DATA.LMDB_PATH = args["lmdb_path"]
440
+ if _check_args('resume'):
441
+ config.MODEL.RESUME = args["resume"]
442
+ if _check_args('pretrained'):
443
+ config.PRETRAINED = args["pretrained"]
444
+ if _check_args('accumulation_steps'):
445
+ config.TRAIN.ACCUMULATION_STEPS = args["accumulation_steps"]
446
+ if _check_args('use_checkpoint'):
447
+ config.TRAIN.USE_CHECKPOINT = True
448
+ if _check_args('amp_opt_level'):
449
+ config.AMP_OPT_LEVEL = args["amp_opt_level"]
450
+ if _check_args('output'):
451
+ config.OUTPUT = args["output"]
452
+ if _check_args('tag'):
453
+ config.TAG = args["tag"]
454
+ if _check_args('eval'):
455
+ config.EVAL_MODE = True
456
+ if _check_args('throughput'):
457
+ config.THROUGHPUT_MODE = True
458
+ if _check_args('test_csv'):
459
+ config.DATA.TEST_DATA_PATH = args["test_csv"]
460
+ if _check_args('test_csv_root'):
461
+ config.DATA.TEST_DATA_CSV_ROOT = args["test_csv_root"]
462
+ if _check_args('learning_rate'):
463
+ config.TRAIN.BASE_LR = args["learning_rate"]
464
+ if _check_args('resize_to'):
465
+ config.TEST.MAX_SIZE = args["resize_to"]
466
+ if _check_args("local_rank"):
467
+ # set local rank for distributed training
468
+ config.LOCAL_RANK = args["local_rank"]
469
+ if _check_args("data_workers"):
470
+ config.DATA.NUM_WORKERS = args["data_workers"]
471
+ if _check_args("disable_pin_memory"):
472
+ config.PIN_MEMORY = False
473
+ if _check_args("data_prefetch_factor"):
474
+ config.DATA.PREFETCH_FACTOR = args["data_prefetch_factor"]
475
+ # output folder
476
+ config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.NAME, config.TAG)
477
+
478
+ config.freeze()
479
+
480
+
481
+ def get_config(args):
482
+ """Get a yacs CfgNode object with default values."""
483
+ # Return a clone so that the defaults will not be altered
484
+ # This is for the "local variable" use pattern
485
+ config = _C.clone()
486
+ update_config(config, args)
487
+
488
+ return config
489
+
490
+
491
+ def get_custom_config(cfg):
492
+ config = _C.clone()
493
+ _update_config_from_file(config, cfg)
494
+ return config
spai/data/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
2
+ # and University of Amsterdam. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .data_mfm import build_loader_mfm
18
+ from .data_finetune import build_loader_finetune, build_loader_test
19
+
20
+ def build_loader(config, logger, is_pretrain, is_test):
21
+ if is_pretrain:
22
+ return build_loader_mfm(config, logger)
23
+ elif is_test:
24
+ return build_loader_test(config, logger)
25
+ else:
26
+ return build_loader_finetune(config, logger)
spai/data/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (501 Bytes). View file
 
spai/data/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (577 Bytes). View file
 
spai/data/__pycache__/blur_kernels.cpython-310.pyc ADDED
Binary file (14.5 kB). View file
 
spai/data/__pycache__/blur_kernels.cpython-313.pyc ADDED
Binary file (20.2 kB). View file
 
spai/data/__pycache__/data_finetune.cpython-310.pyc ADDED
Binary file (19.4 kB). View file
 
spai/data/__pycache__/data_finetune.cpython-313.pyc ADDED
Binary file (38.6 kB). View file
 
spai/data/__pycache__/data_mfm.cpython-310.pyc ADDED
Binary file (4.89 kB). View file
 
spai/data/__pycache__/data_mfm.cpython-313.pyc ADDED
Binary file (9.35 kB). View file
 
spai/data/__pycache__/filestorage.cpython-310.pyc ADDED
Binary file (10.8 kB). View file
 
spai/data/__pycache__/filestorage.cpython-313.pyc ADDED
Binary file (17.5 kB). View file
 
spai/data/__pycache__/random_degradations.cpython-310.pyc ADDED
Binary file (12.2 kB). View file
 
spai/data/__pycache__/random_degradations.cpython-313.pyc ADDED
Binary file (20.9 kB). View file
 
spai/data/__pycache__/readers.cpython-310.pyc ADDED
Binary file (6.41 kB). View file
 
spai/data/__pycache__/readers.cpython-313.pyc ADDED
Binary file (9.32 kB). View file
 
spai/data/blur_kernels.py ADDED
@@ -0,0 +1,539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is referenced from BasicSR with modifications.
2
+ # Reference: https://github.com/xinntao/BasicSR/blob/master/basicsr/data/degradations.py # noqa
3
+ # Original licence: Copyright (c) 2020 xinntao, under the Apache 2.0 license.
4
+
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ from scipy import special
10
+
11
+
12
+ def get_rotated_sigma_matrix(sig_x, sig_y, theta):
13
+ """Calculate the rotated sigma matrix (two dimensional matrix).
14
+
15
+ Args:
16
+ sig_x (float): Standard deviation along the horizontal direction.
17
+ sig_y (float): Standard deviation along the vertical direction.
18
+ theta (float): Rotation in radian.
19
+
20
+ Returns:
21
+ ndarray: Rotated sigma matrix.
22
+ """
23
+
24
+ diag = np.array([[sig_x**2, 0], [0, sig_y**2]]).astype(np.float32)
25
+ rot = np.array([[np.cos(theta), -np.sin(theta)],
26
+ [np.sin(theta), np.cos(theta)]]).astype(np.float32)
27
+
28
+ return np.matmul(rot, np.matmul(diag, rot.T))
29
+
30
+
31
+ def _mesh_grid(kernel_size):
32
+ """Generate the mesh grid, centering at zero.
33
+
34
+ Args:
35
+ kernel_size (int): The size of the kernel.
36
+
37
+ Returns:
38
+ x_grid (ndarray): x-coordinates with shape (kernel_size, kernel_size).
39
+ y_grid (ndarray): y-coordiantes with shape (kernel_size, kernel_size).
40
+ xy_grid (ndarray): stacked coordinates with shape
41
+ (kernel_size, kernel_size, 2).
42
+ """
43
+
44
+ range_ = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
45
+ x_grid, y_grid = np.meshgrid(range_, range_)
46
+ xy_grid = np.hstack((x_grid.reshape((kernel_size * kernel_size, 1)),
47
+ y_grid.reshape(kernel_size * kernel_size,
48
+ 1))).reshape(kernel_size, kernel_size,
49
+ 2)
50
+
51
+ return xy_grid, x_grid, y_grid
52
+
53
+
54
+ def calculate_gaussian_pdf(sigma_matrix, grid):
55
+ """Calculate PDF of the bivariate Gaussian distribution.
56
+
57
+ Args:
58
+ sigma_matrix (ndarray): The variance matrix with shape (2, 2).
59
+ grid (ndarray): Coordinates generated by :func:`_mesh_grid`,
60
+ with shape (K, K, 2), where K is the kernel size.
61
+
62
+ Returns:
63
+ kernel (ndarrray): Un-normalized kernel.
64
+ """
65
+
66
+ inverse_sigma = np.linalg.inv(sigma_matrix)
67
+ kernel = np.exp(-0.5 * np.sum(np.matmul(grid, inverse_sigma) * grid, 2))
68
+
69
+ return kernel
70
+
71
+
72
+ def bivariate_gaussian(kernel_size,
73
+ sig_x,
74
+ sig_y=None,
75
+ theta=None,
76
+ grid=None,
77
+ is_isotropic=True):
78
+ """Generate a bivariate isotropic or anisotropic Gaussian kernel.
79
+
80
+ In isotropic mode, only `sig_x` is used. `sig_y` and `theta` are
81
+ ignored.
82
+
83
+ Args:
84
+ kernel_size (int): The size of the kernel
85
+ sig_x (float): Standard deviation along horizontal direction.
86
+ sig_y (float | None, optional): Standard deviation along the vertical
87
+ direction. If it is None, 'is_isotropic' must be set to True.
88
+ Default: None.
89
+ theta (float | None, optional): Rotation in radian. If it is None,
90
+ 'is_isotropic' must be set to True. Default: None.
91
+ grid (ndarray, optional): Coordinates generated by :func:`_mesh_grid`,
92
+ with shape (K, K, 2), where K is the kernel size. Default: None
93
+ is_isotropic (bool, optional): Whether to use an isotropic kernel.
94
+ Default: True.
95
+
96
+ Returns:
97
+ kernel (ndarray): normalized kernel (i.e. sum to 1).
98
+ """
99
+
100
+ if grid is None:
101
+ grid, _, _ = _mesh_grid(kernel_size)
102
+
103
+ if is_isotropic:
104
+ sigma_matrix = np.array([[sig_x**2, 0], [0,
105
+ sig_x**2]]).astype(np.float32)
106
+ else:
107
+ if sig_y is None:
108
+ raise ValueError('"sig_y" cannot be None if "is_isotropic" is '
109
+ 'False.')
110
+
111
+ sigma_matrix = get_rotated_sigma_matrix(sig_x, sig_y, theta)
112
+
113
+ kernel = calculate_gaussian_pdf(sigma_matrix, grid)
114
+ kernel = kernel / np.sum(kernel)
115
+
116
+ return kernel
117
+
118
+
119
+ def bivariate_generalized_gaussian(kernel_size,
120
+ sig_x,
121
+ sig_y=None,
122
+ theta=None,
123
+ beta=1,
124
+ grid=None,
125
+ is_isotropic=True):
126
+ """Generate a bivariate generalized Gaussian kernel.
127
+
128
+ Described in `Parameter Estimation For Multivariate Generalized
129
+ Gaussian Distributions` by Pascal et. al (2013). In isotropic mode,
130
+ only `sig_x` is used. `sig_y` and `theta` is ignored.
131
+
132
+ Args:
133
+ kernel_size (int): The size of the kernel
134
+ sig_x (float): Standard deviation along horizontal direction
135
+ sig_y (float | None, optional): Standard deviation along the vertical
136
+ direction. If it is None, 'is_isotropic' must be set to True.
137
+ Default: None.
138
+ theta (float | None, optional): Rotation in radian. If it is None,
139
+ 'is_isotropic' must be set to True. Default: None.
140
+ beta (float, optional): Shape parameter, beta = 1 is the normal
141
+ distribution. Default: 1.
142
+ grid (ndarray, optional): Coordinates generated by :func:`_mesh_grid`,
143
+ with shape (K, K, 2), where K is the kernel size. Default: None
144
+ is_isotropic (bool, optional): Whether to use an isotropic kernel.
145
+ Default: True.
146
+
147
+ Returns:
148
+ kernel (ndarray): normalized kernel.
149
+
150
+ """
151
+
152
+ if grid is None:
153
+ grid, _, _ = _mesh_grid(kernel_size)
154
+
155
+ if is_isotropic:
156
+ sigma_matrix = np.array([[sig_x**2, 0], [0,
157
+ sig_x**2]]).astype(np.float32)
158
+ else:
159
+ sigma_matrix = get_rotated_sigma_matrix(sig_x, sig_y, theta)
160
+
161
+ inverse_sigma = np.linalg.inv(sigma_matrix)
162
+ kernel = np.exp(
163
+ -0.5 *
164
+ np.power(np.sum(np.matmul(grid, inverse_sigma) * grid, 2), beta))
165
+ kernel = kernel / np.sum(kernel)
166
+
167
+ return kernel
168
+
169
+
170
+ def bivariate_plateau(kernel_size,
171
+ sig_x,
172
+ sig_y,
173
+ theta,
174
+ beta,
175
+ grid=None,
176
+ is_isotropic=True):
177
+ """Generate a plateau-like anisotropic kernel.
178
+
179
+ This kernel has a form of 1 / (1+x^(beta)).
180
+ Ref: https://stats.stackexchange.com/questions/203629/is-there-a-plateau-shaped-distribution # noqa
181
+ In the isotropic mode, only `sig_x` is used. `sig_y` and `theta` is ignored.
182
+
183
+ Args:
184
+ kernel_size (int): The size of the kernel
185
+ sig_x (float): Standard deviation along horizontal direction
186
+ sig_y (float): Standard deviation along the vertical direction.
187
+ theta (float): Rotation in radian.
188
+ beta (float): Shape parameter, beta = 1 is the normal distribution.
189
+ grid (ndarray, optional): Coordinates generated by :func:`_mesh_grid`,
190
+ with shape (K, K, 2), where K is the kernel size. Default: None
191
+ is_isotropic (bool, optional): Whether to use an isotropic kernel.
192
+ Default: True.
193
+ Returns:
194
+ kernel (ndarray): normalized kernel (i.e. sum to 1).
195
+ """
196
+ if grid is None:
197
+ grid, _, _ = _mesh_grid(kernel_size)
198
+
199
+ if is_isotropic:
200
+ sigma_matrix = np.array([[sig_x**2, 0], [0,
201
+ sig_x**2]]).astype(np.float32)
202
+ else:
203
+ sigma_matrix = get_rotated_sigma_matrix(sig_x, sig_y, theta)
204
+
205
+ inverse_sigma = np.linalg.inv(sigma_matrix)
206
+ kernel = np.reciprocal(
207
+ np.power(np.sum(np.matmul(grid, inverse_sigma) * grid, 2), beta) + 1)
208
+ kernel = kernel / np.sum(kernel)
209
+
210
+ return kernel
211
+
212
+
213
+ def random_bivariate_gaussian_kernel(kernel_size,
214
+ sigma_x_range,
215
+ sigma_y_range,
216
+ rotation_range,
217
+ noise_range=None,
218
+ is_isotropic=True):
219
+ """Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
220
+
221
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and
222
+ `rotation_range` is ignored.
223
+
224
+ Args:
225
+ kernel_size (int): The size of the kernel.
226
+ sigma_x_range (tuple): The range of the standard deviation along the
227
+ horizontal direction. Default: [0.6, 5]
228
+ sigma_y_range (tuple): The range of the standard deviation along the
229
+ vertical direction. Default: [0.6, 5]
230
+ rotation_range (tuple): Range of rotation in radian.
231
+ noise_range (tuple, optional): Multiplicative kernel noise.
232
+ Default: None.
233
+ is_isotropic (bool, optional): Whether to use an isotropic kernel.
234
+ Default: True.
235
+
236
+ Returns:
237
+ kernel (ndarray): The kernel whose parameters are sampled from the
238
+ specified range.
239
+ """
240
+
241
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
242
+ assert sigma_x_range[0] <= sigma_x_range[1], 'Wrong sigma_x_range.'
243
+
244
+ sigma_x = random.uniform(sigma_x_range[0], sigma_x_range[1])
245
+ if is_isotropic is False:
246
+ assert sigma_y_range[0] <= sigma_y_range[1], 'Wrong sigma_y_range.'
247
+ assert rotation_range[0] <= rotation_range[1], 'Wrong rotation_range.'
248
+ sigma_y = random.uniform(sigma_y_range[0], sigma_y_range[1])
249
+ rotation = random.uniform(rotation_range[0], rotation_range[1])
250
+ else:
251
+ sigma_y = sigma_x
252
+ rotation = 0
253
+
254
+ kernel = bivariate_gaussian(
255
+ kernel_size, sigma_x, sigma_y, rotation, is_isotropic=is_isotropic)
256
+
257
+ # add multiplicative noise
258
+ if noise_range is not None:
259
+ assert noise_range[0] <= noise_range[1], 'Wrong noise range.'
260
+ noise = torch.FloatTensor(
261
+ *(kernel.shape)).uniform_(noise_range[0], noise_range[1]).numpy()
262
+ kernel = kernel * noise
263
+ kernel = kernel / np.sum(kernel)
264
+
265
+ return kernel
266
+
267
+
268
+ def random_bivariate_generalized_gaussian_kernel(kernel_size,
269
+ sigma_x_range,
270
+ sigma_y_range,
271
+ rotation_range,
272
+ beta_range,
273
+ noise_range=None,
274
+ is_isotropic=True):
275
+ """Randomly generate bivariate generalized Gaussian kernels.
276
+
277
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and
278
+ `rotation_range` is ignored.
279
+
280
+ Args:
281
+ kernel_size (int): The size of the kernel.
282
+ sigma_x_range (tuple): The range of the standard deviation along the
283
+ horizontal direction. Default: [0.6, 5]
284
+ sigma_y_range (tuple): The range of the standard deviation along the
285
+ vertical direction. Default: [0.6, 5]
286
+ rotation_range (tuple): Range of rotation in radian.
287
+ beta_range (float): The range of the shape parameter, beta = 1 is the
288
+ normal distribution.
289
+ noise_range (tuple, optional): Multiplicative kernel noise.
290
+ Default: None.
291
+ is_isotropic (bool, optional): Whether to use an isotropic kernel.
292
+ Default: True.
293
+
294
+ Returns:
295
+ kernel (ndarray):
296
+ """
297
+
298
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
299
+ assert sigma_x_range[0] <= sigma_x_range[1], 'Wrong sigma_x_range.'
300
+
301
+ sigma_x = random.uniform(sigma_x_range[0], sigma_x_range[1])
302
+ if is_isotropic is False:
303
+ assert sigma_y_range[0] <= sigma_y_range[1], 'Wrong sigma_y_range.'
304
+ assert rotation_range[0] <= rotation_range[1], 'Wrong rotation_range.'
305
+ sigma_y = random.uniform(sigma_y_range[0], sigma_y_range[1])
306
+ rotation = random.uniform(rotation_range[0], rotation_range[1])
307
+ else:
308
+ sigma_y = sigma_x
309
+ rotation = 0
310
+
311
+ # assume beta_range[0] <= 1 <= beta_range[1]
312
+ if random.random() <= 0.5:
313
+ beta = random.uniform(beta_range[0], 1)
314
+ else:
315
+ beta = random.uniform(1, beta_range[1])
316
+
317
+ kernel = bivariate_generalized_gaussian(
318
+ kernel_size,
319
+ sigma_x,
320
+ sigma_y,
321
+ rotation,
322
+ beta,
323
+ is_isotropic=is_isotropic)
324
+
325
+ # add multiplicative noise
326
+ if noise_range is not None:
327
+ assert noise_range[0] <= noise_range[1], 'Wrong noise range.'
328
+ noise = torch.FloatTensor(
329
+ *(kernel.shape)).uniform_(noise_range[0], noise_range[1]).numpy()
330
+ kernel = kernel * noise
331
+ kernel = kernel / np.sum(kernel)
332
+
333
+ return kernel
334
+
335
+
336
+ def random_bivariate_plateau_kernel(kernel_size,
337
+ sigma_x_range,
338
+ sigma_y_range,
339
+ rotation_range,
340
+ beta_range,
341
+ noise_range=None,
342
+ is_isotropic=True):
343
+ """Randomly generate bivariate plateau kernels.
344
+
345
+ In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and
346
+ `rotation_range` is ignored.
347
+
348
+ Args:
349
+ kernel_size (int): The size of the kernel.
350
+ sigma_x_range (tuple): The range of the standard deviation along the
351
+ horizontal direction. Default: [0.6, 5]
352
+ sigma_y_range (tuple): The range of the standard deviation along the
353
+ vertical direction. Default: [0.6, 5]
354
+ rotation_range (tuple): Range of rotation in radian.
355
+ beta_range (float): The range of the shape parameter, beta = 1 is the
356
+ normal distribution.
357
+ noise_range (tuple, optional): Multiplicative kernel noise.
358
+ Default: None.
359
+ is_isotropic (bool, optional): Whether to use an isotropic kernel.
360
+ Default: True.
361
+
362
+ Returns:
363
+ kernel (ndarray):
364
+ """
365
+
366
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
367
+ assert sigma_x_range[0] <= sigma_x_range[1], 'Wrong sigma_x_range.'
368
+ sigma_x = random.uniform(sigma_x_range[0], sigma_x_range[1])
369
+
370
+ if is_isotropic is False:
371
+ assert sigma_y_range[0] <= sigma_y_range[1], 'Wrong sigma_y_range.'
372
+ assert rotation_range[0] <= rotation_range[1], 'Wrong rotation_range.'
373
+ sigma_y = random.uniform(sigma_y_range[0], sigma_y_range[1])
374
+ rotation = random.uniform(rotation_range[0], rotation_range[1])
375
+ else:
376
+ sigma_y = sigma_x
377
+ rotation = 0
378
+
379
+ # TODO: this may be not proper
380
+ if random.random() <= 0.5:
381
+ beta = random.uniform(beta_range[0], 1)
382
+ else:
383
+ beta = random.uniform(1, beta_range[1])
384
+
385
+ kernel = bivariate_plateau(
386
+ kernel_size,
387
+ sigma_x,
388
+ sigma_y,
389
+ rotation,
390
+ beta,
391
+ is_isotropic=is_isotropic)
392
+
393
+ # add multiplicative noise
394
+ if noise_range is not None:
395
+ assert noise_range[0] <= noise_range[1], 'Wrong noise range.'
396
+ noise = torch.FloatTensor(
397
+ *(kernel.shape)).uniform_(noise_range[0], noise_range[1]).numpy()
398
+ kernel = kernel * noise
399
+ kernel = kernel / np.sum(kernel)
400
+
401
+ return kernel
402
+
403
+
404
+ def random_circular_lowpass_kernel(omega_range, kernel_size, pad_to=0):
405
+ """ Generate a 2D Sinc filter
406
+
407
+ Reference: https://dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter # noqa
408
+
409
+ Args:
410
+ omega_range (tuple): The cutoff frequency in radian (pi is max).
411
+ kernel_size (int): The size of the kernel. It must be an odd number.
412
+ pad_to (int, optional): The size of the padded kernel. It must be odd
413
+ or zero. Default: 0.
414
+
415
+ Returns:
416
+ ndarray: The Sinc kernel with specified parameters.
417
+ """
418
+ err = np.geterr()
419
+ np.seterr(divide='ignore', invalid='ignore')
420
+
421
+ assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
422
+ omega = random.uniform(omega_range[0], omega_range[-1])
423
+
424
+ kernel = np.fromfunction(
425
+ lambda x, y: omega * special.j1(omega * np.sqrt(
426
+ (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) /
427
+ (2 * np.pi * np.sqrt((x - (kernel_size - 1) / 2)**2 +
428
+ (y - (kernel_size - 1) / 2)**2)),
429
+ [kernel_size, kernel_size])
430
+ kernel[(kernel_size - 1) // 2,
431
+ (kernel_size - 1) // 2] = omega**2 / (4 * np.pi)
432
+ kernel = kernel / np.sum(kernel)
433
+
434
+ if pad_to > kernel_size:
435
+ pad_size = (pad_to - kernel_size) // 2
436
+ kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
437
+
438
+ np.seterr(**err)
439
+
440
+ return kernel
441
+
442
+
443
+ def random_mixed_kernels(kernel_list,
444
+ kernel_prob,
445
+ kernel_size,
446
+ sigma_x_range=[0.6, 5],
447
+ sigma_y_range=[0.6, 5],
448
+ rotation_range=[-np.pi, np.pi],
449
+ beta_gaussian_range=[0.5, 8],
450
+ beta_plateau_range=[1, 2],
451
+ omega_range=[0, np.pi],
452
+ noise_range=None):
453
+ """Randomly generate a kernel.
454
+
455
+
456
+ Args:
457
+ kernel_list (list): A list of kernel types. Choices are
458
+ 'iso', 'aniso', 'skew', 'generalized_iso', 'generalized_aniso',
459
+ 'plateau_iso', 'plateau_aniso', 'sinc'.
460
+ kernel_prob (list): The probability of choosing of the corresponding
461
+ kernel.
462
+ kernel_size (int): The size of the kernel.
463
+ sigma_x_range (list, optional): The range of the standard deviation
464
+ along the horizontal direction. Default: (0.6, 5).
465
+ sigma_y_range (list, optional): The range of the standard deviation
466
+ along the vertical direction. Default: (0.6, 5).
467
+ rotation_range (list, optional): Range of rotation in radian.
468
+ Default: (-np.pi, np.pi).
469
+ beta_gaussian_range (list, optional): The range of the shape parameter
470
+ for generalized Gaussian. Default: (0.5, 8).
471
+ beta_plateau_range (list, optional): The range of the shape parameter
472
+ for plateau kernel. Default: (1, 2).
473
+ omega_range (list, optional): The range of omega used in Sinc kernel.
474
+ Default: (0, np.pi).
475
+ noise_range (list, optional): Multiplicative kernel noise.
476
+ Default: None.
477
+
478
+ Returns:
479
+ kernel (ndarray): The kernel whose parameters are sampled from the
480
+ specified range.
481
+ """
482
+
483
+ kernel_type = random.choices(kernel_list, weights=kernel_prob)[0]
484
+ if kernel_type == 'iso':
485
+ kernel = random_bivariate_gaussian_kernel(
486
+ kernel_size,
487
+ sigma_x_range,
488
+ sigma_y_range,
489
+ rotation_range,
490
+ noise_range=noise_range,
491
+ is_isotropic=True)
492
+ elif kernel_type == 'aniso':
493
+ kernel = random_bivariate_gaussian_kernel(
494
+ kernel_size,
495
+ sigma_x_range,
496
+ sigma_y_range,
497
+ rotation_range,
498
+ noise_range=noise_range,
499
+ is_isotropic=False)
500
+ elif kernel_type == 'generalized_iso':
501
+ kernel = random_bivariate_generalized_gaussian_kernel(
502
+ kernel_size,
503
+ sigma_x_range,
504
+ sigma_y_range,
505
+ rotation_range,
506
+ beta_gaussian_range,
507
+ noise_range=noise_range,
508
+ is_isotropic=True)
509
+ elif kernel_type == 'generalized_aniso':
510
+ kernel = random_bivariate_generalized_gaussian_kernel(
511
+ kernel_size,
512
+ sigma_x_range,
513
+ sigma_y_range,
514
+ rotation_range,
515
+ beta_gaussian_range,
516
+ noise_range=noise_range,
517
+ is_isotropic=False)
518
+ elif kernel_type == 'plateau_iso':
519
+ kernel = random_bivariate_plateau_kernel(
520
+ kernel_size,
521
+ sigma_x_range,
522
+ sigma_y_range,
523
+ rotation_range,
524
+ beta_plateau_range,
525
+ noise_range=None,
526
+ is_isotropic=True)
527
+ elif kernel_type == 'plateau_aniso':
528
+ kernel = random_bivariate_plateau_kernel(
529
+ kernel_size,
530
+ sigma_x_range,
531
+ sigma_y_range,
532
+ rotation_range,
533
+ beta_plateau_range,
534
+ noise_range=None,
535
+ is_isotropic=False)
536
+ elif kernel_type == 'sinc':
537
+ kernel = random_circular_lowpass_kernel(omega_range, kernel_size)
538
+
539
+ return kernel
spai/data/data_finetune.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
2
+ # and University of Amsterdam. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import collections
18
+ import os
19
+ import pathlib
20
+ import random
21
+ from functools import partial
22
+ from typing import Any, Union, Optional, Iterable
23
+ from collections.abc import Callable
24
+
25
+ import albumentations as A
26
+ import torchvision.transforms.functional
27
+ from albumentations.augmentations.transforms import ImageCompressionType
28
+ from albumentations.pytorch import ToTensorV2
29
+ import numpy as np
30
+ import torch
31
+ from torch.utils.data import DataLoader
32
+ from PIL import Image
33
+ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
34
+ from timm.data import Mixup
35
+ import cv2
36
+ from torchvision.transforms.v2.functional import ten_crop, pad
37
+ import filetype
38
+
39
+ from spai.data import readers
40
+ from spai.data import filestorage
41
+ from spai import data_utils
42
+
43
+
44
+ class CSVDataset(torch.utils.data.Dataset):
45
+ def __init__(
46
+ self,
47
+ csv_path: pathlib.Path,
48
+ csv_root_path: pathlib.Path,
49
+ split: str,
50
+ transform,
51
+ path_column: str = "image",
52
+ split_column: str = "split",
53
+ class_column: str = "class",
54
+ views: int = 1,
55
+ concatenate_views_horizontally: bool = False,
56
+ lmdb_storage: Optional[pathlib.Path] = None,
57
+ views_generator: Optional[Callable[[Image.Image], tuple[Image.Image, ...]]] = None
58
+ ):
59
+ super().__init__()
60
+ self.csv_path: pathlib.Path = csv_path
61
+ self.csv_root_path: pathlib.Path = csv_root_path
62
+ self.split: str = split
63
+ self.path_column: str = path_column
64
+ self.split_column: str = split_column
65
+ self.class_column: str = class_column
66
+ self.transform = transform
67
+ self.views: int = views
68
+ self.views_generator: Optional[
69
+ Callable[[Image.Image], tuple[Image.Image, ...]]] = views_generator
70
+ self.concatenate_views_horizontally: bool = concatenate_views_horizontally
71
+ self.lmdb_storage: Optional[pathlib.Path] = lmdb_storage
72
+
73
+ # Reader to be used for data loading. Its creation is deferred
74
+ self.data_reader: Optional[readers.DataReader] = None
75
+
76
+ if split not in ["train", "val", "test"]:
77
+ raise RuntimeError(f"Unsupported split: {split}")
78
+
79
+ # Path of the CSV file is expected to be absolute.
80
+ reader = readers.FileSystemReader(pathlib.Path("/"))
81
+ self.entries: list[dict[str, Any]] = reader.read_csv_file(str(self.csv_path))
82
+ self.entries = [e for e in self.entries if e[self.split_column] == self.split]
83
+
84
+ self.num_classes: int = len(
85
+ collections.Counter([e[self.class_column] for e in self.entries]).keys()
86
+ )
87
+
88
+ def __len__(self):
89
+ return len(self.entries)
90
+
91
+ def __getitem__(self, idx: int) -> tuple[torch.Tensor, np.ndarray, int]:
92
+ """Returns the requested image sample from the dataset.
93
+
94
+ :returns: A tuple containing the image tensor, the labels numpy array and the
95
+ index in dataset.
96
+ Image tensor: (V x 3 x H x W) where V is the number of augmented views.
97
+ Label array: (1, )
98
+ Index
99
+ """
100
+ # Defer the creation of the data reader until the first read operation in order to
101
+ # properly handle the spawning of multiple processes by DataLoader, where each one
102
+ # should contain a separate reader object.
103
+ if self.data_reader is None:
104
+ self._create_data_reader()
105
+
106
+ # Load sample.
107
+ img_obj: Image.Image = self.data_reader.load_image(
108
+ self.entries[idx][self.path_column], channels=3
109
+ )
110
+ label: int = int(self.entries[idx][self.class_column])
111
+
112
+ # Generate multiple views of an image either through a provided views generation
113
+ # function or through multiple augmentations of the image.
114
+ if self.views_generator is not None:
115
+ augmented_views: tuple[Image.Image, ...] = self.views_generator(img_obj)
116
+ augmented_views: list[np.ndarray] = [np.array(v) for v in augmented_views]
117
+ augmented_views: list[torch.Tensor] = [
118
+ self.transform(image=v)["image"] for v in augmented_views
119
+ ]
120
+ else:
121
+ img: np.ndarray = np.array(img_obj)
122
+ augmented_views: list[torch.Tensor] = []
123
+ for _ in range(self.views):
124
+ augmented_views.append(self.transform(image=img)["image"])
125
+
126
+ # Either concatenate the views in a single big image, or provide them stacked
127
+ # into a new tensor dimension.
128
+ if self.concatenate_views_horizontally:
129
+ augmented_img: torch.Tensor = torch.cat(augmented_views, dim=-1)
130
+ augmented_img = augmented_img.unsqueeze(dim=0)
131
+ else:
132
+ augmented_img: torch.Tensor = torch.stack(augmented_views, dim=0)
133
+
134
+ # Cleanup resources.
135
+ img_obj.close()
136
+
137
+ return augmented_img, np.array(label, dtype=float), idx
138
+
139
+ def get_classes_num(self) -> int:
140
+ return self.num_classes
141
+
142
+ def get_dataset_root_path(self) -> pathlib.Path:
143
+ if self.lmdb_storage is not None:
144
+ return self.lmdb_storage
145
+ else:
146
+ return self.csv_root_path
147
+
148
+ def update_dataset_csv(
149
+ self,
150
+ column_name: str,
151
+ values: dict[int, Any],
152
+ export_dir: Optional[pathlib.Path] = None
153
+ ) -> None:
154
+ for idx, v in values.items():
155
+ self.entries[idx][column_name] = v
156
+
157
+ # Make sure that a valid value for the updated column exists for all entries.
158
+ for e in self.entries:
159
+ if column_name not in e:
160
+ e[column_name] = ""
161
+
162
+ if export_dir:
163
+ export_path: pathlib.Path = export_dir / self.csv_path.name
164
+ data_utils.write_csv_file(self.entries, export_path, delimiter=",")
165
+
166
+ def _create_data_reader(self) -> None:
167
+ # Limit the number of OpenCV threads to 2 to utilize multiple processes. Otherwise,
168
+ # each process spawns a number of threads equal to the number of logical cores and
169
+ # the overall performance gets worse due to threads congestion.
170
+ cv2.setNumThreads(1)
171
+
172
+ if self.lmdb_storage is None:
173
+ self.data_reader: readers.FileSystemReader = readers.FileSystemReader(
174
+ pathlib.Path(self.csv_root_path)
175
+ )
176
+ else:
177
+ self.data_reader: readers.LMDBFileStorageReader = readers.LMDBFileStorageReader(
178
+ filestorage.LMDBFileStorage(self.lmdb_storage, read_only=True)
179
+ )
180
+
181
+
182
+ class CSVDatasetTriplet(torch.utils.data.Dataset):
183
+ def __init__(
184
+ self,
185
+ csv_path: pathlib.Path,
186
+ csv_root_path: pathlib.Path,
187
+ split: str,
188
+ transform,
189
+ path_column: str = "image",
190
+ split_column: str = "split",
191
+ class_column: str = "class",
192
+ lmdb_storage: Optional[pathlib.Path] = None
193
+ ):
194
+ super().__init__()
195
+ self.csv_path: pathlib.Path = csv_path
196
+ self.csv_root_path: pathlib.Path = csv_root_path
197
+ self.split: str = split
198
+ self.path_column: str = path_column
199
+ self.split_column: str = split_column
200
+ self.class_column: str = class_column
201
+ self.transform = transform
202
+ self.lmdb_storage: Optional[pathlib.Path] = lmdb_storage
203
+
204
+ # Reader to be used for data loading. Its creation is deferred
205
+ self.data_reader: Optional[readers.DataReader] = None
206
+
207
+ if split not in ["train", "val", "test"]:
208
+ raise RuntimeError(f"Unsupported split: {split}")
209
+
210
+ # Path of the CSV file is expected to be absolute.
211
+ reader = readers.FileSystemReader(pathlib.Path("/"))
212
+ self.entries: list[dict[str, Any]] = reader.read_csv_file(str(self.csv_path))
213
+ self.entries = [e for e in self.entries if e[self.split_column] == self.split]
214
+
215
+ self.num_classes: int = len(
216
+ collections.Counter([e[self.class_column] for e in self.entries]).keys()
217
+ )
218
+
219
+ # Save paths that will be accessed by different dataloaders as numpy arrays in
220
+ # order to avoid copy-on-read of python objects, and thus child processes to
221
+ # take huge amounts of memory.
222
+ self.anchor_v: Optional[np.ndarray] = None
223
+ self.anchor_o: Optional[np.ndarray] = None
224
+ self.positive_v: Optional[np.ndarray] = None
225
+ self.positive_o: Optional[np.ndarray] = None
226
+ self.negative_v: Optional[np.ndarray] = None
227
+ self.negative_o: Optional[np.ndarray] = None
228
+ self.triplets_num: Optional[int] = None
229
+ self.generate_triplets()
230
+
231
+ def __len__(self) -> int:
232
+ return self.triplets_num
233
+
234
+ def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
235
+ """Returns the triplet with the specified index.
236
+
237
+ :returns: A tuple in the form of (anchor_img, positive_img, negative_img)
238
+ """
239
+ # Defer the creation of the data reader until the first read operation in order to
240
+ # properly handle the spawning of multiple processes by DataLoader, where each one
241
+ # should contain a separate reader object.
242
+ if self.data_reader is None:
243
+ self._create_data_reader()
244
+
245
+ anchor_path: str = sequence_to_string(unpack_sequence(self.anchor_v, self.anchor_o, idx))
246
+ positive_path: str = sequence_to_string(
247
+ unpack_sequence(self.positive_v, self.positive_o, idx))
248
+ negative_path: str = sequence_to_string(
249
+ unpack_sequence(self.negative_v, self.negative_o, idx))
250
+
251
+ anchor_img_obj: Image.Image = self.data_reader.load_image(anchor_path, channels=3)
252
+ positive_img_obj: Image.Image = self.data_reader.load_image(positive_path, channels=3)
253
+ negative_img_obj: Image.Image = self.data_reader.load_image(negative_path, channels=3)
254
+
255
+ anchor_img: np.ndarray = np.array(anchor_img_obj)
256
+ positive_img: np.ndarray = np.array(positive_img_obj)
257
+ negative_img: np.ndarray = np.array(negative_img_obj)
258
+
259
+ anchor_img_obj.close()
260
+ positive_img_obj.close()
261
+ negative_img_obj.close()
262
+
263
+ return (self.transform(image=anchor_img)["image"],
264
+ self.transform(image=positive_img)["image"],
265
+ self.transform(image=negative_img)["image"])
266
+
267
+ def get_classes_num(self) -> int:
268
+ return self.num_classes
269
+
270
+ def get_dataset_root_path(self) -> pathlib.Path:
271
+ if self.lmdb_storage is not None:
272
+ return self.lmdb_storage
273
+ else:
274
+ return self.csv_root_path
275
+
276
+ def generate_triplets(self) -> None:
277
+ # Separate the entries into groups of each class.
278
+ entries_per_class: dict[int, list[dict[str, Any]]] = {
279
+ i: [] for i in range(self.num_classes)
280
+ }
281
+ for e in self.entries:
282
+ entries_per_class[int(e[self.class_column])].append(e)
283
+
284
+ triplets: list[tuple[dict[str,Any], dict[str, Any], dict[str, Any]]] = []
285
+ for class_id, class_group in entries_per_class.items():
286
+ class_group: list[dict[str, Any]] = list(class_group)
287
+ rest_groups: list[list[dict[str, Any]]] = list(entries_per_class.values())
288
+ del rest_groups[class_id]
289
+
290
+ for i, e in enumerate(class_group):
291
+ negative_sample: dict[str, Any] = random.choice(random.choice(rest_groups))
292
+
293
+ positive_sample: dict[str, Any] = random.choice(class_group)
294
+ while e == positive_sample:
295
+ positive_sample: dict[str, Any] = random.choice(class_group)
296
+
297
+ triplets.append((e, positive_sample, negative_sample))
298
+
299
+ self.anchor_v, self.anchor_o = pack_sequences(
300
+ [string_to_sequence(t[0][self.path_column]) for t in triplets]
301
+ )
302
+ self.positive_v, self.positive_o = pack_sequences(
303
+ [string_to_sequence(t[1][self.path_column]) for t in triplets]
304
+ )
305
+ self.negative_v, self.negative_o = pack_sequences(
306
+ [string_to_sequence(t[2][self.path_column]) for t in triplets]
307
+ )
308
+ self.anchor_labels: np.ndarray = np.array([int(t[0][self.class_column]) for t in triplets])
309
+ self.triplets_num = len(triplets)
310
+
311
+ def _create_data_reader(self) -> None:
312
+ # Limit the number of OpenCV threads to 2 to utilize multiple processes. Otherwise,
313
+ # each process spawns a number of threads equal to the number of logical cores and
314
+ # the overall performance gets worse due to threads congestion.
315
+ cv2.setNumThreads(1)
316
+
317
+ if self.lmdb_storage is None:
318
+ self.data_reader: readers.FileSystemReader = readers.FileSystemReader(
319
+ pathlib.Path(self.csv_root_path)
320
+ )
321
+ else:
322
+ self.data_reader: readers.LMDBFileStorageReader = readers.LMDBFileStorageReader(
323
+ filestorage.LMDBFileStorage(self.lmdb_storage, read_only=True)
324
+ )
325
+
326
+
327
+ def build_loader_finetune(config, logger):
328
+ config.defrost()
329
+ dataset_train, config.MODEL.NUM_CLASSES = build_dataset(
330
+ config.DATA.DATA_PATH,
331
+ config.DATA.CSV_ROOT,
332
+ config=config,
333
+ split_name="train",
334
+ logger=logger
335
+ )
336
+ config.freeze()
337
+ dataset_val, _ = build_dataset(
338
+ config.DATA.DATA_PATH,
339
+ config.DATA.CSV_ROOT,
340
+ config=config,
341
+ split_name="val",
342
+ logger=logger
343
+ )
344
+ logger.info(f"Train images: {len(dataset_train)} | Validation images: {len(dataset_val)}")
345
+ logger.info(f"Train Images Source: {dataset_train.get_dataset_root_path()}")
346
+ logger.info(f"Validation Images Source: {dataset_val.get_dataset_root_path()}")
347
+
348
+ data_loader_train = DataLoader(
349
+ dataset_train,
350
+ batch_size=config.DATA.BATCH_SIZE,
351
+ num_workers=config.DATA.NUM_WORKERS,
352
+ pin_memory=config.DATA.PIN_MEMORY,
353
+ drop_last=True,
354
+ shuffle=True,
355
+ prefetch_factor=config.DATA.PREFETCH_FACTOR
356
+ )
357
+ data_loader_val = DataLoader(
358
+ dataset_val,
359
+ batch_size=config.DATA.VAL_BATCH_SIZE or config.DATA.BATCH_SIZE,
360
+ num_workers=config.DATA.NUM_WORKERS,
361
+ pin_memory=config.DATA.PIN_MEMORY,
362
+ drop_last=False,
363
+ prefetch_factor=config.DATA.VAL_PREFETCH_FACTOR or config.DATA.PREFETCH_FACTOR,
364
+ collate_fn=(torch.utils.data.default_collate
365
+ if not config.MODEL.RESOLUTION_MODE == "arbitrary"
366
+ else image_enlisting_collate_fn)
367
+ )
368
+
369
+ # Setup mixup / cutmix
370
+ mixup_fn = None
371
+ mixup_active: bool = (config.AUG.MIXUP > 0
372
+ or config.AUG.CUTMIX > 0.
373
+ or config.AUG.CUTMIX_MINMAX is not None)
374
+ if mixup_active:
375
+ mixup_fn = Mixup(
376
+ mixup_alpha=config.AUG.MIXUP,
377
+ cutmix_alpha=config.AUG.CUTMIX,
378
+ cutmix_minmax=config.AUG.CUTMIX_MINMAX,
379
+ prob=config.AUG.MIXUP_PROB,
380
+ switch_prob=config.AUG.MIXUP_SWITCH_PROB,
381
+ mode=config.AUG.MIXUP_MODE,
382
+ label_smoothing=config.MODEL.LABEL_SMOOTHING,
383
+ num_classes=config.MODEL.NUM_CLASSES
384
+ )
385
+
386
+ return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn
387
+
388
+
389
+ def build_loader_test(
390
+ config,
391
+ logger,
392
+ split: str = "test",
393
+ dummy_csv_dir: Optional[pathlib.Path] = None,
394
+ ) -> tuple[list[str], list[torch.utils.data.Dataset], list[torch.utils.data.DataLoader]]:
395
+ # Obtain the root directory for each test input (either a CSV file or a directory).
396
+ input_root_paths: list[pathlib.Path]
397
+ if len(config.DATA.TEST_DATA_CSV_ROOT) > 1:
398
+ input_root_paths = [pathlib.Path(p) for p in config.DATA.TEST_DATA_CSV_ROOT]
399
+ elif len(config.DATA.TEST_DATA_CSV_ROOT) == 1:
400
+ input_root_paths = [pathlib.Path(config.DATA.TEST_DATA_CSV_ROOT[0])
401
+ for _ in config.DATA.TEST_DATA_PATH]
402
+ else:
403
+ input_root_paths = [pathlib.Path(input_path).parent
404
+ for input_path in config.DATA.TEST_DATA_PATH]
405
+
406
+ # If some input is a directory, create a dummy csv file for it.
407
+ csv_paths: list[pathlib.Path] = []
408
+ csv_root_paths: list[pathlib.Path] = []
409
+ for input_path, input_root_path in zip(config.DATA.TEST_DATA_PATH, input_root_paths):
410
+ input_path: pathlib.Path = pathlib.Path(input_path)
411
+ if input_path.is_dir():
412
+ # Create a dummy csv and point directories
413
+ if dummy_csv_dir is None:
414
+ dummy_csv_dir = pathlib.Path("./outputs")
415
+ entries: list[dict[str, str]] = [
416
+ {
417
+ "image": str(file_path.name),
418
+ "split": split,
419
+ "class": "1" # TODO: Remove csv requirement for dummy ground-truth.
420
+ }
421
+ for file_path in input_path.iterdir() if filetype.is_image(file_path)
422
+ ]
423
+ dummy_csv_path: pathlib.Path = dummy_csv_dir / f"{input_path.stem}.csv"
424
+ data_utils.write_csv_file(entries, dummy_csv_path, delimiter=",")
425
+ csv_paths.append(dummy_csv_path.absolute())
426
+ csv_root_paths.append(input_path.absolute()) # Paths in CSV are relative to input dir.
427
+ else:
428
+ csv_paths.append(input_path.absolute())
429
+ csv_root_paths.append(input_root_path.absolute())
430
+
431
+ # Obtain the separate testing sets and their names.
432
+ test_datasets: list[CSVDataset] = []
433
+ test_datasets_names: list[str] = []
434
+ num_classes_per_dataset: list[int] = []
435
+ for csv_path, csv_root_path in zip(csv_paths, csv_root_paths):
436
+ csv_path: pathlib.Path = pathlib.Path(csv_path)
437
+ dataset: CSVDataset
438
+ dataset, num_classes = build_dataset(csv_path, csv_root_path, config, split, logger)
439
+ test_datasets.append(dataset)
440
+ test_datasets_names.append(csv_path.stem)
441
+ num_classes_per_dataset.append(num_classes)
442
+ # Check that the number of classes match among all test sets.
443
+ unique_number_of_classes: list[int] = list(collections.Counter(num_classes_per_dataset).keys())
444
+ if len(unique_number_of_classes) > 1:
445
+ raise RuntimeError(
446
+ f"Encountered different number of classes among test sets: {unique_number_of_classes}"
447
+ )
448
+
449
+ for dataset, dataset_name in zip(test_datasets, test_datasets_names):
450
+ logger.info(f"Dataset \'{dataset_name}\' | Split: {split} | Total images: {len(dataset)} | "
451
+ f"Source: {dataset.get_dataset_root_path()}")
452
+
453
+ # Create the corresponding data loaders.
454
+ force_cpu: bool = os.environ.get("SPAI_FORCE_CPU", "0") == "1"
455
+ cuda_visible: str = os.environ.get("CUDA_VISIBLE_DEVICES", "")
456
+ use_cuda: bool = (not force_cpu) and (cuda_visible != "")
457
+ test_num_workers: int = config.DATA.NUM_WORKERS if use_cuda else min(config.DATA.NUM_WORKERS, 2)
458
+ test_pin_memory: bool = config.DATA.PIN_MEMORY if use_cuda else False
459
+
460
+ test_data_loaders: list[torch.utils.data.DataLoader] = [
461
+ DataLoader(
462
+ dataset,
463
+ batch_size=config.DATA.TEST_BATCH_SIZE or config.DATA.BATCH_SIZE,
464
+ num_workers=test_num_workers,
465
+ pin_memory=test_pin_memory,
466
+ drop_last=False,
467
+ prefetch_factor=config.DATA.TEST_PREFETCH_FACTOR or config.DATA.PREFETCH_FACTOR,
468
+ collate_fn=(torch.utils.data.default_collate
469
+ if not config.MODEL.RESOLUTION_MODE == "arbitrary"
470
+ else image_enlisting_collate_fn)
471
+ )
472
+ for dataset in test_datasets
473
+ ]
474
+
475
+ return test_datasets_names, test_datasets, test_data_loaders
476
+
477
+
478
+ def build_dataset(
479
+ csv_path: pathlib.Path,
480
+ csv_root_dir: pathlib.Path,
481
+ config,
482
+ split_name: str,
483
+ logger,
484
+ ) -> tuple[Union[CSVDataset, CSVDatasetTriplet], int]:
485
+ if split_name not in ["train", "val", "test"]:
486
+ raise RuntimeError(f"Unsupported split: {split_name}")
487
+
488
+ transform = build_transform(split_name == "train", config)
489
+ logger.info(f"Data transform | mode: {config.TRAIN.MODE} | split: {split_name}:\n{transform}")
490
+
491
+ if split_name == "train" and config.TRAIN.LOSS == "triplet":
492
+ dataset = CSVDatasetTriplet(
493
+ csv_path,
494
+ csv_root_dir,
495
+ split=split_name,
496
+ transform=transform,
497
+ lmdb_storage=pathlib.Path(config.DATA.LMDB_PATH) if config.DATA.LMDB_PATH else None
498
+ )
499
+ elif split_name == "train" and config.TRAIN.LOSS == "supcont":
500
+ assert config.DATA.AUGMENTED_VIEWS > 1, "SupCon loss requires at least 2 views."
501
+ dataset = CSVDataset(
502
+ csv_path,
503
+ csv_root_dir,
504
+ split=split_name,
505
+ transform=transform,
506
+ views=config.DATA.AUGMENTED_VIEWS,
507
+ lmdb_storage=pathlib.Path(config.DATA.LMDB_PATH) if config.DATA.LMDB_PATH else None
508
+ )
509
+ elif split_name == "train" and config.MODEL.RESOLUTION_MODE == "arbitrary":
510
+ dataset = CSVDataset(
511
+ csv_path,
512
+ csv_root_dir,
513
+ split=split_name,
514
+ transform=transform,
515
+ views=config.DATA.AUGMENTED_VIEWS,
516
+ concatenate_views_horizontally=True,
517
+ lmdb_storage=pathlib.Path(config.DATA.LMDB_PATH) if config.DATA.LMDB_PATH else None
518
+ )
519
+ else:
520
+ views_generator: Optional[Callable[[Image.Image], tuple[Image.Image, ...]]]
521
+ if config.TEST.VIEWS_GENERATION_APPROACH == "tencrop":
522
+ def safe_ten_crop(img: Image.Image) -> tuple[Image.Image, ...]:
523
+ width = img.width
524
+ height = img.height
525
+ left_padding: int = max((config.DATA.IMG_SIZE - width) // 2, 0)
526
+ right_padding: int = max(
527
+ (config.DATA.IMG_SIZE - width) // 2
528
+ + (((config.DATA.IMG_SIZE - width) % 2) if config.DATA.IMG_SIZE > width else 0),
529
+ 0
530
+ )
531
+ top_padding: int = max((config.DATA.IMG_SIZE - height) // 2, 0)
532
+ bottom_padding: int = max(
533
+ (config.DATA.IMG_SIZE - height) // 2
534
+ + (((config.DATA.IMG_SIZE - height) % 2) if config.DATA.IMG_SIZE > height else 0),
535
+ 0
536
+ )
537
+ img = pad(img, [left_padding, top_padding, right_padding, bottom_padding])
538
+ return ten_crop(img, size=config.DATA.IMG_SIZE)
539
+
540
+ views_generator = safe_ten_crop
541
+ elif config.TEST.VIEWS_GENERATION_APPROACH is None:
542
+ views_generator = None
543
+ else:
544
+ raise TypeError(f"{config.TEST.VIEW_GENERATION_APPROACH} is not a supported "
545
+ f"view generation approach.")
546
+
547
+ dataset = CSVDataset(
548
+ csv_path,
549
+ csv_root_dir,
550
+ split=split_name,
551
+ transform=transform,
552
+ lmdb_storage=pathlib.Path(config.DATA.LMDB_PATH) if config.DATA.LMDB_PATH else None,
553
+ views_generator=views_generator
554
+ )
555
+ num_classes: int = dataset.get_classes_num()
556
+
557
+ return dataset, num_classes
558
+
559
+
560
+ def build_transform(is_train, config) -> Callable[[np.ndarray], np.ndarray]:
561
+ # resize_im: bool = config.DATA.IMG_SIZE > 32
562
+ # # this should always dispatch to transforms_imagenet_train
563
+ # transform = create_transform(
564
+ # input_size=config.DATA.IMG_SIZE,
565
+ # is_training=True,
566
+ # color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None,
567
+ # auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None,
568
+ # re_prob=config.AUG.REPROB,
569
+ # re_mode=config.AUG.REMODE,
570
+ # re_count=config.AUG.RECOUNT,
571
+ # interpolation=config.DATA.INTERPOLATION,
572
+ # )
573
+ # if not resize_im:
574
+ # # replace RandomResizedCropAndInterpolation with
575
+ # # RandomCrop
576
+ # transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4)
577
+ # transform.transforms.insert(0, torchvision.transforms.v2.JPEG((50, 100)))
578
+ # transform.transforms.insert(4, torchvision.transforms.GaussianBlur(kernel_size=(3, 9), sigma=(0.01, 0.5)))
579
+
580
+ if is_train: # Training augmentations
581
+ transforms_list = []
582
+
583
+ if config.AUG.MIN_CROP_AREA == config.AUG.MAX_CROP_AREA:
584
+ transforms_list.append(
585
+ A.PadIfNeeded(min_height=config.DATA.IMG_SIZE, min_width=config.DATA.IMG_SIZE)
586
+ )
587
+ transforms_list.append(
588
+ A.RandomCrop(height=config.DATA.IMG_SIZE, width=config.DATA.IMG_SIZE)
589
+ )
590
+ else:
591
+ transforms_list.append(
592
+ A.RandomResizedCrop(size=(config.DATA.IMG_SIZE, config.DATA.IMG_SIZE),
593
+ scale=(config.AUG.MIN_CROP_AREA, config.AUG.MAX_CROP_AREA))
594
+ )
595
+ transforms_list.extend([
596
+ A.HorizontalFlip(p=config.AUG.HORIZONTAL_FLIP_PROB),
597
+ A.VerticalFlip(p=config.AUG.VERTICAL_FLIP_PROB),
598
+ A.Rotate(limit=config.AUG.ROTATION_DEGREES,
599
+ crop_border=True,
600
+ p=config.AUG.ROTATION_PROB)
601
+ ])
602
+ if config.AUG.ROTATION_PROB > .0:
603
+ # Rotation with crop_border set to True leads to images smaller than the target
604
+ # size. So, restore the target size.
605
+ transforms_list.append(
606
+ A.Resize(height=config.DATA.IMG_SIZE, width=config.DATA.IMG_SIZE)
607
+ )
608
+ transforms_list.extend([
609
+ A.GaussianBlur(blur_limit=(3, 9),
610
+ sigma_limit=(0.01, 0.5),
611
+ p=config.AUG.GAUSSIAN_BLUR_PROB),
612
+ A.GaussNoise(p=config.AUG.GAUSSIAN_NOISE_PROB),
613
+ A.ColorJitter(
614
+ p=config.AUG.COLOR_JITTER,
615
+ brightness=config.AUG.COLOR_JITTER_BRIGHTNESS_RANGE,
616
+ contrast=config.AUG.COLOR_JITTER_CONTRAST_RANGE,
617
+ saturation=config.AUG.COLOR_JITTER_SATURATION_RANGE,
618
+ hue=config.AUG.COLOR_JITTER_HUE_RANGE,
619
+ ),
620
+ A.Sharpen(p=config.AUG.SHARPEN_PROB,
621
+ alpha=config.AUG.SHARPEN_ALPHA_RANGE,
622
+ lightness=config.AUG.SHARPEN_LIGHTNESS_RANGE),
623
+ A.ImageCompression(quality_lower=config.AUG.JPEG_MIN_QUALITY,
624
+ quality_upper=config.AUG.JPEG_MAX_QUALITY,
625
+ compression_type=ImageCompressionType.JPEG,
626
+ p=config.AUG.JPEG_COMPRESSION_PROB),
627
+ A.ImageCompression(quality_lower=config.AUG.WEBP_MIN_QUALITY,
628
+ quality_upper=config.AUG.WEBP_MAX_QUALITY,
629
+ compression_type=ImageCompressionType.WEBP,
630
+ p=config.AUG.WEBP_COMPRESSION_PROB),
631
+ ])
632
+ if config.MODEL.REQUIRED_NORMALIZATION == "imagenet":
633
+ transforms_list.append(
634
+ A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)
635
+ )
636
+ elif config.MODEL.REQUIRED_NORMALIZATION == "positive_0_1":
637
+ transforms_list.append(
638
+ A.Normalize(mean=0., std=1.)
639
+ )
640
+ else:
641
+ raise RuntimeError(f"Unsupported Normalization: {config.MODEL.REQUIRED_NORMALIZATION}")
642
+ transforms_list.append(ToTensorV2())
643
+ transform = A.Compose(transforms_list)
644
+
645
+ else: # Inference augmentations
646
+ transforms_list = [
647
+ A.ImageCompression(quality_lower=config.TEST.JPEG_QUALITY,
648
+ quality_upper=config.TEST.JPEG_QUALITY,
649
+ compression_type=ImageCompressionType.JPEG,
650
+ p=1.0 if config.TEST.JPEG_COMPRESSION else .0),
651
+ A.ImageCompression(quality_lower=config.TEST.WEBP_QUALITY,
652
+ quality_upper=config.TEST.WEBP_QUALITY,
653
+ compression_type=ImageCompressionType.WEBP,
654
+ p=1.0 if config.TEST.WEBP_COMPRESSION else .0),
655
+ A.GaussianBlur(blur_limit=(config.TEST.GAUSSIAN_BLUR_KERNEL_SIZE,
656
+ config.TEST.GAUSSIAN_BLUR_KERNEL_SIZE),
657
+ sigma_limit=0,
658
+ p=1.0 if config.TEST.GAUSSIAN_BLUR else .0),
659
+ A.GaussNoise(var_limit=(config.TEST.GAUSSIAN_NOISE_SIGMA**2,
660
+ config.TEST.GAUSSIAN_NOISE_SIGMA**2),
661
+ p=1.0 if config.TEST.GAUSSIAN_NOISE else .0),
662
+ A.RandomScale(scale_limit=(config.TEST.SCALE_FACTOR-1, config.TEST.SCALE_FACTOR-1),
663
+ p=1.0 if config.TEST.SCALE else .0)
664
+ ]
665
+ if config.TEST.MAX_SIZE is not None:
666
+ transforms_list.append(A.SmallestMaxSize(max_size=config.TEST.MAX_SIZE))
667
+
668
+ if config.TEST.ORIGINAL_RESOLUTION:
669
+ transforms_list.append(A.PadIfNeeded(min_height=config.DATA.IMG_SIZE,
670
+ min_width=config.DATA.IMG_SIZE))
671
+ elif config.TEST.CROP:
672
+ transforms_list.append(A.PadIfNeeded(min_height=config.DATA.IMG_SIZE,
673
+ min_width=config.DATA.IMG_SIZE))
674
+ transforms_list.append(A.CenterCrop(height=config.DATA.IMG_SIZE,
675
+ width=config.DATA.IMG_SIZE))
676
+ else:
677
+ transforms_list.append(A.Resize(config.DATA.IMG_SIZE, config.DATA.IMG_SIZE))
678
+ if config.MODEL.REQUIRED_NORMALIZATION == "imagenet":
679
+ transforms_list.append(A.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD))
680
+ elif config.MODEL.REQUIRED_NORMALIZATION == "positive_0_1":
681
+ transforms_list.append(A.Normalize(mean=0., std=1.))
682
+ else:
683
+ raise RuntimeError(f"Unsupported Normalization: {config.MODEL.REQUIRED_NORMALIZATION}")
684
+ transforms_list.append(ToTensorV2())
685
+ transform = A.Compose(transforms_list)
686
+
687
+ return transform
688
+
689
+
690
+ def string_to_sequence(s: str, dtype=np.int32) -> np.ndarray:
691
+ return np.array([ord(c) for c in s], dtype=dtype)
692
+
693
+
694
+ def sequence_to_string(seq: np.ndarray) -> str:
695
+ return ''.join([chr(c) for c in seq])
696
+
697
+
698
+ def pack_sequences(seqs: Union[np.ndarray, list]) -> (np.ndarray, np.ndarray):
699
+ values = np.concatenate(seqs, axis=0)
700
+ offsets = np.cumsum([len(s) for s in seqs])
701
+ return values, offsets
702
+
703
+
704
+ def unpack_sequence(values: np.ndarray, offsets: np.ndarray, index: int) -> np.ndarray:
705
+ off1 = offsets[index]
706
+ if index > 0:
707
+ off0 = offsets[index - 1]
708
+ elif index == 0:
709
+ off0 = 0
710
+ else:
711
+ raise ValueError(index)
712
+ return values[off0:off1]
713
+
714
+
715
+ def image_enlisting_collate_fn(
716
+ batch: Iterable[tuple[torch.Tensor, np.ndarray, int]]
717
+ ) -> tuple[list[torch.Tensor], torch.Tensor, torch.Tensor]:
718
+ """Collate function that enlists its entries."""
719
+ return (
720
+ [torch.utils.data.default_collate([s[0]]) for s in batch],
721
+ torch.utils.data.default_collate([s[1] for s in batch]),
722
+ torch.utils.data.default_collate([s[2] for s in batch]),
723
+ )
spai/data/data_mfm.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.distributed as dist
4
+ import torchvision.transforms as T
5
+ from torch.utils.data import DataLoader, DistributedSampler
6
+ from torch.utils.data._utils.collate import default_collate
7
+ from torchvision.datasets import ImageFolder
8
+ from timm.data.transforms import _pil_interp
9
+
10
+ from .random_degradations import RandomBlur, RandomNoise
11
+
12
+
13
+ class FreqMaskGenerator:
14
+ def __init__(self,
15
+ input_size=224,
16
+ mask_radius1=16,
17
+ mask_radius2=999,
18
+ sample_ratio=0.5):
19
+ self.input_size = input_size
20
+ self.mask_radius1 = mask_radius1
21
+ self.mask_radius2 = mask_radius2
22
+ self.sample_ratio = sample_ratio
23
+ self.mask = np.ones((self.input_size, self.input_size), dtype=int)
24
+ for y in range(self.input_size):
25
+ for x in range(self.input_size):
26
+ if ((x - self.input_size // 2) ** 2 + (y - self.input_size // 2) ** 2) >= self.mask_radius1 ** 2 \
27
+ and ((x - self.input_size // 2) ** 2 + (y - self.input_size // 2) ** 2) < self.mask_radius2 ** 2:
28
+ self.mask[y, x] = 0
29
+
30
+ def __call__(self):
31
+ rnd = torch.bernoulli(torch.tensor(self.sample_ratio, dtype=torch.float)).item()
32
+ if rnd == 0: # high-pass
33
+ return 1 - self.mask
34
+ elif rnd == 1: # low-pass
35
+ return self.mask
36
+ else:
37
+ raise ValueError
38
+
39
+
40
+ class MFMTransform:
41
+ def __init__(self, config):
42
+ self.transform_img = T.Compose([
43
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
44
+ T.RandomResizedCrop(config.DATA.IMG_SIZE, scale=(config.DATA.MIN_CROP_SCALE, 1.), interpolation=_pil_interp(config.DATA.INTERPOLATION)),
45
+ T.RandomHorizontalFlip(),
46
+ ])
47
+
48
+ self.filter_type = config.DATA.FILTER_TYPE
49
+
50
+ if config.MODEL.TYPE == 'swin':
51
+ model_patch_size = config.MODEL.SWIN.PATCH_SIZE
52
+ elif config.MODEL.TYPE == 'vit':
53
+ model_patch_size = config.MODEL.VIT.PATCH_SIZE
54
+ elif config.MODEL.TYPE == 'resnet':
55
+ model_patch_size = 1
56
+ else:
57
+ raise NotImplementedError
58
+
59
+ if config.DATA.FILTER_TYPE == 'deblur':
60
+ self.degrade_transform = RandomBlur(
61
+ params=dict(
62
+ kernel_size=config.DATA.BLUR.KERNEL_SIZE,
63
+ kernel_list=config.DATA.BLUR.KERNEL_LIST,
64
+ kernel_prob=config.DATA.BLUR.KERNEL_PROB,
65
+ sigma_x=config.DATA.BLUR.SIGMA_X,
66
+ sigma_y=config.DATA.BLUR.SIGMA_Y,
67
+ rotate_angle=config.DATA.BLUR.ROTATE_ANGLE,
68
+ beta_gaussian=config.DATA.BLUR.BETA_GAUSSIAN,
69
+ beta_plateau=config.DATA.BLUR.BETA_PLATEAU),
70
+ )
71
+ elif config.DATA.FILTER_TYPE == 'denoise':
72
+ self.degrade_transform = RandomNoise(
73
+ params=dict(
74
+ noise_type=config.DATA.NOISE.TYPE,
75
+ noise_prob=config.DATA.NOISE.PROB,
76
+ gaussian_sigma=config.DATA.NOISE.GAUSSIAN_SIGMA,
77
+ gaussian_gray_noise_prob=config.DATA.NOISE.GAUSSIAN_GRAY_NOISE_PROB,
78
+ poisson_scale=config.DATA.NOISE.POISSON_SCALE,
79
+ poisson_gray_noise_prob=config.DATA.NOISE.POISSON_GRAY_NOISE_PROB),
80
+ )
81
+ elif config.DATA.FILTER_TYPE == 'mfm':
82
+ self.freq_mask_generator = FreqMaskGenerator(
83
+ input_size=config.DATA.IMG_SIZE,
84
+ mask_radius1=config.DATA.MASK_RADIUS1,
85
+ mask_radius2=config.DATA.MASK_RADIUS2,
86
+ sample_ratio=config.DATA.SAMPLE_RATIO
87
+ )
88
+
89
+ def __call__(self, img):
90
+ img = self.transform_img(img) # PIL Image (HxWxC, 0-255), no normalization
91
+ if self.filter_type in ['deblur', 'denoise']:
92
+ img_lq = np.array(img).astype(np.float32) / 255.
93
+ img_lq = self.degrade_transform(img_lq)
94
+ img_lq = torch.from_numpy(img_lq.transpose(2, 0, 1))
95
+ else:
96
+ img_lq = None
97
+ img = T.ToTensor()(img) # Tensor (CxHxW, 0-1)
98
+ if self.filter_type == 'mfm':
99
+ mask = self.freq_mask_generator()
100
+ else:
101
+ mask = None
102
+
103
+ return img, img_lq, mask
104
+
105
+
106
+ def collate_fn(batch):
107
+ if not isinstance(batch[0][0], tuple):
108
+ return default_collate(batch)
109
+ else:
110
+ batch_num = len(batch)
111
+ ret = []
112
+ for item_idx in range(len(batch[0][0])):
113
+ if batch[0][0][item_idx] is None:
114
+ ret.append(None)
115
+ else:
116
+ ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
117
+ ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
118
+ return ret
119
+
120
+
121
+ def build_loader_mfm(config, logger):
122
+ transform = MFMTransform(config)
123
+ logger.info(f'Pre-train data transform:\n{transform}')
124
+
125
+ dataset = ImageFolder(config.DATA.DATA_PATH, transform)
126
+ logger.info(f'Build dataset: train images = {len(dataset)}')
127
+
128
+ sampler = DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank(), shuffle=True)
129
+ dataloader = DataLoader(dataset, config.DATA.BATCH_SIZE, sampler=sampler, num_workers=config.DATA.NUM_WORKERS, pin_memory=True, drop_last=True, collate_fn=collate_fn)
130
+
131
+ return dataloader
spai/data/filestorage.py ADDED
@@ -0,0 +1,387 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
2
+ # and University of Amsterdam. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import csv
18
+ import hashlib
19
+ import io
20
+ import logging
21
+ import pathlib
22
+ from collections import Counter
23
+ from typing import Union, Optional
24
+
25
+ import click
26
+ import lmdb
27
+ import tqdm
28
+ import networkx as nx
29
+
30
+
31
+ __version__: str = "0.1.0-alpha"
32
+ __revision__: int = 2
33
+ __author__: str = "Dimitrios Karageorgiou"
34
+ __email__: str = "dkarageo@iti.gr"
35
+
36
+
37
+ class LMDBFileStorage:
38
+ """A file storage for handling large datasets based on LMDB."""
39
+ def __init__(self,
40
+ db_path: pathlib.Path,
41
+ map_size: int = 1024*1024*1024*1024, # 1TB
42
+ read_only: bool = False,
43
+ max_readers: int = 128):
44
+ self.db: lmdb.Environment = lmdb.open(
45
+ str(db_path),
46
+ map_size=map_size,
47
+ subdir=False,
48
+ readonly=read_only,
49
+ max_readers=max_readers,
50
+ lock=False,
51
+ sync=False
52
+ )
53
+
54
+ def open_file(self, file_id: str, mode: str = "r") -> Union[io.TextIOWrapper, io.BytesIO]:
55
+ """Returns a file-like stream of a file in the database."""
56
+ # with self.db.begin() as trans:
57
+ # data: bytes = trans.get(file_id.encode("ascii"))
58
+ with self.db.begin(buffers=True) as trans:
59
+ data = trans.get(file_id.encode("utf-8"))
60
+ stream: io.BytesIO = io.BytesIO(data)
61
+
62
+ if mode == "r":
63
+ reader: io.TextIOWrapper = io.TextIOWrapper(stream)
64
+ elif mode == "b":
65
+ reader: io.BytesIO = stream
66
+ else:
67
+ raise RuntimeError(f"Unsupported file mode: '{mode}'. Only 'r' and 'b' are supported.")
68
+
69
+ return reader
70
+
71
+ def write_file(self, file_id: str, file_data: bytes) -> None:
72
+ with self.db.begin(write=True) as trans:
73
+ trans.put(file_id.encode("utf-8"), file_data)
74
+
75
+ def get_all_ids(self) -> list[str]:
76
+ with self.db.begin() as trans:
77
+ cursor = trans.cursor()
78
+ ids: list[str] = [k for k, _ in cursor]
79
+ return ids
80
+
81
+ def close(self) -> None:
82
+ self.db.close()
83
+
84
+
85
+ @click.group()
86
+ def cli() -> None:
87
+ pass
88
+
89
+
90
+ @cli.command()
91
+ @click.option("-c", "--csv_file", required=True,
92
+ type=click.Path(dir_okay=False, exists=True, path_type=pathlib.Path),
93
+ help="Path to a CSV file containing relative paths to the dataset files.")
94
+ @click.option("-b", "--base_dir",
95
+ type=click.Path(file_okay=False, exists=True, path_type=pathlib.Path),
96
+ help="Base directory of the dataset. Paths inside the CSV should be relative "
97
+ "to that path. When not provided, the directory of the CSV file is "
98
+ "considered as the base directory.")
99
+ @click.option("-o", "--output_file", required=True,
100
+ type=click.Path(dir_okay=False, path_type=pathlib.Path),
101
+ help="Path to the database. If the file does not "
102
+ "exist, a new database is generated. Otherwise, it should point to a "
103
+ "previous instance of the LMDB, where data will be added.")
104
+ def add_csv(
105
+ csv_file: pathlib.Path,
106
+ base_dir: Optional[pathlib.Path],
107
+ output_file: pathlib.Path
108
+ ) -> None:
109
+ if base_dir is None:
110
+ base_dir = csv_file.parent
111
+ db: LMDBFileStorage = LMDBFileStorage(output_file)
112
+ add_csv_to_db(csv_file, db, base_dir)
113
+ db.close()
114
+
115
+
116
+ @cli.command()
117
+ @click.option("-s", "--src", required=True,
118
+ type=click.Path(dir_okay=False, path_type=pathlib.Path, exists=True),
119
+ help="Database whose files will be added to the destination database.")
120
+ @click.option("-d", "--dest", required=True,
121
+ type=click.Path(dir_okay=False, path_type=pathlib.Path),
122
+ help="Database where file from source database will be added.")
123
+ def add_db(
124
+ src: pathlib.Path,
125
+ dest: pathlib.Path
126
+ ) -> None:
127
+ """Adds all the contents of a database to another."""
128
+ src_db: LMDBFileStorage = LMDBFileStorage(src, read_only=True)
129
+ dest_db: LMDBFileStorage = LMDBFileStorage(dest)
130
+
131
+ for k in tqdm.tqdm(src_db.get_all_ids(), desc="Copying files", unit="file"):
132
+ k = str(k, 'UTF-8')
133
+ src_file: io.BytesIO = src_db.open_file(k, mode="b")
134
+ dest_db.write_file(k, src_file.read())
135
+
136
+ src_db.close()
137
+ dest_db.close()
138
+
139
+
140
+ @cli.command()
141
+ @click.option("-c", "--csv_file", required=True,
142
+ type=click.Path(dir_okay=False, exists=True, path_type=pathlib.Path),
143
+ help="Path to a CSV file containing relative paths to the dataset files.")
144
+ @click.option("-b", "--base_dir",
145
+ type=click.Path(file_okay=False, exists=True, path_type=pathlib.Path),
146
+ help="Base directory of the dataset. Paths inside the CSV should be relative "
147
+ "to that path. When not provided, the directory of the CSV file is "
148
+ "considered as the base directory.")
149
+ @click.option("-o", "--output_file", required=True,
150
+ type=click.Path(dir_okay=False, path_type=pathlib.Path, exists=True),
151
+ help="Path to the database to verify.")
152
+ def verify_csv(
153
+ csv_file: pathlib.Path,
154
+ base_dir: Optional[pathlib.Path],
155
+ output_file: pathlib.Path
156
+ ) -> None:
157
+ if base_dir is None:
158
+ base_dir = csv_file.parent
159
+ db: LMDBFileStorage = LMDBFileStorage(output_file, read_only=True)
160
+ verify_csv_in_db(csv_file, db, base_dir)
161
+ db.close()
162
+
163
+
164
+ @cli.command()
165
+ @click.option("-d", "--database", required=True,
166
+ type=click.Path(dir_okay=False, path_type=pathlib.Path, exists=True),
167
+ help="Database whose keys will be printed.")
168
+ @click.option("-h", "--hierarchical", is_flag=True,
169
+ help="List files in DB according to directories hierarchy.")
170
+ def list_db(
171
+ database: pathlib.Path,
172
+ hierarchical: bool
173
+ ) -> None:
174
+ """Lists the contents of a file storage."""
175
+ db: LMDBFileStorage = LMDBFileStorage(database, read_only=True)
176
+
177
+ if not hierarchical:
178
+ for k in db.get_all_ids():
179
+ print(k)
180
+ else:
181
+ # The db contains filenames as keys, so their parents will always be the dir names.
182
+ ids: list[str] = [str(pathlib.Path(str(k, 'UTF-8')).parent) for k in db.get_all_ids()]
183
+ counts: Counter = Counter(ids)
184
+
185
+ dir_graph: nx.DiGraph = nx.DiGraph()
186
+ for k in counts.keys():
187
+ dir_graph.add_edge(str(pathlib.Path(k).parent), k)
188
+ dir_graph.nodes[k]["items_num"] = counts[k]
189
+
190
+ top_level_nodes: list[str] = [n for n in dir_graph.nodes if dir_graph.in_degree(n) == 0]
191
+ top_level_nodes = sorted(top_level_nodes)
192
+
193
+ for n in top_level_nodes:
194
+ print_dirs_from_graph(dir_graph, n)
195
+
196
+
197
+ def add_csv_to_db(
198
+ csv_file: pathlib.Path,
199
+ db: LMDBFileStorage,
200
+ base_dir: pathlib.Path,
201
+ key_base_dir: Optional[pathlib.Path] = None,
202
+ verbose: bool = True
203
+ ) -> int:
204
+ """Adds the contents of the file paths included in a CSV file into an LMDB File Storage.
205
+
206
+ Paths of the files, relative to the base dir, are utilized as keys into the storage.
207
+ Thus, the maximum allowed path length is 511 bytes.
208
+
209
+ The contents of nested CSV files are recursively added into the LMDB File Storage.
210
+ In that case, keys represent the file structure relative to the base dir.
211
+
212
+ :param csv_file: Path to a CSV file describing a dataset.
213
+ :param db: An instance of LMDB File Storage, where files will be added.
214
+ :param base_dir: Directory where paths included into the CSV file are relative to.
215
+ :param key_base_dir: Directory where paths encoded into the keys of the LMDB File
216
+ Storage will be relative to. It should be either the same or an upper directory
217
+ compared to base dir. When this argument is omitted, the value of base dir
218
+ is used.
219
+ :param verbose: When set to False, progress messages will not be printed.
220
+ """
221
+ entries: list[dict[str, str]] = read_csv_file(csv_file, verbose=verbose)
222
+
223
+ if key_base_dir is None:
224
+ key_base_dir = base_dir
225
+
226
+ if verbose:
227
+ pbar = tqdm.tqdm(entries, desc="Writing CSV data to database", unit="file")
228
+ else:
229
+ pbar = entries
230
+
231
+ files_written: int = 0
232
+ for e in pbar:
233
+ # Generate key-path pairs for each path in the CSV.
234
+ files_to_write: dict[str, pathlib.Path] = find_files(
235
+ list(e.values()),
236
+ base_dir,
237
+ key_base_dir
238
+ )
239
+
240
+ files_written += write_files_to_db(files_to_write, db)
241
+
242
+ # Recursively add the contents of the encountered CSV files.
243
+ for p in files_to_write.values():
244
+ if p.suffix == ".csv":
245
+ files_written += add_csv_to_db(
246
+ p, db, p.parent, key_base_dir=key_base_dir, verbose=False
247
+ )
248
+
249
+ if verbose:
250
+ pbar.set_postfix({"Files Written": files_written})
251
+
252
+ return files_written
253
+
254
+
255
+ def verify_csv_in_db(
256
+ csv_file: pathlib.Path,
257
+ db: LMDBFileStorage,
258
+ base_dir: pathlib.Path,
259
+ key_base_dir: Optional[pathlib.Path] = None,
260
+ verbose: bool = True
261
+ ) -> int:
262
+ entries: list[dict[str, str]] = read_csv_file(csv_file, verbose=verbose)
263
+
264
+ if key_base_dir is None:
265
+ key_base_dir = base_dir
266
+
267
+ if verbose:
268
+ pbar = tqdm.tqdm(entries, desc="Verifying CSV data in database", unit="file")
269
+ else:
270
+ pbar = entries
271
+
272
+ files_verified: int = 0
273
+ for e in pbar:
274
+ # Generate key-path pairs for each path in the CSV.
275
+ files: dict[str, pathlib.Path] = find_files(
276
+ list(e.values()),
277
+ base_dir,
278
+ key_base_dir
279
+ )
280
+
281
+ files_verified += verify_files_in_db(files, db)
282
+
283
+ # Recursively verify the contents of the encountered CSV files.
284
+ for p in files.values():
285
+ if p.suffix == ".csv":
286
+ files_verified += verify_csv_in_db(
287
+ p, db, p.parent, key_base_dir=key_base_dir, verbose=False
288
+ )
289
+
290
+ if verbose:
291
+ pbar.set_postfix({"Files Verified": files_verified})
292
+
293
+ return files_verified
294
+
295
+
296
+ def find_files(
297
+ candidates: list[str],
298
+ base_dir: pathlib.Path,
299
+ key_base_dir: pathlib.Path
300
+ ) -> dict[str, pathlib.Path]:
301
+ files: dict[str, pathlib.Path] = {}
302
+ for c in candidates:
303
+ p: pathlib.Path = base_dir / c
304
+ key: str = str(p.relative_to(key_base_dir))
305
+ if p.exists() and p.is_file():
306
+ files[key] = p
307
+ return files
308
+
309
+
310
+ def write_files_to_db(files: dict[str, pathlib.Path], db: LMDBFileStorage) -> int:
311
+ for k, p in files.items():
312
+ data: bytes = read_raw_file(p)
313
+ db.write_file(k, data)
314
+ return len(files)
315
+
316
+
317
+ def verify_files_in_db(files: dict[str, pathlib.Path], db: LMDBFileStorage) -> int:
318
+ verified: int = 0
319
+ for k, p in files.items():
320
+ # Calculate md5 hash of the file in csv.
321
+ with p.open("rb") as f:
322
+ csv_file_hash: str = md5(f)
323
+ # Calculate md5 hash of the file in db.
324
+ db_file: io.BytesIO = db.open_file(k, mode="b")
325
+ db_file_hash: str = md5(db_file)
326
+ if csv_file_hash == db_file_hash:
327
+ verified += 1
328
+ else:
329
+ logging.error(f"File in DB not matching file in CSV: {str(p)}")
330
+ return verified
331
+
332
+
333
+ def read_csv_file(csv_file: pathlib.Path, verbose: bool = True) -> list[dict[str, str]]:
334
+ # Read the whole csv file.
335
+ if verbose:
336
+ logging.info(f"READING CSV: {str(csv_file)}")
337
+
338
+ entries: list[dict[str, str]] = []
339
+ with csv_file.open() as f:
340
+ reader: csv.DictReader = csv.DictReader(f, delimiter=",")
341
+ if verbose:
342
+ pbar = tqdm.tqdm(reader, desc="Reading CSV entries", unit="entry")
343
+ else:
344
+ pbar = reader
345
+ for row in pbar:
346
+ entries.append(row)
347
+
348
+ if verbose:
349
+ logging.info(f"TOTAL ENTRIES: {len(entries)}")
350
+
351
+ return entries
352
+
353
+
354
+ def read_raw_file(p: pathlib.Path) -> bytes:
355
+ with p.open("rb") as f:
356
+ data: bytes = f.read()
357
+ return data
358
+
359
+
360
+ def md5(stream) -> str:
361
+ """Calculates md5 hash of a file-like stream."""
362
+ hash_md5 = hashlib.md5()
363
+ for chunk in iter(lambda: stream.read(4096), b""):
364
+ hash_md5.update(chunk)
365
+ return hash_md5.hexdigest()
366
+
367
+
368
+ def print_dirs_from_graph(g: nx.DiGraph, n: str, depth: int = 0) -> None:
369
+ if depth > 0:
370
+ init_text: str = " " * (depth - 1) + "|.. "
371
+ else:
372
+ init_text: str = ""
373
+ text: str = f"{init_text}{n}"
374
+ print(text)
375
+
376
+ for s in sorted(g.successors(n)):
377
+ print_dirs_from_graph(g, s, depth+1)
378
+
379
+ if "items_num" in g.nodes[n]:
380
+ init_text = " " * (depth+1)
381
+ text = f"{init_text}({g.nodes[n]['items_num']} files)"
382
+ print(text)
383
+
384
+
385
+ if __name__ == "__main__":
386
+ logging.getLogger().setLevel(logging.INFO)
387
+ cli()
spai/data/random_degradations.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import random
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+
8
+ from . import blur_kernels as blur_kernels
9
+
10
+
11
+ class RandomBlur:
12
+ """Apply random blur to the input.
13
+
14
+ Modified keys are the attributed specified in "keys".
15
+
16
+ Args:
17
+ params (dict): A dictionary specifying the degradation settings.
18
+ keys (list[str]): A list specifying the keys whose values are
19
+ modified.
20
+ """
21
+
22
+ def __init__(self, params):
23
+ self.params = params
24
+
25
+ def get_kernel(self, num_kernels):
26
+ kernel_type = random.choices(
27
+ self.params['kernel_list'], weights=self.params['kernel_prob'])[0]
28
+ kernel_size = random.choice(self.params['kernel_size'])
29
+
30
+ sigma_x_range = self.params.get('sigma_x', [0, 0])
31
+ sigma_x = random.uniform(sigma_x_range[0], sigma_x_range[1])
32
+ sigma_x_step = self.params.get('sigma_x_step', 0)
33
+
34
+ sigma_y_range = self.params.get('sigma_y', [0, 0])
35
+ sigma_y = random.uniform(sigma_y_range[0], sigma_y_range[1])
36
+ sigma_y_step = self.params.get('sigma_y_step', 0)
37
+
38
+ rotate_angle_range = self.params.get('rotate_angle', [-np.pi, np.pi])
39
+ rotate_angle = random.uniform(rotate_angle_range[0],
40
+ rotate_angle_range[1])
41
+ rotate_angle_step = self.params.get('rotate_angle_step', 0)
42
+
43
+ beta_gau_range = self.params.get('beta_gaussian', [0.5, 4])
44
+ beta_gau = random.uniform(beta_gau_range[0], beta_gau_range[1])
45
+ beta_gau_step = self.params.get('beta_gaussian_step', 0)
46
+
47
+ beta_pla_range = self.params.get('beta_plateau', [1, 2])
48
+ beta_pla = random.uniform(beta_pla_range[0], beta_pla_range[1])
49
+ beta_pla_step = self.params.get('beta_plateau_step', 0)
50
+
51
+ omega_range = self.params.get('omega', None)
52
+ omega_step = self.params.get('omega_step', 0)
53
+ if omega_range is None: # follow Real-ESRGAN settings if not specified
54
+ if kernel_size < 13:
55
+ omega_range = [np.pi / 3., np.pi]
56
+ else:
57
+ omega_range = [np.pi / 5., np.pi]
58
+ omega = random.uniform(omega_range[0], omega_range[1])
59
+
60
+ # determine blurring kernel
61
+ kernels = []
62
+ for _ in range(0, num_kernels):
63
+ kernel = blur_kernels.random_mixed_kernels(
64
+ [kernel_type],
65
+ [1],
66
+ kernel_size,
67
+ [sigma_x, sigma_x],
68
+ [sigma_y, sigma_y],
69
+ [rotate_angle, rotate_angle],
70
+ [beta_gau, beta_gau],
71
+ [beta_pla, beta_pla],
72
+ [omega, omega],
73
+ None,
74
+ )
75
+ kernels.append(kernel)
76
+
77
+ # update kernel parameters
78
+ sigma_x += random.uniform(-sigma_x_step, sigma_x_step)
79
+ sigma_y += random.uniform(-sigma_y_step, sigma_y_step)
80
+ rotate_angle += random.uniform(-rotate_angle_step,
81
+ rotate_angle_step)
82
+ beta_gau += random.uniform(-beta_gau_step, beta_gau_step)
83
+ beta_pla += random.uniform(-beta_pla_step, beta_pla_step)
84
+ omega += random.uniform(-omega_step, omega_step)
85
+
86
+ sigma_x = np.clip(sigma_x, sigma_x_range[0], sigma_x_range[1])
87
+ sigma_y = np.clip(sigma_y, sigma_y_range[0], sigma_y_range[1])
88
+ rotate_angle = np.clip(rotate_angle, rotate_angle_range[0],
89
+ rotate_angle_range[1])
90
+ beta_gau = np.clip(beta_gau, beta_gau_range[0], beta_gau_range[1])
91
+ beta_pla = np.clip(beta_pla, beta_pla_range[0], beta_pla_range[1])
92
+ omega = np.clip(omega, omega_range[0], omega_range[1])
93
+
94
+ return kernels
95
+
96
+ def _apply_random_blur(self, imgs):
97
+ is_single_image = False
98
+ if isinstance(imgs, np.ndarray):
99
+ is_single_image = True
100
+ imgs = [imgs]
101
+
102
+ # get kernel and blur the input
103
+ kernels = self.get_kernel(num_kernels=len(imgs))
104
+ imgs = [
105
+ cv2.filter2D(img, -1, kernel)
106
+ for img, kernel in zip(imgs, kernels)
107
+ ]
108
+
109
+ if is_single_image:
110
+ imgs = imgs[0]
111
+
112
+ return imgs
113
+
114
+ def __call__(self, results):
115
+ if random.random() > self.params.get('prob', 1):
116
+ return results
117
+
118
+ results = self._apply_random_blur(results)
119
+
120
+ return results
121
+
122
+ def __repr__(self):
123
+ repr_str = self.__class__.__name__
124
+ repr_str += (f'(params={self.params})')
125
+ return repr_str
126
+
127
+
128
+ class RandomResize:
129
+ """Randomly resize the input.
130
+
131
+ Modified keys are the attributed specified in "keys".
132
+
133
+ Args:
134
+ params (dict): A dictionary specifying the degradation settings.
135
+ keys (list[str]): A list specifying the keys whose values are
136
+ modified.
137
+ """
138
+
139
+ def __init__(self, params):
140
+ self.params = params
141
+
142
+ self.resize_dict = dict(
143
+ bilinear=cv2.INTER_LINEAR,
144
+ bicubic=cv2.INTER_CUBIC,
145
+ area=cv2.INTER_AREA,
146
+ lanczos=cv2.INTER_LANCZOS4)
147
+
148
+ def _random_resize(self, imgs):
149
+ is_single_image = False
150
+ if isinstance(imgs, np.ndarray):
151
+ is_single_image = True
152
+ imgs = [imgs]
153
+
154
+ h, w = imgs[0].shape[:2]
155
+
156
+ resize_opt = self.params['resize_opt']
157
+ resize_prob = self.params['resize_prob']
158
+ resize_opt = random.choices(resize_opt, weights=resize_prob)[0].lower()
159
+ if resize_opt not in self.resize_dict:
160
+ raise NotImplementedError(f'resize_opt [{resize_opt}] is not '
161
+ 'implemented')
162
+ resize_opt = self.resize_dict[resize_opt]
163
+
164
+ resize_step = self.params.get('resize_step', 0)
165
+
166
+ # determine the target size, if not provided
167
+ target_size = self.params.get('target_size', None)
168
+ if target_size is None:
169
+ resize_mode = random.choices(['up', 'down', 'keep'],
170
+ weights=self.params['resize_mode_prob'])[0]
171
+ resize_scale = self.params['resize_scale']
172
+ if resize_mode == 'up':
173
+ scale_factor = random.uniform(1, resize_scale[1])
174
+ elif resize_mode == 'down':
175
+ scale_factor = random.uniform(resize_scale[0], 1)
176
+ else:
177
+ scale_factor = 1
178
+
179
+ # determine output size
180
+ h_out, w_out = h * scale_factor, w * scale_factor
181
+ if self.params.get('is_size_even', False):
182
+ h_out, w_out = 2 * (h_out // 2), 2 * (w_out // 2)
183
+ target_size = (int(h_out), int(w_out))
184
+ else:
185
+ resize_step = 0
186
+
187
+ # resize the input
188
+ if resize_step == 0: # same target_size for all input images
189
+ outputs = [
190
+ cv2.resize(img, target_size[::-1], interpolation=resize_opt)
191
+ for img in imgs
192
+ ]
193
+ else: # different target_size for each input image
194
+ outputs = []
195
+ for img in imgs:
196
+ img = cv2.resize(
197
+ img, target_size[::-1], interpolation=resize_opt)
198
+ outputs.append(img)
199
+
200
+ # update scale
201
+ scale_factor += random.uniform(-resize_step, resize_step)
202
+ scale_factor = np.clip(scale_factor, resize_scale[0],
203
+ resize_scale[1])
204
+
205
+ # determine output size
206
+ h_out, w_out = h * scale_factor, w * scale_factor
207
+ if self.params.get('is_size_even', False):
208
+ h_out, w_out = 2 * (h_out // 2), 2 * (w_out // 2)
209
+ target_size = (int(h_out), int(w_out))
210
+
211
+ if is_single_image:
212
+ outputs = outputs[0]
213
+
214
+ return outputs
215
+
216
+ def __call__(self, results):
217
+ if random.random() > self.params.get('prob', 1):
218
+ return results
219
+
220
+ results = self._random_resize(results)
221
+
222
+ return results
223
+
224
+ def __repr__(self):
225
+ repr_str = self.__class__.__name__
226
+ repr_str += (f'(params={self.params})')
227
+ return repr_str
228
+
229
+
230
+ class RandomNoise:
231
+ """Apply random noise to the input.
232
+
233
+ Currently support Gaussian noise and Poisson noise.
234
+
235
+ Modified keys are the attributed specified in "keys".
236
+
237
+ Args:
238
+ params (dict): A dictionary specifying the degradation settings.
239
+ keys (list[str]): A list specifying the keys whose values are
240
+ modified.
241
+ """
242
+
243
+ def __init__(self, params):
244
+ self.params = params
245
+
246
+ def _apply_gaussian_noise(self, imgs):
247
+ sigma_range = self.params['gaussian_sigma']
248
+ sigma = random.uniform(sigma_range[0], sigma_range[1]) / 255.
249
+
250
+ sigma_step = self.params.get('gaussian_sigma_step', 0)
251
+
252
+ gray_noise_prob = self.params['gaussian_gray_noise_prob']
253
+ is_gray_noise = random.random() < gray_noise_prob
254
+
255
+ outputs = []
256
+ for img in imgs:
257
+ noise = torch.randn(*(img.shape)).numpy() * sigma
258
+ if is_gray_noise:
259
+ noise = noise[:, :, :1]
260
+ outputs.append(img + noise)
261
+
262
+ # update noise level
263
+ sigma += random.uniform(-sigma_step, sigma_step) / 255.
264
+ sigma = np.clip(sigma, sigma_range[0] / 255.,
265
+ sigma_range[1] / 255.)
266
+
267
+ return outputs
268
+
269
+ def _apply_poisson_noise(self, imgs):
270
+ scale_range = self.params['poisson_scale']
271
+ scale = random.uniform(scale_range[0], scale_range[1])
272
+
273
+ scale_step = self.params.get('poisson_scale_step', 0)
274
+
275
+ gray_noise_prob = self.params['poisson_gray_noise_prob']
276
+ is_gray_noise = random.random() < gray_noise_prob
277
+
278
+ outputs = []
279
+ for img in imgs:
280
+ noise = img.copy()
281
+ if is_gray_noise:
282
+ noise = cv2.cvtColor(noise[..., [2, 1, 0]], cv2.COLOR_BGR2GRAY)
283
+ noise = noise[..., np.newaxis]
284
+ noise = np.clip((noise * 255.0).round(), 0, 255) / 255.
285
+ unique_val = 2**np.ceil(np.log2(len(np.unique(noise))))
286
+ noise = torch.poisson(torch.from_numpy(noise * unique_val)).numpy() / unique_val - noise
287
+
288
+ outputs.append(img + noise * scale)
289
+
290
+ # update noise level
291
+ scale += random.uniform(-scale_step, scale_step)
292
+ scale = np.clip(scale, scale_range[0], scale_range[1])
293
+
294
+ return outputs
295
+
296
+ def _apply_random_noise(self, imgs):
297
+ noise_type = random.choices(
298
+ self.params['noise_type'], weights=self.params['noise_prob'])[0]
299
+
300
+ is_single_image = False
301
+ if isinstance(imgs, np.ndarray):
302
+ is_single_image = True
303
+ imgs = [imgs]
304
+
305
+ if noise_type.lower() == 'gaussian':
306
+ imgs = self._apply_gaussian_noise(imgs)
307
+ elif noise_type.lower() == 'poisson':
308
+ imgs = self._apply_poisson_noise(imgs)
309
+ else:
310
+ raise NotImplementedError(f'"noise_type" [{noise_type}] is '
311
+ 'not implemented.')
312
+
313
+ if is_single_image:
314
+ imgs = imgs[0]
315
+
316
+ return imgs
317
+
318
+ def __call__(self, results):
319
+ if random.random() > self.params.get('prob', 1):
320
+ return results
321
+
322
+ results = self._apply_random_noise(results)
323
+
324
+ return results
325
+
326
+ def __repr__(self):
327
+ repr_str = self.__class__.__name__
328
+ repr_str += (f'(params={self.params})')
329
+ return repr_str
330
+
331
+
332
+ class RandomJPEGCompression:
333
+ """Apply random JPEG compression to the input.
334
+
335
+ Modified keys are the attributed specified in "keys".
336
+
337
+ Args:
338
+ params (dict): A dictionary specifying the degradation settings.
339
+ keys (list[str]): A list specifying the keys whose values are
340
+ modified.
341
+ """
342
+
343
+ def __init__(self, params):
344
+ self.params = params
345
+
346
+ def _apply_random_compression(self, imgs):
347
+ is_single_image = False
348
+ if isinstance(imgs, np.ndarray):
349
+ is_single_image = True
350
+ imgs = [imgs]
351
+
352
+ # determine initial compression level and the step size
353
+ quality = self.params['quality']
354
+ quality_step = self.params.get('quality_step', 0)
355
+ jpeg_param = round(random.uniform(quality[0], quality[1]))
356
+
357
+ # apply jpeg compression
358
+ outputs = []
359
+ for img in imgs:
360
+ encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_param]
361
+ _, img_encoded = cv2.imencode('.jpg', img * 255., encode_param)
362
+ outputs.append(np.float32(cv2.imdecode(img_encoded, 1)) / 255.)
363
+
364
+ # update compression level
365
+ jpeg_param += random.uniform(-quality_step, quality_step)
366
+ jpeg_param = round(np.clip(jpeg_param, quality[0], quality[1]))
367
+
368
+ if is_single_image:
369
+ outputs = outputs[0]
370
+
371
+ return outputs
372
+
373
+ def __call__(self, results):
374
+ if random.random() > self.params.get('prob', 1):
375
+ return results
376
+
377
+ results = self._apply_random_compression(results)
378
+
379
+ return results
380
+
381
+ def __repr__(self):
382
+ repr_str = self.__class__.__name__
383
+ repr_str += (f'(params={self.params})')
384
+ return repr_str
385
+
386
+
387
+ allowed_degradations = {
388
+ 'RandomBlur': RandomBlur,
389
+ 'RandomResize': RandomResize,
390
+ 'RandomNoise': RandomNoise,
391
+ 'RandomJPEGCompression': RandomJPEGCompression,
392
+ }
393
+
394
+
395
+ class DegradationsWithShuffle:
396
+ """Apply random degradations to input, with degradations being shuffled.
397
+
398
+ Degradation groups are supported. The order of degradations within the same
399
+ group is preserved. For example, if we have degradations = [a, b, [c, d]]
400
+ and shuffle_idx = None, then the possible orders are
401
+
402
+ ::
403
+
404
+ [a, b, [c, d]]
405
+ [a, [c, d], b]
406
+ [b, a, [c, d]]
407
+ [b, [c, d], a]
408
+ [[c, d], a, b]
409
+ [[c, d], b, a]
410
+
411
+ Modified keys are the attributed specified in "keys".
412
+
413
+ Args:
414
+ degradations (list[dict]): The list of degradations.
415
+ keys (list[str]): A list specifying the keys whose values are
416
+ modified.
417
+ shuffle_idx (list | None, optional): The degradations corresponding to
418
+ these indices are shuffled. If None, all degradations are shuffled.
419
+ """
420
+
421
+ def __init__(self, degradations, shuffle_idx=None):
422
+
423
+ self.degradations = self._build_degradations(degradations)
424
+
425
+ if shuffle_idx is None:
426
+ self.shuffle_idx = list(range(0, len(degradations)))
427
+ else:
428
+ self.shuffle_idx = shuffle_idx
429
+
430
+ def _build_degradations(self, degradations):
431
+ for i, degradation in enumerate(degradations):
432
+ if isinstance(degradation, (list, tuple)):
433
+ degradations[i] = self._build_degradations(degradation)
434
+ else:
435
+ degradation_ = allowed_degradations[degradation['type']]
436
+ degradations[i] = degradation_(degradation['params'])
437
+
438
+ return degradations
439
+
440
+ def __call__(self, results):
441
+ # shuffle degradations
442
+ if len(self.shuffle_idx) > 0:
443
+ shuffle_list = [self.degradations[i] for i in self.shuffle_idx]
444
+ random.shuffle(shuffle_list)
445
+ for i, idx in enumerate(self.shuffle_idx):
446
+ self.degradations[idx] = shuffle_list[i]
447
+
448
+ # apply degradations to input
449
+ for degradation in self.degradations:
450
+ if isinstance(degradation, (tuple, list)):
451
+ for subdegrdation in degradation:
452
+ results = subdegrdation(results)
453
+ else:
454
+ results = degradation(results)
455
+
456
+ return results
457
+
458
+ def __repr__(self):
459
+ repr_str = self.__class__.__name__
460
+ repr_str += (f'(degradations={self.degradations}, '
461
+ f'shuffle_idx={self.shuffle_idx})')
462
+ return repr_str
spai/data/readers.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
2
+ # and University of Amsterdam. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import csv
18
+ import io
19
+ import pathlib
20
+ from typing import Any, Union, Optional
21
+
22
+ import numpy as np
23
+ import torch
24
+ from PIL import Image
25
+ from torchvision.io import read_image
26
+
27
+ from spai.data import filestorage
28
+
29
+
30
+ class DataReader:
31
+
32
+ def read_csv_file(self, path: str) -> list[dict[str, Any]]:
33
+ raise NotImplementedError
34
+
35
+ def load_image(self, path: str, channels: int) -> Image.Image:
36
+ raise NotImplementedError
37
+
38
+ def get_image_size(self, path: str) -> tuple[int, int]:
39
+ """Returns the size of an image as a (width, height) tuple."""
40
+ raise NotImplementedError
41
+
42
+ def load_signals_from_csv(
43
+ self,
44
+ csv_path: str,
45
+ column_name: str = "seg_map",
46
+ channels: int = 1,
47
+ data_specifier: Optional[dict[str, str]] = None
48
+ ) -> list[np.ndarray]:
49
+ """Loads all the signals specified in a column of a CSV file.
50
+
51
+ The default values for the column name and the number of channels have
52
+ been specified for the CSV containing the instance segmentation maps of
53
+ an image."""
54
+ raise NotImplementedError
55
+
56
+ def load_file_path_or_stream(self, path: str) -> Union[pathlib.Path, io.FileIO, io.BytesIO]:
57
+ raise NotImplementedError
58
+
59
+
60
+ class FileSystemReader(DataReader):
61
+ """Reader that maps relative paths to absolute paths of the filesystem."""
62
+ def __init__(self, root_path: pathlib.Path):
63
+ super().__init__()
64
+ self.root_path: pathlib.Path = root_path
65
+
66
+ def read_csv_file(self, path: str) -> list[dict[str, Any]]:
67
+ with (self.root_path/path).open("r") as f:
68
+ reader = csv.DictReader(f, delimiter=",")
69
+ contents: list[dict[str, Any]] = [row for row in reader]
70
+ return contents
71
+
72
+ def get_image_size(self, path: str) -> tuple[int, int]:
73
+ with Image.open(self.root_path/path) as image:
74
+ image_size: tuple[int, int] = image.size
75
+ return image_size
76
+
77
+ def load_image(self, path: str, channels: int) -> Image.Image:
78
+ try:
79
+ image = Image.open(self.root_path/path)
80
+ if channels == 1:
81
+ image = image.convert("L")
82
+ else:
83
+ image = image.convert("RGB")
84
+ except Exception as e:
85
+ print(f"Failed to read: {path}")
86
+ raise e
87
+ # image = np.array(image)
88
+ #
89
+ # if len(image.shape) == 2:
90
+ # image = np.expand_dims(image, axis=2)
91
+ return image
92
+
93
+ def load_signals_from_csv(
94
+ self,
95
+ csv_path: str,
96
+ column_name: str = "seg_map",
97
+ channels: int = 1,
98
+ data_specifier: Optional[dict[str, str]] = None
99
+ ) -> list[np.ndarray]:
100
+ csv_data: list[dict[str, Any]] = self.read_csv_file(csv_path)
101
+
102
+ signals: list[np.ndarray] = []
103
+ for row in csv_data:
104
+ # Ignore entries that do not match with the given data specifier.
105
+ if data_specifier is not None and not data_specifier_matches_entry(row, data_specifier):
106
+ continue
107
+
108
+ signal_path: pathlib.Path = (self.root_path / csv_path).parent / row[column_name]
109
+ signal: np.ndarray = self.load_image(str(signal_path.relative_to(self.root_path)),
110
+ channels=channels)
111
+ signals.append(signal)
112
+
113
+ return signals
114
+
115
+ def load_file_path_or_stream(self, path: str) -> Union[pathlib.Path, io.FileIO, io.BytesIO]:
116
+ return self.root_path / path
117
+
118
+
119
+ class LMDBFileStorageReader(DataReader):
120
+ """Reader that maps relative paths into an LMDBFileStorage."""
121
+ def __init__(self, storage: filestorage.LMDBFileStorage):
122
+ super().__init__()
123
+ self.storage: filestorage.LMDBFileStorage = storage
124
+
125
+ def read_csv_file(self, path: str) -> list[dict[str, Any]]:
126
+ stream = self.storage.open_file(path)
127
+ reader = csv.DictReader(stream, delimiter=",")
128
+ contents: list[dict[str, Any]] = [row for row in reader]
129
+ return contents
130
+
131
+ def get_image_size(self, path: str) -> tuple[int, int]:
132
+ stream = self.storage.open_file(path, mode="b")
133
+ with Image.open(stream) as image:
134
+ image_size: tuple[int, int] = image.size
135
+ return image_size
136
+
137
+ def load_image(self, path: str, channels: int) -> Image.Image:
138
+ stream = self.storage.open_file(path, mode="b")
139
+ with Image.open(stream) as image:
140
+ if channels == 1:
141
+ image = image.convert("L")
142
+ else:
143
+ image = image.convert("RGB")
144
+ # image = np.array(image)
145
+ stream.close()
146
+
147
+ # if len(image.shape) == 2:
148
+ # image = np.expand_dims(image, axis=2)
149
+ return image
150
+
151
+ def load_signals_from_csv(
152
+ self,
153
+ csv_path: str,
154
+ column_name: str = "seg_map",
155
+ channels: int = 1,
156
+ data_specifier: Optional[dict[str, str]] = None
157
+ ) -> list[np.ndarray]:
158
+ csv_data: list[dict[str, Any]] = self.read_csv_file(csv_path)
159
+
160
+ signals: list[np.ndarray] = []
161
+ for row in csv_data:
162
+ # Ignore entries that do not match with the given data specifier.
163
+ if data_specifier is not None and not data_specifier_matches_entry(row, data_specifier):
164
+ continue
165
+
166
+ signal_path: str = str(pathlib.Path(csv_path).parent / row[column_name])
167
+ signal: np.ndarray = self.load_image(signal_path, channels=channels)
168
+ signals.append(signal)
169
+
170
+ return signals
171
+
172
+
173
+ def data_specifier_matches_entry(entry: dict[str, str], specifier: dict[str, str]) -> bool:
174
+ """Checks whether a CSV entry matches a data specifier."""
175
+ for k, v in specifier.items():
176
+ if k not in entry or entry[k] != v:
177
+ return False
178
+ return True
spai/data_utils.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 Centre for Research and Technology Hellas
2
+ # and University of Amsterdam. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import pathlib
18
+ import csv
19
+ import hashlib
20
+ from typing import Any, Optional
21
+
22
+
23
+ def read_csv_file(
24
+ path: pathlib.Path,
25
+ delimiter: str = ","
26
+ ) -> list[dict[str, Any]]:
27
+ with path.open("r") as f:
28
+ reader = csv.DictReader(f, delimiter=delimiter)
29
+ contents: list[dict[str, Any]] = [row for row in reader]
30
+ return contents
31
+
32
+
33
+ def write_csv_file(
34
+ data: list[dict[str, Any]],
35
+ output_file: pathlib.Path,
36
+ fieldnames: Optional[list[str]] = None,
37
+ delimiter: str = "|"
38
+ ) -> None:
39
+ if fieldnames is None:
40
+ fieldnames = list(data[0].keys())
41
+ with output_file.open("w", newline="") as f:
42
+ writer: csv.DictWriter = csv.DictWriter(f, fieldnames=fieldnames, delimiter=delimiter)
43
+ writer.writeheader()
44
+ for r in data:
45
+ writer.writerow(r)
46
+
47
+
48
+ def compute_file_md5(path: pathlib.Path) -> str:
49
+ with path.open("rb") as f:
50
+ return hashlib.md5(f.read()).hexdigest()