antoine.carreaud67 commited on
Commit ·
eea6ebc
1
Parent(s): 9ffd7b5
Stop tracking grayscale mixed local files
Browse files- .gitignore +9 -0
- configs/config_mixed_domain_ft_segformer_reliable5_grayscale.yaml +0 -74
- dataset/definition_dataset_grayscale_mixed.py +0 -114
- dataset/flairhub_grayscale_mixed.py +0 -24
- dataset/prepare_swisstlm3d_patches.py +0 -233
- dataset/rasterize_swisstlm3d.py +0 -403
- dataset/split_swisstlm3d_geographic.py +0 -314
- dataset/stats_swisstlm3d_masks.py +0 -136
- dataset/swissimage_grayscale_mixed.py +0 -24
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
paper/
|
| 2 |
+
configs/config_mixed_domain_ft_segformer_reliable5_grayscale.yaml
|
| 3 |
+
dataset/definition_dataset_grayscale_mixed.py
|
| 4 |
+
dataset/flairhub_grayscale_mixed.py
|
| 5 |
+
dataset/prepare_swisstlm3d_patches.py
|
| 6 |
+
dataset/rasterize_swisstlm3d.py
|
| 7 |
+
dataset/split_swisstlm3d_geographic.py
|
| 8 |
+
dataset/stats_swisstlm3d_masks.py
|
| 9 |
+
dataset/swissimage_grayscale_mixed.py
|
configs/config_mixed_domain_ft_segformer_reliable5_grayscale.yaml
DELETED
|
@@ -1,74 +0,0 @@
|
|
| 1 |
-
paths:
|
| 2 |
-
flairhub_path: /mnt/CalcShare/datasets/FLAIR1024_merged
|
| 3 |
-
swiss_path: /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_reliable5_forest_center20_patches_2048to1024_split
|
| 4 |
-
|
| 5 |
-
flairhub_train_img_subdir: train/img
|
| 6 |
-
flairhub_train_msk_subdir: train/msk
|
| 7 |
-
flairhub_val_img_subdir: valid/img
|
| 8 |
-
flairhub_val_msk_subdir: valid/msk
|
| 9 |
-
flairhub_test_img_subdir: test/img
|
| 10 |
-
flairhub_test_msk_subdir: test/msk
|
| 11 |
-
|
| 12 |
-
swiss_train_img_subdir: train/img
|
| 13 |
-
swiss_train_msk_subdir: train/msk
|
| 14 |
-
swiss_val_img_subdir: val/img
|
| 15 |
-
swiss_val_msk_subdir: val/msk
|
| 16 |
-
swiss_test_img_subdir: test/img
|
| 17 |
-
swiss_test_msk_subdir: test/msk
|
| 18 |
-
|
| 19 |
-
save_dir: weights
|
| 20 |
-
pretrained_path: weights/CASWiT-Base-SSL-aug_FLAIRHUB_SF.pth
|
| 21 |
-
|
| 22 |
-
model:
|
| 23 |
-
model_name: openmmlab/upernet-swin-base
|
| 24 |
-
num_classes: 15 # Keep the 15-class head for exact FLAIR-HUB checkpoint compatibility; forest merge is handled by label remapping below.
|
| 25 |
-
cross_attention_heads: 1
|
| 26 |
-
ignore_index: 255
|
| 27 |
-
fusion_mlp_ratio: 4.0
|
| 28 |
-
fusion_drop_path: 0.1
|
| 29 |
-
lr_supervision_weight: 0.5
|
| 30 |
-
head: segformer
|
| 31 |
-
|
| 32 |
-
training:
|
| 33 |
-
batch_size: 1
|
| 34 |
-
num_workers: 8
|
| 35 |
-
num_epochs: 10
|
| 36 |
-
learning_rate: 8.0e-06
|
| 37 |
-
amp: true
|
| 38 |
-
seed: 42
|
| 39 |
-
eta_min: 1.0e-06
|
| 40 |
-
|
| 41 |
-
mixed:
|
| 42 |
-
# Conservative domain adaptation with sparse high-confidence SwissTLM3D labels only.
|
| 43 |
-
swiss_ratio: 0.20
|
| 44 |
-
epoch_length: 0
|
| 45 |
-
best_on: mean
|
| 46 |
-
simulate_grayscale_inputs: true
|
| 47 |
-
|
| 48 |
-
labels:
|
| 49 |
-
flair_label_remap:
|
| 50 |
-
13: 12
|
| 51 |
-
swiss_label_remap:
|
| 52 |
-
13: 12
|
| 53 |
-
|
| 54 |
-
wandb:
|
| 55 |
-
use_wandb: true
|
| 56 |
-
project: CASWiT-SwissTLM3D
|
| 57 |
-
entity: soloo
|
| 58 |
-
run_name: CASWiT-Base-SSL-aug_FLAIRHUB_SF_mixed_ft_reliable5_forest_center20_gray_r02
|
| 59 |
-
|
| 60 |
-
print_device: true
|
| 61 |
-
|
| 62 |
-
augmentations:
|
| 63 |
-
enable: true
|
| 64 |
-
p_hflip: 0.5
|
| 65 |
-
p_vflip: 0.5
|
| 66 |
-
p_rot90: 0.5
|
| 67 |
-
color_jitter:
|
| 68 |
-
brightness: 0.05
|
| 69 |
-
contrast: 0.05
|
| 70 |
-
saturation: 0.05
|
| 71 |
-
hue: 0.02
|
| 72 |
-
blur:
|
| 73 |
-
p: 0.0
|
| 74 |
-
kernel: 3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/definition_dataset_grayscale_mixed.py
DELETED
|
@@ -1,114 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Dataset definitions dedicated to grayscale-simulated mixed-domain fine-tuning.
|
| 3 |
-
|
| 4 |
-
This variant keeps the exact same HR/LR extraction logic as the standard fusion
|
| 5 |
-
dataset, but converts RGB imagery to simulated grayscale replicated on 3 bands
|
| 6 |
-
directly inside the dataset definition.
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
-
from __future__ import annotations
|
| 10 |
-
|
| 11 |
-
import os
|
| 12 |
-
from pathlib import Path
|
| 13 |
-
from typing import Dict, Optional
|
| 14 |
-
|
| 15 |
-
import numpy as np
|
| 16 |
-
import torch
|
| 17 |
-
from PIL import Image
|
| 18 |
-
from torch.utils.data import Dataset
|
| 19 |
-
from torchvision import transforms
|
| 20 |
-
|
| 21 |
-
from dataset.definition_dataset import (
|
| 22 |
-
apply_label_remap,
|
| 23 |
-
load_image,
|
| 24 |
-
load_mask,
|
| 25 |
-
to_pil_uint8,
|
| 26 |
-
to_tensor_img,
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def simulate_grayscale_rgb(image_hwc_float01: np.ndarray) -> np.ndarray:
|
| 31 |
-
"""
|
| 32 |
-
Convert an RGB image in [0,1] to simulated grayscale and replicate it on 3 channels.
|
| 33 |
-
"""
|
| 34 |
-
if image_hwc_float01.ndim != 3 or image_hwc_float01.shape[-1] != 3:
|
| 35 |
-
raise ValueError(
|
| 36 |
-
f"simulate_grayscale_rgb expects HWC RGB input, got shape={image_hwc_float01.shape}"
|
| 37 |
-
)
|
| 38 |
-
gray = (
|
| 39 |
-
0.299 * image_hwc_float01[..., 0]
|
| 40 |
-
+ 0.587 * image_hwc_float01[..., 1]
|
| 41 |
-
+ 0.114 * image_hwc_float01[..., 2]
|
| 42 |
-
).astype(np.float32, copy=False)
|
| 43 |
-
return np.repeat(gray[..., None], 3, axis=-1)
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
class SemanticSegmentationDatasetFusionGrayscaleMixed(Dataset):
|
| 47 |
-
"""
|
| 48 |
-
Mixed-domain HR/LR fusion dataset with grayscale simulation replicated on 3 bands.
|
| 49 |
-
|
| 50 |
-
The geometry, crops, downsampling and augmentations are intentionally kept
|
| 51 |
-
identical to the standard fusion dataset so that only the color information
|
| 52 |
-
is removed during training/validation/test.
|
| 53 |
-
"""
|
| 54 |
-
|
| 55 |
-
def __init__(
|
| 56 |
-
self,
|
| 57 |
-
image_dir: Path,
|
| 58 |
-
mask_dir: Path,
|
| 59 |
-
transform: Optional[transforms.Compose] = None,
|
| 60 |
-
augment=None,
|
| 61 |
-
label_remap: Optional[Dict[int, int]] = None,
|
| 62 |
-
):
|
| 63 |
-
self.image_dir = Path(image_dir)
|
| 64 |
-
self.mask_dir = Path(mask_dir)
|
| 65 |
-
self.image_filenames = sorted(os.listdir(self.image_dir))
|
| 66 |
-
self.mask_filenames = sorted(os.listdir(self.mask_dir))
|
| 67 |
-
assert len(self.image_filenames) == len(self.mask_filenames), "Images/Masks count mismatch"
|
| 68 |
-
self.transform = transform
|
| 69 |
-
self.augment = augment
|
| 70 |
-
self.label_remap = label_remap or {}
|
| 71 |
-
|
| 72 |
-
def __len__(self):
|
| 73 |
-
return len(self.image_filenames)
|
| 74 |
-
|
| 75 |
-
def __getitem__(self, idx):
|
| 76 |
-
image_path = self.image_dir / self.image_filenames[idx]
|
| 77 |
-
mask_path = self.mask_dir / self.mask_filenames[idx]
|
| 78 |
-
|
| 79 |
-
image = simulate_grayscale_rgb(load_image(image_path))
|
| 80 |
-
mask = load_mask(mask_path)
|
| 81 |
-
mask = apply_label_remap(mask, self.label_remap)
|
| 82 |
-
mask[mask >= 15] = 255
|
| 83 |
-
|
| 84 |
-
hr_crop_size = 512
|
| 85 |
-
crop_x, crop_y = 256, 256
|
| 86 |
-
|
| 87 |
-
image_hr = image[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size]
|
| 88 |
-
mask_hr = mask[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size]
|
| 89 |
-
|
| 90 |
-
image_lr = image[::2, ::2, :]
|
| 91 |
-
mask_lr = mask[::2, ::2]
|
| 92 |
-
|
| 93 |
-
img_hr_pil = to_pil_uint8(image_hr)
|
| 94 |
-
img_lr_pil = to_pil_uint8(image_lr)
|
| 95 |
-
|
| 96 |
-
m_hr_pil = Image.fromarray(mask_hr.astype(np.uint8), mode="L")
|
| 97 |
-
m_lr_pil = Image.fromarray(mask_lr.astype(np.uint8), mode="L")
|
| 98 |
-
|
| 99 |
-
if self.augment is not None:
|
| 100 |
-
img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil = self.augment(
|
| 101 |
-
img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil
|
| 102 |
-
)
|
| 103 |
-
|
| 104 |
-
if self.transform:
|
| 105 |
-
image_hr = self.transform(img_hr_pil)
|
| 106 |
-
image_lr = self.transform(img_lr_pil)
|
| 107 |
-
else:
|
| 108 |
-
image_hr = to_tensor_img(np.array(img_hr_pil))
|
| 109 |
-
image_lr = to_tensor_img(np.array(img_lr_pil))
|
| 110 |
-
|
| 111 |
-
mask_hr = torch.as_tensor(np.array(m_hr_pil, dtype=np.uint8), dtype=torch.long)
|
| 112 |
-
mask_lr = torch.as_tensor(np.array(m_lr_pil, dtype=np.uint8), dtype=torch.long)
|
| 113 |
-
|
| 114 |
-
return image_hr, mask_hr, image_lr, mask_lr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/flairhub_grayscale_mixed.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
from typing import Dict, Optional
|
| 3 |
-
|
| 4 |
-
from torchvision import transforms
|
| 5 |
-
|
| 6 |
-
from dataset.definition_dataset_grayscale_mixed import (
|
| 7 |
-
SemanticSegmentationDatasetFusionGrayscaleMixed,
|
| 8 |
-
)
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def build(
|
| 12 |
-
image_dir: Path,
|
| 13 |
-
mask_dir: Path,
|
| 14 |
-
transform: Optional[transforms.Compose],
|
| 15 |
-
augment=None,
|
| 16 |
-
label_remap: Optional[Dict[int, int]] = None,
|
| 17 |
-
):
|
| 18 |
-
return SemanticSegmentationDatasetFusionGrayscaleMixed(
|
| 19 |
-
image_dir,
|
| 20 |
-
mask_dir,
|
| 21 |
-
transform=transform,
|
| 22 |
-
augment=augment,
|
| 23 |
-
label_remap=label_remap,
|
| 24 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/prepare_swisstlm3d_patches.py
DELETED
|
@@ -1,233 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Prepare SwissImage + SwissTLM3D patches for training.
|
| 3 |
-
|
| 4 |
-
Optimized version:
|
| 5 |
-
- processes source tiles in parallel across multiple CPU workers
|
| 6 |
-
- reads only 2048x2048 windows with rasterio instead of loading whole 10000x10000 tiles
|
| 7 |
-
- resizes image patches to 1024x1024 with bilinear interpolation
|
| 8 |
-
- resizes mask patches to 1024x1024 with nearest-neighbor interpolation
|
| 9 |
-
- optionally skips patches whose mask is entirely 255
|
| 10 |
-
"""
|
| 11 |
-
|
| 12 |
-
from __future__ import annotations
|
| 13 |
-
|
| 14 |
-
import argparse
|
| 15 |
-
import os
|
| 16 |
-
import sys
|
| 17 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 18 |
-
from pathlib import Path
|
| 19 |
-
|
| 20 |
-
import numpy as np
|
| 21 |
-
import rasterio
|
| 22 |
-
from PIL import Image
|
| 23 |
-
from tqdm import tqdm
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
project_root = Path(__file__).resolve().parent.parent
|
| 27 |
-
sys.path.insert(0, str(project_root))
|
| 28 |
-
|
| 29 |
-
Image.MAX_IMAGE_PIXELS = None
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
def parse_args() -> argparse.Namespace:
|
| 33 |
-
parser = argparse.ArgumentParser(description="Prepare 2048->1024 SwissTLM3D training patches.")
|
| 34 |
-
parser.add_argument(
|
| 35 |
-
"--image_dir",
|
| 36 |
-
type=Path,
|
| 37 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/image"),
|
| 38 |
-
help="Directory containing full SwissImage tiles.",
|
| 39 |
-
)
|
| 40 |
-
parser.add_argument(
|
| 41 |
-
"--mask_dir",
|
| 42 |
-
type=Path,
|
| 43 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d"),
|
| 44 |
-
help="Directory containing full SwissTLM3D masks aligned with image_dir.",
|
| 45 |
-
)
|
| 46 |
-
parser.add_argument(
|
| 47 |
-
"--output_dir",
|
| 48 |
-
type=Path,
|
| 49 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024"),
|
| 50 |
-
help="Directory where prepared patches will be written.",
|
| 51 |
-
)
|
| 52 |
-
parser.add_argument(
|
| 53 |
-
"--tile_size",
|
| 54 |
-
type=int,
|
| 55 |
-
default=2048,
|
| 56 |
-
help="Patch size extracted from the original SwissImage tile.",
|
| 57 |
-
)
|
| 58 |
-
parser.add_argument(
|
| 59 |
-
"--output_size",
|
| 60 |
-
type=int,
|
| 61 |
-
default=1024,
|
| 62 |
-
help="Final patch size after resizing.",
|
| 63 |
-
)
|
| 64 |
-
parser.add_argument(
|
| 65 |
-
"--ignore_index",
|
| 66 |
-
type=int,
|
| 67 |
-
default=255,
|
| 68 |
-
help="Ignore label value in masks.",
|
| 69 |
-
)
|
| 70 |
-
parser.add_argument(
|
| 71 |
-
"--drop_all_ignore",
|
| 72 |
-
action="store_true",
|
| 73 |
-
help="Skip patches whose mask is entirely equal to ignore_index.",
|
| 74 |
-
)
|
| 75 |
-
parser.add_argument(
|
| 76 |
-
"--overwrite",
|
| 77 |
-
action="store_true",
|
| 78 |
-
help="Overwrite existing patch files if they already exist.",
|
| 79 |
-
)
|
| 80 |
-
parser.add_argument(
|
| 81 |
-
"--workers",
|
| 82 |
-
type=int,
|
| 83 |
-
default=os.cpu_count() or 1,
|
| 84 |
-
help="Number of worker processes. Default: all visible CPU cores.",
|
| 85 |
-
)
|
| 86 |
-
return parser.parse_args()
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
def ensure_output_dirs(output_dir: Path) -> tuple[Path, Path]:
|
| 90 |
-
img_dir = output_dir / "img"
|
| 91 |
-
msk_dir = output_dir / "msk"
|
| 92 |
-
img_dir.mkdir(parents=True, exist_ok=True)
|
| 93 |
-
msk_dir.mkdir(parents=True, exist_ok=True)
|
| 94 |
-
return img_dir, msk_dir
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
def resize_image_patch(image_hwc_uint8: np.ndarray, output_size: int) -> np.ndarray:
|
| 98 |
-
img = Image.fromarray(image_hwc_uint8, mode="RGB")
|
| 99 |
-
img = img.resize((output_size, output_size), resample=Image.BILINEAR)
|
| 100 |
-
return np.asarray(img, dtype=np.uint8)
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
def resize_mask_patch(mask_hw: np.ndarray, output_size: int) -> np.ndarray:
|
| 104 |
-
img = Image.fromarray(mask_hw.astype(np.uint8), mode="L")
|
| 105 |
-
img = img.resize((output_size, output_size), resample=Image.NEAREST)
|
| 106 |
-
return np.asarray(img, dtype=np.uint8)
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
def iter_non_overlapping_windows(height: int, width: int, tile_size: int):
|
| 110 |
-
for y0 in range(0, height - tile_size + 1, tile_size):
|
| 111 |
-
for x0 in range(0, width - tile_size + 1, tile_size):
|
| 112 |
-
yield y0, x0
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
def patch_name(stem: str, y0: int, x0: int) -> str:
|
| 116 |
-
return f"{stem}_y{y0:05d}_x{x0:05d}.tif"
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
def read_image_window(src: rasterio.DatasetReader, y0: int, x0: int, tile_size: int) -> np.ndarray:
|
| 120 |
-
window = rasterio.windows.Window(col_off=x0, row_off=y0, width=tile_size, height=tile_size)
|
| 121 |
-
arr = src.read(indexes=[1, 2, 3], window=window) # CHW
|
| 122 |
-
return np.transpose(arr, (1, 2, 0)).astype(np.uint8, copy=False)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
def read_mask_window(src: rasterio.DatasetReader, y0: int, x0: int, tile_size: int) -> np.ndarray:
|
| 126 |
-
window = rasterio.windows.Window(col_off=x0, row_off=y0, width=tile_size, height=tile_size)
|
| 127 |
-
return src.read(1, window=window)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
def process_tile(
|
| 131 |
-
image_path: str,
|
| 132 |
-
mask_path: str,
|
| 133 |
-
out_img_dir: str,
|
| 134 |
-
out_msk_dir: str,
|
| 135 |
-
tile_size: int,
|
| 136 |
-
output_size: int,
|
| 137 |
-
ignore_index: int,
|
| 138 |
-
drop_all_ignore: bool,
|
| 139 |
-
overwrite: bool,
|
| 140 |
-
) -> tuple[int, int, int]:
|
| 141 |
-
image_path = Path(image_path)
|
| 142 |
-
mask_path = Path(mask_path)
|
| 143 |
-
out_img_dir = Path(out_img_dir)
|
| 144 |
-
out_msk_dir = Path(out_msk_dir)
|
| 145 |
-
|
| 146 |
-
kept = 0
|
| 147 |
-
dropped_all_ignore_count = 0
|
| 148 |
-
total_windows = 0
|
| 149 |
-
stem = image_path.stem
|
| 150 |
-
|
| 151 |
-
with rasterio.open(image_path) as img_src, rasterio.open(mask_path) as msk_src:
|
| 152 |
-
if (img_src.height, img_src.width) != (msk_src.height, msk_src.width):
|
| 153 |
-
raise ValueError(
|
| 154 |
-
f"Shape mismatch for {image_path.name}: "
|
| 155 |
-
f"image={(img_src.height, img_src.width)} mask={(msk_src.height, msk_src.width)}"
|
| 156 |
-
)
|
| 157 |
-
|
| 158 |
-
for y0, x0 in iter_non_overlapping_windows(msk_src.height, msk_src.width, tile_size):
|
| 159 |
-
total_windows += 1
|
| 160 |
-
out_name = patch_name(stem, y0, x0)
|
| 161 |
-
out_img = out_img_dir / out_name
|
| 162 |
-
out_msk = out_msk_dir / out_name
|
| 163 |
-
|
| 164 |
-
if not overwrite and out_img.exists() and out_msk.exists():
|
| 165 |
-
kept += 1
|
| 166 |
-
continue
|
| 167 |
-
|
| 168 |
-
mask_patch = read_mask_window(msk_src, y0, x0, tile_size)
|
| 169 |
-
if drop_all_ignore and np.all(mask_patch == ignore_index):
|
| 170 |
-
dropped_all_ignore_count += 1
|
| 171 |
-
continue
|
| 172 |
-
|
| 173 |
-
image_patch = read_image_window(img_src, y0, x0, tile_size)
|
| 174 |
-
image_patch_resized = resize_image_patch(image_patch, output_size)
|
| 175 |
-
mask_patch_resized = resize_mask_patch(mask_patch, output_size)
|
| 176 |
-
|
| 177 |
-
Image.fromarray(image_patch_resized, mode="RGB").save(out_img)
|
| 178 |
-
Image.fromarray(mask_patch_resized, mode="L").save(out_msk)
|
| 179 |
-
kept += 1
|
| 180 |
-
|
| 181 |
-
return kept, dropped_all_ignore_count, total_windows
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def main() -> None:
|
| 185 |
-
args = parse_args()
|
| 186 |
-
out_img_dir, out_msk_dir = ensure_output_dirs(args.output_dir)
|
| 187 |
-
|
| 188 |
-
image_paths = {p.name: p for p in sorted(args.image_dir.glob("*.tif"))}
|
| 189 |
-
mask_paths = {p.name: p for p in sorted(args.mask_dir.glob("*.tif"))}
|
| 190 |
-
common_names = sorted(image_paths.keys() & mask_paths.keys())
|
| 191 |
-
|
| 192 |
-
if not common_names:
|
| 193 |
-
raise RuntimeError("No matching .tif image/mask pairs found.")
|
| 194 |
-
|
| 195 |
-
workers = max(1, int(args.workers))
|
| 196 |
-
kept = 0
|
| 197 |
-
dropped_all_ignore = 0
|
| 198 |
-
total_windows = 0
|
| 199 |
-
|
| 200 |
-
with ProcessPoolExecutor(max_workers=workers) as executor:
|
| 201 |
-
futures = [
|
| 202 |
-
executor.submit(
|
| 203 |
-
process_tile,
|
| 204 |
-
str(image_paths[name]),
|
| 205 |
-
str(mask_paths[name]),
|
| 206 |
-
str(out_img_dir),
|
| 207 |
-
str(out_msk_dir),
|
| 208 |
-
int(args.tile_size),
|
| 209 |
-
int(args.output_size),
|
| 210 |
-
int(args.ignore_index),
|
| 211 |
-
bool(args.drop_all_ignore),
|
| 212 |
-
bool(args.overwrite),
|
| 213 |
-
)
|
| 214 |
-
for name in common_names
|
| 215 |
-
]
|
| 216 |
-
|
| 217 |
-
for future in tqdm(as_completed(futures), total=len(futures), desc="Tiles", unit="tile"):
|
| 218 |
-
tile_kept, tile_dropped, tile_windows = future.result()
|
| 219 |
-
kept += tile_kept
|
| 220 |
-
dropped_all_ignore += tile_dropped
|
| 221 |
-
total_windows += tile_windows
|
| 222 |
-
|
| 223 |
-
print(f"matched_tiles: {len(common_names)}")
|
| 224 |
-
print(f"workers: {workers}")
|
| 225 |
-
print(f"total_windows: {total_windows}")
|
| 226 |
-
print(f"kept_patches: {kept}")
|
| 227 |
-
print(f"dropped_all_ignore: {dropped_all_ignore}")
|
| 228 |
-
print(f"output_img_dir: {out_img_dir}")
|
| 229 |
-
print(f"output_msk_dir: {out_msk_dir}")
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
if __name__ == "__main__":
|
| 233 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/rasterize_swisstlm3d.py
DELETED
|
@@ -1,403 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Rasterize SwissTLM3D features onto SwissImage tiles.
|
| 3 |
-
|
| 4 |
-
This script writes masks that are compatible with the 15-class FLAIR-HUB
|
| 5 |
-
land-cover setup used in this repository. SwissTLM3D does not expose every
|
| 6 |
-
FLAIR-HUB class with the same level of granularity, so the script only
|
| 7 |
-
projects a reliable subset of classes and leaves the rest to ignore_index=255.
|
| 8 |
-
|
| 9 |
-
The implementation intentionally relies on GDAL/OGR command-line tools
|
| 10 |
-
(`gdal_rasterize`) plus `rasterio`, because the current `caswit` environment
|
| 11 |
-
does not ship a full vector Python stack.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
import argparse
|
| 17 |
-
import os
|
| 18 |
-
import subprocess
|
| 19 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 20 |
-
from dataclasses import dataclass
|
| 21 |
-
from pathlib import Path
|
| 22 |
-
from typing import Iterable
|
| 23 |
-
|
| 24 |
-
import numpy as np
|
| 25 |
-
import rasterio
|
| 26 |
-
from rasterio.errors import RasterioIOError
|
| 27 |
-
from tqdm import tqdm
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
IGNORE_INDEX = 255
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
FLAIR_HUB_CLASSES = {
|
| 34 |
-
0: "building",
|
| 35 |
-
1: "greenhouse",
|
| 36 |
-
2: "swimming_pool",
|
| 37 |
-
3: "impervious surface",
|
| 38 |
-
4: "pervious surface",
|
| 39 |
-
5: "bare soil",
|
| 40 |
-
6: "water",
|
| 41 |
-
7: "snow",
|
| 42 |
-
8: "herbaceous vegetation",
|
| 43 |
-
9: "agricultural land",
|
| 44 |
-
10: "plowed land",
|
| 45 |
-
11: "vineyard",
|
| 46 |
-
12: "deciduous",
|
| 47 |
-
13: "coniferous",
|
| 48 |
-
14: "brushwood",
|
| 49 |
-
}
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
def sql_in(values: Iterable[str]) -> str:
|
| 53 |
-
return "(" + ", ".join(f"'{value}'" for value in values) + ")"
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
@dataclass(frozen=True)
|
| 57 |
-
class BurnSpec:
|
| 58 |
-
name: str
|
| 59 |
-
layer: str
|
| 60 |
-
class_id: int
|
| 61 |
-
where: str
|
| 62 |
-
sql_expr: str = "geom"
|
| 63 |
-
|
| 64 |
-
@property
|
| 65 |
-
def rtree(self) -> str:
|
| 66 |
-
return f"rtree_{self.layer}_geom"
|
| 67 |
-
|
| 68 |
-
def sql(self, minx: float, miny: float, maxx: float, maxy: float) -> str:
|
| 69 |
-
spatial_filter = (
|
| 70 |
-
f"id IN (SELECT id FROM {self.rtree} "
|
| 71 |
-
f"WHERE maxx >= {minx} AND minx <= {maxx} "
|
| 72 |
-
f"AND maxy >= {miny} AND miny <= {maxy})"
|
| 73 |
-
)
|
| 74 |
-
return (
|
| 75 |
-
f"SELECT {self.sql_expr} AS geom "
|
| 76 |
-
f"FROM {self.layer} "
|
| 77 |
-
f"WHERE ({self.where}) AND {spatial_filter}"
|
| 78 |
-
)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
ROAD_WIDTH_CASE = """
|
| 82 |
-
CASE objektart
|
| 83 |
-
WHEN '1m Weg' THEN 0.5
|
| 84 |
-
WHEN '1m Wegfragment' THEN 0.5
|
| 85 |
-
WHEN '2m Weg' THEN 1.0
|
| 86 |
-
WHEN '2m Wegfragment' THEN 1.0
|
| 87 |
-
WHEN '3m Strasse' THEN 1.5
|
| 88 |
-
WHEN '4m Strasse' THEN 2.0
|
| 89 |
-
WHEN '6m Strasse' THEN 3.0
|
| 90 |
-
WHEN '8m Strasse' THEN 4.0
|
| 91 |
-
WHEN '10m Strasse' THEN 5.0
|
| 92 |
-
WHEN 'Autobahn' THEN 10.0
|
| 93 |
-
WHEN 'Autostrasse' THEN 6.0
|
| 94 |
-
WHEN 'Ausfahrt' THEN 3.5
|
| 95 |
-
WHEN 'Einfahrt' THEN 3.0
|
| 96 |
-
WHEN 'Zufahrt' THEN 3.0
|
| 97 |
-
WHEN 'Dienstzufahrt' THEN 3.0
|
| 98 |
-
WHEN 'Raststaette' THEN 8.0
|
| 99 |
-
WHEN 'Platz' THEN 6.0
|
| 100 |
-
WHEN 'Markierte Spur' THEN 2.0
|
| 101 |
-
WHEN 'Verbindung' THEN 2.0
|
| 102 |
-
ELSE 0.0
|
| 103 |
-
END
|
| 104 |
-
""".strip()
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
BUILDING_OBJECTS = (
|
| 108 |
-
"Gebaeude",
|
| 109 |
-
"Hochhaus",
|
| 110 |
-
"Historische Baute",
|
| 111 |
-
"Im Bau",
|
| 112 |
-
"Kapelle",
|
| 113 |
-
"Offenes Gebaeude",
|
| 114 |
-
"Sakraler Turm",
|
| 115 |
-
"Sakrales Gebaeude",
|
| 116 |
-
"Turm",
|
| 117 |
-
"Unterirdisches Gebaeude",
|
| 118 |
-
"Einhausung",
|
| 119 |
-
)
|
| 120 |
-
|
| 121 |
-
GREENHOUSE_OBJECTS = ("Treibhaus",)
|
| 122 |
-
|
| 123 |
-
BARE_SOIL_OBJECTS = (
|
| 124 |
-
"Fels",
|
| 125 |
-
"Fels locker",
|
| 126 |
-
"Felsbloecke",
|
| 127 |
-
"Felsbloecke locker",
|
| 128 |
-
"Lockergestein",
|
| 129 |
-
"Lockergestein locker",
|
| 130 |
-
)
|
| 131 |
-
|
| 132 |
-
WATER_OBJECTS = (
|
| 133 |
-
"Fliessgewaesser",
|
| 134 |
-
"Stehende Gewaesser",
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
SNOW_OBJECTS = (
|
| 138 |
-
"Gletscher",
|
| 139 |
-
"Schneefeld Toteis",
|
| 140 |
-
)
|
| 141 |
-
|
| 142 |
-
BRUSHWOOD_OBJECTS = ("Gebueschwald",)
|
| 143 |
-
|
| 144 |
-
AGRICULTURAL_OBJECTS = (
|
| 145 |
-
"Baumschule",
|
| 146 |
-
"Obstanlage",
|
| 147 |
-
"Schrebergartenareal",
|
| 148 |
-
)
|
| 149 |
-
|
| 150 |
-
VINEYARD_OBJECTS = ("Reben",)
|
| 151 |
-
|
| 152 |
-
TRAFFIC_AREA_OBJECTS = (
|
| 153 |
-
"Flugfeldareal",
|
| 154 |
-
"Flughafenareal",
|
| 155 |
-
"Flugplatzareal",
|
| 156 |
-
"Gleisareal",
|
| 157 |
-
"Heliport",
|
| 158 |
-
"Oeffentliches Parkplatzareal",
|
| 159 |
-
"Privates Fahrareal",
|
| 160 |
-
"Privates Parkplatzareal",
|
| 161 |
-
"Rastplatzareal",
|
| 162 |
-
"Verkehrsflaeche",
|
| 163 |
-
)
|
| 164 |
-
|
| 165 |
-
ROAD_OBJECTS = (
|
| 166 |
-
"10m Strasse",
|
| 167 |
-
"1m Weg",
|
| 168 |
-
"1m Wegfragment",
|
| 169 |
-
"2m Weg",
|
| 170 |
-
"2m Wegfragment",
|
| 171 |
-
"3m Strasse",
|
| 172 |
-
"4m Strasse",
|
| 173 |
-
"6m Strasse",
|
| 174 |
-
"8m Strasse",
|
| 175 |
-
"Ausfahrt",
|
| 176 |
-
"Autobahn",
|
| 177 |
-
"Autostrasse",
|
| 178 |
-
"Dienstzufahrt",
|
| 179 |
-
"Einfahrt",
|
| 180 |
-
"Markierte Spur",
|
| 181 |
-
"Platz",
|
| 182 |
-
"Raststaette",
|
| 183 |
-
"Verbindung",
|
| 184 |
-
"Zufahrt",
|
| 185 |
-
)
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
BURN_SPECS = (
|
| 189 |
-
BurnSpec(
|
| 190 |
-
name="bare_soil",
|
| 191 |
-
layer="tlm_bb_bodenbedeckung",
|
| 192 |
-
class_id=5,
|
| 193 |
-
where=f"objektart IN {sql_in(BARE_SOIL_OBJECTS)}",
|
| 194 |
-
),
|
| 195 |
-
BurnSpec(
|
| 196 |
-
name="snow",
|
| 197 |
-
layer="tlm_bb_bodenbedeckung",
|
| 198 |
-
class_id=7,
|
| 199 |
-
where=f"objektart IN {sql_in(SNOW_OBJECTS)}",
|
| 200 |
-
),
|
| 201 |
-
BurnSpec(
|
| 202 |
-
name="brushwood",
|
| 203 |
-
layer="tlm_bb_bodenbedeckung",
|
| 204 |
-
class_id=14,
|
| 205 |
-
where=f"objektart IN {sql_in(BRUSHWOOD_OBJECTS)}",
|
| 206 |
-
),
|
| 207 |
-
BurnSpec(
|
| 208 |
-
name="agricultural_land",
|
| 209 |
-
layer="tlm_areale_nutzungsareal",
|
| 210 |
-
class_id=9,
|
| 211 |
-
where=f"objektart IN {sql_in(AGRICULTURAL_OBJECTS)}",
|
| 212 |
-
),
|
| 213 |
-
BurnSpec(
|
| 214 |
-
name="vineyard",
|
| 215 |
-
layer="tlm_areale_nutzungsareal",
|
| 216 |
-
class_id=11,
|
| 217 |
-
where=f"objektart IN {sql_in(VINEYARD_OBJECTS)}",
|
| 218 |
-
),
|
| 219 |
-
BurnSpec(
|
| 220 |
-
name="impervious_polygons",
|
| 221 |
-
layer="tlm_areale_verkehrsareal",
|
| 222 |
-
class_id=3,
|
| 223 |
-
where=f"objektart IN {sql_in(TRAFFIC_AREA_OBJECTS)}",
|
| 224 |
-
),
|
| 225 |
-
BurnSpec(
|
| 226 |
-
name="impervious_roads",
|
| 227 |
-
layer="tlm_strassen_strasse",
|
| 228 |
-
class_id=3,
|
| 229 |
-
where=f"objektart IN {sql_in(ROAD_OBJECTS)}",
|
| 230 |
-
sql_expr=f"ST_Buffer(geom, {ROAD_WIDTH_CASE})",
|
| 231 |
-
),
|
| 232 |
-
BurnSpec(
|
| 233 |
-
name="water",
|
| 234 |
-
layer="tlm_bb_bodenbedeckung",
|
| 235 |
-
class_id=6,
|
| 236 |
-
where=f"objektart IN {sql_in(WATER_OBJECTS)}",
|
| 237 |
-
),
|
| 238 |
-
BurnSpec(
|
| 239 |
-
name="building",
|
| 240 |
-
layer="tlm_bauten_gebaeude_footprint",
|
| 241 |
-
class_id=0,
|
| 242 |
-
where=f"objektart IN {sql_in(BUILDING_OBJECTS)}",
|
| 243 |
-
),
|
| 244 |
-
BurnSpec(
|
| 245 |
-
name="greenhouse",
|
| 246 |
-
layer="tlm_bauten_gebaeude_footprint",
|
| 247 |
-
class_id=1,
|
| 248 |
-
where=f"objektart IN {sql_in(GREENHOUSE_OBJECTS)}",
|
| 249 |
-
),
|
| 250 |
-
)
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
def init_mask_from_image(image_path: Path, output_path: Path) -> tuple[float, float, float, float, int, int]:
|
| 254 |
-
if output_path.exists():
|
| 255 |
-
output_path.unlink()
|
| 256 |
-
|
| 257 |
-
with rasterio.open(image_path) as src:
|
| 258 |
-
bounds = src.bounds
|
| 259 |
-
width = src.width
|
| 260 |
-
height = src.height
|
| 261 |
-
transform = src.transform
|
| 262 |
-
crs = src.crs
|
| 263 |
-
|
| 264 |
-
profile = {
|
| 265 |
-
"driver": "GTiff",
|
| 266 |
-
"width": width,
|
| 267 |
-
"height": height,
|
| 268 |
-
"count": 1,
|
| 269 |
-
"dtype": "uint8",
|
| 270 |
-
"crs": crs,
|
| 271 |
-
"transform": transform,
|
| 272 |
-
"nodata": IGNORE_INDEX,
|
| 273 |
-
"compress": "lzw",
|
| 274 |
-
"tiled": True,
|
| 275 |
-
"blockxsize": 512,
|
| 276 |
-
"blockysize": 512,
|
| 277 |
-
}
|
| 278 |
-
with rasterio.open(output_path, "w", **profile) as dst:
|
| 279 |
-
dst.write(np.full((1, height, width), IGNORE_INDEX, dtype=np.uint8))
|
| 280 |
-
|
| 281 |
-
return bounds.left, bounds.bottom, bounds.right, bounds.top, width, height
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
def rasterize_spec(
|
| 285 |
-
gpkg_path: Path,
|
| 286 |
-
output_path: Path,
|
| 287 |
-
spec: BurnSpec,
|
| 288 |
-
minx: float,
|
| 289 |
-
miny: float,
|
| 290 |
-
maxx: float,
|
| 291 |
-
maxy: float,
|
| 292 |
-
width: int,
|
| 293 |
-
height: int,
|
| 294 |
-
) -> None:
|
| 295 |
-
sql = spec.sql(minx=minx, miny=miny, maxx=maxx, maxy=maxy)
|
| 296 |
-
command = [
|
| 297 |
-
"gdal_rasterize",
|
| 298 |
-
"-b",
|
| 299 |
-
"1",
|
| 300 |
-
"-burn",
|
| 301 |
-
str(spec.class_id),
|
| 302 |
-
"-sql",
|
| 303 |
-
sql,
|
| 304 |
-
"-dialect",
|
| 305 |
-
"SQLITE",
|
| 306 |
-
str(gpkg_path),
|
| 307 |
-
str(output_path),
|
| 308 |
-
]
|
| 309 |
-
subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
def process_one(image_path: Path, output_dir: Path, gpkg_path: Path, overwrite: bool) -> tuple[str, str]:
|
| 313 |
-
output_path = output_dir / image_path.name
|
| 314 |
-
if output_path.exists() and not overwrite:
|
| 315 |
-
return image_path.name, "skipped"
|
| 316 |
-
|
| 317 |
-
try:
|
| 318 |
-
minx, miny, maxx, maxy, width, height = init_mask_from_image(image_path, output_path)
|
| 319 |
-
for spec in BURN_SPECS:
|
| 320 |
-
rasterize_spec(
|
| 321 |
-
gpkg_path=gpkg_path,
|
| 322 |
-
output_path=output_path,
|
| 323 |
-
spec=spec,
|
| 324 |
-
minx=minx,
|
| 325 |
-
miny=miny,
|
| 326 |
-
maxx=maxx,
|
| 327 |
-
maxy=maxy,
|
| 328 |
-
width=width,
|
| 329 |
-
height=height,
|
| 330 |
-
)
|
| 331 |
-
except (subprocess.CalledProcessError, RasterioIOError) as exc:
|
| 332 |
-
return image_path.name, f"error: {exc}"
|
| 333 |
-
|
| 334 |
-
return image_path.name, "ok"
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
def parse_args() -> argparse.Namespace:
|
| 338 |
-
parser = argparse.ArgumentParser(description="Rasterize SwissTLM3D onto SwissImage tiles.")
|
| 339 |
-
parser.add_argument(
|
| 340 |
-
"--gpkg",
|
| 341 |
-
type=Path,
|
| 342 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d/SWISSTLM3D_2026_LV95_LN02.gpkg"),
|
| 343 |
-
help="Path to the extracted SwissTLM3D GeoPackage.",
|
| 344 |
-
)
|
| 345 |
-
parser.add_argument(
|
| 346 |
-
"--image_dir",
|
| 347 |
-
type=Path,
|
| 348 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/image"),
|
| 349 |
-
help="Directory containing SwissImage GeoTIFF tiles.",
|
| 350 |
-
)
|
| 351 |
-
parser.add_argument(
|
| 352 |
-
"--output_dir",
|
| 353 |
-
type=Path,
|
| 354 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d"),
|
| 355 |
-
help="Directory where rasterized masks will be written.",
|
| 356 |
-
)
|
| 357 |
-
parser.add_argument("--workers", type=int, default=max(1, os.cpu_count() // 2))
|
| 358 |
-
parser.add_argument("--limit", type=int, default=None, help="Optional number of image tiles to process.")
|
| 359 |
-
parser.add_argument("--overwrite", action="store_true", help="Overwrite existing masks.")
|
| 360 |
-
return parser.parse_args()
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
def main() -> None:
|
| 364 |
-
args = parse_args()
|
| 365 |
-
args.output_dir.mkdir(parents=True, exist_ok=True)
|
| 366 |
-
|
| 367 |
-
if not args.gpkg.exists():
|
| 368 |
-
raise FileNotFoundError(f"Missing SwissTLM3D GeoPackage: {args.gpkg}")
|
| 369 |
-
if not args.image_dir.exists():
|
| 370 |
-
raise FileNotFoundError(f"Missing SwissImage directory: {args.image_dir}")
|
| 371 |
-
|
| 372 |
-
image_paths = sorted(args.image_dir.glob("*.tif"))
|
| 373 |
-
if args.limit is not None:
|
| 374 |
-
image_paths = image_paths[: args.limit]
|
| 375 |
-
|
| 376 |
-
if not image_paths:
|
| 377 |
-
raise RuntimeError(f"No .tif files found in {args.image_dir}")
|
| 378 |
-
|
| 379 |
-
failures = []
|
| 380 |
-
with ProcessPoolExecutor(max_workers=args.workers) as executor:
|
| 381 |
-
futures = {
|
| 382 |
-
executor.submit(process_one, image_path, args.output_dir, args.gpkg, args.overwrite): image_path
|
| 383 |
-
for image_path in image_paths
|
| 384 |
-
}
|
| 385 |
-
for future in tqdm(as_completed(futures), total=len(futures), desc="Rasterizing SwissTLM3D"):
|
| 386 |
-
image_name, status = future.result()
|
| 387 |
-
if status != "ok" and status != "skipped":
|
| 388 |
-
failures.append((image_name, status))
|
| 389 |
-
|
| 390 |
-
if failures:
|
| 391 |
-
print("Failures:")
|
| 392 |
-
for image_name, status in failures[:20]:
|
| 393 |
-
print(f"- {image_name}: {status}")
|
| 394 |
-
raise SystemExit(1)
|
| 395 |
-
|
| 396 |
-
print(f"Done. Wrote masks to {args.output_dir}")
|
| 397 |
-
print("Projected classes:")
|
| 398 |
-
for spec in BURN_SPECS:
|
| 399 |
-
print(f"- {spec.class_id:2d}: {FLAIR_HUB_CLASSES[spec.class_id]} ({spec.name})")
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
if __name__ == "__main__":
|
| 403 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/split_swisstlm3d_geographic.py
DELETED
|
@@ -1,314 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Create a deterministic random split for SwissImage + SwissTLM3D masks.
|
| 3 |
-
|
| 4 |
-
Supported naming conventions:
|
| 5 |
-
1. Full tiles:
|
| 6 |
-
swissimage-dop10_<year>_<x>-<y>_0.1_2056.tif
|
| 7 |
-
2. Prepared patches:
|
| 8 |
-
swissimage-dop10_<year>_<x>-<y>_0.1_2056_y<row>_x<col>.tif
|
| 9 |
-
|
| 10 |
-
For prepared patches, the split is performed at source-tile level `(year, x, y)`,
|
| 11 |
-
so all patches from the same original tile stay in the same split.
|
| 12 |
-
"""
|
| 13 |
-
|
| 14 |
-
from __future__ import annotations
|
| 15 |
-
|
| 16 |
-
import argparse
|
| 17 |
-
import csv
|
| 18 |
-
import os
|
| 19 |
-
import random
|
| 20 |
-
import re
|
| 21 |
-
import shutil
|
| 22 |
-
from dataclasses import dataclass
|
| 23 |
-
from pathlib import Path
|
| 24 |
-
from typing import Iterable
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
FULL_TILE_RE = re.compile(
|
| 28 |
-
r"^swissimage-dop10_(?P<year>\d{4})_(?P<x>\d+)-(?P<y>\d+)_0\.1_2056\.tif$"
|
| 29 |
-
)
|
| 30 |
-
PATCH_RE = re.compile(
|
| 31 |
-
r"^swissimage-dop10_(?P<year>\d{4})_(?P<x>\d+)-(?P<y>\d+)_0\.1_2056_y(?P<yoff>\d+)_x(?P<xoff>\d+)\.tif$"
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
@dataclass(frozen=True)
|
| 36 |
-
class TileRecord:
|
| 37 |
-
name: str
|
| 38 |
-
year: int
|
| 39 |
-
x: int
|
| 40 |
-
y: int
|
| 41 |
-
x_offset: int
|
| 42 |
-
y_offset: int
|
| 43 |
-
x_geo: int
|
| 44 |
-
y_geo: int
|
| 45 |
-
image_path: Path
|
| 46 |
-
mask_path: Path
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
def parse_args() -> argparse.Namespace:
|
| 50 |
-
parser = argparse.ArgumentParser(description="Create a random 60/20/20 split for SwissTLM3D fine-tuning.")
|
| 51 |
-
parser.add_argument(
|
| 52 |
-
"--image_dir",
|
| 53 |
-
type=Path,
|
| 54 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/image"),
|
| 55 |
-
help="Directory containing SwissImage tiles.",
|
| 56 |
-
)
|
| 57 |
-
parser.add_argument(
|
| 58 |
-
"--mask_dir",
|
| 59 |
-
type=Path,
|
| 60 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d"),
|
| 61 |
-
help="Directory containing SwissTLM3D masks with matching filenames.",
|
| 62 |
-
)
|
| 63 |
-
parser.add_argument(
|
| 64 |
-
"--output_dir",
|
| 65 |
-
type=Path,
|
| 66 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_ft"),
|
| 67 |
-
help="Output root with train/val/test subfolders.",
|
| 68 |
-
)
|
| 69 |
-
parser.add_argument(
|
| 70 |
-
"--train_ratio",
|
| 71 |
-
type=float,
|
| 72 |
-
default=0.6,
|
| 73 |
-
help="Train ratio.",
|
| 74 |
-
)
|
| 75 |
-
parser.add_argument(
|
| 76 |
-
"--val_ratio",
|
| 77 |
-
type=float,
|
| 78 |
-
default=0.2,
|
| 79 |
-
help="Validation ratio.",
|
| 80 |
-
)
|
| 81 |
-
parser.add_argument(
|
| 82 |
-
"--test_ratio",
|
| 83 |
-
type=float,
|
| 84 |
-
default=0.2,
|
| 85 |
-
help="Test ratio.",
|
| 86 |
-
)
|
| 87 |
-
parser.add_argument(
|
| 88 |
-
"--seed",
|
| 89 |
-
type=int,
|
| 90 |
-
default=42,
|
| 91 |
-
help="Random seed used for deterministic split assignment.",
|
| 92 |
-
)
|
| 93 |
-
parser.add_argument(
|
| 94 |
-
"--link_mode",
|
| 95 |
-
choices=["symlink", "hardlink", "copy"],
|
| 96 |
-
default="symlink",
|
| 97 |
-
help="How to materialize split files.",
|
| 98 |
-
)
|
| 99 |
-
parser.add_argument(
|
| 100 |
-
"--overwrite",
|
| 101 |
-
action="store_true",
|
| 102 |
-
help="Delete an existing output_dir before writing the split.",
|
| 103 |
-
)
|
| 104 |
-
return parser.parse_args()
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def parse_tile_name(path: Path) -> tuple[int, int, int, int, int]:
|
| 108 |
-
patch_match = PATCH_RE.match(path.name)
|
| 109 |
-
if patch_match is not None:
|
| 110 |
-
year = int(patch_match.group("year"))
|
| 111 |
-
x = int(patch_match.group("x"))
|
| 112 |
-
y = int(patch_match.group("y"))
|
| 113 |
-
y_offset = int(patch_match.group("yoff"))
|
| 114 |
-
x_offset = int(patch_match.group("xoff"))
|
| 115 |
-
return year, x, y, x_offset, y_offset
|
| 116 |
-
|
| 117 |
-
full_match = FULL_TILE_RE.match(path.name)
|
| 118 |
-
if full_match is not None:
|
| 119 |
-
year = int(full_match.group("year"))
|
| 120 |
-
x = int(full_match.group("x"))
|
| 121 |
-
y = int(full_match.group("y"))
|
| 122 |
-
return year, x, y, 0, 0
|
| 123 |
-
|
| 124 |
-
raise ValueError(f"Unexpected SwissImage filename format: {path.name}")
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def collect_records(image_dir: Path, mask_dir: Path) -> list[TileRecord]:
|
| 128 |
-
image_files = {p.name: p for p in sorted(image_dir.glob("*.tif"))}
|
| 129 |
-
mask_files = {p.name: p for p in sorted(mask_dir.glob("*.tif"))}
|
| 130 |
-
common_names = sorted(image_files.keys() & mask_files.keys())
|
| 131 |
-
|
| 132 |
-
if not common_names:
|
| 133 |
-
raise RuntimeError("No matching image/mask filenames found.")
|
| 134 |
-
|
| 135 |
-
records = []
|
| 136 |
-
for name in common_names:
|
| 137 |
-
year, x, y, x_offset, y_offset = parse_tile_name(Path(name))
|
| 138 |
-
records.append(
|
| 139 |
-
TileRecord(
|
| 140 |
-
name=name,
|
| 141 |
-
year=year,
|
| 142 |
-
x=x,
|
| 143 |
-
y=y,
|
| 144 |
-
x_offset=x_offset,
|
| 145 |
-
y_offset=y_offset,
|
| 146 |
-
x_geo=x * 10000 + x_offset,
|
| 147 |
-
y_geo=y * 10000 + y_offset,
|
| 148 |
-
image_path=image_files[name],
|
| 149 |
-
mask_path=mask_files[name],
|
| 150 |
-
)
|
| 151 |
-
)
|
| 152 |
-
return records
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
def split_random_by_source_tile(
|
| 156 |
-
records: list[TileRecord],
|
| 157 |
-
train_ratio: float,
|
| 158 |
-
val_ratio: float,
|
| 159 |
-
test_ratio: float,
|
| 160 |
-
seed: int,
|
| 161 |
-
) -> tuple[list[TileRecord], list[TileRecord], list[TileRecord]]:
|
| 162 |
-
tile_to_records: dict[tuple[int, int, int], list[TileRecord]] = {}
|
| 163 |
-
for record in records:
|
| 164 |
-
tile_to_records.setdefault((record.year, record.x, record.y), []).append(record)
|
| 165 |
-
|
| 166 |
-
groups = list(tile_to_records.items())
|
| 167 |
-
rng = random.Random(seed)
|
| 168 |
-
rng.shuffle(groups)
|
| 169 |
-
|
| 170 |
-
total = len(groups)
|
| 171 |
-
n_train = round(total * train_ratio)
|
| 172 |
-
n_val = round(total * val_ratio)
|
| 173 |
-
n_test = total - n_train - n_val
|
| 174 |
-
|
| 175 |
-
if min(n_train, n_val, n_test) <= 0:
|
| 176 |
-
raise RuntimeError("Random split produced an empty split. Adjust ratios or input size.")
|
| 177 |
-
|
| 178 |
-
train_groups = groups[:n_train]
|
| 179 |
-
val_groups = groups[n_train:n_train + n_val]
|
| 180 |
-
test_groups = groups[n_train + n_val:]
|
| 181 |
-
|
| 182 |
-
train_records = [record for _, tile_records in train_groups for record in tile_records]
|
| 183 |
-
val_records = [record for _, tile_records in val_groups for record in tile_records]
|
| 184 |
-
test_records = [record for _, tile_records in test_groups for record in tile_records]
|
| 185 |
-
|
| 186 |
-
return train_records, val_records, test_records
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
def ensure_empty_dir(path: Path, overwrite: bool) -> None:
|
| 190 |
-
if path.exists():
|
| 191 |
-
if not overwrite:
|
| 192 |
-
raise FileExistsError(f"{path} already exists. Use --overwrite to recreate it.")
|
| 193 |
-
shutil.rmtree(path)
|
| 194 |
-
path.mkdir(parents=True, exist_ok=True)
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
def materialize(src: Path, dst: Path, mode: str) -> None:
|
| 198 |
-
if mode == "symlink":
|
| 199 |
-
os.symlink(src, dst)
|
| 200 |
-
return
|
| 201 |
-
if mode == "hardlink":
|
| 202 |
-
os.link(src, dst)
|
| 203 |
-
return
|
| 204 |
-
shutil.copy2(src, dst)
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
def write_split(
|
| 208 |
-
split_name: str,
|
| 209 |
-
records: Iterable[TileRecord],
|
| 210 |
-
output_dir: Path,
|
| 211 |
-
link_mode: str,
|
| 212 |
-
) -> None:
|
| 213 |
-
img_dir = output_dir / split_name / "img"
|
| 214 |
-
msk_dir = output_dir / split_name / "msk"
|
| 215 |
-
img_dir.mkdir(parents=True, exist_ok=True)
|
| 216 |
-
msk_dir.mkdir(parents=True, exist_ok=True)
|
| 217 |
-
|
| 218 |
-
for record in records:
|
| 219 |
-
materialize(record.image_path, img_dir / record.name, link_mode)
|
| 220 |
-
materialize(record.mask_path, msk_dir / record.name, link_mode)
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
def write_manifest(
|
| 224 |
-
output_dir: Path,
|
| 225 |
-
mode: str,
|
| 226 |
-
train_records: list[TileRecord],
|
| 227 |
-
val_records: list[TileRecord],
|
| 228 |
-
test_records: list[TileRecord],
|
| 229 |
-
) -> None:
|
| 230 |
-
manifest_path = output_dir / "split_manifest.csv"
|
| 231 |
-
with manifest_path.open("w", newline="") as f:
|
| 232 |
-
writer = csv.writer(f)
|
| 233 |
-
writer.writerow(["split", "name", "year", "x", "y", "x_offset", "y_offset", "x_geo", "y_geo", "image_path", "mask_path"])
|
| 234 |
-
for split_name, records in (
|
| 235 |
-
("train", train_records),
|
| 236 |
-
("val", val_records),
|
| 237 |
-
("test", test_records),
|
| 238 |
-
):
|
| 239 |
-
for record in records:
|
| 240 |
-
writer.writerow(
|
| 241 |
-
[
|
| 242 |
-
split_name,
|
| 243 |
-
record.name,
|
| 244 |
-
record.year,
|
| 245 |
-
record.x,
|
| 246 |
-
record.y,
|
| 247 |
-
record.x_offset,
|
| 248 |
-
record.y_offset,
|
| 249 |
-
record.x_geo,
|
| 250 |
-
record.y_geo,
|
| 251 |
-
str(record.image_path),
|
| 252 |
-
str(record.mask_path),
|
| 253 |
-
]
|
| 254 |
-
)
|
| 255 |
-
|
| 256 |
-
summary_path = output_dir / "split_summary.txt"
|
| 257 |
-
with summary_path.open("w") as f:
|
| 258 |
-
f.write(f"mode={mode}\n")
|
| 259 |
-
f.write(f"train={len(train_records)}\n")
|
| 260 |
-
f.write(f"val={len(val_records)}\n")
|
| 261 |
-
f.write(f"test={len(test_records)}\n")
|
| 262 |
-
f.write(
|
| 263 |
-
f"train_coords=({min(r.x_geo for r in train_records)}, {max(r.x_geo for r in train_records)}) x "
|
| 264 |
-
f"({min(r.y_geo for r in train_records)}, {max(r.y_geo for r in train_records)})\n"
|
| 265 |
-
)
|
| 266 |
-
f.write(
|
| 267 |
-
f"val_coords=({min(r.x_geo for r in val_records)}, {max(r.x_geo for r in val_records)}) x "
|
| 268 |
-
f"({min(r.y_geo for r in val_records)}, {max(r.y_geo for r in val_records)})\n"
|
| 269 |
-
)
|
| 270 |
-
f.write(
|
| 271 |
-
f"test_coords=({min(r.x_geo for r in test_records)}, {max(r.x_geo for r in test_records)}) x "
|
| 272 |
-
f"({min(r.y_geo for r in test_records)}, {max(r.y_geo for r in test_records)})\n"
|
| 273 |
-
)
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
def main() -> None:
|
| 277 |
-
args = parse_args()
|
| 278 |
-
|
| 279 |
-
ratio_sum = args.train_ratio + args.val_ratio + args.test_ratio
|
| 280 |
-
if abs(ratio_sum - 1.0) > 1e-8:
|
| 281 |
-
raise ValueError(f"Ratios must sum to 1.0, got {ratio_sum:.6f}")
|
| 282 |
-
|
| 283 |
-
if not args.image_dir.exists():
|
| 284 |
-
raise FileNotFoundError(f"Missing image_dir: {args.image_dir}")
|
| 285 |
-
if not args.mask_dir.exists():
|
| 286 |
-
raise FileNotFoundError(f"Missing mask_dir: {args.mask_dir}")
|
| 287 |
-
|
| 288 |
-
records = collect_records(args.image_dir, args.mask_dir)
|
| 289 |
-
train_records, val_records, test_records = split_random_by_source_tile(
|
| 290 |
-
records=records,
|
| 291 |
-
train_ratio=args.train_ratio,
|
| 292 |
-
val_ratio=args.val_ratio,
|
| 293 |
-
test_ratio=args.test_ratio,
|
| 294 |
-
seed=args.seed,
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
ensure_empty_dir(args.output_dir, args.overwrite)
|
| 298 |
-
write_split("train", train_records, args.output_dir, args.link_mode)
|
| 299 |
-
write_split("val", val_records, args.output_dir, args.link_mode)
|
| 300 |
-
write_split("test", test_records, args.output_dir, args.link_mode)
|
| 301 |
-
write_manifest(args.output_dir, "random_by_source_tile", train_records, val_records, test_records)
|
| 302 |
-
|
| 303 |
-
total = len(records)
|
| 304 |
-
print(f"Matched tiles: {total}")
|
| 305 |
-
print(f"Split mode: random_by_source_tile")
|
| 306 |
-
print(f"Seed: {args.seed}")
|
| 307 |
-
print(f"Train: {len(train_records)} ({100.0 * len(train_records) / total:.2f}%)")
|
| 308 |
-
print(f"Val: {len(val_records)} ({100.0 * len(val_records) / total:.2f}%)")
|
| 309 |
-
print(f"Test: {len(test_records)} ({100.0 * len(test_records) / total:.2f}%)")
|
| 310 |
-
print(f"Output: {args.output_dir}")
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
if __name__ == "__main__":
|
| 314 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/stats_swisstlm3d_masks.py
DELETED
|
@@ -1,136 +0,0 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Compute mask statistics for SwissTLM3D-generated masks with a progress bar.
|
| 3 |
-
|
| 4 |
-
Outputs:
|
| 5 |
-
- global class histogram
|
| 6 |
-
- class proportions
|
| 7 |
-
- count of masks entirely equal to 255
|
| 8 |
-
- count of masks containing at least one non-255 pixel
|
| 9 |
-
- CSV list of fully-255 masks
|
| 10 |
-
- text summary
|
| 11 |
-
"""
|
| 12 |
-
|
| 13 |
-
from __future__ import annotations
|
| 14 |
-
|
| 15 |
-
import argparse
|
| 16 |
-
import csv
|
| 17 |
-
from collections import Counter
|
| 18 |
-
from pathlib import Path
|
| 19 |
-
|
| 20 |
-
import numpy as np
|
| 21 |
-
import rasterio
|
| 22 |
-
from tqdm import tqdm
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
CLASS_NAMES = {
|
| 26 |
-
0: "building",
|
| 27 |
-
1: "greenhouse",
|
| 28 |
-
2: "swimming_pool",
|
| 29 |
-
3: "impervious_surface",
|
| 30 |
-
4: "pervious_surface",
|
| 31 |
-
5: "bare_soil",
|
| 32 |
-
6: "water",
|
| 33 |
-
7: "snow",
|
| 34 |
-
8: "herbaceous_vegetation",
|
| 35 |
-
9: "agricultural_land",
|
| 36 |
-
10: "plowed_land",
|
| 37 |
-
11: "vineyard",
|
| 38 |
-
12: "deciduous",
|
| 39 |
-
13: "coniferous",
|
| 40 |
-
14: "brushwood",
|
| 41 |
-
255: "ignore",
|
| 42 |
-
}
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def parse_args() -> argparse.Namespace:
|
| 46 |
-
parser = argparse.ArgumentParser(description="Compute class statistics for SwissTLM3D masks.")
|
| 47 |
-
parser.add_argument(
|
| 48 |
-
"--mask_dir",
|
| 49 |
-
type=Path,
|
| 50 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d"),
|
| 51 |
-
help="Directory containing rasterized SwissTLM3D masks.",
|
| 52 |
-
)
|
| 53 |
-
parser.add_argument(
|
| 54 |
-
"--out_dir",
|
| 55 |
-
type=Path,
|
| 56 |
-
default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d_stats"),
|
| 57 |
-
help="Directory where summary files will be written.",
|
| 58 |
-
)
|
| 59 |
-
return parser.parse_args()
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
def main() -> None:
|
| 63 |
-
args = parse_args()
|
| 64 |
-
args.out_dir.mkdir(parents=True, exist_ok=True)
|
| 65 |
-
|
| 66 |
-
mask_paths = sorted(args.mask_dir.glob("*.tif"))#[:100]
|
| 67 |
-
if not mask_paths:
|
| 68 |
-
raise RuntimeError(f"No .tif masks found in {args.mask_dir}")
|
| 69 |
-
|
| 70 |
-
pixel_counts: Counter[int] = Counter()
|
| 71 |
-
all_255_masks: list[str] = []
|
| 72 |
-
non_255_masks = 0
|
| 73 |
-
partial_255_masks = 0
|
| 74 |
-
|
| 75 |
-
for mask_path in tqdm(mask_paths, desc="Scanning masks", unit="mask"):
|
| 76 |
-
with rasterio.open(mask_path) as src:
|
| 77 |
-
arr = src.read(1)
|
| 78 |
-
|
| 79 |
-
values, counts = np.unique(arr, return_counts=True)
|
| 80 |
-
local = dict(zip(values.tolist(), counts.tolist()))
|
| 81 |
-
pixel_counts.update(local)
|
| 82 |
-
|
| 83 |
-
if set(local.keys()) == {255}:
|
| 84 |
-
all_255_masks.append(mask_path.name)
|
| 85 |
-
else:
|
| 86 |
-
non_255_masks += 1
|
| 87 |
-
if 255 in local:
|
| 88 |
-
partial_255_masks += 1
|
| 89 |
-
|
| 90 |
-
total_pixels = sum(pixel_counts.values())
|
| 91 |
-
|
| 92 |
-
summary_txt = args.out_dir / "summary.txt"
|
| 93 |
-
class_csv = args.out_dir / "class_stats.csv"
|
| 94 |
-
all_255_csv = args.out_dir / "all_255_masks.csv"
|
| 95 |
-
|
| 96 |
-
with summary_txt.open("w") as f:
|
| 97 |
-
f.write(f"mask_count={len(mask_paths)}\n")
|
| 98 |
-
f.write(f"all_255_count={len(all_255_masks)}\n")
|
| 99 |
-
f.write(f"non_255_count={non_255_masks}\n")
|
| 100 |
-
f.write(f"partial_255_count={partial_255_masks}\n")
|
| 101 |
-
f.write(f"total_pixels={total_pixels}\n")
|
| 102 |
-
|
| 103 |
-
with class_csv.open("w", newline="") as f:
|
| 104 |
-
writer = csv.writer(f)
|
| 105 |
-
writer.writerow(["value", "class_name", "pixel_count", "proportion"])
|
| 106 |
-
for value in sorted(pixel_counts):
|
| 107 |
-
writer.writerow(
|
| 108 |
-
[
|
| 109 |
-
value,
|
| 110 |
-
CLASS_NAMES.get(value, "unknown"),
|
| 111 |
-
pixel_counts[value],
|
| 112 |
-
pixel_counts[value] / total_pixels if total_pixels > 0 else 0.0,
|
| 113 |
-
]
|
| 114 |
-
)
|
| 115 |
-
|
| 116 |
-
with all_255_csv.open("w", newline="") as f:
|
| 117 |
-
writer = csv.writer(f)
|
| 118 |
-
writer.writerow(["filename"])
|
| 119 |
-
for name in all_255_masks:
|
| 120 |
-
writer.writerow([name])
|
| 121 |
-
|
| 122 |
-
print(f"mask_count: {len(mask_paths)}")
|
| 123 |
-
print(f"all_255_count: {len(all_255_masks)}")
|
| 124 |
-
print(f"non_255_count: {non_255_masks}")
|
| 125 |
-
print(f"partial_255_count: {partial_255_masks}")
|
| 126 |
-
print("class_stats:")
|
| 127 |
-
for value in sorted(pixel_counts):
|
| 128 |
-
proportion = pixel_counts[value] / total_pixels if total_pixels > 0 else 0.0
|
| 129 |
-
print(f" {value:>3} {CLASS_NAMES.get(value, 'unknown'):<24} {pixel_counts[value]} ({proportion:.8f})")
|
| 130 |
-
print(f"wrote: {summary_txt}")
|
| 131 |
-
print(f"wrote: {class_csv}")
|
| 132 |
-
print(f"wrote: {all_255_csv}")
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
if __name__ == "__main__":
|
| 136 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dataset/swissimage_grayscale_mixed.py
DELETED
|
@@ -1,24 +0,0 @@
|
|
| 1 |
-
from pathlib import Path
|
| 2 |
-
from typing import Dict, Optional
|
| 3 |
-
|
| 4 |
-
from torchvision import transforms
|
| 5 |
-
|
| 6 |
-
from dataset.definition_dataset_grayscale_mixed import (
|
| 7 |
-
SemanticSegmentationDatasetFusionGrayscaleMixed,
|
| 8 |
-
)
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
def build(
|
| 12 |
-
image_dir: Path,
|
| 13 |
-
mask_dir: Path,
|
| 14 |
-
transform: Optional[transforms.Compose],
|
| 15 |
-
augment=None,
|
| 16 |
-
label_remap: Optional[Dict[int, int]] = None,
|
| 17 |
-
):
|
| 18 |
-
return SemanticSegmentationDatasetFusionGrayscaleMixed(
|
| 19 |
-
image_dir,
|
| 20 |
-
mask_dir,
|
| 21 |
-
transform=transform,
|
| 22 |
-
augment=augment,
|
| 23 |
-
label_remap=label_remap,
|
| 24 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|