CircleStar commited on
Commit
db426cb
Β·
1 Parent(s): 4d3ef90

Upload 7 files

Browse files
Files changed (7) hide show
  1. app.py +176 -0
  2. config.py +40 -0
  3. data.py +140 -0
  4. metrics.py +49 -0
  5. model.py +61 -0
  6. train.py +294 -0
  7. visualize.py +44 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from config import APP_TITLE, set_seed, SEED
4
+ from train import (
5
+ load_dataset_action,
6
+ update_explorer_sample,
7
+ update_compare_sample,
8
+ train_experiment,
9
+ handle_click_dataset,
10
+ handle_click_exp_a,
11
+ handle_click_exp_b,
12
+ handle_click_exp_c,
13
+ )
14
+
15
+ set_seed(SEED)
16
+
17
+ custom_css = """
18
+ #compare-a img, #compare-b img, #compare-c img, #explorer img {
19
+ image-rendering: pixelated;
20
+ }
21
+ .small-note { font-size: 0.9rem; opacity: 0.85; }
22
+ """
23
+
24
+ with gr.Blocks(title=APP_TITLE, css=custom_css) as demo:
25
+ gr.Markdown(f"# {APP_TITLE}\nInteractive teaching app for multispectral semantic segmentation.")
26
+
27
+ dataset_state = gr.State(None)
28
+ experiments_state = gr.State([])
29
+
30
+ # ── Tab 1: Image Explorer ────────────────────────────────
31
+ with gr.Tab("1) Image explorer"):
32
+ with gr.Row():
33
+ with gr.Column(scale=1):
34
+ train_size = gr.Slider(60, 2000, value=240, step=20, label="Train subset size")
35
+ val_size = gr.Slider(20, 500, value=60, step=10, label="Validation subset size")
36
+ image_size = gr.Slider(64, 256, value=128, step=32, label="Image size")
37
+ load_btn = gr.Button("Load / rebuild dataset", variant="primary")
38
+ dataset_info = gr.Markdown("### No dataset loaded yet")
39
+ gr.Markdown(
40
+ "<div class='small-note'>Uses procedural synthetic data. "
41
+ "See <code>data.py β†’ load_data()</code> to plug in a real dataset.</div>"
42
+ )
43
+
44
+ with gr.Column(scale=2, elem_id="explorer"):
45
+ explorer_sample_index = gr.Slider(0, 59, value=0, step=1, label="Validation sample index")
46
+ with gr.Row():
47
+ explorer_rgb = gr.Image(label="RGB / false-color", type="numpy", height=400)
48
+ explorer_gt = gr.Image(label="Ground truth mask", type="numpy", height=400)
49
+ explorer_overlay = gr.Image(label="Ground truth overlay",type="numpy", height=400)
50
+ explorer_click_info = gr.Markdown("### Click the RGB image to inspect a pixel")
51
+
52
+ # ── Tab 2: Model Trainer ─────────────────────────────────
53
+ with gr.Tab("2) Model trainer"):
54
+ with gr.Row():
55
+ with gr.Column(scale=1):
56
+ run_name = gr.Textbox(label="Experiment name", placeholder="e.g. lr-1e-3_ep-5")
57
+ slot_label = gr.Radio(choices=["A", "B", "C"], value="A", label="Save to slot")
58
+ learning_rate = gr.Slider(1e-4, 5e-3, value=1e-3, step=1e-4, label="Learning rate")
59
+ batch_size = gr.Slider(2, 32, value=8, step=2, label="Batch size")
60
+ epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
61
+ base_channels = gr.Slider(8, 64, value=16, step=8, label="Model width (base channels)")
62
+ train_btn = gr.Button("Train experiment", variant="primary")
63
+ with gr.Column(scale=1):
64
+ train_summary = gr.Markdown("### No training run yet")
65
+ gr.Markdown(
66
+ "<div class='small-note'>Each slot (A / B / C) stores one run independently. "
67
+ "Overwrite a slot to update it. Results appear in the <b>Result comparison</b> tab.</div>"
68
+ )
69
+
70
+ # ── Tab 3: Result Comparison ─────────────────────────────
71
+ with gr.Tab("3) Result comparison"):
72
+ compare_sample_index = gr.Slider(0, 59, value=0, step=1, label="Validation sample index")
73
+ with gr.Row():
74
+ with gr.Column(scale=1, elem_id="compare-a"):
75
+ gr.Markdown("## Slot A")
76
+ compare_a_rgb = gr.Image(label="Reference RGB", type="numpy", height=380)
77
+ compare_a_pred = gr.Image(label="Prediction mask", type="numpy", height=380)
78
+ compare_a_overlay = gr.Image(label="Prediction overlay", type="numpy", height=380)
79
+ compare_a_metrics = gr.Markdown("### No experiment")
80
+ compare_a_error = gr.Image(label="Correctness map", type="numpy", height=380)
81
+ compare_a_click = gr.Markdown("### Click overlay to inspect pixel")
82
+
83
+ with gr.Column(scale=1, elem_id="compare-b"):
84
+ gr.Markdown("## Slot B")
85
+ compare_b_rgb = gr.Image(label="Reference RGB", type="numpy", height=380)
86
+ compare_b_pred = gr.Image(label="Prediction mask", type="numpy", height=380)
87
+ compare_b_overlay = gr.Image(label="Prediction overlay", type="numpy", height=380)
88
+ compare_b_metrics = gr.Markdown("### No experiment")
89
+ compare_b_error = gr.Image(label="Correctness map", type="numpy", height=380)
90
+ compare_b_click = gr.Markdown("### Click overlay to inspect pixel")
91
+
92
+ with gr.Column(scale=1, elem_id="compare-c"):
93
+ gr.Markdown("## Slot C")
94
+ compare_c_rgb = gr.Image(label="Reference RGB", type="numpy", height=380)
95
+ compare_c_pred = gr.Image(label="Prediction mask", type="numpy", height=380)
96
+ compare_c_overlay = gr.Image(label="Prediction overlay", type="numpy", height=380)
97
+ compare_c_metrics = gr.Markdown("### No experiment")
98
+ compare_c_error = gr.Image(label="Correctness map", type="numpy", height=380)
99
+ compare_c_click = gr.Markdown("### Click overlay to inspect pixel")
100
+
101
+ # ── Shared output lists ───────────────────────────────────
102
+ _compare_outputs = [
103
+ compare_a_rgb, compare_a_pred, compare_a_overlay, compare_a_metrics, compare_a_error, compare_a_click,
104
+ compare_b_rgb, compare_b_pred, compare_b_overlay, compare_b_metrics, compare_b_error, compare_b_click,
105
+ compare_c_rgb, compare_c_pred, compare_c_overlay, compare_c_metrics, compare_c_error, compare_c_click,
106
+ ]
107
+
108
+ # ── Event bindings ────────────────────────────────────────
109
+
110
+ # Load dataset β†’ reset experiments, update explorer, reset compare slider
111
+ load_btn.click(
112
+ fn=load_dataset_action,
113
+ inputs=[train_size, val_size, image_size],
114
+ outputs=[
115
+ dataset_state,
116
+ experiments_state,
117
+ dataset_info,
118
+ explorer_rgb, explorer_gt, explorer_overlay,
119
+ explorer_click_info,
120
+ explorer_sample_index,
121
+ compare_sample_index,
122
+ ],
123
+ )
124
+
125
+ # Explorer sample slider β†’ update Tab 1 images
126
+ explorer_sample_index.change(
127
+ fn=update_explorer_sample,
128
+ inputs=[dataset_state, explorer_sample_index],
129
+ outputs=[explorer_rgb, explorer_gt, explorer_overlay, explorer_click_info],
130
+ )
131
+
132
+ # Click on explorer image β†’ pixel info
133
+ explorer_rgb.select(
134
+ fn=handle_click_dataset,
135
+ inputs=[dataset_state, explorer_sample_index],
136
+ outputs=[explorer_click_info],
137
+ )
138
+
139
+ # Train β†’ update experiments + Tab 3
140
+ train_btn.click(
141
+ fn=train_experiment,
142
+ inputs=[
143
+ dataset_state, experiments_state,
144
+ slot_label, learning_rate, batch_size, epochs, base_channels,
145
+ run_name,
146
+ ],
147
+ outputs=[experiments_state, train_summary, compare_sample_index, *_compare_outputs],
148
+ )
149
+
150
+ # Compare sample slider β†’ update Tab 3
151
+ compare_sample_index.change(
152
+ fn=update_compare_sample,
153
+ inputs=[dataset_state, experiments_state, compare_sample_index],
154
+ outputs=_compare_outputs,
155
+ )
156
+
157
+ # Click on overlay images β†’ pixel info
158
+ compare_a_overlay.select(
159
+ fn=handle_click_exp_a,
160
+ inputs=[dataset_state, experiments_state, compare_sample_index],
161
+ outputs=[compare_a_click],
162
+ )
163
+ compare_b_overlay.select(
164
+ fn=handle_click_exp_b,
165
+ inputs=[dataset_state, experiments_state, compare_sample_index],
166
+ outputs=[compare_b_click],
167
+ )
168
+ compare_c_overlay.select(
169
+ fn=handle_click_exp_c,
170
+ inputs=[dataset_state, experiments_state, compare_sample_index],
171
+ outputs=[compare_c_click],
172
+ )
173
+
174
+
175
+ if __name__ == "__main__":
176
+ demo.launch()
config.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ APP_TITLE = "Multispectral Segmentation Lab"
5
+ SEED = 42
6
+ DEFAULT_IMAGE_SIZE = 128
7
+ NUM_CHANNELS = 7
8
+ NUM_CLASSES = 8
9
+ BAND_NAMES = ["B02", "B03", "B04", "B05", "B06", "B08", "B11"]
10
+ CLASS_NAMES = [
11
+ "Forest",
12
+ "Shrubland",
13
+ "Grassland",
14
+ "Wetland",
15
+ "Cropland",
16
+ "Urban/Built-up",
17
+ "Barren",
18
+ "Water",
19
+ ]
20
+ CLASS_COLORS = np.array(
21
+ [
22
+ [34, 139, 34], # Forest
23
+ [154, 205, 50], # Shrubland
24
+ [124, 252, 0], # Grassland
25
+ [0, 128, 128], # Wetland
26
+ [255, 215, 0], # Cropland
27
+ [178, 34, 34], # Urban/Built-up
28
+ [210, 180, 140], # Barren
29
+ [30, 144, 255], # Water
30
+ ],
31
+ dtype=np.uint8,
32
+ )
33
+
34
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
+
36
+
37
+ def set_seed(seed: int = SEED):
38
+ np.random.seed(seed)
39
+ torch.manual_seed(seed)
40
+ torch.cuda.manual_seed_all(seed)
data.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Tuple, Dict, Optional
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+ from config import SEED, DEFAULT_IMAGE_SIZE, NUM_CHANNELS, NUM_CLASSES
9
+
10
+
11
+ def _draw_disk(mask: np.ndarray, center_y: int, center_x: int, radius: int, value: int):
12
+ h, w = mask.shape
13
+ yy, xx = np.ogrid[:h, :w]
14
+ mask[(yy - center_y) ** 2 + (xx - center_x) ** 2 <= radius ** 2] = value
15
+
16
+
17
+ def _draw_rect(mask: np.ndarray, y0: int, x0: int, y1: int, x1: int, value: int):
18
+ mask[max(0, y0):min(mask.shape[0], y1), max(0, x0):min(mask.shape[1], x1)] = value
19
+
20
+
21
+ def generate_synthetic_sample(size: int = DEFAULT_IMAGE_SIZE, seed: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
22
+ rng = np.random.default_rng(seed)
23
+ h = w = size
24
+
25
+ mask = np.full((h, w), 2, dtype=np.int64)
26
+
27
+ for _ in range(rng.integers(1, 3)):
28
+ cy, cx = rng.integers(h // 6, 5 * h // 6, size=2)
29
+ _draw_disk(mask, int(cy), int(cx), int(rng.integers(h // 10, h // 5)), 7)
30
+
31
+ for cls in [0, 1, 0, 1]:
32
+ cy, cx = rng.integers(h // 8, 7 * h // 8, size=2)
33
+ _draw_disk(mask, int(cy), int(cx), int(rng.integers(h // 12, h // 6)), cls)
34
+
35
+ water = mask == 7
36
+ wet = np.zeros_like(water)
37
+ for dy in [-1, 0, 1]:
38
+ for dx in [-1, 0, 1]:
39
+ wet |= np.roll(np.roll(water, dy, axis=0), dx, axis=1)
40
+ wet &= ~water
41
+ mask[wet & (rng.random((h, w)) > 0.25)] = 3
42
+
43
+ for _ in range(rng.integers(1, 3)):
44
+ y0 = int(rng.integers(0, h - h // 4))
45
+ x0 = int(rng.integers(0, w - w // 4))
46
+ hh = int(rng.integers(h // 8, h // 4))
47
+ ww = int(rng.integers(w // 8, w // 3))
48
+ _draw_rect(mask, y0, x0, y0 + hh, x0 + ww, 4)
49
+ for row in range(y0, min(h, y0 + hh), 6):
50
+ mask[row: min(h, row + 2), x0: min(w, x0 + ww)] = 2
51
+
52
+ for _ in range(rng.integers(1, 4)):
53
+ y0 = int(rng.integers(0, h - h // 5))
54
+ x0 = int(rng.integers(0, w - w // 5))
55
+ _draw_rect(mask, y0, x0, y0 + int(rng.integers(h // 10, h // 5)), x0 + int(rng.integers(w // 10, w // 5)), 5)
56
+
57
+ if rng.random() > 0.3:
58
+ road_y = int(rng.integers(h // 5, 4 * h // 5))
59
+ mask[max(0, road_y - 1):min(h, road_y + 2), :] = 5
60
+ if rng.random() > 0.5:
61
+ road_x = int(rng.integers(w // 5, 4 * w // 5))
62
+ mask[:, max(0, road_x - 1):min(w, road_x + 2)] = 5
63
+
64
+ for _ in range(rng.integers(1, 3)):
65
+ cy, cx = rng.integers(h // 8, 7 * h // 8, size=2)
66
+ _draw_disk(mask, int(cy), int(cx), int(rng.integers(h // 14, h // 8)), 6)
67
+
68
+ signatures = np.array([
69
+ [0.10, 0.14, 0.10, 0.25, 0.36, 0.60, 0.24], # Forest
70
+ [0.13, 0.18, 0.14, 0.24, 0.30, 0.47, 0.23], # Shrubland
71
+ [0.16, 0.22, 0.17, 0.26, 0.32, 0.50, 0.20], # Grassland
72
+ [0.09, 0.13, 0.11, 0.18, 0.22, 0.30, 0.10], # Wetland
73
+ [0.18, 0.24, 0.20, 0.30, 0.36, 0.52, 0.18], # Cropland
74
+ [0.24, 0.26, 0.28, 0.30, 0.31, 0.33, 0.36], # Urban
75
+ [0.28, 0.30, 0.32, 0.34, 0.35, 0.36, 0.38], # Barren
76
+ [0.05, 0.04, 0.03, 0.02, 0.02, 0.01, 0.00], # Water
77
+ ], dtype=np.float32)
78
+
79
+ img = np.zeros((NUM_CHANNELS, h, w), dtype=np.float32)
80
+ for c in range(NUM_CLASSES):
81
+ region = mask == c
82
+ for b in range(NUM_CHANNELS):
83
+ img[b][region] = signatures[c, b]
84
+
85
+ yy, xx = np.mgrid[0:h, 0:w]
86
+ grad1 = (xx / max(1, w - 1)).astype(np.float32)
87
+ grad2 = (yy / max(1, h - 1)).astype(np.float32)
88
+ for b in range(NUM_CHANNELS):
89
+ img[b] += 0.03 * np.sin((b + 1) * grad1 * math.pi)
90
+ img[b] += 0.02 * np.cos((b + 2) * grad2 * math.pi)
91
+ img[b] += rng.normal(0, 0.02, size=(h, w)).astype(np.float32)
92
+
93
+ return np.clip(img, 0.0, 1.0), mask
94
+
95
+
96
+ class MultiSpectralDataset(Dataset):
97
+ def __init__(self, images: np.ndarray, masks: np.ndarray):
98
+ self.images = images.astype(np.float32)
99
+ self.masks = masks.astype(np.int64)
100
+
101
+ def __len__(self):
102
+ return len(self.images)
103
+
104
+ def __getitem__(self, idx: int):
105
+ return torch.from_numpy(self.images[idx]), torch.from_numpy(self.masks[idx])
106
+
107
+
108
+ def build_synthetic_dataset(
109
+ train_size: int, val_size: int, image_size: int
110
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, str]:
111
+ total = train_size + val_size
112
+ images, masks = [], []
113
+ for i in range(total):
114
+ img, mask = generate_synthetic_sample(size=image_size, seed=SEED + i)
115
+ images.append(img)
116
+ masks.append(mask)
117
+ images = np.stack(images)
118
+ masks = np.stack(masks)
119
+ status = f"Synthetic data | Train: {train_size} | Val: {val_size} | Size: {image_size}Γ—{image_size}"
120
+ return images[:train_size], masks[:train_size], images[train_size:], masks[train_size:], status
121
+
122
+
123
+ def load_data(train_size: int, val_size: int, image_size: int) -> Dict[str, object]:
124
+ """
125
+ Load dataset. Currently uses procedural synthetic data.
126
+
127
+ TODO: To plug in your own real dataset, replace the call below with a
128
+ custom loader that returns numpy arrays:
129
+ - images: (N, 7, H, W) float32, values in [0, 1]
130
+ - masks: (N, H, W) int64, class indices in [0, NUM_CLASSES)
131
+ Then assign tr_x, tr_y, va_x, va_y accordingly and update `status`.
132
+ """
133
+ tr_x, tr_y, va_x, va_y, status = build_synthetic_dataset(train_size, val_size, image_size)
134
+ return {
135
+ "train_images": tr_x,
136
+ "train_masks": tr_y,
137
+ "val_images": va_x,
138
+ "val_masks": va_y,
139
+ "status": status,
140
+ }
metrics.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional
2
+ import numpy as np
3
+ from config import NUM_CLASSES, CLASS_NAMES
4
+
5
+
6
+ def compute_metrics(pred: np.ndarray, gt: np.ndarray, num_classes: int = NUM_CLASSES) -> Dict[str, object]:
7
+ pred = pred.astype(np.int64)
8
+ gt = gt.astype(np.int64)
9
+ cm = np.zeros((num_classes, num_classes), dtype=np.int64)
10
+ flat_gt = gt.reshape(-1)
11
+ flat_pred = pred.reshape(-1)
12
+ for g, p in zip(flat_gt, flat_pred):
13
+ if 0 <= g < num_classes and 0 <= p < num_classes:
14
+ cm[g, p] += 1
15
+
16
+ overall_acc = float((flat_gt == flat_pred).mean())
17
+ per_class_acc = []
18
+ per_class_iou = []
19
+ for c in range(num_classes):
20
+ tp = cm[c, c]
21
+ gt_total = cm[c, :].sum()
22
+ pred_total = cm[:, c].sum()
23
+ union = gt_total + pred_total - tp
24
+ acc = float(tp / gt_total) if gt_total > 0 else None
25
+ iou = float(tp / union) if union > 0 else None
26
+ per_class_acc.append(acc)
27
+ per_class_iou.append(iou)
28
+ miou = float(np.nanmean([x if x is not None else np.nan for x in per_class_iou]))
29
+ return {
30
+ "overall_acc": overall_acc,
31
+ "miou": miou,
32
+ "per_class_acc": per_class_acc,
33
+ "per_class_iou": per_class_iou,
34
+ "confusion_matrix": cm.tolist(),
35
+ }
36
+
37
+
38
+ def metrics_markdown(metrics: Dict[str, object], title: str = "Metrics") -> str:
39
+ lines = [f"### {title}"]
40
+ lines.append(f"- Overall accuracy: **{metrics['overall_acc'] * 100:.2f}%**")
41
+ lines.append(f"- Mean IoU: **{metrics['miou'] * 100:.2f}%**")
42
+ lines.append("")
43
+ lines.append("| Class | Accuracy | IoU |")
44
+ lines.append("|---|---:|---:|")
45
+ for name, acc, iou in zip(CLASS_NAMES, metrics["per_class_acc"], metrics["per_class_iou"]):
46
+ acc_s = "β€”" if acc is None else f"{acc * 100:.1f}%"
47
+ iou_s = "β€”" if iou is None else f"{iou * 100:.1f}%"
48
+ lines.append(f"| {name} | {acc_s} | {iou_s} |")
49
+ return "\n".join(lines)
model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from config import NUM_CHANNELS, NUM_CLASSES
4
+
5
+
6
+ class DoubleConv(nn.Module):
7
+ def __init__(self, in_ch: int, out_ch: int):
8
+ super().__init__()
9
+ self.net = nn.Sequential(
10
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
11
+ nn.BatchNorm2d(out_ch),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
14
+ nn.BatchNorm2d(out_ch),
15
+ nn.ReLU(inplace=True),
16
+ )
17
+
18
+ def forward(self, x):
19
+ return self.net(x)
20
+
21
+
22
+ class SmallUNet(nn.Module):
23
+ def __init__(self, in_channels: int = NUM_CHANNELS, num_classes: int = NUM_CLASSES, base_channels: int = 16):
24
+ super().__init__()
25
+ self.enc1 = DoubleConv(in_channels, base_channels)
26
+ self.pool1 = nn.MaxPool2d(2)
27
+ self.enc2 = DoubleConv(base_channels, base_channels * 2)
28
+ self.pool2 = nn.MaxPool2d(2)
29
+ self.enc3 = DoubleConv(base_channels * 2, base_channels * 4)
30
+ self.pool3 = nn.MaxPool2d(2)
31
+
32
+ self.bottleneck = DoubleConv(base_channels * 4, base_channels * 8)
33
+
34
+ self.up3 = nn.ConvTranspose2d(base_channels * 8, base_channels * 4, kernel_size=2, stride=2)
35
+ self.dec3 = DoubleConv(base_channels * 8, base_channels * 4)
36
+ self.up2 = nn.ConvTranspose2d(base_channels * 4, base_channels * 2, kernel_size=2, stride=2)
37
+ self.dec2 = DoubleConv(base_channels * 4, base_channels * 2)
38
+ self.up1 = nn.ConvTranspose2d(base_channels * 2, base_channels, kernel_size=2, stride=2)
39
+ self.dec1 = DoubleConv(base_channels * 2, base_channels)
40
+
41
+ self.head = nn.Conv2d(base_channels, num_classes, kernel_size=1)
42
+
43
+ def forward(self, x):
44
+ e1 = self.enc1(x)
45
+ e2 = self.enc2(self.pool1(e1))
46
+ e3 = self.enc3(self.pool2(e2))
47
+ b = self.bottleneck(self.pool3(e3))
48
+
49
+ d3 = self.up3(b)
50
+ d3 = torch.cat([d3, e3], dim=1)
51
+ d3 = self.dec3(d3)
52
+
53
+ d2 = self.up2(d3)
54
+ d2 = torch.cat([d2, e2], dim=1)
55
+ d2 = self.dec2(d2)
56
+
57
+ d1 = self.up1(d2)
58
+ d1 = torch.cat([d1, e1], dim=1)
59
+ d1 = self.dec1(d1)
60
+
61
+ return self.head(d1)
train.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Optional, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.utils.data import DataLoader
8
+ import gradio as gr
9
+ from PIL import Image
10
+
11
+ from config import DEVICE, NUM_CHANNELS, NUM_CLASSES, DEFAULT_IMAGE_SIZE, BAND_NAMES, CLASS_NAMES
12
+ from data import MultiSpectralDataset, load_data
13
+ from model import SmallUNet
14
+ from visualize import multispectral_to_rgb, mask_to_color, overlay_mask, correctness_overlay
15
+ from metrics import compute_metrics, metrics_markdown
16
+
17
+
18
+ # ── Inference ────────────────────────────────────────────────
19
+
20
+ def build_prediction_cache(
21
+ model: nn.Module, images: np.ndarray, batch_size: int = 8
22
+ ) -> Tuple[np.ndarray, np.ndarray]:
23
+ dummy_masks = np.zeros((len(images), images.shape[-2], images.shape[-1]), dtype=np.int64)
24
+ ds = MultiSpectralDataset(images, dummy_masks)
25
+ loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
26
+ preds, probs = [], []
27
+ model.eval()
28
+ with torch.no_grad():
29
+ for xb, _ in loader:
30
+ xb = xb.to(DEVICE)
31
+ pb = F.softmax(model(xb), dim=1)
32
+ preds.append(torch.argmax(pb, dim=1).cpu().numpy())
33
+ probs.append(pb.cpu().numpy())
34
+ return np.concatenate(preds, axis=0), np.concatenate(probs, axis=0)
35
+
36
+
37
+ # ── Render helpers ───────────────────────────────────────────
38
+
39
+ def _blank(size: int = DEFAULT_IMAGE_SIZE) -> Image.Image:
40
+ return Image.fromarray(np.full((size, size, 3), 245, dtype=np.uint8))
41
+
42
+
43
+ def pixel_info_markdown(
44
+ x: int, y: int,
45
+ img7: np.ndarray, gt: np.ndarray,
46
+ pred: Optional[np.ndarray], probs: Optional[np.ndarray],
47
+ ) -> str:
48
+ h, w = gt.shape
49
+ x = int(np.clip(x, 0, w - 1))
50
+ y = int(np.clip(y, 0, h - 1))
51
+ lines = [f"### Pixel ({x}, {y})", f"- Ground truth: **{CLASS_NAMES[int(gt[y, x])]}**"]
52
+ if pred is not None:
53
+ pred_class = int(pred[y, x])
54
+ lines.append(f"- Prediction: **{CLASS_NAMES[pred_class]}**")
55
+ lines.append(f"- Correct: **{'Yes' if pred_class == int(gt[y, x]) else 'No'}**")
56
+ if probs is not None:
57
+ top_ids = np.argsort(probs[:, y, x])[::-1][:3]
58
+ lines.append("- Top probabilities: " + ", ".join(
59
+ f"{CLASS_NAMES[i]} {probs[i, y, x] * 100:.1f}%" for i in top_ids
60
+ ))
61
+ else:
62
+ lines.append("- Prediction: β€”")
63
+ lines += ["", "**Band values**"] + [f"- {n}: {float(img7[b, y, x]):.3f}" for b, n in enumerate(BAND_NAMES)]
64
+ return "\n".join(lines)
65
+
66
+
67
+ def render_experiment_panel(
68
+ dataset_state: Dict, exp: Optional[Dict], sample_idx: int
69
+ ) -> Tuple:
70
+ """Returns (rgb, pred_color, overlay, metrics_md, error_map, click_md)."""
71
+ b = _blank()
72
+ no_data = (b, b, b, "### No data loaded", b, "### Click info")
73
+ if dataset_state is None or "val_images" not in dataset_state:
74
+ return no_data
75
+ val_images = dataset_state["val_images"]
76
+ val_masks = dataset_state["val_masks"]
77
+ if len(val_images) == 0:
78
+ return no_data
79
+
80
+ idx = max(0, min(int(sample_idx), len(val_images) - 1))
81
+ rgb = multispectral_to_rgb(val_images[idx])
82
+ gt = val_masks[idx]
83
+
84
+ if exp is None:
85
+ return (
86
+ rgb, mask_to_color(gt), overlay_mask(rgb, gt),
87
+ "### No experiment selected",
88
+ _blank(),
89
+ pixel_info_markdown(0, 0, val_images[idx], gt, None, None),
90
+ )
91
+
92
+ # Guard: experiment predictions might be from a different dataset
93
+ if idx >= len(exp["val_preds"]):
94
+ return (
95
+ rgb, mask_to_color(gt), overlay_mask(rgb, gt),
96
+ "### Dataset reloaded β€” retrain to refresh",
97
+ _blank(),
98
+ "### Retrain needed",
99
+ )
100
+
101
+ pred = exp["val_preds"][idx].astype(np.uint8)
102
+ probs = exp["val_probs"][idx].astype(np.float32)
103
+ sample_metrics = compute_metrics(pred, gt, num_classes=NUM_CLASSES)
104
+ return (
105
+ rgb,
106
+ mask_to_color(pred),
107
+ overlay_mask(rgb, pred),
108
+ metrics_markdown(sample_metrics, title=f"Slot {exp['slot']} β€” {exp['name']} (sample {idx})"),
109
+ correctness_overlay(rgb, pred, gt),
110
+ pixel_info_markdown(0, 0, val_images[idx], gt, pred, probs),
111
+ )
112
+
113
+
114
+ def render_compare_view(dataset_state, experiments, sample_idx: int) -> Tuple:
115
+ """Returns 18 values: 6 outputs Γ— 3 slots (A, B, C)."""
116
+ slot_map = {e["slot"]: e for e in experiments}
117
+ return (
118
+ *render_experiment_panel(dataset_state, slot_map.get("A"), sample_idx),
119
+ *render_experiment_panel(dataset_state, slot_map.get("B"), sample_idx),
120
+ *render_experiment_panel(dataset_state, slot_map.get("C"), sample_idx),
121
+ )
122
+
123
+
124
+ # ── Gradio action functions ──────────────────────────────────
125
+
126
+ def load_dataset_action(train_size: int, val_size: int, image_size: int):
127
+ """
128
+ Loads a fresh dataset and resets all experiment state.
129
+ Returns 9 values for Gradio outputs.
130
+ """
131
+ train_size, val_size, image_size = int(train_size), int(val_size), int(image_size)
132
+ dataset_state = load_data(train_size, val_size, image_size)
133
+ val_count = len(dataset_state["val_images"])
134
+
135
+ rgb = multispectral_to_rgb(dataset_state["val_images"][0])
136
+ gt = dataset_state["val_masks"][0]
137
+ dataset_info = "\n".join([
138
+ "### Dataset loaded (synthetic)",
139
+ f"- {dataset_state['status']}",
140
+ f"- Channels: **{NUM_CHANNELS}** ({', '.join(BAND_NAMES)})",
141
+ f"- Classes: **{NUM_CLASSES}** ({', '.join(CLASS_NAMES)})",
142
+ "",
143
+ "_Using procedural synthetic data. See `data.py β†’ load_data()` to plug in a real dataset._",
144
+ ])
145
+
146
+ return (
147
+ dataset_state,
148
+ [], # reset experiments_state
149
+ dataset_info,
150
+ rgb,
151
+ mask_to_color(gt),
152
+ overlay_mask(rgb, gt),
153
+ pixel_info_markdown(0, 0, dataset_state["val_images"][0], gt, None, None),
154
+ gr.update(maximum=max(0, val_count - 1), value=0), # explorer_sample_index
155
+ gr.update(maximum=max(0, val_count - 1), value=0), # compare_sample_index
156
+ )
157
+
158
+
159
+ def update_explorer_sample(dataset_state, sample_idx: int):
160
+ """Updates the Tab 1 explorer images when the sample index slider changes."""
161
+ if dataset_state is None or "val_images" not in dataset_state:
162
+ b = _blank()
163
+ return b, b, b, "### No dataset loaded"
164
+ val_images = dataset_state["val_images"]
165
+ val_masks = dataset_state["val_masks"]
166
+ idx = max(0, min(int(sample_idx), len(val_images) - 1))
167
+ rgb = multispectral_to_rgb(val_images[idx])
168
+ gt = val_masks[idx]
169
+ return (
170
+ rgb,
171
+ mask_to_color(gt),
172
+ overlay_mask(rgb, gt),
173
+ pixel_info_markdown(0, 0, val_images[idx], gt, None, None),
174
+ )
175
+
176
+
177
+ def update_compare_sample(dataset_state, experiments, sample_idx: int):
178
+ """Updates Tab 3 when the compare sample index slider changes."""
179
+ if dataset_state is None or "val_images" not in dataset_state:
180
+ raise gr.Error("Load a dataset first.")
181
+ return render_compare_view(dataset_state, experiments, int(sample_idx))
182
+
183
+
184
+ def train_experiment(
185
+ dataset_state: Dict,
186
+ experiments: List[Dict],
187
+ slot_label: str,
188
+ learning_rate: float,
189
+ batch_size: int,
190
+ epochs: int,
191
+ base_channels: int,
192
+ run_name: str,
193
+ ):
194
+ """
195
+ Trains a SmallUNet and stores results in the given slot.
196
+ Returns 21 values: experiments, summary, compare_sample_index update, + 18 compare outputs.
197
+ """
198
+ if dataset_state is None or "train_images" not in dataset_state:
199
+ raise gr.Error("Load a dataset first.")
200
+
201
+ train_images = dataset_state["train_images"]
202
+ train_masks = dataset_state["train_masks"]
203
+ val_images = dataset_state["val_images"]
204
+ val_masks = dataset_state["val_masks"]
205
+
206
+ loader = DataLoader(
207
+ MultiSpectralDataset(train_images, train_masks),
208
+ batch_size=int(batch_size), shuffle=True,
209
+ )
210
+ model = SmallUNet(NUM_CHANNELS, NUM_CLASSES, int(base_channels)).to(DEVICE)
211
+ optimizer = torch.optim.Adam(model.parameters(), lr=float(learning_rate))
212
+ criterion = nn.CrossEntropyLoss()
213
+
214
+ history = []
215
+ for _ in range(int(epochs)):
216
+ model.train()
217
+ total_loss, n = 0.0, 0
218
+ for xb, yb in loader:
219
+ xb, yb = xb.to(DEVICE), yb.to(DEVICE)
220
+ optimizer.zero_grad(set_to_none=True)
221
+ loss = criterion(model(xb), yb)
222
+ loss.backward()
223
+ optimizer.step()
224
+ total_loss += float(loss.item())
225
+ n += 1
226
+ history.append(total_loss / max(1, n))
227
+
228
+ val_preds, val_probs = build_prediction_cache(model, val_images, batch_size=max(1, int(batch_size)))
229
+ global_metrics = compute_metrics(val_preds.reshape(-1), val_masks.reshape(-1), num_classes=NUM_CLASSES)
230
+
231
+ experiment = {
232
+ "name": (run_name or f"Run {len(experiments) + 1}").strip(),
233
+ "slot": slot_label,
234
+ "config": {
235
+ "learning_rate": float(learning_rate),
236
+ "batch_size": int(batch_size),
237
+ "epochs": int(epochs),
238
+ "base_channels": int(base_channels),
239
+ },
240
+ "train_loss_history": history,
241
+ "global_metrics": global_metrics,
242
+ "val_preds": val_preds.astype(np.uint8),
243
+ "val_probs": val_probs.astype(np.float32),
244
+ }
245
+
246
+ slot_map = {e["slot"]: e for e in experiments}
247
+ slot_map[slot_label] = experiment
248
+ experiments = [slot_map[s] for s in ["A", "B", "C"] if s in slot_map]
249
+
250
+ summary = "\n".join([
251
+ f"### Training finished β€” Slot {slot_label}",
252
+ f"- Experiment: **{experiment['name']}**",
253
+ f"- Device: **{DEVICE}** | Epochs: **{int(epochs)}**",
254
+ f"- Final loss: **{history[-1]:.4f}**",
255
+ f"- Val accuracy: **{global_metrics['overall_acc'] * 100:.2f}%**",
256
+ f"- Val mIoU: **{global_metrics['miou'] * 100:.2f}%**",
257
+ ])
258
+
259
+ compare_slider = gr.update(maximum=max(0, len(val_images) - 1), value=0)
260
+ compare_outputs = render_compare_view(dataset_state, experiments, 0)
261
+ return experiments, summary, compare_slider, *compare_outputs
262
+
263
+
264
+ # ── Click handlers ───────────────────────────────────────────
265
+
266
+ def handle_click_dataset(evt: gr.SelectData, dataset_state, sample_idx: int):
267
+ if dataset_state is None or "val_images" not in dataset_state:
268
+ return "### No dataset"
269
+ idx = max(0, min(int(sample_idx), len(dataset_state["val_images"]) - 1))
270
+ x, y = evt.index
271
+ return pixel_info_markdown(int(x), int(y), dataset_state["val_images"][idx], dataset_state["val_masks"][idx], None, None)
272
+
273
+
274
+ def _handle_click_experiment(evt: gr.SelectData, dataset_state, experiments, slot: str, sample_idx: int):
275
+ if dataset_state is None or "val_images" not in dataset_state:
276
+ return "### No dataset"
277
+ idx = max(0, min(int(sample_idx), len(dataset_state["val_images"]) - 1))
278
+ exp = next((e for e in experiments if e["slot"] == slot), None)
279
+ x, y = evt.index
280
+ img7 = dataset_state["val_images"][idx]
281
+ gt = dataset_state["val_masks"][idx]
282
+ if exp is None or idx >= len(exp["val_preds"]):
283
+ return pixel_info_markdown(int(x), int(y), img7, gt, None, None)
284
+ return pixel_info_markdown(int(x), int(y), img7, gt, exp["val_preds"][idx], exp["val_probs"][idx])
285
+
286
+
287
+ def handle_click_exp_a(evt, dataset_state, experiments, sample_idx):
288
+ return _handle_click_experiment(evt, dataset_state, experiments, "A", sample_idx)
289
+
290
+ def handle_click_exp_b(evt, dataset_state, experiments, sample_idx):
291
+ return _handle_click_experiment(evt, dataset_state, experiments, "B", sample_idx)
292
+
293
+ def handle_click_exp_c(evt, dataset_state, experiments, sample_idx):
294
+ return _handle_click_experiment(evt, dataset_state, experiments, "C", sample_idx)
visualize.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from config import CLASS_COLORS
3
+
4
+
5
+ def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> np.ndarray:
6
+ x = x.astype(np.float32)
7
+ lo = np.percentile(x, low)
8
+ hi = np.percentile(x, high)
9
+ if hi <= lo:
10
+ hi = lo + 1e-6
11
+ return np.clip((x - lo) / (hi - lo), 0, 1)
12
+
13
+
14
+ def multispectral_to_rgb(img7: np.ndarray) -> np.ndarray:
15
+ """img7 shape: (7, H, W) β€” uses B04, B03, B02 for natural RGB view."""
16
+ r = percentile_stretch(img7[2])
17
+ g = percentile_stretch(img7[1])
18
+ b = percentile_stretch(img7[0])
19
+ rgb = np.stack([r, g, b], axis=-1)
20
+ return (rgb * 255).astype(np.uint8)
21
+
22
+
23
+ def mask_to_color(mask: np.ndarray) -> np.ndarray:
24
+ return CLASS_COLORS[mask]
25
+
26
+
27
+ def overlay_mask(rgb: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.ndarray:
28
+ color_mask = mask_to_color(mask)
29
+ out = ((1 - alpha) * rgb.astype(np.float32) + alpha * color_mask.astype(np.float32)).clip(0, 255)
30
+ return out.astype(np.uint8)
31
+
32
+
33
+ def correctness_map(pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
34
+ correct = pred == gt
35
+ out = np.zeros((pred.shape[0], pred.shape[1], 3), dtype=np.uint8)
36
+ out[correct] = np.array([0, 220, 0], dtype=np.uint8)
37
+ out[~correct] = np.array([220, 0, 0], dtype=np.uint8)
38
+ return out
39
+
40
+
41
+ def correctness_overlay(rgb: np.ndarray, pred: np.ndarray, gt: np.ndarray, alpha: float = 0.38) -> np.ndarray:
42
+ cm = correctness_map(pred, gt)
43
+ out = ((1 - alpha) * rgb.astype(np.float32) + alpha * cm.astype(np.float32)).clip(0, 255)
44
+ return out.astype(np.uint8)