functionNormally commited on
Commit
089078d
Β·
1 Parent(s): 809b793

Redesign: five-step pedagogical flow with spectral baseline

Browse files

Step 1 - Know Your Data: full-scene composite viewer (natural/false/SWIR/single-band)
with labelled-pixel markers overlaid, clickable pixel inspection.

Step 2 - Spectral Signatures: per-class meanΒ±sigma chart and NDVI/NDWI heatmaps
computed from raw (unnormalised) bands for correct ratio semantics.

Step 3 - Spectral Baseline: pure-numpy chunked KNN classifier that predicts the
full 501x1001 scene from 7-band spectral features alone (no spatial context).

Step 4 - Deep Learning: UNet training with side-by-side ground-truth / KNN /
UNet patch comparison to make the spatial-context benefit concrete.

Step 5 - Experiment Lab: compare up to 3 UNet experiments; max capped at 3.

Also fixes natural-colour composite to use H4/H3/H2 (Red/Green/Blue) instead of
the previous H3/H2/H1, and adds matplotlib + scipy to requirements.

Files changed (7) hide show
  1. app.py +220 -101
  2. baseline.py +79 -0
  3. config.py +21 -4
  4. data.py +84 -39
  5. requirements.txt +2 -0
  6. train.py +266 -174
  7. visualize.py +168 -8
app.py CHANGED
@@ -1,151 +1,270 @@
1
  import gradio as gr
2
 
3
- from config import APP_TITLE, set_seed, SEED, DEFAULT_PATCH_SIZE
 
 
 
4
  from train import (
5
  load_dataset_action,
6
- update_explorer_sample,
7
- update_compare_sample,
 
 
 
8
  train_experiment,
9
- handle_click_dataset,
10
- handle_click_exp_a,
11
- handle_click_exp_b,
12
  )
13
 
14
  set_seed(SEED)
15
 
 
 
16
  custom_css = """
17
- #compare-a img, #compare-b img, #explorer img {
18
- image-rendering: pixelated;
19
- }
20
- .small-note { font-size: 0.9rem; opacity: 0.85; }
21
  """
22
 
23
  with gr.Blocks(title=APP_TITLE, css=custom_css) as demo:
24
- gr.Markdown(f"# {APP_TITLE}\nInteractive teaching app for multispectral semantic segmentation.")
 
 
 
 
25
 
26
  dataset_state = gr.State(None)
 
27
  experiments_state = gr.State([])
28
 
29
- # ── Tab 1: Image Explorer ────────────────────────────────
30
- with gr.Tab("1) Image explorer"):
 
 
 
 
 
 
 
31
  with gr.Row():
32
  with gr.Column(scale=1):
33
- patch_size = gr.Slider(64, 512, value=DEFAULT_PATCH_SIZE, step=32, label="Patch size")
34
  load_btn = gr.Button("Load dataset", variant="primary")
35
- dataset_info = gr.Markdown("### No dataset loaded yet")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  gr.Markdown(
37
- "<div class='small-note'>Downloads the satellite dataset from HuggingFace "
38
- "and extracts image patches for training and validation.</div>"
 
 
 
 
39
  )
40
- with gr.Column(scale=2, elem_id="explorer"):
41
- explorer_sample_index = gr.Slider(0, 59, value=0, step=1, label="Validation patch index")
42
- with gr.Row():
43
- explorer_rgb = gr.Image(label="RGB composite", type="numpy", height=400)
44
- explorer_gt = gr.Image(label="Ground truth mask", type="numpy", height=400)
45
- explorer_overlay = gr.Image(label="Ground truth overlay", type="numpy", height=400)
46
- explorer_click_info = gr.Markdown("### Click the RGB image to inspect a pixel")
47
 
48
- # ── Tab 2: Model Trainer ─────────────────────────────────
49
- with gr.Tab("2) Model trainer"):
 
 
 
 
 
 
 
 
 
 
 
 
50
  with gr.Row():
51
  with gr.Column(scale=1):
