Image Segmentation
English
antoine.carreaud67 commited on
Commit
eea6ebc
·
1 Parent(s): 9ffd7b5

Stop tracking grayscale mixed local files

Browse files
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ paper/
2
+ configs/config_mixed_domain_ft_segformer_reliable5_grayscale.yaml
3
+ dataset/definition_dataset_grayscale_mixed.py
4
+ dataset/flairhub_grayscale_mixed.py
5
+ dataset/prepare_swisstlm3d_patches.py
6
+ dataset/rasterize_swisstlm3d.py
7
+ dataset/split_swisstlm3d_geographic.py
8
+ dataset/stats_swisstlm3d_masks.py
9
+ dataset/swissimage_grayscale_mixed.py
configs/config_mixed_domain_ft_segformer_reliable5_grayscale.yaml DELETED
@@ -1,74 +0,0 @@
1
- paths:
2
- flairhub_path: /mnt/CalcShare/datasets/FLAIR1024_merged
3
- swiss_path: /mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_reliable5_forest_center20_patches_2048to1024_split
4
-
5
- flairhub_train_img_subdir: train/img
6
- flairhub_train_msk_subdir: train/msk
7
- flairhub_val_img_subdir: valid/img
8
- flairhub_val_msk_subdir: valid/msk
9
- flairhub_test_img_subdir: test/img
10
- flairhub_test_msk_subdir: test/msk
11
-
12
- swiss_train_img_subdir: train/img
13
- swiss_train_msk_subdir: train/msk
14
- swiss_val_img_subdir: val/img
15
- swiss_val_msk_subdir: val/msk
16
- swiss_test_img_subdir: test/img
17
- swiss_test_msk_subdir: test/msk
18
-
19
- save_dir: weights
20
- pretrained_path: weights/CASWiT-Base-SSL-aug_FLAIRHUB_SF.pth
21
-
22
- model:
23
- model_name: openmmlab/upernet-swin-base
24
- num_classes: 15 # Keep the 15-class head for exact FLAIR-HUB checkpoint compatibility; forest merge is handled by label remapping below.
25
- cross_attention_heads: 1
26
- ignore_index: 255
27
- fusion_mlp_ratio: 4.0
28
- fusion_drop_path: 0.1
29
- lr_supervision_weight: 0.5
30
- head: segformer
31
-
32
- training:
33
- batch_size: 1
34
- num_workers: 8
35
- num_epochs: 10
36
- learning_rate: 8.0e-06
37
- amp: true
38
- seed: 42
39
- eta_min: 1.0e-06
40
-
41
- mixed:
42
- # Conservative domain adaptation with sparse high-confidence SwissTLM3D labels only.
43
- swiss_ratio: 0.20
44
- epoch_length: 0
45
- best_on: mean
46
- simulate_grayscale_inputs: true
47
-
48
- labels:
49
- flair_label_remap:
50
- 13: 12
51
- swiss_label_remap:
52
- 13: 12
53
-
54
- wandb:
55
- use_wandb: true
56
- project: CASWiT-SwissTLM3D
57
- entity: soloo
58
- run_name: CASWiT-Base-SSL-aug_FLAIRHUB_SF_mixed_ft_reliable5_forest_center20_gray_r02
59
-
60
- print_device: true
61
-
62
- augmentations:
63
- enable: true
64
- p_hflip: 0.5
65
- p_vflip: 0.5
66
- p_rot90: 0.5
67
- color_jitter:
68
- brightness: 0.05
69
- contrast: 0.05
70
- saturation: 0.05
71
- hue: 0.02
72
- blur:
73
- p: 0.0
74
- kernel: 3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/definition_dataset_grayscale_mixed.py DELETED
@@ -1,114 +0,0 @@
1
- """
2
- Dataset definitions dedicated to grayscale-simulated mixed-domain fine-tuning.
3
-
4
- This variant keeps the exact same HR/LR extraction logic as the standard fusion
5
- dataset, but converts RGB imagery to simulated grayscale replicated on 3 bands
6
- directly inside the dataset definition.
7
- """
8
-
9
- from __future__ import annotations
10
-
11
- import os
12
- from pathlib import Path
13
- from typing import Dict, Optional
14
-
15
- import numpy as np
16
- import torch
17
- from PIL import Image
18
- from torch.utils.data import Dataset
19
- from torchvision import transforms
20
-
21
- from dataset.definition_dataset import (
22
- apply_label_remap,
23
- load_image,
24
- load_mask,
25
- to_pil_uint8,
26
- to_tensor_img,
27
- )
28
-
29
-
30
- def simulate_grayscale_rgb(image_hwc_float01: np.ndarray) -> np.ndarray:
31
- """
32
- Convert an RGB image in [0,1] to simulated grayscale and replicate it on 3 channels.
33
- """
34
- if image_hwc_float01.ndim != 3 or image_hwc_float01.shape[-1] != 3:
35
- raise ValueError(
36
- f"simulate_grayscale_rgb expects HWC RGB input, got shape={image_hwc_float01.shape}"
37
- )
38
- gray = (
39
- 0.299 * image_hwc_float01[..., 0]
40
- + 0.587 * image_hwc_float01[..., 1]
41
- + 0.114 * image_hwc_float01[..., 2]
42
- ).astype(np.float32, copy=False)
43
- return np.repeat(gray[..., None], 3, axis=-1)
44
-
45
-
46
- class SemanticSegmentationDatasetFusionGrayscaleMixed(Dataset):
47
- """
48
- Mixed-domain HR/LR fusion dataset with grayscale simulation replicated on 3 bands.
49
-
50
- The geometry, crops, downsampling and augmentations are intentionally kept
51
- identical to the standard fusion dataset so that only the color information
52
- is removed during training/validation/test.
53
- """
54
-
55
- def __init__(
56
- self,
57
- image_dir: Path,
58
- mask_dir: Path,
59
- transform: Optional[transforms.Compose] = None,
60
- augment=None,
61
- label_remap: Optional[Dict[int, int]] = None,
62
- ):
63
- self.image_dir = Path(image_dir)
64
- self.mask_dir = Path(mask_dir)
65
- self.image_filenames = sorted(os.listdir(self.image_dir))
66
- self.mask_filenames = sorted(os.listdir(self.mask_dir))
67
- assert len(self.image_filenames) == len(self.mask_filenames), "Images/Masks count mismatch"
68
- self.transform = transform
69
- self.augment = augment
70
- self.label_remap = label_remap or {}
71
-
72
- def __len__(self):
73
- return len(self.image_filenames)
74
-
75
- def __getitem__(self, idx):
76
- image_path = self.image_dir / self.image_filenames[idx]
77
- mask_path = self.mask_dir / self.mask_filenames[idx]
78
-
79
- image = simulate_grayscale_rgb(load_image(image_path))
80
- mask = load_mask(mask_path)
81
- mask = apply_label_remap(mask, self.label_remap)
82
- mask[mask >= 15] = 255
83
-
84
- hr_crop_size = 512
85
- crop_x, crop_y = 256, 256
86
-
87
- image_hr = image[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size]
88
- mask_hr = mask[crop_x:crop_x + hr_crop_size, crop_y:crop_y + hr_crop_size]
89
-
90
- image_lr = image[::2, ::2, :]
91
- mask_lr = mask[::2, ::2]
92
-
93
- img_hr_pil = to_pil_uint8(image_hr)
94
- img_lr_pil = to_pil_uint8(image_lr)
95
-
96
- m_hr_pil = Image.fromarray(mask_hr.astype(np.uint8), mode="L")
97
- m_lr_pil = Image.fromarray(mask_lr.astype(np.uint8), mode="L")
98
-
99
- if self.augment is not None:
100
- img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil = self.augment(
101
- img_hr_pil, m_hr_pil, img_lr_pil, m_lr_pil
102
- )
103
-
104
- if self.transform:
105
- image_hr = self.transform(img_hr_pil)
106
- image_lr = self.transform(img_lr_pil)
107
- else:
108
- image_hr = to_tensor_img(np.array(img_hr_pil))
109
- image_lr = to_tensor_img(np.array(img_lr_pil))
110
-
111
- mask_hr = torch.as_tensor(np.array(m_hr_pil, dtype=np.uint8), dtype=torch.long)
112
- mask_lr = torch.as_tensor(np.array(m_lr_pil, dtype=np.uint8), dtype=torch.long)
113
-
114
- return image_hr, mask_hr, image_lr, mask_lr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/flairhub_grayscale_mixed.py DELETED
@@ -1,24 +0,0 @@
1
- from pathlib import Path
2
- from typing import Dict, Optional
3
-
4
- from torchvision import transforms
5
-
6
- from dataset.definition_dataset_grayscale_mixed import (
7
- SemanticSegmentationDatasetFusionGrayscaleMixed,
8
- )
9
-
10
-
11
- def build(
12
- image_dir: Path,
13
- mask_dir: Path,
14
- transform: Optional[transforms.Compose],
15
- augment=None,
16
- label_remap: Optional[Dict[int, int]] = None,
17
- ):
18
- return SemanticSegmentationDatasetFusionGrayscaleMixed(
19
- image_dir,
20
- mask_dir,
21
- transform=transform,
22
- augment=augment,
23
- label_remap=label_remap,
24
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/prepare_swisstlm3d_patches.py DELETED
@@ -1,233 +0,0 @@
1
- """
2
- Prepare SwissImage + SwissTLM3D patches for training.
3
-
4
- Optimized version:
5
- - processes source tiles in parallel across multiple CPU workers
6
- - reads only 2048x2048 windows with rasterio instead of loading whole 10000x10000 tiles
7
- - resizes image patches to 1024x1024 with bilinear interpolation
8
- - resizes mask patches to 1024x1024 with nearest-neighbor interpolation
9
- - optionally skips patches whose mask is entirely 255
10
- """
11
-
12
- from __future__ import annotations
13
-
14
- import argparse
15
- import os
16
- import sys
17
- from concurrent.futures import ProcessPoolExecutor, as_completed
18
- from pathlib import Path
19
-
20
- import numpy as np
21
- import rasterio
22
- from PIL import Image
23
- from tqdm import tqdm
24
-
25
-
26
- project_root = Path(__file__).resolve().parent.parent
27
- sys.path.insert(0, str(project_root))
28
-
29
- Image.MAX_IMAGE_PIXELS = None
30
-
31
-
32
- def parse_args() -> argparse.Namespace:
33
- parser = argparse.ArgumentParser(description="Prepare 2048->1024 SwissTLM3D training patches.")
34
- parser.add_argument(
35
- "--image_dir",
36
- type=Path,
37
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/image"),
38
- help="Directory containing full SwissImage tiles.",
39
- )
40
- parser.add_argument(
41
- "--mask_dir",
42
- type=Path,
43
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d"),
44
- help="Directory containing full SwissTLM3D masks aligned with image_dir.",
45
- )
46
- parser.add_argument(
47
- "--output_dir",
48
- type=Path,
49
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_patches_2048to1024"),
50
- help="Directory where prepared patches will be written.",
51
- )
52
- parser.add_argument(
53
- "--tile_size",
54
- type=int,
55
- default=2048,
56
- help="Patch size extracted from the original SwissImage tile.",
57
- )
58
- parser.add_argument(
59
- "--output_size",
60
- type=int,
61
- default=1024,
62
- help="Final patch size after resizing.",
63
- )
64
- parser.add_argument(
65
- "--ignore_index",
66
- type=int,
67
- default=255,
68
- help="Ignore label value in masks.",
69
- )
70
- parser.add_argument(
71
- "--drop_all_ignore",
72
- action="store_true",
73
- help="Skip patches whose mask is entirely equal to ignore_index.",
74
- )
75
- parser.add_argument(
76
- "--overwrite",
77
- action="store_true",
78
- help="Overwrite existing patch files if they already exist.",
79
- )
80
- parser.add_argument(
81
- "--workers",
82
- type=int,
83
- default=os.cpu_count() or 1,
84
- help="Number of worker processes. Default: all visible CPU cores.",
85
- )
86
- return parser.parse_args()
87
-
88
-
89
- def ensure_output_dirs(output_dir: Path) -> tuple[Path, Path]:
90
- img_dir = output_dir / "img"
91
- msk_dir = output_dir / "msk"
92
- img_dir.mkdir(parents=True, exist_ok=True)
93
- msk_dir.mkdir(parents=True, exist_ok=True)
94
- return img_dir, msk_dir
95
-
96
-
97
- def resize_image_patch(image_hwc_uint8: np.ndarray, output_size: int) -> np.ndarray:
98
- img = Image.fromarray(image_hwc_uint8, mode="RGB")
99
- img = img.resize((output_size, output_size), resample=Image.BILINEAR)
100
- return np.asarray(img, dtype=np.uint8)
101
-
102
-
103
- def resize_mask_patch(mask_hw: np.ndarray, output_size: int) -> np.ndarray:
104
- img = Image.fromarray(mask_hw.astype(np.uint8), mode="L")
105
- img = img.resize((output_size, output_size), resample=Image.NEAREST)
106
- return np.asarray(img, dtype=np.uint8)
107
-
108
-
109
- def iter_non_overlapping_windows(height: int, width: int, tile_size: int):
110
- for y0 in range(0, height - tile_size + 1, tile_size):
111
- for x0 in range(0, width - tile_size + 1, tile_size):
112
- yield y0, x0
113
-
114
-
115
- def patch_name(stem: str, y0: int, x0: int) -> str:
116
- return f"{stem}_y{y0:05d}_x{x0:05d}.tif"
117
-
118
-
119
- def read_image_window(src: rasterio.DatasetReader, y0: int, x0: int, tile_size: int) -> np.ndarray:
120
- window = rasterio.windows.Window(col_off=x0, row_off=y0, width=tile_size, height=tile_size)
121
- arr = src.read(indexes=[1, 2, 3], window=window) # CHW
122
- return np.transpose(arr, (1, 2, 0)).astype(np.uint8, copy=False)
123
-
124
-
125
- def read_mask_window(src: rasterio.DatasetReader, y0: int, x0: int, tile_size: int) -> np.ndarray:
126
- window = rasterio.windows.Window(col_off=x0, row_off=y0, width=tile_size, height=tile_size)
127
- return src.read(1, window=window)
128
-
129
-
130
- def process_tile(
131
- image_path: str,
132
- mask_path: str,
133
- out_img_dir: str,
134
- out_msk_dir: str,
135
- tile_size: int,
136
- output_size: int,
137
- ignore_index: int,
138
- drop_all_ignore: bool,
139
- overwrite: bool,
140
- ) -> tuple[int, int, int]:
141
- image_path = Path(image_path)
142
- mask_path = Path(mask_path)
143
- out_img_dir = Path(out_img_dir)
144
- out_msk_dir = Path(out_msk_dir)
145
-
146
- kept = 0
147
- dropped_all_ignore_count = 0
148
- total_windows = 0
149
- stem = image_path.stem
150
-
151
- with rasterio.open(image_path) as img_src, rasterio.open(mask_path) as msk_src:
152
- if (img_src.height, img_src.width) != (msk_src.height, msk_src.width):
153
- raise ValueError(
154
- f"Shape mismatch for {image_path.name}: "
155
- f"image={(img_src.height, img_src.width)} mask={(msk_src.height, msk_src.width)}"
156
- )
157
-
158
- for y0, x0 in iter_non_overlapping_windows(msk_src.height, msk_src.width, tile_size):
159
- total_windows += 1
160
- out_name = patch_name(stem, y0, x0)
161
- out_img = out_img_dir / out_name
162
- out_msk = out_msk_dir / out_name
163
-
164
- if not overwrite and out_img.exists() and out_msk.exists():
165
- kept += 1
166
- continue
167
-
168
- mask_patch = read_mask_window(msk_src, y0, x0, tile_size)
169
- if drop_all_ignore and np.all(mask_patch == ignore_index):
170
- dropped_all_ignore_count += 1
171
- continue
172
-
173
- image_patch = read_image_window(img_src, y0, x0, tile_size)
174
- image_patch_resized = resize_image_patch(image_patch, output_size)
175
- mask_patch_resized = resize_mask_patch(mask_patch, output_size)
176
-
177
- Image.fromarray(image_patch_resized, mode="RGB").save(out_img)
178
- Image.fromarray(mask_patch_resized, mode="L").save(out_msk)
179
- kept += 1
180
-
181
- return kept, dropped_all_ignore_count, total_windows
182
-
183
-
184
- def main() -> None:
185
- args = parse_args()
186
- out_img_dir, out_msk_dir = ensure_output_dirs(args.output_dir)
187
-
188
- image_paths = {p.name: p for p in sorted(args.image_dir.glob("*.tif"))}
189
- mask_paths = {p.name: p for p in sorted(args.mask_dir.glob("*.tif"))}
190
- common_names = sorted(image_paths.keys() & mask_paths.keys())
191
-
192
- if not common_names:
193
- raise RuntimeError("No matching .tif image/mask pairs found.")
194
-
195
- workers = max(1, int(args.workers))
196
- kept = 0
197
- dropped_all_ignore = 0
198
- total_windows = 0
199
-
200
- with ProcessPoolExecutor(max_workers=workers) as executor:
201
- futures = [
202
- executor.submit(
203
- process_tile,
204
- str(image_paths[name]),
205
- str(mask_paths[name]),
206
- str(out_img_dir),
207
- str(out_msk_dir),
208
- int(args.tile_size),
209
- int(args.output_size),
210
- int(args.ignore_index),
211
- bool(args.drop_all_ignore),
212
- bool(args.overwrite),
213
- )
214
- for name in common_names
215
- ]
216
-
217
- for future in tqdm(as_completed(futures), total=len(futures), desc="Tiles", unit="tile"):
218
- tile_kept, tile_dropped, tile_windows = future.result()
219
- kept += tile_kept
220
- dropped_all_ignore += tile_dropped
221
- total_windows += tile_windows
222
-
223
- print(f"matched_tiles: {len(common_names)}")
224
- print(f"workers: {workers}")
225
- print(f"total_windows: {total_windows}")
226
- print(f"kept_patches: {kept}")
227
- print(f"dropped_all_ignore: {dropped_all_ignore}")
228
- print(f"output_img_dir: {out_img_dir}")
229
- print(f"output_msk_dir: {out_msk_dir}")
230
-
231
-
232
- if __name__ == "__main__":
233
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/rasterize_swisstlm3d.py DELETED
@@ -1,403 +0,0 @@
1
- """
2
- Rasterize SwissTLM3D features onto SwissImage tiles.
3
-
4
- This script writes masks that are compatible with the 15-class FLAIR-HUB
5
- land-cover setup used in this repository. SwissTLM3D does not expose every
6
- FLAIR-HUB class with the same level of granularity, so the script only
7
- projects a reliable subset of classes and leaves the rest to ignore_index=255.
8
-
9
- The implementation intentionally relies on GDAL/OGR command-line tools
10
- (`gdal_rasterize`) plus `rasterio`, because the current `caswit` environment
11
- does not ship a full vector Python stack.
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import argparse
17
- import os
18
- import subprocess
19
- from concurrent.futures import ProcessPoolExecutor, as_completed
20
- from dataclasses import dataclass
21
- from pathlib import Path
22
- from typing import Iterable
23
-
24
- import numpy as np
25
- import rasterio
26
- from rasterio.errors import RasterioIOError
27
- from tqdm import tqdm
28
-
29
-
30
- IGNORE_INDEX = 255
31
-
32
-
33
- FLAIR_HUB_CLASSES = {
34
- 0: "building",
35
- 1: "greenhouse",
36
- 2: "swimming_pool",
37
- 3: "impervious surface",
38
- 4: "pervious surface",
39
- 5: "bare soil",
40
- 6: "water",
41
- 7: "snow",
42
- 8: "herbaceous vegetation",
43
- 9: "agricultural land",
44
- 10: "plowed land",
45
- 11: "vineyard",
46
- 12: "deciduous",
47
- 13: "coniferous",
48
- 14: "brushwood",
49
- }
50
-
51
-
52
- def sql_in(values: Iterable[str]) -> str:
53
- return "(" + ", ".join(f"'{value}'" for value in values) + ")"
54
-
55
-
56
- @dataclass(frozen=True)
57
- class BurnSpec:
58
- name: str
59
- layer: str
60
- class_id: int
61
- where: str
62
- sql_expr: str = "geom"
63
-
64
- @property
65
- def rtree(self) -> str:
66
- return f"rtree_{self.layer}_geom"
67
-
68
- def sql(self, minx: float, miny: float, maxx: float, maxy: float) -> str:
69
- spatial_filter = (
70
- f"id IN (SELECT id FROM {self.rtree} "
71
- f"WHERE maxx >= {minx} AND minx <= {maxx} "
72
- f"AND maxy >= {miny} AND miny <= {maxy})"
73
- )
74
- return (
75
- f"SELECT {self.sql_expr} AS geom "
76
- f"FROM {self.layer} "
77
- f"WHERE ({self.where}) AND {spatial_filter}"
78
- )
79
-
80
-
81
- ROAD_WIDTH_CASE = """
82
- CASE objektart
83
- WHEN '1m Weg' THEN 0.5
84
- WHEN '1m Wegfragment' THEN 0.5
85
- WHEN '2m Weg' THEN 1.0
86
- WHEN '2m Wegfragment' THEN 1.0
87
- WHEN '3m Strasse' THEN 1.5
88
- WHEN '4m Strasse' THEN 2.0
89
- WHEN '6m Strasse' THEN 3.0
90
- WHEN '8m Strasse' THEN 4.0
91
- WHEN '10m Strasse' THEN 5.0
92
- WHEN 'Autobahn' THEN 10.0
93
- WHEN 'Autostrasse' THEN 6.0
94
- WHEN 'Ausfahrt' THEN 3.5
95
- WHEN 'Einfahrt' THEN 3.0
96
- WHEN 'Zufahrt' THEN 3.0
97
- WHEN 'Dienstzufahrt' THEN 3.0
98
- WHEN 'Raststaette' THEN 8.0
99
- WHEN 'Platz' THEN 6.0
100
- WHEN 'Markierte Spur' THEN 2.0
101
- WHEN 'Verbindung' THEN 2.0
102
- ELSE 0.0
103
- END
104
- """.strip()
105
-
106
-
107
- BUILDING_OBJECTS = (
108
- "Gebaeude",
109
- "Hochhaus",
110
- "Historische Baute",
111
- "Im Bau",
112
- "Kapelle",
113
- "Offenes Gebaeude",
114
- "Sakraler Turm",
115
- "Sakrales Gebaeude",
116
- "Turm",
117
- "Unterirdisches Gebaeude",
118
- "Einhausung",
119
- )
120
-
121
- GREENHOUSE_OBJECTS = ("Treibhaus",)
122
-
123
- BARE_SOIL_OBJECTS = (
124
- "Fels",
125
- "Fels locker",
126
- "Felsbloecke",
127
- "Felsbloecke locker",
128
- "Lockergestein",
129
- "Lockergestein locker",
130
- )
131
-
132
- WATER_OBJECTS = (
133
- "Fliessgewaesser",
134
- "Stehende Gewaesser",
135
- )
136
-
137
- SNOW_OBJECTS = (
138
- "Gletscher",
139
- "Schneefeld Toteis",
140
- )
141
-
142
- BRUSHWOOD_OBJECTS = ("Gebueschwald",)
143
-
144
- AGRICULTURAL_OBJECTS = (
145
- "Baumschule",
146
- "Obstanlage",
147
- "Schrebergartenareal",
148
- )
149
-
150
- VINEYARD_OBJECTS = ("Reben",)
151
-
152
- TRAFFIC_AREA_OBJECTS = (
153
- "Flugfeldareal",
154
- "Flughafenareal",
155
- "Flugplatzareal",
156
- "Gleisareal",
157
- "Heliport",
158
- "Oeffentliches Parkplatzareal",
159
- "Privates Fahrareal",
160
- "Privates Parkplatzareal",
161
- "Rastplatzareal",
162
- "Verkehrsflaeche",
163
- )
164
-
165
- ROAD_OBJECTS = (
166
- "10m Strasse",
167
- "1m Weg",
168
- "1m Wegfragment",
169
- "2m Weg",
170
- "2m Wegfragment",
171
- "3m Strasse",
172
- "4m Strasse",
173
- "6m Strasse",
174
- "8m Strasse",
175
- "Ausfahrt",
176
- "Autobahn",
177
- "Autostrasse",
178
- "Dienstzufahrt",
179
- "Einfahrt",
180
- "Markierte Spur",
181
- "Platz",
182
- "Raststaette",
183
- "Verbindung",
184
- "Zufahrt",
185
- )
186
-
187
-
188
- BURN_SPECS = (
189
- BurnSpec(
190
- name="bare_soil",
191
- layer="tlm_bb_bodenbedeckung",
192
- class_id=5,
193
- where=f"objektart IN {sql_in(BARE_SOIL_OBJECTS)}",
194
- ),
195
- BurnSpec(
196
- name="snow",
197
- layer="tlm_bb_bodenbedeckung",
198
- class_id=7,
199
- where=f"objektart IN {sql_in(SNOW_OBJECTS)}",
200
- ),
201
- BurnSpec(
202
- name="brushwood",
203
- layer="tlm_bb_bodenbedeckung",
204
- class_id=14,
205
- where=f"objektart IN {sql_in(BRUSHWOOD_OBJECTS)}",
206
- ),
207
- BurnSpec(
208
- name="agricultural_land",
209
- layer="tlm_areale_nutzungsareal",
210
- class_id=9,
211
- where=f"objektart IN {sql_in(AGRICULTURAL_OBJECTS)}",
212
- ),
213
- BurnSpec(
214
- name="vineyard",
215
- layer="tlm_areale_nutzungsareal",
216
- class_id=11,
217
- where=f"objektart IN {sql_in(VINEYARD_OBJECTS)}",
218
- ),
219
- BurnSpec(
220
- name="impervious_polygons",
221
- layer="tlm_areale_verkehrsareal",
222
- class_id=3,
223
- where=f"objektart IN {sql_in(TRAFFIC_AREA_OBJECTS)}",
224
- ),
225
- BurnSpec(
226
- name="impervious_roads",
227
- layer="tlm_strassen_strasse",
228
- class_id=3,
229
- where=f"objektart IN {sql_in(ROAD_OBJECTS)}",
230
- sql_expr=f"ST_Buffer(geom, {ROAD_WIDTH_CASE})",
231
- ),
232
- BurnSpec(
233
- name="water",
234
- layer="tlm_bb_bodenbedeckung",
235
- class_id=6,
236
- where=f"objektart IN {sql_in(WATER_OBJECTS)}",
237
- ),
238
- BurnSpec(
239
- name="building",
240
- layer="tlm_bauten_gebaeude_footprint",
241
- class_id=0,
242
- where=f"objektart IN {sql_in(BUILDING_OBJECTS)}",
243
- ),
244
- BurnSpec(
245
- name="greenhouse",
246
- layer="tlm_bauten_gebaeude_footprint",
247
- class_id=1,
248
- where=f"objektart IN {sql_in(GREENHOUSE_OBJECTS)}",
249
- ),
250
- )
251
-
252
-
253
- def init_mask_from_image(image_path: Path, output_path: Path) -> tuple[float, float, float, float, int, int]:
254
- if output_path.exists():
255
- output_path.unlink()
256
-
257
- with rasterio.open(image_path) as src:
258
- bounds = src.bounds
259
- width = src.width
260
- height = src.height
261
- transform = src.transform
262
- crs = src.crs
263
-
264
- profile = {
265
- "driver": "GTiff",
266
- "width": width,
267
- "height": height,
268
- "count": 1,
269
- "dtype": "uint8",
270
- "crs": crs,
271
- "transform": transform,
272
- "nodata": IGNORE_INDEX,
273
- "compress": "lzw",
274
- "tiled": True,
275
- "blockxsize": 512,
276
- "blockysize": 512,
277
- }
278
- with rasterio.open(output_path, "w", **profile) as dst:
279
- dst.write(np.full((1, height, width), IGNORE_INDEX, dtype=np.uint8))
280
-
281
- return bounds.left, bounds.bottom, bounds.right, bounds.top, width, height
282
-
283
-
284
- def rasterize_spec(
285
- gpkg_path: Path,
286
- output_path: Path,
287
- spec: BurnSpec,
288
- minx: float,
289
- miny: float,
290
- maxx: float,
291
- maxy: float,
292
- width: int,
293
- height: int,
294
- ) -> None:
295
- sql = spec.sql(minx=minx, miny=miny, maxx=maxx, maxy=maxy)
296
- command = [
297
- "gdal_rasterize",
298
- "-b",
299
- "1",
300
- "-burn",
301
- str(spec.class_id),
302
- "-sql",
303
- sql,
304
- "-dialect",
305
- "SQLITE",
306
- str(gpkg_path),
307
- str(output_path),
308
- ]
309
- subprocess.run(command, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
310
-
311
-
312
- def process_one(image_path: Path, output_dir: Path, gpkg_path: Path, overwrite: bool) -> tuple[str, str]:
313
- output_path = output_dir / image_path.name
314
- if output_path.exists() and not overwrite:
315
- return image_path.name, "skipped"
316
-
317
- try:
318
- minx, miny, maxx, maxy, width, height = init_mask_from_image(image_path, output_path)
319
- for spec in BURN_SPECS:
320
- rasterize_spec(
321
- gpkg_path=gpkg_path,
322
- output_path=output_path,
323
- spec=spec,
324
- minx=minx,
325
- miny=miny,
326
- maxx=maxx,
327
- maxy=maxy,
328
- width=width,
329
- height=height,
330
- )
331
- except (subprocess.CalledProcessError, RasterioIOError) as exc:
332
- return image_path.name, f"error: {exc}"
333
-
334
- return image_path.name, "ok"
335
-
336
-
337
- def parse_args() -> argparse.Namespace:
338
- parser = argparse.ArgumentParser(description="Rasterize SwissTLM3D onto SwissImage tiles.")
339
- parser.add_argument(
340
- "--gpkg",
341
- type=Path,
342
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d/SWISSTLM3D_2026_LV95_LN02.gpkg"),
343
- help="Path to the extracted SwissTLM3D GeoPackage.",
344
- )
345
- parser.add_argument(
346
- "--image_dir",
347
- type=Path,
348
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/image"),
349
- help="Directory containing SwissImage GeoTIFF tiles.",
350
- )
351
- parser.add_argument(
352
- "--output_dir",
353
- type=Path,
354
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d"),
355
- help="Directory where rasterized masks will be written.",
356
- )
357
- parser.add_argument("--workers", type=int, default=max(1, os.cpu_count() // 2))
358
- parser.add_argument("--limit", type=int, default=None, help="Optional number of image tiles to process.")
359
- parser.add_argument("--overwrite", action="store_true", help="Overwrite existing masks.")
360
- return parser.parse_args()
361
-
362
-
363
- def main() -> None:
364
- args = parse_args()
365
- args.output_dir.mkdir(parents=True, exist_ok=True)
366
-
367
- if not args.gpkg.exists():
368
- raise FileNotFoundError(f"Missing SwissTLM3D GeoPackage: {args.gpkg}")
369
- if not args.image_dir.exists():
370
- raise FileNotFoundError(f"Missing SwissImage directory: {args.image_dir}")
371
-
372
- image_paths = sorted(args.image_dir.glob("*.tif"))
373
- if args.limit is not None:
374
- image_paths = image_paths[: args.limit]
375
-
376
- if not image_paths:
377
- raise RuntimeError(f"No .tif files found in {args.image_dir}")
378
-
379
- failures = []
380
- with ProcessPoolExecutor(max_workers=args.workers) as executor:
381
- futures = {
382
- executor.submit(process_one, image_path, args.output_dir, args.gpkg, args.overwrite): image_path
383
- for image_path in image_paths
384
- }
385
- for future in tqdm(as_completed(futures), total=len(futures), desc="Rasterizing SwissTLM3D"):
386
- image_name, status = future.result()
387
- if status != "ok" and status != "skipped":
388
- failures.append((image_name, status))
389
-
390
- if failures:
391
- print("Failures:")
392
- for image_name, status in failures[:20]:
393
- print(f"- {image_name}: {status}")
394
- raise SystemExit(1)
395
-
396
- print(f"Done. Wrote masks to {args.output_dir}")
397
- print("Projected classes:")
398
- for spec in BURN_SPECS:
399
- print(f"- {spec.class_id:2d}: {FLAIR_HUB_CLASSES[spec.class_id]} ({spec.name})")
400
-
401
-
402
- if __name__ == "__main__":
403
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/split_swisstlm3d_geographic.py DELETED
@@ -1,314 +0,0 @@
1
- """
2
- Create a deterministic random split for SwissImage + SwissTLM3D masks.
3
-
4
- Supported naming conventions:
5
- 1. Full tiles:
6
- swissimage-dop10_<year>_<x>-<y>_0.1_2056.tif
7
- 2. Prepared patches:
8
- swissimage-dop10_<year>_<x>-<y>_0.1_2056_y<row>_x<col>.tif
9
-
10
- For prepared patches, the split is performed at source-tile level `(year, x, y)`,
11
- so all patches from the same original tile stay in the same split.
12
- """
13
-
14
- from __future__ import annotations
15
-
16
- import argparse
17
- import csv
18
- import os
19
- import random
20
- import re
21
- import shutil
22
- from dataclasses import dataclass
23
- from pathlib import Path
24
- from typing import Iterable
25
-
26
-
27
- FULL_TILE_RE = re.compile(
28
- r"^swissimage-dop10_(?P<year>\d{4})_(?P<x>\d+)-(?P<y>\d+)_0\.1_2056\.tif$"
29
- )
30
- PATCH_RE = re.compile(
31
- r"^swissimage-dop10_(?P<year>\d{4})_(?P<x>\d+)-(?P<y>\d+)_0\.1_2056_y(?P<yoff>\d+)_x(?P<xoff>\d+)\.tif$"
32
- )
33
-
34
-
35
- @dataclass(frozen=True)
36
- class TileRecord:
37
- name: str
38
- year: int
39
- x: int
40
- y: int
41
- x_offset: int
42
- y_offset: int
43
- x_geo: int
44
- y_geo: int
45
- image_path: Path
46
- mask_path: Path
47
-
48
-
49
- def parse_args() -> argparse.Namespace:
50
- parser = argparse.ArgumentParser(description="Create a random 60/20/20 split for SwissTLM3D fine-tuning.")
51
- parser.add_argument(
52
- "--image_dir",
53
- type=Path,
54
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/image"),
55
- help="Directory containing SwissImage tiles.",
56
- )
57
- parser.add_argument(
58
- "--mask_dir",
59
- type=Path,
60
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d"),
61
- help="Directory containing SwissTLM3D masks with matching filenames.",
62
- )
63
- parser.add_argument(
64
- "--output_dir",
65
- type=Path,
66
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/swisstlm3d_ft"),
67
- help="Output root with train/val/test subfolders.",
68
- )
69
- parser.add_argument(
70
- "--train_ratio",
71
- type=float,
72
- default=0.6,
73
- help="Train ratio.",
74
- )
75
- parser.add_argument(
76
- "--val_ratio",
77
- type=float,
78
- default=0.2,
79
- help="Validation ratio.",
80
- )
81
- parser.add_argument(
82
- "--test_ratio",
83
- type=float,
84
- default=0.2,
85
- help="Test ratio.",
86
- )
87
- parser.add_argument(
88
- "--seed",
89
- type=int,
90
- default=42,
91
- help="Random seed used for deterministic split assignment.",
92
- )
93
- parser.add_argument(
94
- "--link_mode",
95
- choices=["symlink", "hardlink", "copy"],
96
- default="symlink",
97
- help="How to materialize split files.",
98
- )
99
- parser.add_argument(
100
- "--overwrite",
101
- action="store_true",
102
- help="Delete an existing output_dir before writing the split.",
103
- )
104
- return parser.parse_args()
105
-
106
-
107
- def parse_tile_name(path: Path) -> tuple[int, int, int, int, int]:
108
- patch_match = PATCH_RE.match(path.name)
109
- if patch_match is not None:
110
- year = int(patch_match.group("year"))
111
- x = int(patch_match.group("x"))
112
- y = int(patch_match.group("y"))
113
- y_offset = int(patch_match.group("yoff"))
114
- x_offset = int(patch_match.group("xoff"))
115
- return year, x, y, x_offset, y_offset
116
-
117
- full_match = FULL_TILE_RE.match(path.name)
118
- if full_match is not None:
119
- year = int(full_match.group("year"))
120
- x = int(full_match.group("x"))
121
- y = int(full_match.group("y"))
122
- return year, x, y, 0, 0
123
-
124
- raise ValueError(f"Unexpected SwissImage filename format: {path.name}")
125
-
126
-
127
- def collect_records(image_dir: Path, mask_dir: Path) -> list[TileRecord]:
128
- image_files = {p.name: p for p in sorted(image_dir.glob("*.tif"))}
129
- mask_files = {p.name: p for p in sorted(mask_dir.glob("*.tif"))}
130
- common_names = sorted(image_files.keys() & mask_files.keys())
131
-
132
- if not common_names:
133
- raise RuntimeError("No matching image/mask filenames found.")
134
-
135
- records = []
136
- for name in common_names:
137
- year, x, y, x_offset, y_offset = parse_tile_name(Path(name))
138
- records.append(
139
- TileRecord(
140
- name=name,
141
- year=year,
142
- x=x,
143
- y=y,
144
- x_offset=x_offset,
145
- y_offset=y_offset,
146
- x_geo=x * 10000 + x_offset,
147
- y_geo=y * 10000 + y_offset,
148
- image_path=image_files[name],
149
- mask_path=mask_files[name],
150
- )
151
- )
152
- return records
153
-
154
-
155
- def split_random_by_source_tile(
156
- records: list[TileRecord],
157
- train_ratio: float,
158
- val_ratio: float,
159
- test_ratio: float,
160
- seed: int,
161
- ) -> tuple[list[TileRecord], list[TileRecord], list[TileRecord]]:
162
- tile_to_records: dict[tuple[int, int, int], list[TileRecord]] = {}
163
- for record in records:
164
- tile_to_records.setdefault((record.year, record.x, record.y), []).append(record)
165
-
166
- groups = list(tile_to_records.items())
167
- rng = random.Random(seed)
168
- rng.shuffle(groups)
169
-
170
- total = len(groups)
171
- n_train = round(total * train_ratio)
172
- n_val = round(total * val_ratio)
173
- n_test = total - n_train - n_val
174
-
175
- if min(n_train, n_val, n_test) <= 0:
176
- raise RuntimeError("Random split produced an empty split. Adjust ratios or input size.")
177
-
178
- train_groups = groups[:n_train]
179
- val_groups = groups[n_train:n_train + n_val]
180
- test_groups = groups[n_train + n_val:]
181
-
182
- train_records = [record for _, tile_records in train_groups for record in tile_records]
183
- val_records = [record for _, tile_records in val_groups for record in tile_records]
184
- test_records = [record for _, tile_records in test_groups for record in tile_records]
185
-
186
- return train_records, val_records, test_records
187
-
188
-
189
- def ensure_empty_dir(path: Path, overwrite: bool) -> None:
190
- if path.exists():
191
- if not overwrite:
192
- raise FileExistsError(f"{path} already exists. Use --overwrite to recreate it.")
193
- shutil.rmtree(path)
194
- path.mkdir(parents=True, exist_ok=True)
195
-
196
-
197
- def materialize(src: Path, dst: Path, mode: str) -> None:
198
- if mode == "symlink":
199
- os.symlink(src, dst)
200
- return
201
- if mode == "hardlink":
202
- os.link(src, dst)
203
- return
204
- shutil.copy2(src, dst)
205
-
206
-
207
- def write_split(
208
- split_name: str,
209
- records: Iterable[TileRecord],
210
- output_dir: Path,
211
- link_mode: str,
212
- ) -> None:
213
- img_dir = output_dir / split_name / "img"
214
- msk_dir = output_dir / split_name / "msk"
215
- img_dir.mkdir(parents=True, exist_ok=True)
216
- msk_dir.mkdir(parents=True, exist_ok=True)
217
-
218
- for record in records:
219
- materialize(record.image_path, img_dir / record.name, link_mode)
220
- materialize(record.mask_path, msk_dir / record.name, link_mode)
221
-
222
-
223
- def write_manifest(
224
- output_dir: Path,
225
- mode: str,
226
- train_records: list[TileRecord],
227
- val_records: list[TileRecord],
228
- test_records: list[TileRecord],
229
- ) -> None:
230
- manifest_path = output_dir / "split_manifest.csv"
231
- with manifest_path.open("w", newline="") as f:
232
- writer = csv.writer(f)
233
- writer.writerow(["split", "name", "year", "x", "y", "x_offset", "y_offset", "x_geo", "y_geo", "image_path", "mask_path"])
234
- for split_name, records in (
235
- ("train", train_records),
236
- ("val", val_records),
237
- ("test", test_records),
238
- ):
239
- for record in records:
240
- writer.writerow(
241
- [
242
- split_name,
243
- record.name,
244
- record.year,
245
- record.x,
246
- record.y,
247
- record.x_offset,
248
- record.y_offset,
249
- record.x_geo,
250
- record.y_geo,
251
- str(record.image_path),
252
- str(record.mask_path),
253
- ]
254
- )
255
-
256
- summary_path = output_dir / "split_summary.txt"
257
- with summary_path.open("w") as f:
258
- f.write(f"mode={mode}\n")
259
- f.write(f"train={len(train_records)}\n")
260
- f.write(f"val={len(val_records)}\n")
261
- f.write(f"test={len(test_records)}\n")
262
- f.write(
263
- f"train_coords=({min(r.x_geo for r in train_records)}, {max(r.x_geo for r in train_records)}) x "
264
- f"({min(r.y_geo for r in train_records)}, {max(r.y_geo for r in train_records)})\n"
265
- )
266
- f.write(
267
- f"val_coords=({min(r.x_geo for r in val_records)}, {max(r.x_geo for r in val_records)}) x "
268
- f"({min(r.y_geo for r in val_records)}, {max(r.y_geo for r in val_records)})\n"
269
- )
270
- f.write(
271
- f"test_coords=({min(r.x_geo for r in test_records)}, {max(r.x_geo for r in test_records)}) x "
272
- f"({min(r.y_geo for r in test_records)}, {max(r.y_geo for r in test_records)})\n"
273
- )
274
-
275
-
276
- def main() -> None:
277
- args = parse_args()
278
-
279
- ratio_sum = args.train_ratio + args.val_ratio + args.test_ratio
280
- if abs(ratio_sum - 1.0) > 1e-8:
281
- raise ValueError(f"Ratios must sum to 1.0, got {ratio_sum:.6f}")
282
-
283
- if not args.image_dir.exists():
284
- raise FileNotFoundError(f"Missing image_dir: {args.image_dir}")
285
- if not args.mask_dir.exists():
286
- raise FileNotFoundError(f"Missing mask_dir: {args.mask_dir}")
287
-
288
- records = collect_records(args.image_dir, args.mask_dir)
289
- train_records, val_records, test_records = split_random_by_source_tile(
290
- records=records,
291
- train_ratio=args.train_ratio,
292
- val_ratio=args.val_ratio,
293
- test_ratio=args.test_ratio,
294
- seed=args.seed,
295
- )
296
-
297
- ensure_empty_dir(args.output_dir, args.overwrite)
298
- write_split("train", train_records, args.output_dir, args.link_mode)
299
- write_split("val", val_records, args.output_dir, args.link_mode)
300
- write_split("test", test_records, args.output_dir, args.link_mode)
301
- write_manifest(args.output_dir, "random_by_source_tile", train_records, val_records, test_records)
302
-
303
- total = len(records)
304
- print(f"Matched tiles: {total}")
305
- print(f"Split mode: random_by_source_tile")
306
- print(f"Seed: {args.seed}")
307
- print(f"Train: {len(train_records)} ({100.0 * len(train_records) / total:.2f}%)")
308
- print(f"Val: {len(val_records)} ({100.0 * len(val_records) / total:.2f}%)")
309
- print(f"Test: {len(test_records)} ({100.0 * len(test_records) / total:.2f}%)")
310
- print(f"Output: {args.output_dir}")
311
-
312
-
313
- if __name__ == "__main__":
314
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/stats_swisstlm3d_masks.py DELETED
@@ -1,136 +0,0 @@
1
- """
2
- Compute mask statistics for SwissTLM3D-generated masks with a progress bar.
3
-
4
- Outputs:
5
- - global class histogram
6
- - class proportions
7
- - count of masks entirely equal to 255
8
- - count of masks containing at least one non-255 pixel
9
- - CSV list of fully-255 masks
10
- - text summary
11
- """
12
-
13
- from __future__ import annotations
14
-
15
- import argparse
16
- import csv
17
- from collections import Counter
18
- from pathlib import Path
19
-
20
- import numpy as np
21
- import rasterio
22
- from tqdm import tqdm
23
-
24
-
25
- CLASS_NAMES = {
26
- 0: "building",
27
- 1: "greenhouse",
28
- 2: "swimming_pool",
29
- 3: "impervious_surface",
30
- 4: "pervious_surface",
31
- 5: "bare_soil",
32
- 6: "water",
33
- 7: "snow",
34
- 8: "herbaceous_vegetation",
35
- 9: "agricultural_land",
36
- 10: "plowed_land",
37
- 11: "vineyard",
38
- 12: "deciduous",
39
- 13: "coniferous",
40
- 14: "brushwood",
41
- 255: "ignore",
42
- }
43
-
44
-
45
- def parse_args() -> argparse.Namespace:
46
- parser = argparse.ArgumentParser(description="Compute class statistics for SwissTLM3D masks.")
47
- parser.add_argument(
48
- "--mask_dir",
49
- type=Path,
50
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d"),
51
- help="Directory containing rasterized SwissTLM3D masks.",
52
- )
53
- parser.add_argument(
54
- "--out_dir",
55
- type=Path,
56
- default=Path("/mnt/CalcShare/datasets/Suisse_full/data/mask_tlm3d_stats"),
57
- help="Directory where summary files will be written.",
58
- )
59
- return parser.parse_args()
60
-
61
-
62
- def main() -> None:
63
- args = parse_args()
64
- args.out_dir.mkdir(parents=True, exist_ok=True)
65
-
66
- mask_paths = sorted(args.mask_dir.glob("*.tif"))#[:100]
67
- if not mask_paths:
68
- raise RuntimeError(f"No .tif masks found in {args.mask_dir}")
69
-
70
- pixel_counts: Counter[int] = Counter()
71
- all_255_masks: list[str] = []
72
- non_255_masks = 0
73
- partial_255_masks = 0
74
-
75
- for mask_path in tqdm(mask_paths, desc="Scanning masks", unit="mask"):
76
- with rasterio.open(mask_path) as src:
77
- arr = src.read(1)
78
-
79
- values, counts = np.unique(arr, return_counts=True)
80
- local = dict(zip(values.tolist(), counts.tolist()))
81
- pixel_counts.update(local)
82
-
83
- if set(local.keys()) == {255}:
84
- all_255_masks.append(mask_path.name)
85
- else:
86
- non_255_masks += 1
87
- if 255 in local:
88
- partial_255_masks += 1
89
-
90
- total_pixels = sum(pixel_counts.values())
91
-
92
- summary_txt = args.out_dir / "summary.txt"
93
- class_csv = args.out_dir / "class_stats.csv"
94
- all_255_csv = args.out_dir / "all_255_masks.csv"
95
-
96
- with summary_txt.open("w") as f:
97
- f.write(f"mask_count={len(mask_paths)}\n")
98
- f.write(f"all_255_count={len(all_255_masks)}\n")
99
- f.write(f"non_255_count={non_255_masks}\n")
100
- f.write(f"partial_255_count={partial_255_masks}\n")
101
- f.write(f"total_pixels={total_pixels}\n")
102
-
103
- with class_csv.open("w", newline="") as f:
104
- writer = csv.writer(f)
105
- writer.writerow(["value", "class_name", "pixel_count", "proportion"])
106
- for value in sorted(pixel_counts):
107
- writer.writerow(
108
- [
109
- value,
110
- CLASS_NAMES.get(value, "unknown"),
111
- pixel_counts[value],
112
- pixel_counts[value] / total_pixels if total_pixels > 0 else 0.0,
113
- ]
114
- )
115
-
116
- with all_255_csv.open("w", newline="") as f:
117
- writer = csv.writer(f)
118
- writer.writerow(["filename"])
119
- for name in all_255_masks:
120
- writer.writerow([name])
121
-
122
- print(f"mask_count: {len(mask_paths)}")
123
- print(f"all_255_count: {len(all_255_masks)}")
124
- print(f"non_255_count: {non_255_masks}")
125
- print(f"partial_255_count: {partial_255_masks}")
126
- print("class_stats:")
127
- for value in sorted(pixel_counts):
128
- proportion = pixel_counts[value] / total_pixels if total_pixels > 0 else 0.0
129
- print(f" {value:>3} {CLASS_NAMES.get(value, 'unknown'):<24} {pixel_counts[value]} ({proportion:.8f})")
130
- print(f"wrote: {summary_txt}")
131
- print(f"wrote: {class_csv}")
132
- print(f"wrote: {all_255_csv}")
133
-
134
-
135
- if __name__ == "__main__":
136
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset/swissimage_grayscale_mixed.py DELETED
@@ -1,24 +0,0 @@
1
- from pathlib import Path
2
- from typing import Dict, Optional
3
-
4
- from torchvision import transforms
5
-
6
- from dataset.definition_dataset_grayscale_mixed import (
7
- SemanticSegmentationDatasetFusionGrayscaleMixed,
8
- )
9
-
10
-
11
- def build(
12
- image_dir: Path,
13
- mask_dir: Path,
14
- transform: Optional[transforms.Compose],
15
- augment=None,
16
- label_remap: Optional[Dict[int, int]] = None,
17
- ):
18
- return SemanticSegmentationDatasetFusionGrayscaleMixed(
19
- image_dir,
20
- mask_dir,
21
- transform=transform,
22
- augment=augment,
23
- label_remap=label_remap,
24
- )