52
- run_name = gr.Textbox(label="Experiment name", placeholder="e.g. lr-1e-3_ep-5")
53
- learning_rate = gr.Slider(1e-4, 5e-3, value=1e-3, step=1e-4, label="Learning rate")
54
- batch_size = gr.Slider(2, 32, value=8, step=2, label="Batch size")
55
- epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
56
- base_channels = gr.Slider(8, 64, value=16, step=8, label="Model width (base channels)")
57
- train_btn = gr.Button("Train experiment", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  with gr.Column(scale=1):
59
- train_summary = gr.Markdown("### No training run yet")
60
- model_table = gr.Markdown("### No models trained yet")
 
 
 
 
61
  gr.Markdown(
62
- "<div class='small-note'>Accuracy and mIoU are computed on labeled pixels only "
63
- "(unlabeled pixels are ignored during training and evaluation).</div>"
 
64
  )
 
65
 
66
- # ── Tab 3: Result Comparison ─────────────────────────────
67
- with gr.Tab("3) Result comparison"):
68
- compare_sample_index = gr.Slider(0, 59, value=0, step=1, label="Validation patch index")
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  with gr.Row():
70
- compare_sel_a = gr.Dropdown(choices=[], value=None, label="Left model", interactive=True)
71
- compare_sel_b = gr.Dropdown(choices=[], value=None, label="Right model", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
72
  with gr.Row():
73
- with gr.Column(scale=1, elem_id="compare-a"):
74
  gr.Markdown("## Left")
75
- compare_a_rgb = gr.Image(label="RGB β€” click to inspect pixel", type="numpy", height=280)
76
- compare_a_click = gr.Markdown("### Click the RGB or overlay image to inspect a pixel")
77
- compare_a_pred = gr.Image(label="Prediction mask", type="numpy", height=280)
78
- compare_a_overlay = gr.Image(label="Prediction overlay β€” click to inspect pixel", type="numpy", height=280)
79
- compare_a_metrics = gr.Markdown("### No model selected")
80
- compare_a_error = gr.Image(label="Correctness map", type="numpy", height=280)
81
-
82
- with gr.Column(scale=1, elem_id="compare-b"):
83
  gr.Markdown("## Right")
84
- compare_b_rgb = gr.Image(label="RGB β€” click to inspect pixel", type="numpy", height=280)
85
- compare_b_click = gr.Markdown("### Click the RGB or overlay image to inspect a pixel")
86
- compare_b_pred = gr.Image(label="Prediction mask", type="numpy", height=280)
87
- compare_b_overlay = gr.Image(label="Prediction overlay β€” click to inspect pixel", type="numpy", height=280)
88
- compare_b_metrics = gr.Markdown("### No model selected")
89
- compare_b_error = gr.Image(label="Correctness map", type="numpy", height=280)
90
-
91
- # ── Shared lists ─────────────────────────────────────────
92
- # Order matches render_experiment_panel: (rgb, pred, overlay, metrics, error, click)
93
- _compare_outputs = [
94
- compare_a_rgb, compare_a_pred, compare_a_overlay, compare_a_metrics, compare_a_error, compare_a_click,
95
- compare_b_rgb, compare_b_pred, compare_b_overlay, compare_b_metrics, compare_b_error, compare_b_click,
 
 
 
 
 
 
 
 
 
96
  ]
97
- _compare_inputs = [dataset_state, experiments_state, compare_sel_a, compare_sel_b, compare_sample_index]
98
-
99
- # ── Event bindings ────────────────────────────────────────
100
-
101
- load_btn.click(
102
- fn=load_dataset_action,
103
- inputs=[patch_size],
104
- outputs=[
105
- dataset_state, experiments_state,
106
- dataset_info,
107
- explorer_rgb, explorer_gt, explorer_overlay,
108
- explorer_click_info,
109
- explorer_sample_index,
110
- compare_sample_index,
111
- compare_sel_a, compare_sel_b,
112
- ],
113
  )
114
 
115
- explorer_sample_index.change(
116
- fn=update_explorer_sample,
117
- inputs=[dataset_state, explorer_sample_index],
118
- outputs=[explorer_rgb, explorer_gt, explorer_overlay, explorer_click_info],
119
  )
120
 
121
- explorer_rgb.select(
122
- fn=handle_click_dataset,
123
- inputs=[dataset_state, explorer_sample_index],
124
- outputs=[explorer_click_info],
125
  )
126
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  train_btn.click(
128
  fn=train_experiment,
129
- inputs=[dataset_state, experiments_state, learning_rate, batch_size, epochs, base_channels, run_name],
130
- outputs=[experiments_state, train_summary, model_table, compare_sel_a, compare_sel_b],
 
 
 
131
  )
132
 
133
- for sel in [compare_sel_a, compare_sel_b]:
134
- sel.change(fn=update_compare_sample, inputs=_compare_inputs, outputs=_compare_outputs)
 
 
 
135
 
136
- compare_sample_index.change(fn=update_compare_sample, inputs=_compare_inputs, outputs=_compare_outputs)
 
 
 
 
137
 
138
- for img in [compare_a_rgb, compare_a_overlay]:
139
- img.select(
140
- fn=handle_click_exp_a,
141
- inputs=[dataset_state, experiments_state, compare_sel_a, compare_sample_index],
142
- outputs=[compare_a_click],
143
- )
144
- for img in [compare_b_rgb, compare_b_overlay]:
145
  img.select(
146
- fn=handle_click_exp_b,
147
- inputs=[dataset_state, experiments_state, compare_sel_b, compare_sample_index],
148
- outputs=[compare_b_click],
149
  )
150
 
151
 
 
1
  import gradio as gr
2
 
3
+ from config import (
4
+ APP_TITLE, set_seed, SEED, DEFAULT_PATCH_SIZE,
5
+ COMPOSITE_PRESETS, BAND_DESCRIPTIONS, MAX_EXPERIMENTS,
6
+ )
7
  from train import (
8
  load_dataset_action,
9
+ update_step1_composite,
10
+ handle_click_step1,
11
+ update_step2_index,
12
+ run_baseline_action,
13
+ update_step4_patch,
14
  train_experiment,
15
+ update_step5_comparison,
16
+ handle_click_step5,
 
17
  )
18
 
19
  set_seed(SEED)
20
 
21
+ _COMPOSITE_CHOICES = list(COMPOSITE_PRESETS.keys()) + BAND_DESCRIPTIONS
22
+
23
  custom_css = """
24
+ #step1-img img, #step3-pred img, #step3-correct img { image-rendering: pixelated; }
25
+ .step-header { font-size: 1.05rem; font-weight: 600; margin-bottom: 4px; }
26
+ .hint { font-size: 0.88rem; color: #666; }
 
27
  """
28
 
29
  with gr.Blocks(title=APP_TITLE, css=custom_css) as demo:
30
+ gr.Markdown(f"# {APP_TITLE}")
31
+ gr.Markdown(
32
+ "A five-step journey from raw satellite pixels to deep-learning segmentation. "
33
+ "Work through the tabs in order β€” each step builds on the previous one."
34
+ )
35
 
36
  dataset_state = gr.State(None)
37
+ baseline_state = gr.State(None)
38
  experiments_state = gr.State([])
39
 
40
+ # ────────────────────────────────────────────────────────
41
+ # Step 1 β€” Know Your Data
42
+ # ────────────────────────────────────────────────────────
43
+ with gr.Tab("Step 1 Β· Know Your Data"):
44
+ gr.Markdown(
45
+ "**Start here.** Load the dataset, then explore the 7 spectral bands. "
46
+ "Squares on the image are training labels; circles are validation labels.",
47
+ elem_classes="hint",
48
+ )
49
  with gr.Row():
50
  with gr.Column(scale=1):
51
+ patch_size = gr.Slider(64, 512, value=DEFAULT_PATCH_SIZE, step=32, label="Patch size (for training)")
52
  load_btn = gr.Button("Load dataset", variant="primary")
53
+ composite_dd = gr.Dropdown(
54
+ choices=_COMPOSITE_CHOICES,
55
+ value="Natural Color (R/G/B)",
56
+ label="View mode",
57
+ )
58
+ step1_info = gr.Markdown("*Load the dataset to begin.*")
59
+
60
+ with gr.Column(scale=3, elem_id="step1-img"):
61
+ step1_image = gr.Image(label="Full scene β€” click to inspect a pixel", type="numpy")
62
+ step1_click = gr.Markdown("*Click anywhere on the image.*")
63
+
64
+ # ────────────────────────────────────────────────────────
65
+ # Step 2 β€” Spectral Signatures
66
+ # ────────────────────────────────────────────────────────
67
+ with gr.Tab("Step 2 Β· Spectral Signatures"):
68
+ gr.Markdown(
69
+ "Each land cover type has a characteristic pattern of brightness across the 7 bands β€” "
70
+ "its **spectral signature**. Notice how H_5 (NIR) separates vegetation from water. "
71
+ "NDVI and NDWI are hand-crafted indices that exploit this difference.",
72
+ elem_classes="hint",
73
+ )
74
+ with gr.Row():
75
+ with gr.Column(scale=1):
76
+ index_radio = gr.Radio(
77
+ choices=["NDVI", "NDWI"],
78
+ value="NDVI",
79
+ label="Spectral index map",
80
+ )
81
  gr.Markdown(
82
+ "**NDVI** = (NIR βˆ’ Red) / (NIR + Red) \n"
83
+ "High values β†’ dense vegetation (Forest, Agriculture) \n"
84
+ "Low / negative β†’ water, urban, bare soil\n\n"
85
+ "**NDWI** = (Green βˆ’ NIR) / (Green + NIR) \n"
86
+ "Positive β†’ water; negative β†’ land",
87
+ elem_classes="hint",
88
  )
 
 
 
 
 
 
 
89
 
90
+ with gr.Column(scale=3):
91
+ step2_sig_chart = gr.Image(label="Spectral signatures (training labels)", type="numpy")
92
+ step2_index_map = gr.Image(label="Index map", type="numpy")
93
+
94
+ # ────────────────────────────────────────────────────────
95
+ # Step 3 β€” Spectral Baseline (KNN)
96
+ # ────────────────────────────────────────────────────────
97
+ with gr.Tab("Step 3 Β· Spectral Baseline"):
98
+ gr.Markdown(
99
+ "**No convolutions here.** We classify every pixel using only its 7 band values, "
100
+ "finding the k nearest training pixels in spectral space. "
101
+ "This shows you what's achievable without any spatial context.",
102
+ elem_classes="hint",
103
+ )
104
  with gr.Row():
105
  with gr.Column(scale=1):
106
+ k_slider = gr.Slider(1, 5, value=3, step=2, label="k (number of neighbours)")
107
+ baseline_btn = gr.Button("Run KNN baseline", variant="primary")
108
+ step3_metrics = gr.Markdown("*Run the baseline to see results.*")
109
+
110
+ with gr.Column(scale=3, elem_id="step3-pred"):
111
+ step3_full_pred = gr.Image(
112
+ label="Full-scene prediction Β· overlaid on natural-colour image Β· coloured dots = val labels (green=correct, red=wrong)",
113
+ type="numpy",
114
+ )
115
+
116
+ # ────────────────────────────────────────────────────────
117
+ # Step 4 β€” Deep Learning (UNet)
118
+ # ────────────────────────────────────────────────────────
119
+ with gr.Tab("Step 4 Β· Deep Learning"):
120
+ gr.Markdown(
121
+ "A **U-Net** sees a patch of pixels at once, not just one pixel. "
122
+ "Its encoder captures local texture; skip connections preserve spatial detail. "
123
+ "Train a model and compare it patch-by-patch against the KNN baseline.",
124
+ elem_classes="hint",
125
+ )
126
+ with gr.Row():
127
  with gr.Column(scale=1):
128
+ run_name = gr.Textbox(label="Experiment name", placeholder="e.g. lr-1e-3_ch-16")
129
+ learning_rate = gr.Slider(1e-4, 5e-3, value=1e-3, step=1e-4, label="Learning rate")
130
+ batch_size = gr.Slider(2, 32, value=8, step=2, label="Batch size")
131
+ epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
132
+ base_channels = gr.Slider(8, 64, value=16, step=8, label="Model width (base channels)")
133
+ train_btn = gr.Button("Train model", variant="primary")
134
  gr.Markdown(
135
+ f"*Max {MAX_EXPERIMENTS} experiments total. "
136
+ "Reload data to reset.*",
137
+ elem_classes="hint",
138
  )
139
+ step4_summary = gr.Markdown("*Train a model to see results.*")
140
 
141
+ with gr.Column(scale=3):
142
+ step4_patch_slider = gr.Slider(0, 59, value=0, step=1, label="Validation patch index")
143
+ with gr.Row():
144
+ step4_gt_img = gr.Image(label="Ground truth overlay", type="numpy", height=280)
145
+ step4_bl_img = gr.Image(label="KNN Baseline prediction", type="numpy", height=280)
146
+ step4_un_img = gr.Image(label="UNet prediction", type="numpy", height=280)
147
+
148
+ # ────────────────────────────────────────────────────────
149
+ # Step 5 β€” Experiment Lab
150
+ # ────────────────────────────────────────────────────────
151
+ with gr.Tab("Step 5 Β· Experiment Lab"):
152
+ gr.Markdown(
153
+ f"Compare up to **{MAX_EXPERIMENTS}** UNet experiments side by side. "
154
+ "Try different learning rates, epochs, or model widths and see what changes.",
155
+ elem_classes="hint",
156
+ )
157
  with gr.Row():
158
+ step5_sel_a = gr.Dropdown(choices=[], value=None, label="Left model", interactive=True)
159
+ step5_sel_b = gr.Dropdown(choices=[], value=None, label="Right model", interactive=True)
160
+ step5_patch_slider = gr.Slider(0, 59, value=0, step=1, label="Validation patch index")
161
+ step5_table = gr.Markdown("*No experiments yet.*")
162
+
163
+ gr.Markdown(
164
+ "**Guiding questions**\n"
165
+ "- Double the epochs β€” does mIoU keep improving or plateau?\n"
166
+ "- Halve the learning rate β€” does training become more stable?\n"
167
+ "- Increase base channels from 16 to 32 β€” worth the extra time?",
168
+ elem_classes="hint",
169
+ )
170
+
171
  with gr.Row():
172
+ with gr.Column(scale=1):
173
  gr.Markdown("## Left")
174
+ s5_a_rgb = gr.Image(label="RGB", type="numpy", height=240)
175
+ s5_a_pred = gr.Image(label="Prediction", type="numpy", height=240)
176
+ s5_a_overlay = gr.Image(label="Overlay", type="numpy", height=240)
177
+ s5_a_metrics = gr.Markdown("*No model selected.*")
178
+ s5_a_error = gr.Image(label="Correctness map",type="numpy", height=240)
179
+
180
+ with gr.Column(scale=1):
 
181
  gr.Markdown("## Right")
182
+ s5_b_rgb = gr.Image(label="RGB", type="numpy", height=240)
183
+ s5_b_pred = gr.Image(label="Prediction", type="numpy", height=240)
184
+ s5_b_overlay = gr.Image(label="Overlay", type="numpy", height=240)
185
+ s5_b_metrics = gr.Markdown("*No model selected.*")
186
+ s5_b_error = gr.Image(label="Correctness map",type="numpy", height=240)
187
+
188
+ # ── Event wiring ─────────────────────────────────────────
189
+
190
+ _load_outputs = [
191
+ dataset_state, baseline_state, experiments_state,
192
+ # Tab 1
193
+ step1_info, step1_image, step1_click,
194
+ # Tab 2
195
+ step2_sig_chart, step2_index_map,
196
+ # Tab 3
197
+ step3_metrics, step3_full_pred,
198
+ # Tab 4
199
+ step4_summary, step4_patch_slider,
200
+ step4_gt_img, step4_bl_img, step4_un_img,
201
+ # Tab 5
202
+ step5_table, step5_sel_a, step5_sel_b,
203
  ]
204
+
205
+ load_btn.click(fn=load_dataset_action, inputs=[patch_size], outputs=_load_outputs)
206
+
207
+ composite_dd.change(
208
+ fn=update_step1_composite,
209
+ inputs=[dataset_state, composite_dd],
210
+ outputs=[step1_image, step1_click],
 
 
 
 
 
 
 
 
 
211
  )
212
 
213
+ step1_image.select(
214
+ fn=handle_click_step1,
215
+ inputs=[dataset_state],
216
+ outputs=[step1_click],
217
  )
218
 
219
+ index_radio.change(
220
+ fn=update_step2_index,
221
+ inputs=[dataset_state, index_radio],
222
+ outputs=[step2_index_map],
223
  )
224
 
225
+ baseline_btn.click(
226
+ fn=run_baseline_action,
227
+ inputs=[dataset_state, k_slider],
228
+ outputs=[baseline_state, step3_metrics, step3_full_pred],
229
+ )
230
+
231
+ _train_outputs = [
232
+ experiments_state,
233
+ step4_summary, step4_patch_slider,
234
+ step4_gt_img, step4_bl_img, step4_un_img,
235
+ step5_table, step5_sel_a, step5_sel_b,
236
+ ]
237
+
238
  train_btn.click(
239
  fn=train_experiment,
240
+ inputs=[
241
+ dataset_state, baseline_state, experiments_state,
242
+ learning_rate, batch_size, epochs, base_channels, run_name,
243
+ ],
244
+ outputs=_train_outputs,
245
  )
246
 
247
+ step4_patch_slider.change(
248
+ fn=update_step4_patch,
249
+ inputs=[dataset_state, baseline_state, experiments_state, step4_patch_slider],
250
+ outputs=[step4_gt_img, step4_bl_img, step4_un_img],
251
+ )
252
 
253
+ _s5_inputs = [dataset_state, experiments_state, step5_sel_a, step5_sel_b, step5_patch_slider]
254
+ _s5_outputs = [
255
+ s5_a_rgb, s5_a_pred, s5_a_overlay, s5_a_metrics, s5_a_error,
256
+ s5_b_rgb, s5_b_pred, s5_b_overlay, s5_b_metrics, s5_b_error,
257
+ ]
258
 
259
+ for trigger in [step5_sel_a, step5_sel_b, step5_patch_slider]:
260
+ trigger.change(fn=update_step5_comparison, inputs=_s5_inputs, outputs=_s5_outputs)
261
+
262
+ for img, sel in [(s5_a_rgb, step5_sel_a), (s5_a_overlay, step5_sel_a),
263
+ (s5_b_rgb, step5_sel_b), (s5_b_overlay, step5_sel_b)]:
 
 
264
  img.select(
265
+ fn=handle_click_step5,
266
+ inputs=[dataset_state, experiments_state, sel, step5_patch_slider],
267
+ outputs=[s5_a_metrics if sel == step5_sel_a else s5_b_metrics],
268
  )
269
 
270
 
baseline.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Spectral baseline classifier: KNN on raw 7-band pixel values, no spatial context."""
2
+ import numpy as np
3
+
4
+ from config import NUM_CHANNELS, NUM_CLASSES, IGNORE_INDEX
5
+ from metrics import compute_metrics, metrics_markdown
6
+
7
+
8
+ def _knn_predict(
9
+ train_X: np.ndarray,
10
+ train_y: np.ndarray,
11
+ query_X: np.ndarray,
12
+ k: int,
13
+ chunk: int = 50_000,
14
+ ) -> np.ndarray:
15
+ """Chunked nearest-neighbour prediction to keep peak RAM reasonable."""
16
+ N = len(query_X)
17
+ preds = np.empty(N, dtype=np.int64)
18
+ k = min(k, len(train_X))
19
+
20
+ for start in range(0, N, chunk):
21
+ end = min(start + chunk, N)
22
+ block = query_X[start:end] # (B, 7)
23
+ dists = np.sum((block[:, None, :] - train_X[None, :, :]) ** 2, axis=2) # (B, N_tr)
24
+ nn_idx = np.argpartition(dists, k - 1, axis=1)[:, :k] # (B, k)
25
+ labels = train_y[nn_idx] # (B, k)
26
+ if k == 1:
27
+ preds[start:end] = labels[:, 0]
28
+ else:
29
+ # Vectorised majority vote
30
+ votes = (labels[:, :, None] == np.arange(NUM_CLASSES)[None, None, :]).sum(axis=1)
31
+ preds[start:end] = votes.argmax(axis=1)
32
+
33
+ return preds
34
+
35
+
36
+ def run_knn_baseline(
37
+ full_image: np.ndarray,
38
+ full_train_mask: np.ndarray,
39
+ full_val_mask: np.ndarray,
40
+ val_images: np.ndarray,
41
+ k: int = 3,
42
+ ):
43
+ """
44
+ Train KNN on labeled training pixels; predict (a) the full scene and (b) each
45
+ validation patch. Evaluate against the full validation mask.
46
+
47
+ Returns
48
+ -------
49
+ full_pred : (H, W) – class index for every pixel in the scene
50
+ val_preds : (N, ph, pw) – patch-level predictions for step-4 comparison
51
+ metrics : dict
52
+ metrics_md : str
53
+ """
54
+ C, H, W = full_image.shape
55
+
56
+ labeled = full_train_mask != IGNORE_INDEX
57
+ if not labeled.any():
58
+ raise ValueError("No labeled training pixels found in TRAINING.tif.")
59
+
60
+ train_X = full_image[:, labeled].T # (N_tr, 7)
61
+ train_y = full_train_mask[labeled] # (N_tr,)
62
+
63
+ # --- Full scene prediction ---
64
+ all_X = full_image.reshape(C, H * W).T # (H*W, 7)
65
+ full_pred = _knn_predict(train_X, train_y, all_X, k).reshape(H, W)
66
+
67
+ # --- Validation patch predictions (same patches as UNet) ---
68
+ N_val, ph, pw = val_images.shape[0], val_images.shape[2], val_images.shape[3]
69
+ all_patch_X = np.concatenate(
70
+ [p.reshape(C, -1).T for p in val_images], axis=0
71
+ ) # (N_val * ph * pw, 7)
72
+ patch_preds_flat = _knn_predict(train_X, train_y, all_patch_X, k)
73
+ val_preds = patch_preds_flat.reshape(N_val, ph, pw).astype(np.int64)
74
+
75
+ # --- Metrics on full val mask ---
76
+ metrics = compute_metrics(full_pred.ravel(), full_val_mask.ravel())
77
+ metrics_md = metrics_markdown(metrics, title=f"KNN Baseline (k={k})")
78
+
79
+ return full_pred.astype(np.int64), val_preds, metrics, metrics_md
config.py CHANGED
@@ -7,19 +7,36 @@ DEFAULT_PATCH_SIZE = 128
7
  NUM_CHANNELS = 7
8
  NUM_CLASSES = 4
9
  IGNORE_INDEX = 255
 
10
 
11
  BAND_NAMES = ["H_1", "H_2", "H_3", "H_4", "H_5", "H_6", "H_7"]
 
 
 
 
 
 
 
 
 
12
  CLASS_NAMES = ["Water", "Urban", "Agriculture", "Forest"]
13
  CLASS_COLORS = np.array(
14
  [
15
- [30, 144, 255], # Water β€” blue
16
- [220, 50, 50], # Urban β€” red
17
- [255, 215, 0], # Agriculture β€” yellow
18
- [34, 139, 34], # Forest β€” green
19
  ],
20
  dtype=np.uint8,
21
  )
22
 
 
 
 
 
 
 
 
23
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
24
 
25
 
 
7
  NUM_CHANNELS = 7
8
  NUM_CLASSES = 4
9
  IGNORE_INDEX = 255
10
+ MAX_EXPERIMENTS = 3
11
 
12
  BAND_NAMES = ["H_1", "H_2", "H_3", "H_4", "H_5", "H_6", "H_7"]
13
+ BAND_DESCRIPTIONS = [
14
+ "H_1 (Coastal/Aerosol)",
15
+ "H_2 (Blue)",
16
+ "H_3 (Green)",
17
+ "H_4 (Red)",
18
+ "H_5 (NIR)",
19
+ "H_6 (SWIR-1)",
20
+ "H_7 (SWIR-2)",
21
+ ]
22
  CLASS_NAMES = ["Water", "Urban", "Agriculture", "Forest"]
23
  CLASS_COLORS = np.array(
24
  [
25
+ [30, 144, 255], # Water - blue
26
+ [220, 50, 50], # Urban - red
27
+ [255, 215, 0], # Agriculture - yellow
28
+ [34, 139, 34], # Forest - green
29
  ],
30
  dtype=np.uint8,
31
  )
32
 
33
+ # (R-band-index, G-band-index, B-band-index) for composite presets
34
+ COMPOSITE_PRESETS = {
35
+ "Natural Color (R/G/B)": (3, 2, 1),
36
+ "False Color NIR (NIR/R/G)": (4, 3, 2),
37
+ "SWIR Composite (SWIR2/NIR/R)": (6, 4, 3),
38
+ }
39
+
40
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
41
 
42
 
data.py CHANGED
@@ -19,9 +19,14 @@ VAL_MASK_FILE = "GROUND TRUTH.tif"
19
 
20
  # ── File helpers ─────────────────────────────────────────────
21
 
22
- def _download(filename: str) -> str:
 
 
 
23
  if not DATASET_REPO:
24
- raise EnvironmentError("DATASET_REPO not set in Space secrets.")
 
 
25
  return hf_hub_download(
26
  repo_id=DATASET_REPO,
27
  filename=filename,
@@ -39,11 +44,10 @@ def _read_band(path: str) -> np.ndarray:
39
 
40
 
41
  def _read_mask_raw(path: str) -> Tuple[np.ndarray, object, str]:
42
- """Returns (raw_array, nodata_value, info_string)."""
43
  with rasterio.open(path) as src:
44
  data = src.read(1)
45
  nodata = src.nodata
46
- info = f"shape={src.shape} dtype={src.dtypes[0]} nodata={nodata} bands={src.count}"
47
  return data, nodata, info
48
 
49
 
@@ -59,22 +63,33 @@ def _normalize(image: np.ndarray) -> np.ndarray:
59
  return out
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def _remap_mask(raw: np.ndarray, nodata_val) -> Tuple[np.ndarray, List[int]]:
63
  """
64
- Map raw pixel values to 0..NUM_CLASSES-1.
65
- Value 0 is treated as unlabeled background β†’ IGNORE_INDEX.
66
- Nodata pixels β†’ IGNORE_INDEX.
67
- Returns (remapped_mask, sorted_raw_class_values_used).
68
  """
69
  if nodata_val is not None:
70
  nodata_px = raw == int(nodata_val)
71
  else:
72
  nodata_px = np.zeros(raw.shape, dtype=bool)
73
 
74
- # Treat pixel value 0 as unlabeled background
75
- background_px = raw == 0
76
- ignore_px = nodata_px | background_px
77
-
78
  valid = ~ignore_px
79
  raw_unique = sorted(int(v) for v in np.unique(raw[valid]))
80
 
@@ -97,7 +112,6 @@ def _extract_patches(
97
  stride = patch_size // 2
98
  imgs, masks = [], []
99
 
100
- # Build step lists that always include the last valid position (covers edges)
101
  def steps(size):
102
  s = list(range(0, size - patch_size + 1, stride))
103
  if not s:
@@ -110,12 +124,10 @@ def _extract_patches(
110
  for x in steps(W):
111
  pm = mask[y : y + patch_size, x : x + patch_size]
112
  pi = image[:, y : y + patch_size, x : x + patch_size]
113
- # Include any patch that contains at least one labeled pixel
114
  if pm.shape == (patch_size, patch_size) and (pm != IGNORE_INDEX).any():
115
  imgs.append(pi)
116
  masks.append(pm)
117
 
118
- # Last resort: pad with zeros/IGNORE if image is smaller than patch_size
119
  if not imgs:
120
  ph = min(patch_size, H)
121
  pw = min(patch_size, W)
@@ -129,6 +141,29 @@ def _extract_patches(
129
  return np.stack(imgs).astype(np.float32), np.stack(masks).astype(np.int64)
130
 
131
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  # ── Dataset class ─────────────────────────────────────────────
133
 
134
  class MultiSpectralDataset(Dataset):
@@ -146,28 +181,38 @@ class MultiSpectralDataset(Dataset):
146
  # ── Public API ────────────────────────────���───────────────────
147
 
148
  def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict:
149
- # Download and stack bands
150
- band_arrays = [_read_band(_download(f)) for f in BAND_FILES]
151
- image = _normalize(np.stack(band_arrays, axis=0)) # (7, H, W) float32
 
 
 
 
152
 
153
- # Read raw masks + metadata
154
- raw_train, nd_train, info_train = _read_mask_raw(_download(TRAIN_MASK_FILE))
155
- raw_val, nd_val, info_val = _read_mask_raw(_download(VAL_MASK_FILE))
156
 
157
- # Remap to 0-indexed classes
 
 
 
 
158
  train_mask, train_vals = _remap_mask(raw_train, nd_train)
159
  val_mask, val_vals = _remap_mask(raw_val, nd_val)
160
 
161
  if not train_vals:
162
  raise ValueError(
163
- f"TRAINING.tif has no labeled pixels after nodata removal. "
164
- f"File info: {info_train} | Unique raw values: {np.unique(raw_train).tolist()}"
165
  )
166
 
167
- # Extract patches
168
  tr_imgs, tr_masks = _extract_patches(image, train_mask, patch_size)
169
  va_imgs, va_masks = _extract_patches(image, val_mask, patch_size)
170
 
 
 
 
171
  train_labeled = int((train_mask != IGNORE_INDEX).sum())
172
  val_labeled = int((val_mask != IGNORE_INDEX).sum())
173
 
@@ -179,21 +224,21 @@ def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict:
179
  return " | ".join(parts)
180
 
181
  status = "\n".join([
182
- f"Train patches: **{len(tr_imgs)}** | Val patches: **{len(va_imgs)}** | Patch: **{patch_size}Γ—{patch_size}**",
183
- "",
184
- f"**TRAINING.tif** `{info_train}`",
185
- f"Raw values β†’ classes: `{dict(zip(train_vals, CLASS_NAMES[:len(train_vals)]))}`",
186
- f"Labeled pixels: **{train_labeled:,}** β€” {_class_dist(train_mask, train_labeled)}",
187
- "",
188
- f"**GROUND TRUTH.tif** `{info_val}`",
189
- f"Raw values β†’ classes: `{dict(zip(val_vals, CLASS_NAMES[:len(val_vals)]))}`",
190
- f"Labeled pixels: **{val_labeled:,}** β€” {_class_dist(val_mask, val_labeled)}",
191
  ])
192
 
193
  return {
194
- "train_images": tr_imgs,
195
- "train_masks": tr_masks,
196
- "val_images": va_imgs,
197
- "val_masks": va_masks,
198
- "status": status,
 
 
 
 
 
 
199
  }
 
19
 
20
  # ── File helpers ─────────────────────────────────────────────
21
 
22
+ def _get_path(filename: str) -> str:
23
+ """Use local file if it exists, otherwise download from HuggingFace."""
24
+ if os.path.exists(filename):
25
+ return filename
26
  if not DATASET_REPO:
27
+ raise EnvironmentError(
28
+ f"'{filename}' not found locally and DATASET_REPO is not set."
29
+ )
30
  return hf_hub_download(
31
  repo_id=DATASET_REPO,
32
  filename=filename,
 
44
 
45
 
46
  def _read_mask_raw(path: str) -> Tuple[np.ndarray, object, str]:
 
47
  with rasterio.open(path) as src:
48
  data = src.read(1)
49
  nodata = src.nodata
50
+ info = f"shape={src.shape} dtype={src.dtypes[0]} nodata={nodata}"
51
  return data, nodata, info
52
 
53
 
 
63
  return out
64
 
65
 
66
+ def _compute_ndvi(raw_image: np.ndarray) -> np.ndarray:
67
+ """NDVI = (NIR - Red) / (NIR + Red). H_5=NIR (idx 4), H_4=Red (idx 3)."""
68
+ nir = raw_image[4].astype(np.float32)
69
+ red = raw_image[3].astype(np.float32)
70
+ denom = nir + red
71
+ return np.where(np.abs(denom) > 1e-6, (nir - red) / denom, 0.0).clip(-1.0, 1.0)
72
+
73
+
74
+ def _compute_ndwi(raw_image: np.ndarray) -> np.ndarray:
75
+ """NDWI = (Green - NIR) / (Green + NIR). H_3=Green (idx 2), H_5=NIR (idx 4)."""
76
+ green = raw_image[2].astype(np.float32)
77
+ nir = raw_image[4].astype(np.float32)
78
+ denom = green + nir
79
+ return np.where(np.abs(denom) > 1e-6, (green - nir) / denom, 0.0).clip(-1.0, 1.0)
80
+
81
+
82
  def _remap_mask(raw: np.ndarray, nodata_val) -> Tuple[np.ndarray, List[int]]:
83
  """
84
+ Map raw pixel values -> 0..NUM_CLASSES-1.
85
+ Value 0 and nodata -> IGNORE_INDEX.
 
 
86
  """
87
  if nodata_val is not None:
88
  nodata_px = raw == int(nodata_val)
89
  else:
90
  nodata_px = np.zeros(raw.shape, dtype=bool)
91
 
92
+ ignore_px = nodata_px | (raw == 0)
 
 
 
93
  valid = ~ignore_px
94
  raw_unique = sorted(int(v) for v in np.unique(raw[valid]))
95
 
 
112
  stride = patch_size // 2
113
  imgs, masks = [], []
114
 
 
115
  def steps(size):
116
  s = list(range(0, size - patch_size + 1, stride))
117
  if not s:
 
124
  for x in steps(W):
125
  pm = mask[y : y + patch_size, x : x + patch_size]
126
  pi = image[:, y : y + patch_size, x : x + patch_size]
 
127
  if pm.shape == (patch_size, patch_size) and (pm != IGNORE_INDEX).any():
128
  imgs.append(pi)
129
  masks.append(pm)
130
 
 
131
  if not imgs:
132
  ph = min(patch_size, H)
133
  pw = min(patch_size, W)
 
141
  return np.stack(imgs).astype(np.float32), np.stack(masks).astype(np.int64)
142
 
143
 
144
+ # ── Spectral analysis ─────────────────────────────────────────
145
+
146
+ def compute_spectral_signatures(full_image: np.ndarray, full_mask: np.ndarray) -> Dict:
147
+ """Per-class mean and std across the 7 normalized bands, from labeled pixels."""
148
+ sigs = {}
149
+ for cls_idx in range(NUM_CLASSES):
150
+ px = full_mask == cls_idx
151
+ if px.sum() == 0:
152
+ sigs[cls_idx] = {
153
+ "mean": np.zeros(NUM_CHANNELS, dtype=np.float32),
154
+ "std": np.zeros(NUM_CHANNELS, dtype=np.float32),
155
+ "n": 0,
156
+ }
157
+ else:
158
+ vals = full_image[:, px] # (7, N)
159
+ sigs[cls_idx] = {
160
+ "mean": vals.mean(axis=1).astype(np.float32),
161
+ "std": vals.std(axis=1).astype(np.float32),
162
+ "n": int(px.sum()),
163
+ }
164
+ return sigs
165
+
166
+
167
  # ── Dataset class ─────────────────────────────────────────────
168
 
169
  class MultiSpectralDataset(Dataset):
 
181
  # ── Public API ────────────────────────────���───────────────────
182
 
183
  def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict:
184
+ # Read raw bands
185
+ raw_bands = [_read_band(_get_path(f)) for f in BAND_FILES]
186
+ raw_image = np.stack(raw_bands, axis=0) # (7, H, W) raw float32
187
+
188
+ # Compute spectral indices from raw values (ratio, so must use raw)
189
+ ndvi = _compute_ndvi(raw_image)
190
+ ndwi = _compute_ndwi(raw_image)
191
 
192
+ # Normalize
193
+ image = _normalize(raw_image) # (7, H, W) normalized [0,1]
 
194
 
195
+ # Read masks
196
+ raw_train, nd_train, info_train = _read_mask_raw(_get_path(TRAIN_MASK_FILE))
197
+ raw_val, nd_val, info_val = _read_mask_raw(_get_path(VAL_MASK_FILE))
198
+
199
+ # Remap
200
  train_mask, train_vals = _remap_mask(raw_train, nd_train)
201
  val_mask, val_vals = _remap_mask(raw_val, nd_val)
202
 
203
  if not train_vals:
204
  raise ValueError(
205
+ f"TRAINING.tif has no labeled pixels. Info: {info_train} | "
206
+ f"Unique raw values: {np.unique(raw_train).tolist()}"
207
  )
208
 
209
+ # Patches
210
  tr_imgs, tr_masks = _extract_patches(image, train_mask, patch_size)
211
  va_imgs, va_masks = _extract_patches(image, val_mask, patch_size)
212
 
213
+ # Spectral signatures from training labels
214
+ signatures = compute_spectral_signatures(image, train_mask)
215
+
216
  train_labeled = int((train_mask != IGNORE_INDEX).sum())
217
  val_labeled = int((val_mask != IGNORE_INDEX).sum())
218
 
 
224
  return " | ".join(parts)
225
 
226
  status = "\n".join([
227
+ f"Train patches: **{len(tr_imgs)}** | Val patches: **{len(va_imgs)}** | Patch: **{patch_size}x{patch_size}**",
228
+ f"Training labels: **{train_labeled:,}** px β€” {_class_dist(train_mask, train_labeled)}",
229
+ f"Validation labels: **{val_labeled:,}** px β€” {_class_dist(val_mask, val_labeled)}",
 
 
 
 
 
 
230
  ])
231
 
232
  return {
233
+ "full_image": image,
234
+ "full_train_mask": train_mask,
235
+ "full_val_mask": val_mask,
236
+ "ndvi": ndvi,
237
+ "ndwi": ndwi,
238
+ "signatures": signatures,
239
+ "train_images": tr_imgs,
240
+ "train_masks": tr_masks,
241
+ "val_images": va_imgs,
242
+ "val_masks": va_masks,
243
+ "status": status,
244
  }
requirements.txt CHANGED
@@ -4,3 +4,5 @@ Pillow
4
  torch
5
  rasterio
6
  huggingface_hub
 
 
 
4
  torch
5
  rasterio
6
  huggingface_hub
7
+ matplotlib
8
+ scipy
train.py CHANGED
@@ -9,11 +9,20 @@ import gradio as gr
9
 
10
  from config import (
11
  DEVICE, NUM_CHANNELS, NUM_CLASSES, DEFAULT_PATCH_SIZE,
12
- BAND_NAMES, CLASS_NAMES, IGNORE_INDEX,
 
13
  )
14
  from data import MultiSpectralDataset, load_data
15
  from model import SmallUNet
16
- from visualize import multispectral_to_rgb, mask_to_color, overlay_mask, correctness_overlay
 
 
 
 
 
 
 
 
17
  from metrics import compute_metrics, metrics_markdown
18
 
19
 
@@ -22,8 +31,8 @@ from metrics import compute_metrics, metrics_markdown
22
  def build_prediction_cache(
23
  model: nn.Module, images: np.ndarray, batch_size: int = 8
24
  ) -> Tuple[np.ndarray, np.ndarray]:
25
- dummy_masks = np.zeros((len(images), images.shape[-2], images.shape[-1]), dtype=np.int64)
26
- ds = MultiSpectralDataset(images, dummy_masks)
27
  loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
28
  preds, probs = [], []
29
  model.eval()
@@ -36,16 +45,19 @@ def build_prediction_cache(
36
  return np.concatenate(preds, axis=0), np.concatenate(probs, axis=0)
37
 
38
 
39
- # ── Render helpers ───────────────────────────────────────────
40
 
41
- def _blank(size: int = DEFAULT_PATCH_SIZE) -> np.ndarray:
42
- return np.full((size, size, 3), 200, dtype=np.uint8)
 
 
43
 
44
 
45
  def pixel_info_markdown(
46
  x: int, y: int,
47
  img7: np.ndarray, gt: np.ndarray,
48
- pred: Optional[np.ndarray], probs: Optional[np.ndarray],
 
49
  ) -> str:
50
  h, w = gt.shape
51
  x = int(np.clip(x, 0, w - 1))
@@ -53,165 +65,244 @@ def pixel_info_markdown(
53
 
54
  gt_class = int(gt[y, x])
55
  gt_name = CLASS_NAMES[gt_class] if gt_class != IGNORE_INDEX else "Unlabeled"
56
- lines = [f"### Pixel ({x}, {y})", f"- Ground truth: **{gt_name}**"]
57
 
58
  if pred is not None:
 
 
59
  if gt_class != IGNORE_INDEX:
60
- pred_class = int(pred[y, x])
61
- lines.append(f"- Prediction: **{CLASS_NAMES[pred_class]}**")
62
- lines.append(f"- Correct: **{'Yes' if pred_class == gt_class else 'No'}**")
63
- if probs is not None:
64
- top_ids = np.argsort(probs[:, y, x])[::-1][:3]
65
- lines.append("- Top probabilities: " + ", ".join(
66
- f"{CLASS_NAMES[i]} {probs[i, y, x] * 100:.1f}%" for i in top_ids
67
- ))
68
- else:
69
- lines.append("- Prediction: β€” *(unlabeled pixel)*")
70
- else:
71
- lines.append("- Prediction: β€”")
72
 
73
  lines += ["", "**Band values**"] + [
74
- f"- {n}: {float(img7[b, y, x]):.3f}" for b, n in enumerate(BAND_NAMES)
 
 
 
 
 
 
 
 
 
 
 
75
  ]
 
 
 
 
 
 
 
 
76
  return "\n".join(lines)
77
 
78
 
79
- def _get_exp_by_name(experiments: List[Dict], name: Optional[str]) -> Optional[Dict]:
80
- if not name:
81
- return None
82
- return next((e for e in experiments if e["name"] == name), None)
 
 
 
 
 
 
 
 
83
 
 
84
 
85
- def render_experiment_panel(
86
- dataset_state: Dict, exp: Optional[Dict], sample_idx: int
87
- ) -> Tuple:
88
- """Returns (rgb, pred_color, overlay, metrics_md, error_map, click_md)."""
89
- b = _blank()
90
- no_data = (b, b, b, "### No data loaded", b, "### Load a dataset first")
91
- if dataset_state is None or "val_images" not in dataset_state:
92
- return no_data
93
  val_images = dataset_state["val_images"]
94
  val_masks = dataset_state["val_masks"]
95
- if len(val_images) == 0:
96
- return no_data
97
 
98
- idx = max(0, min(int(sample_idx), len(val_images) - 1))
99
  rgb = multispectral_to_rgb(val_images[idx])
100
  gt = val_masks[idx]
 
101
 
102
- if exp is None:
103
- return (
104
- rgb, mask_to_color(gt), overlay_mask(rgb, gt),
105
- "### No model selected",
106
- _blank(),
107
- pixel_info_markdown(0, 0, val_images[idx], gt, None, None),
108
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  if idx >= len(exp["val_preds"]):
111
- return (
112
- rgb, mask_to_color(gt), overlay_mask(rgb, gt),
113
- "### Dataset reloaded β€” retrain to refresh",
114
- _blank(),
115
- "### Retrain needed",
116
- )
117
 
118
  pred = exp["val_preds"][idx].astype(np.int64)
119
  probs = exp["val_probs"][idx].astype(np.float32)
120
- sample_metrics = compute_metrics(pred, gt, num_classes=NUM_CLASSES)
121
  return (
122
  rgb,
123
  mask_to_color(pred),
124
  overlay_mask(rgb, pred),
125
- metrics_markdown(sample_metrics, title=f"{exp['name']} (sample {idx})"),
126
  correctness_overlay(rgb, pred, gt),
127
- pixel_info_markdown(0, 0, val_images[idx], gt, pred, probs),
128
- )
129
-
130
-
131
- def render_compare_view(
132
- dataset_state, experiments, name_a, name_b, sample_idx: int
133
- ) -> Tuple:
134
- return (
135
- *render_experiment_panel(dataset_state, _get_exp_by_name(experiments, name_a), sample_idx),
136
- *render_experiment_panel(dataset_state, _get_exp_by_name(experiments, name_b), sample_idx),
137
  )
138
 
139
 
140
- def experiments_table_markdown(experiments: List[Dict]) -> str:
141
- if not experiments:
142
- return "### No models trained yet"
143
- lines = [
144
- "### Trained models", "",
145
- "| # | Name | LR | Epochs | Base Ch | Val Acc | mIoU |",
146
- "|---|---|---:|---:|---:|---:|---:|",
147
- ]
148
- for i, e in enumerate(experiments):
149
- cfg = e["config"]
150
- lines.append(
151
- f"| {i + 1} | {e['name']} | {cfg['learning_rate']:.4f} | {cfg['epochs']} "
152
- f"| {cfg['base_channels']} | {e['global_metrics']['overall_acc'] * 100:.1f}% "
153
- f"| {e['global_metrics']['miou'] * 100:.1f}% |"
154
- )
155
- return "\n".join(lines)
156
-
157
-
158
  # ── Gradio action functions ─────���────────────────────────────
159
 
160
  def load_dataset_action(patch_size: int):
161
- """Downloads the HF dataset and builds patches. Returns 9 values."""
162
- patch_size = int(patch_size)
163
  dataset_state = load_data(patch_size)
164
- val_count = len(dataset_state["val_images"])
165
-
166
- rgb = multispectral_to_rgb(dataset_state["val_images"][0])
167
- gt = dataset_state["val_masks"][0]
168
- dataset_info = "\n".join([
169
- "### Dataset loaded",
170
- f"- {dataset_state['status']}",
171
- f"- Channels: **{NUM_CHANNELS}** ({', '.join(BAND_NAMES)})",
172
- f"- Classes: **{NUM_CLASSES}** ({', '.join(CLASS_NAMES)})",
 
 
 
 
 
 
173
  ])
174
 
 
 
175
  return (
176
  dataset_state,
177
- [],
 
 
178
  dataset_info,
179
- rgb,
180
- mask_to_color(gt),
181
- overlay_mask(rgb, gt),
182
- pixel_info_markdown(0, 0, dataset_state["val_images"][0], gt, None, None),
183
- gr.update(maximum=max(0, val_count - 1), value=0), # explorer_sample_index
184
- gr.update(maximum=max(0, val_count - 1), value=0), # compare_sample_index
185
- gr.update(choices=[], value=None), # compare_sel_a
186
- gr.update(choices=[], value=None), # compare_sel_b
 
 
 
 
 
 
 
 
187
  )
188
 
189
 
190
- def update_explorer_sample(dataset_state, sample_idx: int):
191
- if dataset_state is None or "val_images" not in dataset_state:
192
- b = _blank()
193
- return b, b, b, "### No dataset loaded"
194
- val_images = dataset_state["val_images"]
195
- val_masks = dataset_state["val_masks"]
196
- idx = max(0, min(int(sample_idx), len(val_images) - 1))
197
- rgb = multispectral_to_rgb(val_images[idx])
198
- gt = val_masks[idx]
199
- return (
200
- rgb,
201
- mask_to_color(gt),
202
- overlay_mask(rgb, gt),
203
- pixel_info_markdown(0, 0, val_images[idx], gt, None, None),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
 
206
 
207
- def update_compare_sample(dataset_state, experiments, sel_a, sel_b, sample_idx: int):
208
- if dataset_state is None or "val_images" not in dataset_state:
209
- raise gr.Error("Load a dataset first.")
210
- return render_compare_view(dataset_state, experiments, sel_a, sel_b, int(sample_idx))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  def train_experiment(
214
  dataset_state: Dict,
 
215
  experiments: List[Dict],
216
  learning_rate: float,
217
  batch_size: int,
@@ -222,6 +313,11 @@ def train_experiment(
222
  ):
223
  if dataset_state is None or "train_images" not in dataset_state:
224
  raise gr.Error("Load a dataset first.")
 
 
 
 
 
225
 
226
  loader = DataLoader(
227
  MultiSpectralDataset(dataset_state["train_images"], dataset_state["train_masks"]),
@@ -233,39 +329,39 @@ def train_experiment(
233
 
234
  n_epochs = int(epochs)
235
  history = []
236
- for epoch_i in range(n_epochs):
237
- progress(epoch_i / n_epochs, desc=f"Epoch {epoch_i + 1}/{n_epochs}")
238
  model.train()
239
- total_loss, n = 0.0, 0
240
  for xb, yb in loader:
241
  xb, yb = xb.to(DEVICE), yb.to(DEVICE)
242
  optimizer.zero_grad(set_to_none=True)
243
  loss = criterion(model(xb), yb)
244
  loss.backward()
245
  optimizer.step()
246
- total_loss += float(loss.item())
247
  n += 1
248
- history.append(total_loss / max(1, n))
249
 
250
  progress(0.95, desc="Running validation inference...")
251
  val_preds, val_probs = build_prediction_cache(
252
  model, dataset_state["val_images"], batch_size=max(1, int(batch_size))
253
  )
254
  global_metrics = compute_metrics(
255
- val_preds.reshape(-1), dataset_state["val_masks"].reshape(-1), num_classes=NUM_CLASSES
256
  )
257
- progress(1.0, desc="Done!")
258
 
259
- base = (run_name or f"Run {len(experiments) + 1}").strip()
260
  existing = {e["name"] for e in experiments}
261
- name, counter = base, 2
262
  while name in existing:
263
- name = f"{base} ({counter})"
264
- counter += 1
265
 
266
  experiment = {
267
- "name": name,
268
- "config": {
269
  "learning_rate": float(learning_rate),
270
  "batch_size": int(batch_size),
271
  "epochs": int(epochs),
@@ -273,63 +369,59 @@ def train_experiment(
273
  },
274
  "train_loss_history": history,
275
  "global_metrics": global_metrics,
276
- "val_preds": val_preds.astype(np.int64),
277
- "val_probs": val_probs.astype(np.float32),
278
  }
279
-
280
  experiments = experiments + [experiment]
281
- summary = "\n".join([
282
- f"### Training finished β€” **{name}**",
283
- f"- Device: **{DEVICE}** | Epochs: **{n_epochs}**",
284
- f"- Final loss: **{history[-1]:.4f}**",
285
- f"- Val accuracy: **{global_metrics['overall_acc'] * 100:.2f}%** (labeled pixels only)",
286
- f"- Val mIoU: **{global_metrics['miou'] * 100:.2f}%**",
287
  ])
288
 
289
- choices = [e["name"] for e in experiments]
 
 
 
 
290
  return (
291
- experiments, summary,
292
- experiments_table_markdown(experiments),
293
- gr.update(choices=choices),
294
- gr.update(choices=choices),
 
 
 
295
  )
296
 
297
 
298
- # ── Click handlers ───────────────────────────────────────────
299
-
300
- def handle_click_dataset(evt: gr.SelectData, dataset_state, sample_idx: int):
301
- if dataset_state is None or "val_images" not in dataset_state:
302
- return "### No dataset"
303
- idx = max(0, min(int(sample_idx), len(dataset_state["val_images"]) - 1))
304
- x, y = evt.index
305
- return pixel_info_markdown(
306
- int(x), int(y),
307
- dataset_state["val_images"][idx], dataset_state["val_masks"][idx],
308
- None, None,
309
- )
310
 
311
 
312
- def _handle_click_experiment(
313
- evt: gr.SelectData, dataset_state, experiments,
314
- model_name: Optional[str], sample_idx: int,
315
  ) -> str:
316
  try:
317
- if dataset_state is None or "val_images" not in dataset_state:
318
- return "### No dataset loaded"
319
- idx = max(0, min(int(sample_idx), len(dataset_state["val_images"]) - 1))
320
  exp = _get_exp_by_name(experiments, model_name)
321
  x, y = evt.index
322
  img7 = dataset_state["val_images"][idx]
323
  gt = dataset_state["val_masks"][idx]
324
- if exp is None or idx >= len(exp["val_preds"]):
325
- return pixel_info_markdown(int(x), int(y), img7, gt, None, None)
326
- return pixel_info_markdown(int(x), int(y), img7, gt, exp["val_preds"][idx], exp["val_probs"][idx])
327
  except Exception as e:
328
- return f"### Click error: `{type(e).__name__}: {e}`"
329
-
330
-
331
- def handle_click_exp_a(evt, dataset_state, experiments, sel_a, sample_idx):
332
- return _handle_click_experiment(evt, dataset_state, experiments, sel_a, sample_idx)
333
-
334
- def handle_click_exp_b(evt, dataset_state, experiments, sel_b, sample_idx):
335
- return _handle_click_experiment(evt, dataset_state, experiments, sel_b, sample_idx)
 
9
 
10
  from config import (
11
  DEVICE, NUM_CHANNELS, NUM_CLASSES, DEFAULT_PATCH_SIZE,
12
+ BAND_NAMES, BAND_DESCRIPTIONS, CLASS_NAMES, IGNORE_INDEX,
13
+ COMPOSITE_PRESETS, MAX_EXPERIMENTS,
14
  )
15
  from data import MultiSpectralDataset, load_data
16
  from model import SmallUNet
17
+ from baseline import run_knn_baseline
18
+ from visualize import (
19
+ render_composite, render_single_band,
20
+ add_labels_overlay, multispectral_to_rgb,
21
+ mask_to_color, overlay_mask, correctness_overlay,
22
+ render_full_prediction_overlay,
23
+ render_spectral_signatures_chart, render_index_map,
24
+ _blank_rgb,
25
+ )
26
  from metrics import compute_metrics, metrics_markdown
27
 
28
 
 
31
  def build_prediction_cache(
32
  model: nn.Module, images: np.ndarray, batch_size: int = 8
33
  ) -> Tuple[np.ndarray, np.ndarray]:
34
+ dummy = np.zeros((len(images), images.shape[-2], images.shape[-1]), dtype=np.int64)
35
+ ds = MultiSpectralDataset(images, dummy)
36
  loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
37
  preds, probs = [], []
38
  model.eval()
 
45
  return np.concatenate(preds, axis=0), np.concatenate(probs, axis=0)
46
 
47
 
48
+ # ── Shared render helpers ────────────────────────────────────
49
 
50
+ def _get_exp_by_name(experiments: List[Dict], name: Optional[str]) -> Optional[Dict]:
51
+ if not name:
52
+ return None
53
+ return next((e for e in experiments if e["name"] == name), None)
54
 
55
 
56
  def pixel_info_markdown(
57
  x: int, y: int,
58
  img7: np.ndarray, gt: np.ndarray,
59
+ pred: Optional[np.ndarray] = None,
60
+ probs: Optional[np.ndarray] = None,
61
  ) -> str:
62
  h, w = gt.shape
63
  x = int(np.clip(x, 0, w - 1))
 
65
 
66
  gt_class = int(gt[y, x])
67
  gt_name = CLASS_NAMES[gt_class] if gt_class != IGNORE_INDEX else "Unlabeled"
68
+ lines = [f"**Pixel ({x}, {y})**", f"Ground truth: **{gt_name}**"]
69
 
70
  if pred is not None:
71
+ pred_class = int(pred[y, x])
72
+ lines.append(f"Prediction: **{CLASS_NAMES[pred_class]}**")
73
  if gt_class != IGNORE_INDEX:
74
+ lines.append(f"Correct: **{'Yes' if pred_class == gt_class else 'No'}**")
75
+ if probs is not None:
76
+ top = np.argsort(probs[:, y, x])[::-1][:3]
77
+ lines.append("Top probs: " + ", ".join(
78
+ f"{CLASS_NAMES[i]} {probs[i, y, x]*100:.1f}%" for i in top
79
+ ))
 
 
 
 
 
 
80
 
81
  lines += ["", "**Band values**"] + [
82
+ f"{BAND_DESCRIPTIONS[b]}: **{float(img7[b, y, x]):.3f}**"
83
+ for b in range(img7.shape[0])
84
+ ]
85
+ return "\n\n".join(lines)
86
+
87
+
88
+ def experiments_table_markdown(experiments: List[Dict]) -> str:
89
+ if not experiments:
90
+ return "No experiments trained yet."
91
+ lines = [
92
+ "| # | Name | LR | Epochs | Channels | Val Acc | mIoU |",
93
+ "|---|---|---:|---:|---:|---:|---:|",
94
  ]
95
+ for i, e in enumerate(experiments):
96
+ cfg = e["config"]
97
+ lines.append(
98
+ f"| {i+1} | {e['name']} | {cfg['learning_rate']:.4f} | {cfg['epochs']} "
99
+ f"| {cfg['base_channels']} "
100
+ f"| {e['global_metrics']['overall_acc']*100:.1f}% "
101
+ f"| {e['global_metrics']['miou']*100:.1f}% |"
102
+ )
103
  return "\n".join(lines)
104
 
105
 
106
+ # ── Step 1 render helpers ──────────────────────────��─────────
107
+
108
+ def _render_step1_image(dataset_state: Dict, composite_choice: str) -> np.ndarray:
109
+ full = dataset_state["full_image"]
110
+ if composite_choice in COMPOSITE_PRESETS:
111
+ r, g, b = COMPOSITE_PRESETS[composite_choice]
112
+ base = render_composite(full, r, g, b)
113
+ else:
114
+ band_idx = BAND_DESCRIPTIONS.index(composite_choice)
115
+ base = render_single_band(full, band_idx)
116
+ return add_labels_overlay(base, dataset_state["full_train_mask"], dataset_state["full_val_mask"])
117
+
118
 
119
+ # ── Step 4 render helpers ────────────────────────────────────
120
 
121
+ def _render_step4_row(
122
+ dataset_state: Dict,
123
+ baseline_state: Optional[Dict],
124
+ experiments: List[Dict],
125
+ patch_idx: int,
126
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
127
+ """Returns (rgb, gt_overlay, baseline_overlay, unet_overlay)."""
 
128
  val_images = dataset_state["val_images"]
129
  val_masks = dataset_state["val_masks"]
130
+ idx = max(0, min(patch_idx, len(val_images) - 1))
 
131
 
 
132
  rgb = multispectral_to_rgb(val_images[idx])
133
  gt = val_masks[idx]
134
+ gt_ov = overlay_mask(rgb, gt)
135
 
136
+ if baseline_state is not None and idx < len(baseline_state["val_preds"]):
137
+ bl_ov = overlay_mask(rgb, baseline_state["val_preds"][idx])
138
+ else:
139
+ bl_ov = _blank_rgb(*rgb.shape[:2])
140
+
141
+ if experiments and idx < len(experiments[-1]["val_preds"]):
142
+ un_ov = overlay_mask(rgb, experiments[-1]["val_preds"][idx])
143
+ else:
144
+ un_ov = _blank_rgb(*rgb.shape[:2])
145
+
146
+ return rgb, gt_ov, bl_ov, un_ov
147
+
148
+
149
+ # ── Step 5 render helpers ────────────────────────────────────
150
+
151
+ def render_step5_panel(
152
+ dataset_state: Dict,
153
+ exp: Optional[Dict],
154
+ patch_idx: int,
155
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, str, np.ndarray]:
156
+ """Returns (rgb, pred_color, overlay, metrics_md, error_map)."""
157
+ blank = _blank_rgb()
158
+ if dataset_state is None or exp is None:
159
+ return blank, blank, blank, "No model selected.", blank
160
+
161
+ val_images = dataset_state["val_images"]
162
+ val_masks = dataset_state["val_masks"]
163
+ idx = max(0, min(patch_idx, len(val_images) - 1))
164
+ rgb = multispectral_to_rgb(val_images[idx])
165
+ gt = val_masks[idx]
166
 
167
  if idx >= len(exp["val_preds"]):
168
+ return rgb, mask_to_color(gt), overlay_mask(rgb, gt), "Dataset reloaded β€” retrain.", blank
 
 
 
 
 
169
 
170
  pred = exp["val_preds"][idx].astype(np.int64)
171
  probs = exp["val_probs"][idx].astype(np.float32)
172
+ m = compute_metrics(pred, gt)
173
  return (
174
  rgb,
175
  mask_to_color(pred),
176
  overlay_mask(rgb, pred),
177
+ metrics_markdown(m, title=f"{exp['name']} Β· patch {idx}"),
178
  correctness_overlay(rgb, pred, gt),
 
 
 
 
 
 
 
 
 
 
179
  )
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  # ── Gradio action functions ─────���────────────────────────────
183
 
184
  def load_dataset_action(patch_size: int):
185
+ patch_size = int(patch_size)
 
186
  dataset_state = load_data(patch_size)
187
+ val_count = len(dataset_state["val_images"])
188
+
189
+ step1_img = _render_step1_image(dataset_state, "Natural Color (R/G/B)")
190
+ sig_chart = render_spectral_signatures_chart(dataset_state["signatures"])
191
+ ndvi_map = render_index_map(
192
+ dataset_state["ndvi"], "NDVI",
193
+ dataset_state["full_train_mask"], dataset_state["full_val_mask"],
194
+ )
195
+ blank = _blank_rgb()
196
+
197
+ dataset_info = "\n\n".join([
198
+ "**Dataset loaded.**",
199
+ dataset_state["status"],
200
+ f"Bands: {', '.join(BAND_NAMES)} | Classes: {', '.join(CLASS_NAMES)}",
201
+ "**Squares** = training labels Β· **Circles** = validation labels",
202
  ])
203
 
204
+ slider_upd = gr.update(maximum=max(0, val_count - 1), value=0)
205
+
206
  return (
207
  dataset_state,
208
+ None, # baseline_state reset
209
+ [], # experiments_state reset
210
+ # Tab 1
211
  dataset_info,
212
+ step1_img,
213
+ "Click the image to inspect a pixel.",
214
+ # Tab 2
215
+ sig_chart,
216
+ ndvi_map,
217
+ # Tab 3
218
+ "Run KNN baseline after loading the dataset.",
219
+ blank,
220
+ # Tab 4
221
+ "Train a model in Step 4.",
222
+ slider_upd,
223
+ blank, blank, blank,
224
+ # Tab 5
225
+ "No experiments yet.",
226
+ gr.update(choices=[], value=None),
227
+ gr.update(choices=[], value=None),
228
  )
229
 
230
 
231
+ def update_step1_composite(dataset_state, composite_choice: str):
232
+ if dataset_state is None:
233
+ return _blank_rgb(), "Load the dataset first."
234
+ img = _render_step1_image(dataset_state, composite_choice)
235
+ return img, "Click the image to inspect a pixel."
236
+
237
+
238
+ def handle_click_step1(evt: gr.SelectData, dataset_state):
239
+ if dataset_state is None:
240
+ return "Load the dataset first."
241
+ x, y = evt.index
242
+ full = dataset_state["full_image"]
243
+ fmask = dataset_state["full_val_mask"]
244
+ H, W = fmask.shape
245
+ x, y = int(np.clip(x, 0, W-1)), int(np.clip(y, 0, H-1))
246
+
247
+ cls = int(fmask[y, x])
248
+ label = CLASS_NAMES[cls] if cls != IGNORE_INDEX else "Unlabeled"
249
+ lines = [
250
+ f"**Pixel ({x}, {y})** | Val label: **{label}**", "",
251
+ "| Band | Value |", "|---|---:|",
252
+ ] + [f"| {BAND_DESCRIPTIONS[b]} | {float(full[b, y, x]):.4f} |" for b in range(7)]
253
+ return "\n".join(lines)
254
+
255
+
256
+ def update_step2_index(dataset_state, index_choice: str):
257
+ if dataset_state is None:
258
+ return _blank_rgb()
259
+ key = index_choice.lower()
260
+ arr = dataset_state[key]
261
+ return render_index_map(
262
+ arr, index_choice,
263
+ dataset_state["full_train_mask"], dataset_state["full_val_mask"],
264
  )
265
 
266
 
267
+ def run_baseline_action(dataset_state, k: int, progress=gr.Progress()):
268
+ if dataset_state is None:
269
+ raise gr.Error("Load the dataset first.")
270
+ progress(0.1, desc="Running KNN on full scene...")
271
+ k = int(k)
272
+ full_pred, val_preds, metrics, metrics_md = run_knn_baseline(
273
+ dataset_state["full_image"],
274
+ dataset_state["full_train_mask"],
275
+ dataset_state["full_val_mask"],
276
+ dataset_state["val_images"],
277
+ k=k,
278
+ )
279
+ progress(0.9, desc="Rendering...")
280
+ baseline_state = {
281
+ "k": k,
282
+ "full_pred": full_pred,
283
+ "val_preds": val_preds,
284
+ "metrics": metrics,
285
+ }
286
+ full_ov = render_full_prediction_overlay(
287
+ dataset_state["full_image"], full_pred, dataset_state["full_val_mask"],
288
+ )
289
+ progress(1.0)
290
+ return baseline_state, metrics_md, full_ov
291
+
292
+
293
+ def update_step4_patch(dataset_state, baseline_state, experiments, patch_idx: int):
294
+ if dataset_state is None:
295
+ blank = _blank_rgb()
296
+ return blank, blank, blank
297
+ _, gt_ov, bl_ov, un_ov = _render_step4_row(
298
+ dataset_state, baseline_state, experiments, int(patch_idx)
299
+ )
300
+ return gt_ov, bl_ov, un_ov
301
 
302
 
303
  def train_experiment(
304
  dataset_state: Dict,
305
+ baseline_state: Optional[Dict],
306
  experiments: List[Dict],
307
  learning_rate: float,
308
  batch_size: int,
 
313
  ):
314
  if dataset_state is None or "train_images" not in dataset_state:
315
  raise gr.Error("Load a dataset first.")
316
+ if len(experiments) >= MAX_EXPERIMENTS:
317
+ raise gr.Error(
318
+ f"Maximum {MAX_EXPERIMENTS} experiments reached. "
319
+ "Go to Step 5 to compare, then reload data to start fresh."
320
+ )
321
 
322
  loader = DataLoader(
323
  MultiSpectralDataset(dataset_state["train_images"], dataset_state["train_masks"]),
 
329
 
330
  n_epochs = int(epochs)
331
  history = []
332
+ for ep in range(n_epochs):
333
+ progress(ep / n_epochs, desc=f"Epoch {ep+1}/{n_epochs}")
334
  model.train()
335
+ total, n = 0.0, 0
336
  for xb, yb in loader:
337
  xb, yb = xb.to(DEVICE), yb.to(DEVICE)
338
  optimizer.zero_grad(set_to_none=True)
339
  loss = criterion(model(xb), yb)
340
  loss.backward()
341
  optimizer.step()
342
+ total += float(loss.item())
343
  n += 1
344
+ history.append(total / max(1, n))
345
 
346
  progress(0.95, desc="Running validation inference...")
347
  val_preds, val_probs = build_prediction_cache(
348
  model, dataset_state["val_images"], batch_size=max(1, int(batch_size))
349
  )
350
  global_metrics = compute_metrics(
351
+ val_preds.reshape(-1), dataset_state["val_masks"].reshape(-1)
352
  )
353
+ progress(1.0)
354
 
355
+ base = (run_name or f"Run {len(experiments)+1}").strip()
356
  existing = {e["name"] for e in experiments}
357
+ name, ctr = base, 2
358
  while name in existing:
359
+ name = f"{base} ({ctr})"
360
+ ctr += 1
361
 
362
  experiment = {
363
+ "name": name,
364
+ "config": {
365
  "learning_rate": float(learning_rate),
366
  "batch_size": int(batch_size),
367
  "epochs": int(epochs),
 
369
  },
370
  "train_loss_history": history,
371
  "global_metrics": global_metrics,
372
+ "val_preds": val_preds.astype(np.int64),
373
+ "val_probs": val_probs.astype(np.float32),
374
  }
 
375
  experiments = experiments + [experiment]
376
+
377
+ summary = "\n\n".join([
378
+ f"**Training finished β€” {name}**",
379
+ f"Device: **{DEVICE}** | Epochs: **{n_epochs}** | Final loss: **{history[-1]:.4f}**",
380
+ f"Val accuracy: **{global_metrics['overall_acc']*100:.2f}%** (labeled px only)",
381
+ f"Val mIoU: **{global_metrics['miou']*100:.2f}%**",
382
  ])
383
 
384
+ choices = [e["name"] for e in experiments]
385
+ val_count = len(dataset_state["val_images"])
386
+
387
+ _, gt_ov, bl_ov, un_ov = _render_step4_row(dataset_state, baseline_state, experiments, 0)
388
+
389
  return (
390
+ experiments,
391
+ summary,
392
+ gr.update(maximum=max(0, val_count-1), value=0), # step4 patch slider
393
+ gt_ov, bl_ov, un_ov,
394
+ experiments_table_markdown(experiments), # step5 table
395
+ gr.update(choices=choices, value=None), # step5 sel_a
396
+ gr.update(choices=choices, value=None), # step5 sel_b
397
  )
398
 
399
 
400
+ def update_step5_comparison(
401
+ dataset_state, experiments, sel_a, sel_b, patch_idx: int
402
+ ):
403
+ idx = int(patch_idx)
404
+ exp_a = _get_exp_by_name(experiments, sel_a)
405
+ exp_b = _get_exp_by_name(experiments, sel_b)
406
+ a_outs = render_step5_panel(dataset_state, exp_a, idx)
407
+ b_outs = render_step5_panel(dataset_state, exp_b, idx)
408
+ return (*a_outs, *b_outs)
 
 
 
409
 
410
 
411
+ def handle_click_step5(
412
+ evt: gr.SelectData,
413
+ dataset_state, experiments, model_name, patch_idx: int,
414
  ) -> str:
415
  try:
416
+ if dataset_state is None:
417
+ return "No dataset loaded."
418
+ idx = max(0, min(int(patch_idx), len(dataset_state["val_images"])-1))
419
  exp = _get_exp_by_name(experiments, model_name)
420
  x, y = evt.index
421
  img7 = dataset_state["val_images"][idx]
422
  gt = dataset_state["val_masks"][idx]
423
+ pred = exp["val_preds"][idx] if (exp and idx < len(exp["val_preds"])) else None
424
+ probs = exp["val_probs"][idx] if (exp and idx < len(exp["val_probs"])) else None
425
+ return pixel_info_markdown(int(x), int(y), img7, gt, pred, probs)
426
  except Exception as e:
427
+ return f"Click error: `{e}`"
 
 
 
 
 
 
 
visualize.py CHANGED
@@ -1,6 +1,15 @@
 
 
 
 
 
1
  import numpy as np
2
- from config import CLASS_COLORS, IGNORE_INDEX
 
 
 
3
 
 
4
 
5
  def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> np.ndarray:
6
  x = x.astype(np.float32)
@@ -11,16 +20,80 @@ def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> n
11
  return np.clip((x - lo) / (hi - lo), 0, 1)
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def multispectral_to_rgb(img7: np.ndarray) -> np.ndarray:
15
- """img7: (7, H, W) β€” uses H_3/H_2/H_1 for natural colour-like composite."""
16
- r = percentile_stretch(img7[2])
17
- g = percentile_stretch(img7[1])
18
- b = percentile_stretch(img7[0])
19
- return (np.stack([r, g, b], axis=-1) * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
 
 
 
 
21
 
22
  def mask_to_color(mask: np.ndarray) -> np.ndarray:
23
- """Class indices β†’ RGB. IGNORE_INDEX pixels rendered as light gray."""
24
  out = np.full((*mask.shape, 3), 200, dtype=np.uint8)
25
  labeled = (mask != IGNORE_INDEX) & (mask >= 0)
26
  if labeled.any():
@@ -35,7 +108,7 @@ def overlay_mask(rgb: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.n
35
 
36
 
37
  def correctness_map(pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
38
- """Green = correct, red = wrong, gray = unlabeled (IGNORE_INDEX)."""
39
  out = np.full((*pred.shape, 3), 180, dtype=np.uint8)
40
  labeled = gt != IGNORE_INDEX
41
  out[labeled & (pred == gt)] = [0, 220, 0]
@@ -47,3 +120,90 @@ def correctness_overlay(rgb: np.ndarray, pred: np.ndarray, gt: np.ndarray, alpha
47
  cm = correctness_map(pred, gt)
48
  out = ((1 - alpha) * rgb.astype(np.float32) + alpha * cm.astype(np.float32)).clip(0, 255)
49
  return out.astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+
3
+ import matplotlib
4
+ matplotlib.use("Agg")
5
+ import matplotlib.pyplot as plt
6
  import numpy as np
7
+ from PIL import Image, ImageDraw
8
+
9
+ from config import CLASS_COLORS, CLASS_NAMES, BAND_NAMES, BAND_DESCRIPTIONS, IGNORE_INDEX, NUM_CLASSES
10
+
11
 
12
+ # ── Low-level helpers ─────────────────────────────────────────
13
 
14
  def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> np.ndarray:
15
  x = x.astype(np.float32)
 
20
  return np.clip((x - lo) / (hi - lo), 0, 1)
21
 
22
 
23
+ def _fig_to_numpy(fig) -> np.ndarray:
24
+ buf = io.BytesIO()
25
+ fig.savefig(buf, format="png", bbox_inches="tight", dpi=110)
26
+ plt.close(fig)
27
+ buf.seek(0)
28
+ return np.array(Image.open(buf).convert("RGB"))
29
+
30
+
31
+ def _blank_rgb(h: int = 300, w: int = 400) -> np.ndarray:
32
+ return np.full((h, w, 3), 220, dtype=np.uint8)
33
+
34
+
35
+ # ── Composite rendering (full image or patch) ─────────────────
36
+
37
+ def render_composite(img7: np.ndarray, r: int, g: int, b: int) -> np.ndarray:
38
+ """img7: (7, H, W) -> (H, W, 3) uint8."""
39
+ return (np.stack([
40
+ percentile_stretch(img7[r]),
41
+ percentile_stretch(img7[g]),
42
+ percentile_stretch(img7[b]),
43
+ ], axis=-1) * 255).astype(np.uint8)
44
+
45
+
46
+ def render_single_band(img7: np.ndarray, band_idx: int) -> np.ndarray:
47
+ """Single band as grayscale RGB."""
48
+ gray = (percentile_stretch(img7[band_idx]) * 255).astype(np.uint8)
49
+ return np.stack([gray, gray, gray], axis=-1)
50
+
51
+
52
  def multispectral_to_rgb(img7: np.ndarray) -> np.ndarray:
53
+ """Natural colour composite: H_4/H_3/H_2 -> R/G/B."""
54
+ return render_composite(img7, r=3, g=2, b=1)
55
+
56
+
57
+ # ── Label markers on full-scene image ────────────────────────
58
+
59
+ def add_labels_overlay(
60
+ base_rgb: np.ndarray,
61
+ train_mask: np.ndarray,
62
+ val_mask: np.ndarray,
63
+ radius: int = 5,
64
+ ) -> np.ndarray:
65
+ """
66
+ Draw class-coloured markers on base_rgb.
67
+ Training labels -> filled squares; validation labels -> circles with white ring.
68
+ """
69
+ img = Image.fromarray(base_rgb)
70
+ draw = ImageDraw.Draw(img)
71
+ H, W = base_rgb.shape[:2]
72
+
73
+ for cls_idx in range(NUM_CLASSES):
74
+ color = tuple(int(c) for c in CLASS_COLORS[cls_idx])
75
+
76
+ ys, xs = np.where(train_mask == cls_idx)
77
+ for y, x in zip(ys.tolist(), xs.tolist()):
78
+ box = [x - radius, y - radius, x + radius, y + radius]
79
+ box = [max(0, box[0]), max(0, box[1]), min(W-1, box[2]), min(H-1, box[3])]
80
+ draw.rectangle(box, fill=color, outline=(255, 255, 255))
81
+
82
+ ys, xs = np.where(val_mask == cls_idx)
83
+ for y, x in zip(ys.tolist(), xs.tolist()):
84
+ outer = [x-radius-2, y-radius-2, x+radius+2, y+radius+2]
85
+ inner = [x-radius, y-radius, x+radius, y+radius ]
86
+ outer = [max(0, v) for v in outer]
87
+ draw.ellipse(outer, fill=(255, 255, 255))
88
+ draw.ellipse(inner, fill=color)
89
 
90
+ return np.array(img)
91
+
92
+
93
+ # ── Mask colourisation ────────────────────────────────────────
94
 
95
  def mask_to_color(mask: np.ndarray) -> np.ndarray:
96
+ """Class indices -> RGB. IGNORE_INDEX pixels rendered as light gray."""
97
  out = np.full((*mask.shape, 3), 200, dtype=np.uint8)
98
  labeled = (mask != IGNORE_INDEX) & (mask >= 0)
99
  if labeled.any():
 
108
 
109
 
110
  def correctness_map(pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
111
+ """Green = correct, red = wrong, gray = unlabeled."""
112
  out = np.full((*pred.shape, 3), 180, dtype=np.uint8)
113
  labeled = gt != IGNORE_INDEX
114
  out[labeled & (pred == gt)] = [0, 220, 0]
 
120
  cm = correctness_map(pred, gt)
121
  out = ((1 - alpha) * rgb.astype(np.float32) + alpha * cm.astype(np.float32)).clip(0, 255)
122
  return out.astype(np.uint8)
123
+
124
+
125
+ # ── Full-scene prediction rendering ──────────────────────────
126
+
127
+ def render_full_prediction_overlay(
128
+ full_image: np.ndarray,
129
+ full_pred: np.ndarray,
130
+ val_mask: np.ndarray,
131
+ alpha: float = 0.40,
132
+ dot_radius: int = 6,
133
+ ) -> np.ndarray:
134
+ """
135
+ Blend predicted class colours over natural-colour composite, then draw
136
+ correctness markers at every validation label location.
137
+ """
138
+ rgb = render_composite(full_image, r=3, g=2, b=1)
139
+ base = overlay_mask(rgb, full_pred, alpha=alpha)
140
+
141
+ img = Image.fromarray(base)
142
+ draw = ImageDraw.Draw(img)
143
+ H, W = base.shape[:2]
144
+
145
+ for cls_idx in range(NUM_CLASSES):
146
+ ys, xs = np.where(val_mask == cls_idx)
147
+ for y, x in zip(ys.tolist(), xs.tolist()):
148
+ correct = (full_pred[y, x] == cls_idx)
149
+ ring = (0, 200, 0) if correct else (220, 0, 0)
150
+ r = dot_radius
151
+ outer = [max(0, x-r-2), max(0, y-r-2), min(W-1, x+r+2), min(H-1, y+r+2)]
152
+ inner = [max(0, x-r), max(0, y-r), min(W-1, x+r), min(H-1, y+r) ]
153
+ draw.ellipse(outer, fill=(255, 255, 255))
154
+ draw.ellipse(inner, fill=ring)
155
+
156
+ return np.array(img)
157
+
158
+
159
+ # ── Matplotlib charts ─────────────────────────────────────────
160
+
161
+ def render_spectral_signatures_chart(signatures: dict) -> np.ndarray:
162
+ """Line chart of per-class mean Β± 1-sigma across the 7 bands."""
163
+ fig, ax = plt.subplots(figsize=(8, 3.8))
164
+ x = np.arange(len(BAND_NAMES))
165
+
166
+ for cls_idx, sig in signatures.items():
167
+ mean = sig["mean"]
168
+ std = sig["std"]
169
+ n = sig["n"]
170
+ color = CLASS_COLORS[cls_idx] / 255.0
171
+ label = f"{CLASS_NAMES[cls_idx]} (n={n})"
172
+ ax.plot(x, mean, "o-", color=color, label=label, linewidth=2, markersize=5)
173
+ ax.fill_between(x, mean - std, mean + std, alpha=0.18, color=color)
174
+
175
+ ax.set_xticks(x)
176
+ ax.set_xticklabels([d.replace(" (", "\n(") for d in BAND_DESCRIPTIONS], fontsize=8)
177
+ ax.set_ylabel("Normalised reflectance")
178
+ ax.set_title("Spectral Signatures by Land Cover Class")
179
+ ax.legend(loc="upper left", fontsize=8)
180
+ ax.grid(True, alpha=0.3)
181
+ fig.tight_layout()
182
+ return _fig_to_numpy(fig)
183
+
184
+
185
+ def render_index_map(
186
+ index_arr: np.ndarray,
187
+ name: str,
188
+ train_mask: np.ndarray,
189
+ val_mask: np.ndarray,
190
+ ) -> np.ndarray:
191
+ """NDVI or NDWI heatmap with class-coloured label markers."""
192
+ cmap = "RdYlGn" if name == "NDVI" else "RdYlBu"
193
+ fig, ax = plt.subplots(figsize=(10, 4.5))
194
+ im = ax.imshow(index_arr, cmap=cmap, vmin=-1, vmax=1, aspect="auto")
195
+ plt.colorbar(im, ax=ax, fraction=0.018, pad=0.02)
196
+
197
+ for cls_idx in range(NUM_CLASSES):
198
+ color = CLASS_COLORS[cls_idx] / 255.0
199
+ name_ = CLASS_NAMES[cls_idx]
200
+ ys, xs = np.where(train_mask == cls_idx)
201
+ ax.scatter(xs, ys, c=[color], s=18, marker="s", label=f"{name_} (train)", zorder=5)
202
+ ys, xs = np.where(val_mask == cls_idx)
203
+ ax.scatter(xs, ys, c=[color], s=18, marker="o",
204
+ edgecolors="white", linewidths=0.6, zorder=6)
205
+
206
+ ax.set_title(f"{name} β€” squares=training labels, circles=val labels")
207
+ ax.legend(loc="upper right", fontsize=7, markerscale=1.4)
208
+ fig.tight_layout()
209
+ return _fig_to_numpy(fig)