Spaces:
Sleeping
Redesign: five-step pedagogical flow with spectral baseline
Browse filesStep 1 - Know Your Data: full-scene composite viewer (natural/false/SWIR/single-band)
with labelled-pixel markers overlaid, clickable pixel inspection.
Step 2 - Spectral Signatures: per-class meanΒ±sigma chart and NDVI/NDWI heatmaps
computed from raw (unnormalised) bands for correct ratio semantics.
Step 3 - Spectral Baseline: pure-numpy chunked KNN classifier that predicts the
full 501x1001 scene from 7-band spectral features alone (no spatial context).
Step 4 - Deep Learning: UNet training with side-by-side ground-truth / KNN /
UNet patch comparison to make the spatial-context benefit concrete.
Step 5 - Experiment Lab: compare up to 3 UNet experiments; max capped at 3.
Also fixes natural-colour composite to use H4/H3/H2 (Red/Green/Blue) instead of
the previous H3/H2/H1, and adds matplotlib + scipy to requirements.
- app.py +220 -101
- baseline.py +79 -0
- config.py +21 -4
- data.py +84 -39
- requirements.txt +2 -0
- train.py +266 -174
- visualize.py +168 -8
|
@@ -1,151 +1,270 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
-
from config import
|
|
|
|
|
|
|
|
|
|
| 4 |
from train import (
|
| 5 |
load_dataset_action,
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
train_experiment,
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
handle_click_exp_b,
|
| 12 |
)
|
| 13 |
|
| 14 |
set_seed(SEED)
|
| 15 |
|
|
|
|
|
|
|
| 16 |
custom_css = """
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
}
|
| 20 |
-
.small-note { font-size: 0.9rem; opacity: 0.85; }
|
| 21 |
"""
|
| 22 |
|
| 23 |
with gr.Blocks(title=APP_TITLE, css=custom_css) as demo:
|
| 24 |
-
gr.Markdown(f"# {APP_TITLE}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
dataset_state = gr.State(None)
|
|
|
|
| 27 |
experiments_state = gr.State([])
|
| 28 |
|
| 29 |
-
# ββ
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
with gr.Row():
|
| 32 |
with gr.Column(scale=1):
|
| 33 |
-
patch_size = gr.Slider(64, 512, value=DEFAULT_PATCH_SIZE, step=32, label="Patch size")
|
| 34 |
load_btn = gr.Button("Load dataset", variant="primary")
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
gr.Markdown(
|
| 37 |
-
"
|
| 38 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
-
with gr.Column(scale=2, elem_id="explorer"):
|
| 41 |
-
explorer_sample_index = gr.Slider(0, 59, value=0, step=1, label="Validation patch index")
|
| 42 |
-
with gr.Row():
|
| 43 |
-
explorer_rgb = gr.Image(label="RGB composite", type="numpy", height=400)
|
| 44 |
-
explorer_gt = gr.Image(label="Ground truth mask", type="numpy", height=400)
|
| 45 |
-
explorer_overlay = gr.Image(label="Ground truth overlay", type="numpy", height=400)
|
| 46 |
-
explorer_click_info = gr.Markdown("### Click the RGB image to inspect a pixel")
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
with gr.Row():
|
| 51 |
with gr.Column(scale=1):
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
with gr.Column(scale=1):
|
| 59 |
-
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
gr.Markdown(
|
| 62 |
-
"
|
| 63 |
-
"
|
|
|
|
| 64 |
)
|
|
|
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 69 |
with gr.Row():
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
with gr.Row():
|
| 73 |
-
with gr.Column(scale=1
|
| 74 |
gr.Markdown("## Left")
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
with gr.Column(scale=1, elem_id="compare-b"):
|
| 83 |
gr.Markdown("## Right")
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
]
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
outputs=[
|
| 105 |
-
dataset_state, experiments_state,
|
| 106 |
-
dataset_info,
|
| 107 |
-
explorer_rgb, explorer_gt, explorer_overlay,
|
| 108 |
-
explorer_click_info,
|
| 109 |
-
explorer_sample_index,
|
| 110 |
-
compare_sample_index,
|
| 111 |
-
compare_sel_a, compare_sel_b,
|
| 112 |
-
],
|
| 113 |
)
|
| 114 |
|
| 115 |
-
|
| 116 |
-
fn=
|
| 117 |
-
inputs=[dataset_state
|
| 118 |
-
outputs=[
|
| 119 |
)
|
| 120 |
|
| 121 |
-
|
| 122 |
-
fn=
|
| 123 |
-
inputs=[dataset_state,
|
| 124 |
-
outputs=[
|
| 125 |
)
|
| 126 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
train_btn.click(
|
| 128 |
fn=train_experiment,
|
| 129 |
-
inputs=[
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
| 131 |
)
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
for
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
)
|
| 144 |
-
for img in [compare_b_rgb, compare_b_overlay]:
|
| 145 |
img.select(
|
| 146 |
-
fn=
|
| 147 |
-
inputs=[dataset_state, experiments_state,
|
| 148 |
-
outputs=[
|
| 149 |
)
|
| 150 |
|
| 151 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
|
| 3 |
+
from config import (
|
| 4 |
+
APP_TITLE, set_seed, SEED, DEFAULT_PATCH_SIZE,
|
| 5 |
+
COMPOSITE_PRESETS, BAND_DESCRIPTIONS, MAX_EXPERIMENTS,
|
| 6 |
+
)
|
| 7 |
from train import (
|
| 8 |
load_dataset_action,
|
| 9 |
+
update_step1_composite,
|
| 10 |
+
handle_click_step1,
|
| 11 |
+
update_step2_index,
|
| 12 |
+
run_baseline_action,
|
| 13 |
+
update_step4_patch,
|
| 14 |
train_experiment,
|
| 15 |
+
update_step5_comparison,
|
| 16 |
+
handle_click_step5,
|
|
|
|
| 17 |
)
|
| 18 |
|
| 19 |
set_seed(SEED)
|
| 20 |
|
| 21 |
+
_COMPOSITE_CHOICES = list(COMPOSITE_PRESETS.keys()) + BAND_DESCRIPTIONS
|
| 22 |
+
|
| 23 |
custom_css = """
|
| 24 |
+
#step1-img img, #step3-pred img, #step3-correct img { image-rendering: pixelated; }
|
| 25 |
+
.step-header { font-size: 1.05rem; font-weight: 600; margin-bottom: 4px; }
|
| 26 |
+
.hint { font-size: 0.88rem; color: #666; }
|
|
|
|
| 27 |
"""
|
| 28 |
|
| 29 |
with gr.Blocks(title=APP_TITLE, css=custom_css) as demo:
|
| 30 |
+
gr.Markdown(f"# {APP_TITLE}")
|
| 31 |
+
gr.Markdown(
|
| 32 |
+
"A five-step journey from raw satellite pixels to deep-learning segmentation. "
|
| 33 |
+
"Work through the tabs in order β each step builds on the previous one."
|
| 34 |
+
)
|
| 35 |
|
| 36 |
dataset_state = gr.State(None)
|
| 37 |
+
baseline_state = gr.State(None)
|
| 38 |
experiments_state = gr.State([])
|
| 39 |
|
| 40 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
+
# Step 1 β Know Your Data
|
| 42 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 43 |
+
with gr.Tab("Step 1 Β· Know Your Data"):
|
| 44 |
+
gr.Markdown(
|
| 45 |
+
"**Start here.** Load the dataset, then explore the 7 spectral bands. "
|
| 46 |
+
"Squares on the image are training labels; circles are validation labels.",
|
| 47 |
+
elem_classes="hint",
|
| 48 |
+
)
|
| 49 |
with gr.Row():
|
| 50 |
with gr.Column(scale=1):
|
| 51 |
+
patch_size = gr.Slider(64, 512, value=DEFAULT_PATCH_SIZE, step=32, label="Patch size (for training)")
|
| 52 |
load_btn = gr.Button("Load dataset", variant="primary")
|
| 53 |
+
composite_dd = gr.Dropdown(
|
| 54 |
+
choices=_COMPOSITE_CHOICES,
|
| 55 |
+
value="Natural Color (R/G/B)",
|
| 56 |
+
label="View mode",
|
| 57 |
+
)
|
| 58 |
+
step1_info = gr.Markdown("*Load the dataset to begin.*")
|
| 59 |
+
|
| 60 |
+
with gr.Column(scale=3, elem_id="step1-img"):
|
| 61 |
+
step1_image = gr.Image(label="Full scene β click to inspect a pixel", type="numpy")
|
| 62 |
+
step1_click = gr.Markdown("*Click anywhere on the image.*")
|
| 63 |
+
|
| 64 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 65 |
+
# Step 2 β Spectral Signatures
|
| 66 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 67 |
+
with gr.Tab("Step 2 Β· Spectral Signatures"):
|
| 68 |
+
gr.Markdown(
|
| 69 |
+
"Each land cover type has a characteristic pattern of brightness across the 7 bands β "
|
| 70 |
+
"its **spectral signature**. Notice how H_5 (NIR) separates vegetation from water. "
|
| 71 |
+
"NDVI and NDWI are hand-crafted indices that exploit this difference.",
|
| 72 |
+
elem_classes="hint",
|
| 73 |
+
)
|
| 74 |
+
with gr.Row():
|
| 75 |
+
with gr.Column(scale=1):
|
| 76 |
+
index_radio = gr.Radio(
|
| 77 |
+
choices=["NDVI", "NDWI"],
|
| 78 |
+
value="NDVI",
|
| 79 |
+
label="Spectral index map",
|
| 80 |
+
)
|
| 81 |
gr.Markdown(
|
| 82 |
+
"**NDVI** = (NIR β Red) / (NIR + Red) \n"
|
| 83 |
+
"High values β dense vegetation (Forest, Agriculture) \n"
|
| 84 |
+
"Low / negative β water, urban, bare soil\n\n"
|
| 85 |
+
"**NDWI** = (Green β NIR) / (Green + NIR) \n"
|
| 86 |
+
"Positive β water; negative β land",
|
| 87 |
+
elem_classes="hint",
|
| 88 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
+
with gr.Column(scale=3):
|
| 91 |
+
step2_sig_chart = gr.Image(label="Spectral signatures (training labels)", type="numpy")
|
| 92 |
+
step2_index_map = gr.Image(label="Index map", type="numpy")
|
| 93 |
+
|
| 94 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 95 |
+
# Step 3 β Spectral Baseline (KNN)
|
| 96 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 97 |
+
with gr.Tab("Step 3 Β· Spectral Baseline"):
|
| 98 |
+
gr.Markdown(
|
| 99 |
+
"**No convolutions here.** We classify every pixel using only its 7 band values, "
|
| 100 |
+
"finding the k nearest training pixels in spectral space. "
|
| 101 |
+
"This shows you what's achievable without any spatial context.",
|
| 102 |
+
elem_classes="hint",
|
| 103 |
+
)
|
| 104 |
with gr.Row():
|
| 105 |
with gr.Column(scale=1):
|
| 106 |
+
k_slider = gr.Slider(1, 5, value=3, step=2, label="k (number of neighbours)")
|
| 107 |
+
baseline_btn = gr.Button("Run KNN baseline", variant="primary")
|
| 108 |
+
step3_metrics = gr.Markdown("*Run the baseline to see results.*")
|
| 109 |
+
|
| 110 |
+
with gr.Column(scale=3, elem_id="step3-pred"):
|
| 111 |
+
step3_full_pred = gr.Image(
|
| 112 |
+
label="Full-scene prediction Β· overlaid on natural-colour image Β· coloured dots = val labels (green=correct, red=wrong)",
|
| 113 |
+
type="numpy",
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 117 |
+
# Step 4 β Deep Learning (UNet)
|
| 118 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 119 |
+
with gr.Tab("Step 4 Β· Deep Learning"):
|
| 120 |
+
gr.Markdown(
|
| 121 |
+
"A **U-Net** sees a patch of pixels at once, not just one pixel. "
|
| 122 |
+
"Its encoder captures local texture; skip connections preserve spatial detail. "
|
| 123 |
+
"Train a model and compare it patch-by-patch against the KNN baseline.",
|
| 124 |
+
elem_classes="hint",
|
| 125 |
+
)
|
| 126 |
+
with gr.Row():
|
| 127 |
with gr.Column(scale=1):
|
| 128 |
+
run_name = gr.Textbox(label="Experiment name", placeholder="e.g. lr-1e-3_ch-16")
|
| 129 |
+
learning_rate = gr.Slider(1e-4, 5e-3, value=1e-3, step=1e-4, label="Learning rate")
|
| 130 |
+
batch_size = gr.Slider(2, 32, value=8, step=2, label="Batch size")
|
| 131 |
+
epochs = gr.Slider(1, 20, value=5, step=1, label="Epochs")
|
| 132 |
+
base_channels = gr.Slider(8, 64, value=16, step=8, label="Model width (base channels)")
|
| 133 |
+
train_btn = gr.Button("Train model", variant="primary")
|
| 134 |
gr.Markdown(
|
| 135 |
+
f"*Max {MAX_EXPERIMENTS} experiments total. "
|
| 136 |
+
"Reload data to reset.*",
|
| 137 |
+
elem_classes="hint",
|
| 138 |
)
|
| 139 |
+
step4_summary = gr.Markdown("*Train a model to see results.*")
|
| 140 |
|
| 141 |
+
with gr.Column(scale=3):
|
| 142 |
+
step4_patch_slider = gr.Slider(0, 59, value=0, step=1, label="Validation patch index")
|
| 143 |
+
with gr.Row():
|
| 144 |
+
step4_gt_img = gr.Image(label="Ground truth overlay", type="numpy", height=280)
|
| 145 |
+
step4_bl_img = gr.Image(label="KNN Baseline prediction", type="numpy", height=280)
|
| 146 |
+
step4_un_img = gr.Image(label="UNet prediction", type="numpy", height=280)
|
| 147 |
+
|
| 148 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 149 |
+
# Step 5 β Experiment Lab
|
| 150 |
+
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 151 |
+
with gr.Tab("Step 5 Β· Experiment Lab"):
|
| 152 |
+
gr.Markdown(
|
| 153 |
+
f"Compare up to **{MAX_EXPERIMENTS}** UNet experiments side by side. "
|
| 154 |
+
"Try different learning rates, epochs, or model widths and see what changes.",
|
| 155 |
+
elem_classes="hint",
|
| 156 |
+
)
|
| 157 |
with gr.Row():
|
| 158 |
+
step5_sel_a = gr.Dropdown(choices=[], value=None, label="Left model", interactive=True)
|
| 159 |
+
step5_sel_b = gr.Dropdown(choices=[], value=None, label="Right model", interactive=True)
|
| 160 |
+
step5_patch_slider = gr.Slider(0, 59, value=0, step=1, label="Validation patch index")
|
| 161 |
+
step5_table = gr.Markdown("*No experiments yet.*")
|
| 162 |
+
|
| 163 |
+
gr.Markdown(
|
| 164 |
+
"**Guiding questions**\n"
|
| 165 |
+
"- Double the epochs β does mIoU keep improving or plateau?\n"
|
| 166 |
+
"- Halve the learning rate β does training become more stable?\n"
|
| 167 |
+
"- Increase base channels from 16 to 32 β worth the extra time?",
|
| 168 |
+
elem_classes="hint",
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
with gr.Row():
|
| 172 |
+
with gr.Column(scale=1):
|
| 173 |
gr.Markdown("## Left")
|
| 174 |
+
s5_a_rgb = gr.Image(label="RGB", type="numpy", height=240)
|
| 175 |
+
s5_a_pred = gr.Image(label="Prediction", type="numpy", height=240)
|
| 176 |
+
s5_a_overlay = gr.Image(label="Overlay", type="numpy", height=240)
|
| 177 |
+
s5_a_metrics = gr.Markdown("*No model selected.*")
|
| 178 |
+
s5_a_error = gr.Image(label="Correctness map",type="numpy", height=240)
|
| 179 |
+
|
| 180 |
+
with gr.Column(scale=1):
|
|
|
|
| 181 |
gr.Markdown("## Right")
|
| 182 |
+
s5_b_rgb = gr.Image(label="RGB", type="numpy", height=240)
|
| 183 |
+
s5_b_pred = gr.Image(label="Prediction", type="numpy", height=240)
|
| 184 |
+
s5_b_overlay = gr.Image(label="Overlay", type="numpy", height=240)
|
| 185 |
+
s5_b_metrics = gr.Markdown("*No model selected.*")
|
| 186 |
+
s5_b_error = gr.Image(label="Correctness map",type="numpy", height=240)
|
| 187 |
+
|
| 188 |
+
# ββ Event wiring βββββββββββββββββββββββββββββββββββββββββ
|
| 189 |
+
|
| 190 |
+
_load_outputs = [
|
| 191 |
+
dataset_state, baseline_state, experiments_state,
|
| 192 |
+
# Tab 1
|
| 193 |
+
step1_info, step1_image, step1_click,
|
| 194 |
+
# Tab 2
|
| 195 |
+
step2_sig_chart, step2_index_map,
|
| 196 |
+
# Tab 3
|
| 197 |
+
step3_metrics, step3_full_pred,
|
| 198 |
+
# Tab 4
|
| 199 |
+
step4_summary, step4_patch_slider,
|
| 200 |
+
step4_gt_img, step4_bl_img, step4_un_img,
|
| 201 |
+
# Tab 5
|
| 202 |
+
step5_table, step5_sel_a, step5_sel_b,
|
| 203 |
]
|
| 204 |
+
|
| 205 |
+
load_btn.click(fn=load_dataset_action, inputs=[patch_size], outputs=_load_outputs)
|
| 206 |
+
|
| 207 |
+
composite_dd.change(
|
| 208 |
+
fn=update_step1_composite,
|
| 209 |
+
inputs=[dataset_state, composite_dd],
|
| 210 |
+
outputs=[step1_image, step1_click],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
)
|
| 212 |
|
| 213 |
+
step1_image.select(
|
| 214 |
+
fn=handle_click_step1,
|
| 215 |
+
inputs=[dataset_state],
|
| 216 |
+
outputs=[step1_click],
|
| 217 |
)
|
| 218 |
|
| 219 |
+
index_radio.change(
|
| 220 |
+
fn=update_step2_index,
|
| 221 |
+
inputs=[dataset_state, index_radio],
|
| 222 |
+
outputs=[step2_index_map],
|
| 223 |
)
|
| 224 |
|
| 225 |
+
baseline_btn.click(
|
| 226 |
+
fn=run_baseline_action,
|
| 227 |
+
inputs=[dataset_state, k_slider],
|
| 228 |
+
outputs=[baseline_state, step3_metrics, step3_full_pred],
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
_train_outputs = [
|
| 232 |
+
experiments_state,
|
| 233 |
+
step4_summary, step4_patch_slider,
|
| 234 |
+
step4_gt_img, step4_bl_img, step4_un_img,
|
| 235 |
+
step5_table, step5_sel_a, step5_sel_b,
|
| 236 |
+
]
|
| 237 |
+
|
| 238 |
train_btn.click(
|
| 239 |
fn=train_experiment,
|
| 240 |
+
inputs=[
|
| 241 |
+
dataset_state, baseline_state, experiments_state,
|
| 242 |
+
learning_rate, batch_size, epochs, base_channels, run_name,
|
| 243 |
+
],
|
| 244 |
+
outputs=_train_outputs,
|
| 245 |
)
|
| 246 |
|
| 247 |
+
step4_patch_slider.change(
|
| 248 |
+
fn=update_step4_patch,
|
| 249 |
+
inputs=[dataset_state, baseline_state, experiments_state, step4_patch_slider],
|
| 250 |
+
outputs=[step4_gt_img, step4_bl_img, step4_un_img],
|
| 251 |
+
)
|
| 252 |
|
| 253 |
+
_s5_inputs = [dataset_state, experiments_state, step5_sel_a, step5_sel_b, step5_patch_slider]
|
| 254 |
+
_s5_outputs = [
|
| 255 |
+
s5_a_rgb, s5_a_pred, s5_a_overlay, s5_a_metrics, s5_a_error,
|
| 256 |
+
s5_b_rgb, s5_b_pred, s5_b_overlay, s5_b_metrics, s5_b_error,
|
| 257 |
+
]
|
| 258 |
|
| 259 |
+
for trigger in [step5_sel_a, step5_sel_b, step5_patch_slider]:
|
| 260 |
+
trigger.change(fn=update_step5_comparison, inputs=_s5_inputs, outputs=_s5_outputs)
|
| 261 |
+
|
| 262 |
+
for img, sel in [(s5_a_rgb, step5_sel_a), (s5_a_overlay, step5_sel_a),
|
| 263 |
+
(s5_b_rgb, step5_sel_b), (s5_b_overlay, step5_sel_b)]:
|
|
|
|
|
|
|
| 264 |
img.select(
|
| 265 |
+
fn=handle_click_step5,
|
| 266 |
+
inputs=[dataset_state, experiments_state, sel, step5_patch_slider],
|
| 267 |
+
outputs=[s5_a_metrics if sel == step5_sel_a else s5_b_metrics],
|
| 268 |
)
|
| 269 |
|
| 270 |
|
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Spectral baseline classifier: KNN on raw 7-band pixel values, no spatial context."""
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
from config import NUM_CHANNELS, NUM_CLASSES, IGNORE_INDEX
|
| 5 |
+
from metrics import compute_metrics, metrics_markdown
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _knn_predict(
|
| 9 |
+
train_X: np.ndarray,
|
| 10 |
+
train_y: np.ndarray,
|
| 11 |
+
query_X: np.ndarray,
|
| 12 |
+
k: int,
|
| 13 |
+
chunk: int = 50_000,
|
| 14 |
+
) -> np.ndarray:
|
| 15 |
+
"""Chunked nearest-neighbour prediction to keep peak RAM reasonable."""
|
| 16 |
+
N = len(query_X)
|
| 17 |
+
preds = np.empty(N, dtype=np.int64)
|
| 18 |
+
k = min(k, len(train_X))
|
| 19 |
+
|
| 20 |
+
for start in range(0, N, chunk):
|
| 21 |
+
end = min(start + chunk, N)
|
| 22 |
+
block = query_X[start:end] # (B, 7)
|
| 23 |
+
dists = np.sum((block[:, None, :] - train_X[None, :, :]) ** 2, axis=2) # (B, N_tr)
|
| 24 |
+
nn_idx = np.argpartition(dists, k - 1, axis=1)[:, :k] # (B, k)
|
| 25 |
+
labels = train_y[nn_idx] # (B, k)
|
| 26 |
+
if k == 1:
|
| 27 |
+
preds[start:end] = labels[:, 0]
|
| 28 |
+
else:
|
| 29 |
+
# Vectorised majority vote
|
| 30 |
+
votes = (labels[:, :, None] == np.arange(NUM_CLASSES)[None, None, :]).sum(axis=1)
|
| 31 |
+
preds[start:end] = votes.argmax(axis=1)
|
| 32 |
+
|
| 33 |
+
return preds
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def run_knn_baseline(
|
| 37 |
+
full_image: np.ndarray,
|
| 38 |
+
full_train_mask: np.ndarray,
|
| 39 |
+
full_val_mask: np.ndarray,
|
| 40 |
+
val_images: np.ndarray,
|
| 41 |
+
k: int = 3,
|
| 42 |
+
):
|
| 43 |
+
"""
|
| 44 |
+
Train KNN on labeled training pixels; predict (a) the full scene and (b) each
|
| 45 |
+
validation patch. Evaluate against the full validation mask.
|
| 46 |
+
|
| 47 |
+
Returns
|
| 48 |
+
-------
|
| 49 |
+
full_pred : (H, W) β class index for every pixel in the scene
|
| 50 |
+
val_preds : (N, ph, pw) β patch-level predictions for step-4 comparison
|
| 51 |
+
metrics : dict
|
| 52 |
+
metrics_md : str
|
| 53 |
+
"""
|
| 54 |
+
C, H, W = full_image.shape
|
| 55 |
+
|
| 56 |
+
labeled = full_train_mask != IGNORE_INDEX
|
| 57 |
+
if not labeled.any():
|
| 58 |
+
raise ValueError("No labeled training pixels found in TRAINING.tif.")
|
| 59 |
+
|
| 60 |
+
train_X = full_image[:, labeled].T # (N_tr, 7)
|
| 61 |
+
train_y = full_train_mask[labeled] # (N_tr,)
|
| 62 |
+
|
| 63 |
+
# --- Full scene prediction ---
|
| 64 |
+
all_X = full_image.reshape(C, H * W).T # (H*W, 7)
|
| 65 |
+
full_pred = _knn_predict(train_X, train_y, all_X, k).reshape(H, W)
|
| 66 |
+
|
| 67 |
+
# --- Validation patch predictions (same patches as UNet) ---
|
| 68 |
+
N_val, ph, pw = val_images.shape[0], val_images.shape[2], val_images.shape[3]
|
| 69 |
+
all_patch_X = np.concatenate(
|
| 70 |
+
[p.reshape(C, -1).T for p in val_images], axis=0
|
| 71 |
+
) # (N_val * ph * pw, 7)
|
| 72 |
+
patch_preds_flat = _knn_predict(train_X, train_y, all_patch_X, k)
|
| 73 |
+
val_preds = patch_preds_flat.reshape(N_val, ph, pw).astype(np.int64)
|
| 74 |
+
|
| 75 |
+
# --- Metrics on full val mask ---
|
| 76 |
+
metrics = compute_metrics(full_pred.ravel(), full_val_mask.ravel())
|
| 77 |
+
metrics_md = metrics_markdown(metrics, title=f"KNN Baseline (k={k})")
|
| 78 |
+
|
| 79 |
+
return full_pred.astype(np.int64), val_preds, metrics, metrics_md
|
|
@@ -7,19 +7,36 @@ DEFAULT_PATCH_SIZE = 128
|
|
| 7 |
NUM_CHANNELS = 7
|
| 8 |
NUM_CLASSES = 4
|
| 9 |
IGNORE_INDEX = 255
|
|
|
|
| 10 |
|
| 11 |
BAND_NAMES = ["H_1", "H_2", "H_3", "H_4", "H_5", "H_6", "H_7"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
CLASS_NAMES = ["Water", "Urban", "Agriculture", "Forest"]
|
| 13 |
CLASS_COLORS = np.array(
|
| 14 |
[
|
| 15 |
-
[30, 144, 255], # Water
|
| 16 |
-
[220, 50, 50], # Urban
|
| 17 |
-
[255, 215, 0], # Agriculture
|
| 18 |
-
[34, 139, 34], # Forest
|
| 19 |
],
|
| 20 |
dtype=np.uint8,
|
| 21 |
)
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
|
| 25 |
|
|
|
|
| 7 |
NUM_CHANNELS = 7
|
| 8 |
NUM_CLASSES = 4
|
| 9 |
IGNORE_INDEX = 255
|
| 10 |
+
MAX_EXPERIMENTS = 3
|
| 11 |
|
| 12 |
BAND_NAMES = ["H_1", "H_2", "H_3", "H_4", "H_5", "H_6", "H_7"]
|
| 13 |
+
BAND_DESCRIPTIONS = [
|
| 14 |
+
"H_1 (Coastal/Aerosol)",
|
| 15 |
+
"H_2 (Blue)",
|
| 16 |
+
"H_3 (Green)",
|
| 17 |
+
"H_4 (Red)",
|
| 18 |
+
"H_5 (NIR)",
|
| 19 |
+
"H_6 (SWIR-1)",
|
| 20 |
+
"H_7 (SWIR-2)",
|
| 21 |
+
]
|
| 22 |
CLASS_NAMES = ["Water", "Urban", "Agriculture", "Forest"]
|
| 23 |
CLASS_COLORS = np.array(
|
| 24 |
[
|
| 25 |
+
[30, 144, 255], # Water - blue
|
| 26 |
+
[220, 50, 50], # Urban - red
|
| 27 |
+
[255, 215, 0], # Agriculture - yellow
|
| 28 |
+
[34, 139, 34], # Forest - green
|
| 29 |
],
|
| 30 |
dtype=np.uint8,
|
| 31 |
)
|
| 32 |
|
| 33 |
+
# (R-band-index, G-band-index, B-band-index) for composite presets
|
| 34 |
+
COMPOSITE_PRESETS = {
|
| 35 |
+
"Natural Color (R/G/B)": (3, 2, 1),
|
| 36 |
+
"False Color NIR (NIR/R/G)": (4, 3, 2),
|
| 37 |
+
"SWIR Composite (SWIR2/NIR/R)": (6, 4, 3),
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 41 |
|
| 42 |
|
|
@@ -19,9 +19,14 @@ VAL_MASK_FILE = "GROUND TRUTH.tif"
|
|
| 19 |
|
| 20 |
# ββ File helpers βββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
|
| 22 |
-
def
|
|
|
|
|
|
|
|
|
|
| 23 |
if not DATASET_REPO:
|
| 24 |
-
raise EnvironmentError(
|
|
|
|
|
|
|
| 25 |
return hf_hub_download(
|
| 26 |
repo_id=DATASET_REPO,
|
| 27 |
filename=filename,
|
|
@@ -39,11 +44,10 @@ def _read_band(path: str) -> np.ndarray:
|
|
| 39 |
|
| 40 |
|
| 41 |
def _read_mask_raw(path: str) -> Tuple[np.ndarray, object, str]:
|
| 42 |
-
"""Returns (raw_array, nodata_value, info_string)."""
|
| 43 |
with rasterio.open(path) as src:
|
| 44 |
data = src.read(1)
|
| 45 |
nodata = src.nodata
|
| 46 |
-
info = f"shape={src.shape} dtype={src.dtypes[0]} nodata={nodata}
|
| 47 |
return data, nodata, info
|
| 48 |
|
| 49 |
|
|
@@ -59,22 +63,33 @@ def _normalize(image: np.ndarray) -> np.ndarray:
|
|
| 59 |
return out
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def _remap_mask(raw: np.ndarray, nodata_val) -> Tuple[np.ndarray, List[int]]:
|
| 63 |
"""
|
| 64 |
-
Map raw pixel values
|
| 65 |
-
Value 0
|
| 66 |
-
Nodata pixels β IGNORE_INDEX.
|
| 67 |
-
Returns (remapped_mask, sorted_raw_class_values_used).
|
| 68 |
"""
|
| 69 |
if nodata_val is not None:
|
| 70 |
nodata_px = raw == int(nodata_val)
|
| 71 |
else:
|
| 72 |
nodata_px = np.zeros(raw.shape, dtype=bool)
|
| 73 |
|
| 74 |
-
|
| 75 |
-
background_px = raw == 0
|
| 76 |
-
ignore_px = nodata_px | background_px
|
| 77 |
-
|
| 78 |
valid = ~ignore_px
|
| 79 |
raw_unique = sorted(int(v) for v in np.unique(raw[valid]))
|
| 80 |
|
|
@@ -97,7 +112,6 @@ def _extract_patches(
|
|
| 97 |
stride = patch_size // 2
|
| 98 |
imgs, masks = [], []
|
| 99 |
|
| 100 |
-
# Build step lists that always include the last valid position (covers edges)
|
| 101 |
def steps(size):
|
| 102 |
s = list(range(0, size - patch_size + 1, stride))
|
| 103 |
if not s:
|
|
@@ -110,12 +124,10 @@ def _extract_patches(
|
|
| 110 |
for x in steps(W):
|
| 111 |
pm = mask[y : y + patch_size, x : x + patch_size]
|
| 112 |
pi = image[:, y : y + patch_size, x : x + patch_size]
|
| 113 |
-
# Include any patch that contains at least one labeled pixel
|
| 114 |
if pm.shape == (patch_size, patch_size) and (pm != IGNORE_INDEX).any():
|
| 115 |
imgs.append(pi)
|
| 116 |
masks.append(pm)
|
| 117 |
|
| 118 |
-
# Last resort: pad with zeros/IGNORE if image is smaller than patch_size
|
| 119 |
if not imgs:
|
| 120 |
ph = min(patch_size, H)
|
| 121 |
pw = min(patch_size, W)
|
|
@@ -129,6 +141,29 @@ def _extract_patches(
|
|
| 129 |
return np.stack(imgs).astype(np.float32), np.stack(masks).astype(np.int64)
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
# ββ Dataset class βββββββββββββββββββββββββββββββββββββββββββββ
|
| 133 |
|
| 134 |
class MultiSpectralDataset(Dataset):
|
|
@@ -146,28 +181,38 @@ class MultiSpectralDataset(Dataset):
|
|
| 146 |
# ββ Public API ββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββ
|
| 147 |
|
| 148 |
def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict:
|
| 149 |
-
#
|
| 150 |
-
|
| 151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
|
| 153 |
-
#
|
| 154 |
-
|
| 155 |
-
raw_val, nd_val, info_val = _read_mask_raw(_download(VAL_MASK_FILE))
|
| 156 |
|
| 157 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
train_mask, train_vals = _remap_mask(raw_train, nd_train)
|
| 159 |
val_mask, val_vals = _remap_mask(raw_val, nd_val)
|
| 160 |
|
| 161 |
if not train_vals:
|
| 162 |
raise ValueError(
|
| 163 |
-
f"TRAINING.tif has no labeled pixels
|
| 164 |
-
f"
|
| 165 |
)
|
| 166 |
|
| 167 |
-
#
|
| 168 |
tr_imgs, tr_masks = _extract_patches(image, train_mask, patch_size)
|
| 169 |
va_imgs, va_masks = _extract_patches(image, val_mask, patch_size)
|
| 170 |
|
|
|
|
|
|
|
|
|
|
| 171 |
train_labeled = int((train_mask != IGNORE_INDEX).sum())
|
| 172 |
val_labeled = int((val_mask != IGNORE_INDEX).sum())
|
| 173 |
|
|
@@ -179,21 +224,21 @@ def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict:
|
|
| 179 |
return " | ".join(parts)
|
| 180 |
|
| 181 |
status = "\n".join([
|
| 182 |
-
f"Train patches: **{len(tr_imgs)}** | Val patches: **{len(va_imgs)}** | Patch: **{patch_size}
|
| 183 |
-
"",
|
| 184 |
-
f"**
|
| 185 |
-
f"Raw values β classes: `{dict(zip(train_vals, CLASS_NAMES[:len(train_vals)]))}`",
|
| 186 |
-
f"Labeled pixels: **{train_labeled:,}** β {_class_dist(train_mask, train_labeled)}",
|
| 187 |
-
"",
|
| 188 |
-
f"**GROUND TRUTH.tif** `{info_val}`",
|
| 189 |
-
f"Raw values β classes: `{dict(zip(val_vals, CLASS_NAMES[:len(val_vals)]))}`",
|
| 190 |
-
f"Labeled pixels: **{val_labeled:,}** β {_class_dist(val_mask, val_labeled)}",
|
| 191 |
])
|
| 192 |
|
| 193 |
return {
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
-
"
|
| 197 |
-
"
|
| 198 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
}
|
|
|
|
| 19 |
|
| 20 |
# ββ File helpers βββββββββββββββββββββββββββββββββββββββββββββ
|
| 21 |
|
| 22 |
+
def _get_path(filename: str) -> str:
|
| 23 |
+
"""Use local file if it exists, otherwise download from HuggingFace."""
|
| 24 |
+
if os.path.exists(filename):
|
| 25 |
+
return filename
|
| 26 |
if not DATASET_REPO:
|
| 27 |
+
raise EnvironmentError(
|
| 28 |
+
f"'{filename}' not found locally and DATASET_REPO is not set."
|
| 29 |
+
)
|
| 30 |
return hf_hub_download(
|
| 31 |
repo_id=DATASET_REPO,
|
| 32 |
filename=filename,
|
|
|
|
| 44 |
|
| 45 |
|
| 46 |
def _read_mask_raw(path: str) -> Tuple[np.ndarray, object, str]:
|
|
|
|
| 47 |
with rasterio.open(path) as src:
|
| 48 |
data = src.read(1)
|
| 49 |
nodata = src.nodata
|
| 50 |
+
info = f"shape={src.shape} dtype={src.dtypes[0]} nodata={nodata}"
|
| 51 |
return data, nodata, info
|
| 52 |
|
| 53 |
|
|
|
|
| 63 |
return out
|
| 64 |
|
| 65 |
|
| 66 |
+
def _compute_ndvi(raw_image: np.ndarray) -> np.ndarray:
|
| 67 |
+
"""NDVI = (NIR - Red) / (NIR + Red). H_5=NIR (idx 4), H_4=Red (idx 3)."""
|
| 68 |
+
nir = raw_image[4].astype(np.float32)
|
| 69 |
+
red = raw_image[3].astype(np.float32)
|
| 70 |
+
denom = nir + red
|
| 71 |
+
return np.where(np.abs(denom) > 1e-6, (nir - red) / denom, 0.0).clip(-1.0, 1.0)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def _compute_ndwi(raw_image: np.ndarray) -> np.ndarray:
|
| 75 |
+
"""NDWI = (Green - NIR) / (Green + NIR). H_3=Green (idx 2), H_5=NIR (idx 4)."""
|
| 76 |
+
green = raw_image[2].astype(np.float32)
|
| 77 |
+
nir = raw_image[4].astype(np.float32)
|
| 78 |
+
denom = green + nir
|
| 79 |
+
return np.where(np.abs(denom) > 1e-6, (green - nir) / denom, 0.0).clip(-1.0, 1.0)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
def _remap_mask(raw: np.ndarray, nodata_val) -> Tuple[np.ndarray, List[int]]:
|
| 83 |
"""
|
| 84 |
+
Map raw pixel values -> 0..NUM_CLASSES-1.
|
| 85 |
+
Value 0 and nodata -> IGNORE_INDEX.
|
|
|
|
|
|
|
| 86 |
"""
|
| 87 |
if nodata_val is not None:
|
| 88 |
nodata_px = raw == int(nodata_val)
|
| 89 |
else:
|
| 90 |
nodata_px = np.zeros(raw.shape, dtype=bool)
|
| 91 |
|
| 92 |
+
ignore_px = nodata_px | (raw == 0)
|
|
|
|
|
|
|
|
|
|
| 93 |
valid = ~ignore_px
|
| 94 |
raw_unique = sorted(int(v) for v in np.unique(raw[valid]))
|
| 95 |
|
|
|
|
| 112 |
stride = patch_size // 2
|
| 113 |
imgs, masks = [], []
|
| 114 |
|
|
|
|
| 115 |
def steps(size):
|
| 116 |
s = list(range(0, size - patch_size + 1, stride))
|
| 117 |
if not s:
|
|
|
|
| 124 |
for x in steps(W):
|
| 125 |
pm = mask[y : y + patch_size, x : x + patch_size]
|
| 126 |
pi = image[:, y : y + patch_size, x : x + patch_size]
|
|
|
|
| 127 |
if pm.shape == (patch_size, patch_size) and (pm != IGNORE_INDEX).any():
|
| 128 |
imgs.append(pi)
|
| 129 |
masks.append(pm)
|
| 130 |
|
|
|
|
| 131 |
if not imgs:
|
| 132 |
ph = min(patch_size, H)
|
| 133 |
pw = min(patch_size, W)
|
|
|
|
| 141 |
return np.stack(imgs).astype(np.float32), np.stack(masks).astype(np.int64)
|
| 142 |
|
| 143 |
|
| 144 |
+
# ββ Spectral analysis βββββββββββββββββββββββββββββββββββββββββ
|
| 145 |
+
|
| 146 |
+
def compute_spectral_signatures(full_image: np.ndarray, full_mask: np.ndarray) -> Dict:
|
| 147 |
+
"""Per-class mean and std across the 7 normalized bands, from labeled pixels."""
|
| 148 |
+
sigs = {}
|
| 149 |
+
for cls_idx in range(NUM_CLASSES):
|
| 150 |
+
px = full_mask == cls_idx
|
| 151 |
+
if px.sum() == 0:
|
| 152 |
+
sigs[cls_idx] = {
|
| 153 |
+
"mean": np.zeros(NUM_CHANNELS, dtype=np.float32),
|
| 154 |
+
"std": np.zeros(NUM_CHANNELS, dtype=np.float32),
|
| 155 |
+
"n": 0,
|
| 156 |
+
}
|
| 157 |
+
else:
|
| 158 |
+
vals = full_image[:, px] # (7, N)
|
| 159 |
+
sigs[cls_idx] = {
|
| 160 |
+
"mean": vals.mean(axis=1).astype(np.float32),
|
| 161 |
+
"std": vals.std(axis=1).astype(np.float32),
|
| 162 |
+
"n": int(px.sum()),
|
| 163 |
+
}
|
| 164 |
+
return sigs
|
| 165 |
+
|
| 166 |
+
|
| 167 |
# ββ Dataset class βββββββββββββββββββββββββββββββββββββββββββββ
|
| 168 |
|
| 169 |
class MultiSpectralDataset(Dataset):
|
|
|
|
| 181 |
# ββ Public API ββββββββββββββββββββββββββββοΏ½οΏ½οΏ½βββββββββββββββββββ
|
| 182 |
|
| 183 |
def load_data(patch_size: int = DEFAULT_PATCH_SIZE) -> Dict:
|
| 184 |
+
# Read raw bands
|
| 185 |
+
raw_bands = [_read_band(_get_path(f)) for f in BAND_FILES]
|
| 186 |
+
raw_image = np.stack(raw_bands, axis=0) # (7, H, W) raw float32
|
| 187 |
+
|
| 188 |
+
# Compute spectral indices from raw values (ratio, so must use raw)
|
| 189 |
+
ndvi = _compute_ndvi(raw_image)
|
| 190 |
+
ndwi = _compute_ndwi(raw_image)
|
| 191 |
|
| 192 |
+
# Normalize
|
| 193 |
+
image = _normalize(raw_image) # (7, H, W) normalized [0,1]
|
|
|
|
| 194 |
|
| 195 |
+
# Read masks
|
| 196 |
+
raw_train, nd_train, info_train = _read_mask_raw(_get_path(TRAIN_MASK_FILE))
|
| 197 |
+
raw_val, nd_val, info_val = _read_mask_raw(_get_path(VAL_MASK_FILE))
|
| 198 |
+
|
| 199 |
+
# Remap
|
| 200 |
train_mask, train_vals = _remap_mask(raw_train, nd_train)
|
| 201 |
val_mask, val_vals = _remap_mask(raw_val, nd_val)
|
| 202 |
|
| 203 |
if not train_vals:
|
| 204 |
raise ValueError(
|
| 205 |
+
f"TRAINING.tif has no labeled pixels. Info: {info_train} | "
|
| 206 |
+
f"Unique raw values: {np.unique(raw_train).tolist()}"
|
| 207 |
)
|
| 208 |
|
| 209 |
+
# Patches
|
| 210 |
tr_imgs, tr_masks = _extract_patches(image, train_mask, patch_size)
|
| 211 |
va_imgs, va_masks = _extract_patches(image, val_mask, patch_size)
|
| 212 |
|
| 213 |
+
# Spectral signatures from training labels
|
| 214 |
+
signatures = compute_spectral_signatures(image, train_mask)
|
| 215 |
+
|
| 216 |
train_labeled = int((train_mask != IGNORE_INDEX).sum())
|
| 217 |
val_labeled = int((val_mask != IGNORE_INDEX).sum())
|
| 218 |
|
|
|
|
| 224 |
return " | ".join(parts)
|
| 225 |
|
| 226 |
status = "\n".join([
|
| 227 |
+
f"Train patches: **{len(tr_imgs)}** | Val patches: **{len(va_imgs)}** | Patch: **{patch_size}x{patch_size}**",
|
| 228 |
+
f"Training labels: **{train_labeled:,}** px β {_class_dist(train_mask, train_labeled)}",
|
| 229 |
+
f"Validation labels: **{val_labeled:,}** px β {_class_dist(val_mask, val_labeled)}",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
])
|
| 231 |
|
| 232 |
return {
|
| 233 |
+
"full_image": image,
|
| 234 |
+
"full_train_mask": train_mask,
|
| 235 |
+
"full_val_mask": val_mask,
|
| 236 |
+
"ndvi": ndvi,
|
| 237 |
+
"ndwi": ndwi,
|
| 238 |
+
"signatures": signatures,
|
| 239 |
+
"train_images": tr_imgs,
|
| 240 |
+
"train_masks": tr_masks,
|
| 241 |
+
"val_images": va_imgs,
|
| 242 |
+
"val_masks": va_masks,
|
| 243 |
+
"status": status,
|
| 244 |
}
|
|
@@ -4,3 +4,5 @@ Pillow
|
|
| 4 |
torch
|
| 5 |
rasterio
|
| 6 |
huggingface_hub
|
|
|
|
|
|
|
|
|
| 4 |
torch
|
| 5 |
rasterio
|
| 6 |
huggingface_hub
|
| 7 |
+
matplotlib
|
| 8 |
+
scipy
|
|
@@ -9,11 +9,20 @@ import gradio as gr
|
|
| 9 |
|
| 10 |
from config import (
|
| 11 |
DEVICE, NUM_CHANNELS, NUM_CLASSES, DEFAULT_PATCH_SIZE,
|
| 12 |
-
BAND_NAMES, CLASS_NAMES, IGNORE_INDEX,
|
|
|
|
| 13 |
)
|
| 14 |
from data import MultiSpectralDataset, load_data
|
| 15 |
from model import SmallUNet
|
| 16 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
from metrics import compute_metrics, metrics_markdown
|
| 18 |
|
| 19 |
|
|
@@ -22,8 +31,8 @@ from metrics import compute_metrics, metrics_markdown
|
|
| 22 |
def build_prediction_cache(
|
| 23 |
model: nn.Module, images: np.ndarray, batch_size: int = 8
|
| 24 |
) -> Tuple[np.ndarray, np.ndarray]:
|
| 25 |
-
|
| 26 |
-
ds = MultiSpectralDataset(images,
|
| 27 |
loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
|
| 28 |
preds, probs = [], []
|
| 29 |
model.eval()
|
|
@@ -36,16 +45,19 @@ def build_prediction_cache(
|
|
| 36 |
return np.concatenate(preds, axis=0), np.concatenate(probs, axis=0)
|
| 37 |
|
| 38 |
|
| 39 |
-
# ββ
|
| 40 |
|
| 41 |
-
def
|
| 42 |
-
|
|
|
|
|
|
|
| 43 |
|
| 44 |
|
| 45 |
def pixel_info_markdown(
|
| 46 |
x: int, y: int,
|
| 47 |
img7: np.ndarray, gt: np.ndarray,
|
| 48 |
-
pred: Optional[np.ndarray]
|
|
|
|
| 49 |
) -> str:
|
| 50 |
h, w = gt.shape
|
| 51 |
x = int(np.clip(x, 0, w - 1))
|
|
@@ -53,165 +65,244 @@ def pixel_info_markdown(
|
|
| 53 |
|
| 54 |
gt_class = int(gt[y, x])
|
| 55 |
gt_name = CLASS_NAMES[gt_class] if gt_class != IGNORE_INDEX else "Unlabeled"
|
| 56 |
-
lines
|
| 57 |
|
| 58 |
if pred is not None:
|
|
|
|
|
|
|
| 59 |
if gt_class != IGNORE_INDEX:
|
| 60 |
-
pred_class =
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
f"{CLASS_NAMES[i]} {probs[i, y, x] * 100:.1f}%" for i in top_ids
|
| 67 |
-
))
|
| 68 |
-
else:
|
| 69 |
-
lines.append("- Prediction: β *(unlabeled pixel)*")
|
| 70 |
-
else:
|
| 71 |
-
lines.append("- Prediction: β")
|
| 72 |
|
| 73 |
lines += ["", "**Band values**"] + [
|
| 74 |
-
f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
return "\n".join(lines)
|
| 77 |
|
| 78 |
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
|
|
|
| 84 |
|
| 85 |
-
def
|
| 86 |
-
dataset_state: Dict,
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
return no_data
|
| 93 |
val_images = dataset_state["val_images"]
|
| 94 |
val_masks = dataset_state["val_masks"]
|
| 95 |
-
|
| 96 |
-
return no_data
|
| 97 |
|
| 98 |
-
idx = max(0, min(int(sample_idx), len(val_images) - 1))
|
| 99 |
rgb = multispectral_to_rgb(val_images[idx])
|
| 100 |
gt = val_masks[idx]
|
|
|
|
| 101 |
|
| 102 |
-
if
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
if idx >= len(exp["val_preds"]):
|
| 111 |
-
return (
|
| 112 |
-
rgb, mask_to_color(gt), overlay_mask(rgb, gt),
|
| 113 |
-
"### Dataset reloaded β retrain to refresh",
|
| 114 |
-
_blank(),
|
| 115 |
-
"### Retrain needed",
|
| 116 |
-
)
|
| 117 |
|
| 118 |
pred = exp["val_preds"][idx].astype(np.int64)
|
| 119 |
probs = exp["val_probs"][idx].astype(np.float32)
|
| 120 |
-
|
| 121 |
return (
|
| 122 |
rgb,
|
| 123 |
mask_to_color(pred),
|
| 124 |
overlay_mask(rgb, pred),
|
| 125 |
-
metrics_markdown(
|
| 126 |
correctness_overlay(rgb, pred, gt),
|
| 127 |
-
pixel_info_markdown(0, 0, val_images[idx], gt, pred, probs),
|
| 128 |
-
)
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
def render_compare_view(
|
| 132 |
-
dataset_state, experiments, name_a, name_b, sample_idx: int
|
| 133 |
-
) -> Tuple:
|
| 134 |
-
return (
|
| 135 |
-
*render_experiment_panel(dataset_state, _get_exp_by_name(experiments, name_a), sample_idx),
|
| 136 |
-
*render_experiment_panel(dataset_state, _get_exp_by_name(experiments, name_b), sample_idx),
|
| 137 |
)
|
| 138 |
|
| 139 |
|
| 140 |
-
def experiments_table_markdown(experiments: List[Dict]) -> str:
|
| 141 |
-
if not experiments:
|
| 142 |
-
return "### No models trained yet"
|
| 143 |
-
lines = [
|
| 144 |
-
"### Trained models", "",
|
| 145 |
-
"| # | Name | LR | Epochs | Base Ch | Val Acc | mIoU |",
|
| 146 |
-
"|---|---|---:|---:|---:|---:|---:|",
|
| 147 |
-
]
|
| 148 |
-
for i, e in enumerate(experiments):
|
| 149 |
-
cfg = e["config"]
|
| 150 |
-
lines.append(
|
| 151 |
-
f"| {i + 1} | {e['name']} | {cfg['learning_rate']:.4f} | {cfg['epochs']} "
|
| 152 |
-
f"| {cfg['base_channels']} | {e['global_metrics']['overall_acc'] * 100:.1f}% "
|
| 153 |
-
f"| {e['global_metrics']['miou'] * 100:.1f}% |"
|
| 154 |
-
)
|
| 155 |
-
return "\n".join(lines)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
# ββ Gradio action functions βββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββ
|
| 159 |
|
| 160 |
def load_dataset_action(patch_size: int):
|
| 161 |
-
|
| 162 |
-
patch_size = int(patch_size)
|
| 163 |
dataset_state = load_data(patch_size)
|
| 164 |
-
val_count
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
"
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
])
|
| 174 |
|
|
|
|
|
|
|
| 175 |
return (
|
| 176 |
dataset_state,
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
dataset_info,
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 187 |
)
|
| 188 |
|
| 189 |
|
| 190 |
-
def
|
| 191 |
-
if dataset_state is None
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
)
|
| 205 |
|
| 206 |
|
| 207 |
-
def
|
| 208 |
-
if dataset_state is None
|
| 209 |
-
raise gr.Error("Load
|
| 210 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
def train_experiment(
|
| 214 |
dataset_state: Dict,
|
|
|
|
| 215 |
experiments: List[Dict],
|
| 216 |
learning_rate: float,
|
| 217 |
batch_size: int,
|
|
@@ -222,6 +313,11 @@ def train_experiment(
|
|
| 222 |
):
|
| 223 |
if dataset_state is None or "train_images" not in dataset_state:
|
| 224 |
raise gr.Error("Load a dataset first.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
|
| 226 |
loader = DataLoader(
|
| 227 |
MultiSpectralDataset(dataset_state["train_images"], dataset_state["train_masks"]),
|
|
@@ -233,39 +329,39 @@ def train_experiment(
|
|
| 233 |
|
| 234 |
n_epochs = int(epochs)
|
| 235 |
history = []
|
| 236 |
-
for
|
| 237 |
-
progress(
|
| 238 |
model.train()
|
| 239 |
-
|
| 240 |
for xb, yb in loader:
|
| 241 |
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
|
| 242 |
optimizer.zero_grad(set_to_none=True)
|
| 243 |
loss = criterion(model(xb), yb)
|
| 244 |
loss.backward()
|
| 245 |
optimizer.step()
|
| 246 |
-
|
| 247 |
n += 1
|
| 248 |
-
history.append(
|
| 249 |
|
| 250 |
progress(0.95, desc="Running validation inference...")
|
| 251 |
val_preds, val_probs = build_prediction_cache(
|
| 252 |
model, dataset_state["val_images"], batch_size=max(1, int(batch_size))
|
| 253 |
)
|
| 254 |
global_metrics = compute_metrics(
|
| 255 |
-
val_preds.reshape(-1), dataset_state["val_masks"].reshape(-1)
|
| 256 |
)
|
| 257 |
-
progress(1.0
|
| 258 |
|
| 259 |
-
base = (run_name or f"Run {len(experiments)
|
| 260 |
existing = {e["name"] for e in experiments}
|
| 261 |
-
name,
|
| 262 |
while name in existing:
|
| 263 |
-
name = f"{base} ({
|
| 264 |
-
|
| 265 |
|
| 266 |
experiment = {
|
| 267 |
-
"name":
|
| 268 |
-
"config":
|
| 269 |
"learning_rate": float(learning_rate),
|
| 270 |
"batch_size": int(batch_size),
|
| 271 |
"epochs": int(epochs),
|
|
@@ -273,63 +369,59 @@ def train_experiment(
|
|
| 273 |
},
|
| 274 |
"train_loss_history": history,
|
| 275 |
"global_metrics": global_metrics,
|
| 276 |
-
"val_preds":
|
| 277 |
-
"val_probs":
|
| 278 |
}
|
| 279 |
-
|
| 280 |
experiments = experiments + [experiment]
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
f"
|
| 284 |
-
f"
|
| 285 |
-
f"
|
| 286 |
-
f"
|
| 287 |
])
|
| 288 |
|
| 289 |
-
choices
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
return (
|
| 291 |
-
experiments,
|
| 292 |
-
|
| 293 |
-
gr.update(
|
| 294 |
-
|
|
|
|
|
|
|
|
|
|
| 295 |
)
|
| 296 |
|
| 297 |
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
dataset_state["val_images"][idx], dataset_state["val_masks"][idx],
|
| 308 |
-
None, None,
|
| 309 |
-
)
|
| 310 |
|
| 311 |
|
| 312 |
-
def
|
| 313 |
-
evt: gr.SelectData,
|
| 314 |
-
|
| 315 |
) -> str:
|
| 316 |
try:
|
| 317 |
-
if dataset_state is None
|
| 318 |
-
return "
|
| 319 |
-
idx = max(0, min(int(
|
| 320 |
exp = _get_exp_by_name(experiments, model_name)
|
| 321 |
x, y = evt.index
|
| 322 |
img7 = dataset_state["val_images"][idx]
|
| 323 |
gt = dataset_state["val_masks"][idx]
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
return pixel_info_markdown(int(x), int(y), img7, gt,
|
| 327 |
except Exception as e:
|
| 328 |
-
return f"
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
def handle_click_exp_a(evt, dataset_state, experiments, sel_a, sample_idx):
|
| 332 |
-
return _handle_click_experiment(evt, dataset_state, experiments, sel_a, sample_idx)
|
| 333 |
-
|
| 334 |
-
def handle_click_exp_b(evt, dataset_state, experiments, sel_b, sample_idx):
|
| 335 |
-
return _handle_click_experiment(evt, dataset_state, experiments, sel_b, sample_idx)
|
|
|
|
| 9 |
|
| 10 |
from config import (
|
| 11 |
DEVICE, NUM_CHANNELS, NUM_CLASSES, DEFAULT_PATCH_SIZE,
|
| 12 |
+
BAND_NAMES, BAND_DESCRIPTIONS, CLASS_NAMES, IGNORE_INDEX,
|
| 13 |
+
COMPOSITE_PRESETS, MAX_EXPERIMENTS,
|
| 14 |
)
|
| 15 |
from data import MultiSpectralDataset, load_data
|
| 16 |
from model import SmallUNet
|
| 17 |
+
from baseline import run_knn_baseline
|
| 18 |
+
from visualize import (
|
| 19 |
+
render_composite, render_single_band,
|
| 20 |
+
add_labels_overlay, multispectral_to_rgb,
|
| 21 |
+
mask_to_color, overlay_mask, correctness_overlay,
|
| 22 |
+
render_full_prediction_overlay,
|
| 23 |
+
render_spectral_signatures_chart, render_index_map,
|
| 24 |
+
_blank_rgb,
|
| 25 |
+
)
|
| 26 |
from metrics import compute_metrics, metrics_markdown
|
| 27 |
|
| 28 |
|
|
|
|
| 31 |
def build_prediction_cache(
|
| 32 |
model: nn.Module, images: np.ndarray, batch_size: int = 8
|
| 33 |
) -> Tuple[np.ndarray, np.ndarray]:
|
| 34 |
+
dummy = np.zeros((len(images), images.shape[-2], images.shape[-1]), dtype=np.int64)
|
| 35 |
+
ds = MultiSpectralDataset(images, dummy)
|
| 36 |
loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
|
| 37 |
preds, probs = [], []
|
| 38 |
model.eval()
|
|
|
|
| 45 |
return np.concatenate(preds, axis=0), np.concatenate(probs, axis=0)
|
| 46 |
|
| 47 |
|
| 48 |
+
# ββ Shared render helpers ββββββββββββββββββββββββββββββββββββ
|
| 49 |
|
| 50 |
+
def _get_exp_by_name(experiments: List[Dict], name: Optional[str]) -> Optional[Dict]:
|
| 51 |
+
if not name:
|
| 52 |
+
return None
|
| 53 |
+
return next((e for e in experiments if e["name"] == name), None)
|
| 54 |
|
| 55 |
|
| 56 |
def pixel_info_markdown(
|
| 57 |
x: int, y: int,
|
| 58 |
img7: np.ndarray, gt: np.ndarray,
|
| 59 |
+
pred: Optional[np.ndarray] = None,
|
| 60 |
+
probs: Optional[np.ndarray] = None,
|
| 61 |
) -> str:
|
| 62 |
h, w = gt.shape
|
| 63 |
x = int(np.clip(x, 0, w - 1))
|
|
|
|
| 65 |
|
| 66 |
gt_class = int(gt[y, x])
|
| 67 |
gt_name = CLASS_NAMES[gt_class] if gt_class != IGNORE_INDEX else "Unlabeled"
|
| 68 |
+
lines = [f"**Pixel ({x}, {y})**", f"Ground truth: **{gt_name}**"]
|
| 69 |
|
| 70 |
if pred is not None:
|
| 71 |
+
pred_class = int(pred[y, x])
|
| 72 |
+
lines.append(f"Prediction: **{CLASS_NAMES[pred_class]}**")
|
| 73 |
if gt_class != IGNORE_INDEX:
|
| 74 |
+
lines.append(f"Correct: **{'Yes' if pred_class == gt_class else 'No'}**")
|
| 75 |
+
if probs is not None:
|
| 76 |
+
top = np.argsort(probs[:, y, x])[::-1][:3]
|
| 77 |
+
lines.append("Top probs: " + ", ".join(
|
| 78 |
+
f"{CLASS_NAMES[i]} {probs[i, y, x]*100:.1f}%" for i in top
|
| 79 |
+
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
|
| 81 |
lines += ["", "**Band values**"] + [
|
| 82 |
+
f"{BAND_DESCRIPTIONS[b]}: **{float(img7[b, y, x]):.3f}**"
|
| 83 |
+
for b in range(img7.shape[0])
|
| 84 |
+
]
|
| 85 |
+
return "\n\n".join(lines)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def experiments_table_markdown(experiments: List[Dict]) -> str:
|
| 89 |
+
if not experiments:
|
| 90 |
+
return "No experiments trained yet."
|
| 91 |
+
lines = [
|
| 92 |
+
"| # | Name | LR | Epochs | Channels | Val Acc | mIoU |",
|
| 93 |
+
"|---|---|---:|---:|---:|---:|---:|",
|
| 94 |
]
|
| 95 |
+
for i, e in enumerate(experiments):
|
| 96 |
+
cfg = e["config"]
|
| 97 |
+
lines.append(
|
| 98 |
+
f"| {i+1} | {e['name']} | {cfg['learning_rate']:.4f} | {cfg['epochs']} "
|
| 99 |
+
f"| {cfg['base_channels']} "
|
| 100 |
+
f"| {e['global_metrics']['overall_acc']*100:.1f}% "
|
| 101 |
+
f"| {e['global_metrics']['miou']*100:.1f}% |"
|
| 102 |
+
)
|
| 103 |
return "\n".join(lines)
|
| 104 |
|
| 105 |
|
| 106 |
+
# ββ Step 1 render helpers ββββββββββββββββββββββββββοΏ½οΏ½βββββββββ
|
| 107 |
+
|
| 108 |
+
def _render_step1_image(dataset_state: Dict, composite_choice: str) -> np.ndarray:
|
| 109 |
+
full = dataset_state["full_image"]
|
| 110 |
+
if composite_choice in COMPOSITE_PRESETS:
|
| 111 |
+
r, g, b = COMPOSITE_PRESETS[composite_choice]
|
| 112 |
+
base = render_composite(full, r, g, b)
|
| 113 |
+
else:
|
| 114 |
+
band_idx = BAND_DESCRIPTIONS.index(composite_choice)
|
| 115 |
+
base = render_single_band(full, band_idx)
|
| 116 |
+
return add_labels_overlay(base, dataset_state["full_train_mask"], dataset_state["full_val_mask"])
|
| 117 |
+
|
| 118 |
|
| 119 |
+
# ββ Step 4 render helpers ββββββββββββββββββββββββββββββββββββ
|
| 120 |
|
| 121 |
+
def _render_step4_row(
|
| 122 |
+
dataset_state: Dict,
|
| 123 |
+
baseline_state: Optional[Dict],
|
| 124 |
+
experiments: List[Dict],
|
| 125 |
+
patch_idx: int,
|
| 126 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
| 127 |
+
"""Returns (rgb, gt_overlay, baseline_overlay, unet_overlay)."""
|
|
|
|
| 128 |
val_images = dataset_state["val_images"]
|
| 129 |
val_masks = dataset_state["val_masks"]
|
| 130 |
+
idx = max(0, min(patch_idx, len(val_images) - 1))
|
|
|
|
| 131 |
|
|
|
|
| 132 |
rgb = multispectral_to_rgb(val_images[idx])
|
| 133 |
gt = val_masks[idx]
|
| 134 |
+
gt_ov = overlay_mask(rgb, gt)
|
| 135 |
|
| 136 |
+
if baseline_state is not None and idx < len(baseline_state["val_preds"]):
|
| 137 |
+
bl_ov = overlay_mask(rgb, baseline_state["val_preds"][idx])
|
| 138 |
+
else:
|
| 139 |
+
bl_ov = _blank_rgb(*rgb.shape[:2])
|
| 140 |
+
|
| 141 |
+
if experiments and idx < len(experiments[-1]["val_preds"]):
|
| 142 |
+
un_ov = overlay_mask(rgb, experiments[-1]["val_preds"][idx])
|
| 143 |
+
else:
|
| 144 |
+
un_ov = _blank_rgb(*rgb.shape[:2])
|
| 145 |
+
|
| 146 |
+
return rgb, gt_ov, bl_ov, un_ov
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# ββ Step 5 render helpers ββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
|
| 151 |
+
def render_step5_panel(
|
| 152 |
+
dataset_state: Dict,
|
| 153 |
+
exp: Optional[Dict],
|
| 154 |
+
patch_idx: int,
|
| 155 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, str, np.ndarray]:
|
| 156 |
+
"""Returns (rgb, pred_color, overlay, metrics_md, error_map)."""
|
| 157 |
+
blank = _blank_rgb()
|
| 158 |
+
if dataset_state is None or exp is None:
|
| 159 |
+
return blank, blank, blank, "No model selected.", blank
|
| 160 |
+
|
| 161 |
+
val_images = dataset_state["val_images"]
|
| 162 |
+
val_masks = dataset_state["val_masks"]
|
| 163 |
+
idx = max(0, min(patch_idx, len(val_images) - 1))
|
| 164 |
+
rgb = multispectral_to_rgb(val_images[idx])
|
| 165 |
+
gt = val_masks[idx]
|
| 166 |
|
| 167 |
if idx >= len(exp["val_preds"]):
|
| 168 |
+
return rgb, mask_to_color(gt), overlay_mask(rgb, gt), "Dataset reloaded β retrain.", blank
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
|
| 170 |
pred = exp["val_preds"][idx].astype(np.int64)
|
| 171 |
probs = exp["val_probs"][idx].astype(np.float32)
|
| 172 |
+
m = compute_metrics(pred, gt)
|
| 173 |
return (
|
| 174 |
rgb,
|
| 175 |
mask_to_color(pred),
|
| 176 |
overlay_mask(rgb, pred),
|
| 177 |
+
metrics_markdown(m, title=f"{exp['name']} Β· patch {idx}"),
|
| 178 |
correctness_overlay(rgb, pred, gt),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
)
|
| 180 |
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
# ββ Gradio action functions βββββοΏ½οΏ½οΏ½ββββββββββββββββββββββββββββ
|
| 183 |
|
| 184 |
def load_dataset_action(patch_size: int):
|
| 185 |
+
patch_size = int(patch_size)
|
|
|
|
| 186 |
dataset_state = load_data(patch_size)
|
| 187 |
+
val_count = len(dataset_state["val_images"])
|
| 188 |
+
|
| 189 |
+
step1_img = _render_step1_image(dataset_state, "Natural Color (R/G/B)")
|
| 190 |
+
sig_chart = render_spectral_signatures_chart(dataset_state["signatures"])
|
| 191 |
+
ndvi_map = render_index_map(
|
| 192 |
+
dataset_state["ndvi"], "NDVI",
|
| 193 |
+
dataset_state["full_train_mask"], dataset_state["full_val_mask"],
|
| 194 |
+
)
|
| 195 |
+
blank = _blank_rgb()
|
| 196 |
+
|
| 197 |
+
dataset_info = "\n\n".join([
|
| 198 |
+
"**Dataset loaded.**",
|
| 199 |
+
dataset_state["status"],
|
| 200 |
+
f"Bands: {', '.join(BAND_NAMES)} | Classes: {', '.join(CLASS_NAMES)}",
|
| 201 |
+
"**Squares** = training labels Β· **Circles** = validation labels",
|
| 202 |
])
|
| 203 |
|
| 204 |
+
slider_upd = gr.update(maximum=max(0, val_count - 1), value=0)
|
| 205 |
+
|
| 206 |
return (
|
| 207 |
dataset_state,
|
| 208 |
+
None, # baseline_state reset
|
| 209 |
+
[], # experiments_state reset
|
| 210 |
+
# Tab 1
|
| 211 |
dataset_info,
|
| 212 |
+
step1_img,
|
| 213 |
+
"Click the image to inspect a pixel.",
|
| 214 |
+
# Tab 2
|
| 215 |
+
sig_chart,
|
| 216 |
+
ndvi_map,
|
| 217 |
+
# Tab 3
|
| 218 |
+
"Run KNN baseline after loading the dataset.",
|
| 219 |
+
blank,
|
| 220 |
+
# Tab 4
|
| 221 |
+
"Train a model in Step 4.",
|
| 222 |
+
slider_upd,
|
| 223 |
+
blank, blank, blank,
|
| 224 |
+
# Tab 5
|
| 225 |
+
"No experiments yet.",
|
| 226 |
+
gr.update(choices=[], value=None),
|
| 227 |
+
gr.update(choices=[], value=None),
|
| 228 |
)
|
| 229 |
|
| 230 |
|
| 231 |
+
def update_step1_composite(dataset_state, composite_choice: str):
|
| 232 |
+
if dataset_state is None:
|
| 233 |
+
return _blank_rgb(), "Load the dataset first."
|
| 234 |
+
img = _render_step1_image(dataset_state, composite_choice)
|
| 235 |
+
return img, "Click the image to inspect a pixel."
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def handle_click_step1(evt: gr.SelectData, dataset_state):
|
| 239 |
+
if dataset_state is None:
|
| 240 |
+
return "Load the dataset first."
|
| 241 |
+
x, y = evt.index
|
| 242 |
+
full = dataset_state["full_image"]
|
| 243 |
+
fmask = dataset_state["full_val_mask"]
|
| 244 |
+
H, W = fmask.shape
|
| 245 |
+
x, y = int(np.clip(x, 0, W-1)), int(np.clip(y, 0, H-1))
|
| 246 |
+
|
| 247 |
+
cls = int(fmask[y, x])
|
| 248 |
+
label = CLASS_NAMES[cls] if cls != IGNORE_INDEX else "Unlabeled"
|
| 249 |
+
lines = [
|
| 250 |
+
f"**Pixel ({x}, {y})** | Val label: **{label}**", "",
|
| 251 |
+
"| Band | Value |", "|---|---:|",
|
| 252 |
+
] + [f"| {BAND_DESCRIPTIONS[b]} | {float(full[b, y, x]):.4f} |" for b in range(7)]
|
| 253 |
+
return "\n".join(lines)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def update_step2_index(dataset_state, index_choice: str):
|
| 257 |
+
if dataset_state is None:
|
| 258 |
+
return _blank_rgb()
|
| 259 |
+
key = index_choice.lower()
|
| 260 |
+
arr = dataset_state[key]
|
| 261 |
+
return render_index_map(
|
| 262 |
+
arr, index_choice,
|
| 263 |
+
dataset_state["full_train_mask"], dataset_state["full_val_mask"],
|
| 264 |
)
|
| 265 |
|
| 266 |
|
| 267 |
+
def run_baseline_action(dataset_state, k: int, progress=gr.Progress()):
|
| 268 |
+
if dataset_state is None:
|
| 269 |
+
raise gr.Error("Load the dataset first.")
|
| 270 |
+
progress(0.1, desc="Running KNN on full scene...")
|
| 271 |
+
k = int(k)
|
| 272 |
+
full_pred, val_preds, metrics, metrics_md = run_knn_baseline(
|
| 273 |
+
dataset_state["full_image"],
|
| 274 |
+
dataset_state["full_train_mask"],
|
| 275 |
+
dataset_state["full_val_mask"],
|
| 276 |
+
dataset_state["val_images"],
|
| 277 |
+
k=k,
|
| 278 |
+
)
|
| 279 |
+
progress(0.9, desc="Rendering...")
|
| 280 |
+
baseline_state = {
|
| 281 |
+
"k": k,
|
| 282 |
+
"full_pred": full_pred,
|
| 283 |
+
"val_preds": val_preds,
|
| 284 |
+
"metrics": metrics,
|
| 285 |
+
}
|
| 286 |
+
full_ov = render_full_prediction_overlay(
|
| 287 |
+
dataset_state["full_image"], full_pred, dataset_state["full_val_mask"],
|
| 288 |
+
)
|
| 289 |
+
progress(1.0)
|
| 290 |
+
return baseline_state, metrics_md, full_ov
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def update_step4_patch(dataset_state, baseline_state, experiments, patch_idx: int):
|
| 294 |
+
if dataset_state is None:
|
| 295 |
+
blank = _blank_rgb()
|
| 296 |
+
return blank, blank, blank
|
| 297 |
+
_, gt_ov, bl_ov, un_ov = _render_step4_row(
|
| 298 |
+
dataset_state, baseline_state, experiments, int(patch_idx)
|
| 299 |
+
)
|
| 300 |
+
return gt_ov, bl_ov, un_ov
|
| 301 |
|
| 302 |
|
| 303 |
def train_experiment(
|
| 304 |
dataset_state: Dict,
|
| 305 |
+
baseline_state: Optional[Dict],
|
| 306 |
experiments: List[Dict],
|
| 307 |
learning_rate: float,
|
| 308 |
batch_size: int,
|
|
|
|
| 313 |
):
|
| 314 |
if dataset_state is None or "train_images" not in dataset_state:
|
| 315 |
raise gr.Error("Load a dataset first.")
|
| 316 |
+
if len(experiments) >= MAX_EXPERIMENTS:
|
| 317 |
+
raise gr.Error(
|
| 318 |
+
f"Maximum {MAX_EXPERIMENTS} experiments reached. "
|
| 319 |
+
"Go to Step 5 to compare, then reload data to start fresh."
|
| 320 |
+
)
|
| 321 |
|
| 322 |
loader = DataLoader(
|
| 323 |
MultiSpectralDataset(dataset_state["train_images"], dataset_state["train_masks"]),
|
|
|
|
| 329 |
|
| 330 |
n_epochs = int(epochs)
|
| 331 |
history = []
|
| 332 |
+
for ep in range(n_epochs):
|
| 333 |
+
progress(ep / n_epochs, desc=f"Epoch {ep+1}/{n_epochs}")
|
| 334 |
model.train()
|
| 335 |
+
total, n = 0.0, 0
|
| 336 |
for xb, yb in loader:
|
| 337 |
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
|
| 338 |
optimizer.zero_grad(set_to_none=True)
|
| 339 |
loss = criterion(model(xb), yb)
|
| 340 |
loss.backward()
|
| 341 |
optimizer.step()
|
| 342 |
+
total += float(loss.item())
|
| 343 |
n += 1
|
| 344 |
+
history.append(total / max(1, n))
|
| 345 |
|
| 346 |
progress(0.95, desc="Running validation inference...")
|
| 347 |
val_preds, val_probs = build_prediction_cache(
|
| 348 |
model, dataset_state["val_images"], batch_size=max(1, int(batch_size))
|
| 349 |
)
|
| 350 |
global_metrics = compute_metrics(
|
| 351 |
+
val_preds.reshape(-1), dataset_state["val_masks"].reshape(-1)
|
| 352 |
)
|
| 353 |
+
progress(1.0)
|
| 354 |
|
| 355 |
+
base = (run_name or f"Run {len(experiments)+1}").strip()
|
| 356 |
existing = {e["name"] for e in experiments}
|
| 357 |
+
name, ctr = base, 2
|
| 358 |
while name in existing:
|
| 359 |
+
name = f"{base} ({ctr})"
|
| 360 |
+
ctr += 1
|
| 361 |
|
| 362 |
experiment = {
|
| 363 |
+
"name": name,
|
| 364 |
+
"config": {
|
| 365 |
"learning_rate": float(learning_rate),
|
| 366 |
"batch_size": int(batch_size),
|
| 367 |
"epochs": int(epochs),
|
|
|
|
| 369 |
},
|
| 370 |
"train_loss_history": history,
|
| 371 |
"global_metrics": global_metrics,
|
| 372 |
+
"val_preds": val_preds.astype(np.int64),
|
| 373 |
+
"val_probs": val_probs.astype(np.float32),
|
| 374 |
}
|
|
|
|
| 375 |
experiments = experiments + [experiment]
|
| 376 |
+
|
| 377 |
+
summary = "\n\n".join([
|
| 378 |
+
f"**Training finished β {name}**",
|
| 379 |
+
f"Device: **{DEVICE}** | Epochs: **{n_epochs}** | Final loss: **{history[-1]:.4f}**",
|
| 380 |
+
f"Val accuracy: **{global_metrics['overall_acc']*100:.2f}%** (labeled px only)",
|
| 381 |
+
f"Val mIoU: **{global_metrics['miou']*100:.2f}%**",
|
| 382 |
])
|
| 383 |
|
| 384 |
+
choices = [e["name"] for e in experiments]
|
| 385 |
+
val_count = len(dataset_state["val_images"])
|
| 386 |
+
|
| 387 |
+
_, gt_ov, bl_ov, un_ov = _render_step4_row(dataset_state, baseline_state, experiments, 0)
|
| 388 |
+
|
| 389 |
return (
|
| 390 |
+
experiments,
|
| 391 |
+
summary,
|
| 392 |
+
gr.update(maximum=max(0, val_count-1), value=0), # step4 patch slider
|
| 393 |
+
gt_ov, bl_ov, un_ov,
|
| 394 |
+
experiments_table_markdown(experiments), # step5 table
|
| 395 |
+
gr.update(choices=choices, value=None), # step5 sel_a
|
| 396 |
+
gr.update(choices=choices, value=None), # step5 sel_b
|
| 397 |
)
|
| 398 |
|
| 399 |
|
| 400 |
+
def update_step5_comparison(
|
| 401 |
+
dataset_state, experiments, sel_a, sel_b, patch_idx: int
|
| 402 |
+
):
|
| 403 |
+
idx = int(patch_idx)
|
| 404 |
+
exp_a = _get_exp_by_name(experiments, sel_a)
|
| 405 |
+
exp_b = _get_exp_by_name(experiments, sel_b)
|
| 406 |
+
a_outs = render_step5_panel(dataset_state, exp_a, idx)
|
| 407 |
+
b_outs = render_step5_panel(dataset_state, exp_b, idx)
|
| 408 |
+
return (*a_outs, *b_outs)
|
|
|
|
|
|
|
|
|
|
| 409 |
|
| 410 |
|
| 411 |
+
def handle_click_step5(
|
| 412 |
+
evt: gr.SelectData,
|
| 413 |
+
dataset_state, experiments, model_name, patch_idx: int,
|
| 414 |
) -> str:
|
| 415 |
try:
|
| 416 |
+
if dataset_state is None:
|
| 417 |
+
return "No dataset loaded."
|
| 418 |
+
idx = max(0, min(int(patch_idx), len(dataset_state["val_images"])-1))
|
| 419 |
exp = _get_exp_by_name(experiments, model_name)
|
| 420 |
x, y = evt.index
|
| 421 |
img7 = dataset_state["val_images"][idx]
|
| 422 |
gt = dataset_state["val_masks"][idx]
|
| 423 |
+
pred = exp["val_preds"][idx] if (exp and idx < len(exp["val_preds"])) else None
|
| 424 |
+
probs = exp["val_probs"][idx] if (exp and idx < len(exp["val_probs"])) else None
|
| 425 |
+
return pixel_info_markdown(int(x), int(y), img7, gt, pred, probs)
|
| 426 |
except Exception as e:
|
| 427 |
+
return f"Click error: `{e}`"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -1,6 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
-
from
|
|
|
|
|
|
|
|
|
|
| 3 |
|
|
|
|
| 4 |
|
| 5 |
def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> np.ndarray:
|
| 6 |
x = x.astype(np.float32)
|
|
@@ -11,16 +20,80 @@ def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> n
|
|
| 11 |
return np.clip((x - lo) / (hi - lo), 0, 1)
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
def multispectral_to_rgb(img7: np.ndarray) -> np.ndarray:
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def mask_to_color(mask: np.ndarray) -> np.ndarray:
|
| 23 |
-
"""Class indices
|
| 24 |
out = np.full((*mask.shape, 3), 200, dtype=np.uint8)
|
| 25 |
labeled = (mask != IGNORE_INDEX) & (mask >= 0)
|
| 26 |
if labeled.any():
|
|
@@ -35,7 +108,7 @@ def overlay_mask(rgb: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.n
|
|
| 35 |
|
| 36 |
|
| 37 |
def correctness_map(pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
|
| 38 |
-
"""Green = correct, red = wrong, gray = unlabeled
|
| 39 |
out = np.full((*pred.shape, 3), 180, dtype=np.uint8)
|
| 40 |
labeled = gt != IGNORE_INDEX
|
| 41 |
out[labeled & (pred == gt)] = [0, 220, 0]
|
|
@@ -47,3 +120,90 @@ def correctness_overlay(rgb: np.ndarray, pred: np.ndarray, gt: np.ndarray, alpha
|
|
| 47 |
cm = correctness_map(pred, gt)
|
| 48 |
out = ((1 - alpha) * rgb.astype(np.float32) + alpha * cm.astype(np.float32)).clip(0, 255)
|
| 49 |
return out.astype(np.uint8)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
|
| 3 |
+
import matplotlib
|
| 4 |
+
matplotlib.use("Agg")
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
import numpy as np
|
| 7 |
+
from PIL import Image, ImageDraw
|
| 8 |
+
|
| 9 |
+
from config import CLASS_COLORS, CLASS_NAMES, BAND_NAMES, BAND_DESCRIPTIONS, IGNORE_INDEX, NUM_CLASSES
|
| 10 |
+
|
| 11 |
|
| 12 |
+
# ββ Low-level helpers βββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
|
| 14 |
def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> np.ndarray:
|
| 15 |
x = x.astype(np.float32)
|
|
|
|
| 20 |
return np.clip((x - lo) / (hi - lo), 0, 1)
|
| 21 |
|
| 22 |
|
| 23 |
+
def _fig_to_numpy(fig) -> np.ndarray:
|
| 24 |
+
buf = io.BytesIO()
|
| 25 |
+
fig.savefig(buf, format="png", bbox_inches="tight", dpi=110)
|
| 26 |
+
plt.close(fig)
|
| 27 |
+
buf.seek(0)
|
| 28 |
+
return np.array(Image.open(buf).convert("RGB"))
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _blank_rgb(h: int = 300, w: int = 400) -> np.ndarray:
|
| 32 |
+
return np.full((h, w, 3), 220, dtype=np.uint8)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ββ Composite rendering (full image or patch) βββββββββββββββββ
|
| 36 |
+
|
| 37 |
+
def render_composite(img7: np.ndarray, r: int, g: int, b: int) -> np.ndarray:
|
| 38 |
+
"""img7: (7, H, W) -> (H, W, 3) uint8."""
|
| 39 |
+
return (np.stack([
|
| 40 |
+
percentile_stretch(img7[r]),
|
| 41 |
+
percentile_stretch(img7[g]),
|
| 42 |
+
percentile_stretch(img7[b]),
|
| 43 |
+
], axis=-1) * 255).astype(np.uint8)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def render_single_band(img7: np.ndarray, band_idx: int) -> np.ndarray:
|
| 47 |
+
"""Single band as grayscale RGB."""
|
| 48 |
+
gray = (percentile_stretch(img7[band_idx]) * 255).astype(np.uint8)
|
| 49 |
+
return np.stack([gray, gray, gray], axis=-1)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
def multispectral_to_rgb(img7: np.ndarray) -> np.ndarray:
|
| 53 |
+
"""Natural colour composite: H_4/H_3/H_2 -> R/G/B."""
|
| 54 |
+
return render_composite(img7, r=3, g=2, b=1)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
# ββ Label markers on full-scene image ββββββββββββββββββββββββ
|
| 58 |
+
|
| 59 |
+
def add_labels_overlay(
|
| 60 |
+
base_rgb: np.ndarray,
|
| 61 |
+
train_mask: np.ndarray,
|
| 62 |
+
val_mask: np.ndarray,
|
| 63 |
+
radius: int = 5,
|
| 64 |
+
) -> np.ndarray:
|
| 65 |
+
"""
|
| 66 |
+
Draw class-coloured markers on base_rgb.
|
| 67 |
+
Training labels -> filled squares; validation labels -> circles with white ring.
|
| 68 |
+
"""
|
| 69 |
+
img = Image.fromarray(base_rgb)
|
| 70 |
+
draw = ImageDraw.Draw(img)
|
| 71 |
+
H, W = base_rgb.shape[:2]
|
| 72 |
+
|
| 73 |
+
for cls_idx in range(NUM_CLASSES):
|
| 74 |
+
color = tuple(int(c) for c in CLASS_COLORS[cls_idx])
|
| 75 |
+
|
| 76 |
+
ys, xs = np.where(train_mask == cls_idx)
|
| 77 |
+
for y, x in zip(ys.tolist(), xs.tolist()):
|
| 78 |
+
box = [x - radius, y - radius, x + radius, y + radius]
|
| 79 |
+
box = [max(0, box[0]), max(0, box[1]), min(W-1, box[2]), min(H-1, box[3])]
|
| 80 |
+
draw.rectangle(box, fill=color, outline=(255, 255, 255))
|
| 81 |
+
|
| 82 |
+
ys, xs = np.where(val_mask == cls_idx)
|
| 83 |
+
for y, x in zip(ys.tolist(), xs.tolist()):
|
| 84 |
+
outer = [x-radius-2, y-radius-2, x+radius+2, y+radius+2]
|
| 85 |
+
inner = [x-radius, y-radius, x+radius, y+radius ]
|
| 86 |
+
outer = [max(0, v) for v in outer]
|
| 87 |
+
draw.ellipse(outer, fill=(255, 255, 255))
|
| 88 |
+
draw.ellipse(inner, fill=color)
|
| 89 |
|
| 90 |
+
return np.array(img)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ββ Mask colourisation ββββββββββββββββββββββββββββββββββββββββ
|
| 94 |
|
| 95 |
def mask_to_color(mask: np.ndarray) -> np.ndarray:
|
| 96 |
+
"""Class indices -> RGB. IGNORE_INDEX pixels rendered as light gray."""
|
| 97 |
out = np.full((*mask.shape, 3), 200, dtype=np.uint8)
|
| 98 |
labeled = (mask != IGNORE_INDEX) & (mask >= 0)
|
| 99 |
if labeled.any():
|
|
|
|
| 108 |
|
| 109 |
|
| 110 |
def correctness_map(pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
|
| 111 |
+
"""Green = correct, red = wrong, gray = unlabeled."""
|
| 112 |
out = np.full((*pred.shape, 3), 180, dtype=np.uint8)
|
| 113 |
labeled = gt != IGNORE_INDEX
|
| 114 |
out[labeled & (pred == gt)] = [0, 220, 0]
|
|
|
|
| 120 |
cm = correctness_map(pred, gt)
|
| 121 |
out = ((1 - alpha) * rgb.astype(np.float32) + alpha * cm.astype(np.float32)).clip(0, 255)
|
| 122 |
return out.astype(np.uint8)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ββ Full-scene prediction rendering ββββββββββββββββββββββββββ
|
| 126 |
+
|
| 127 |
+
def render_full_prediction_overlay(
|
| 128 |
+
full_image: np.ndarray,
|
| 129 |
+
full_pred: np.ndarray,
|
| 130 |
+
val_mask: np.ndarray,
|
| 131 |
+
alpha: float = 0.40,
|
| 132 |
+
dot_radius: int = 6,
|
| 133 |
+
) -> np.ndarray:
|
| 134 |
+
"""
|
| 135 |
+
Blend predicted class colours over natural-colour composite, then draw
|
| 136 |
+
correctness markers at every validation label location.
|
| 137 |
+
"""
|
| 138 |
+
rgb = render_composite(full_image, r=3, g=2, b=1)
|
| 139 |
+
base = overlay_mask(rgb, full_pred, alpha=alpha)
|
| 140 |
+
|
| 141 |
+
img = Image.fromarray(base)
|
| 142 |
+
draw = ImageDraw.Draw(img)
|
| 143 |
+
H, W = base.shape[:2]
|
| 144 |
+
|
| 145 |
+
for cls_idx in range(NUM_CLASSES):
|
| 146 |
+
ys, xs = np.where(val_mask == cls_idx)
|
| 147 |
+
for y, x in zip(ys.tolist(), xs.tolist()):
|
| 148 |
+
correct = (full_pred[y, x] == cls_idx)
|
| 149 |
+
ring = (0, 200, 0) if correct else (220, 0, 0)
|
| 150 |
+
r = dot_radius
|
| 151 |
+
outer = [max(0, x-r-2), max(0, y-r-2), min(W-1, x+r+2), min(H-1, y+r+2)]
|
| 152 |
+
inner = [max(0, x-r), max(0, y-r), min(W-1, x+r), min(H-1, y+r) ]
|
| 153 |
+
draw.ellipse(outer, fill=(255, 255, 255))
|
| 154 |
+
draw.ellipse(inner, fill=ring)
|
| 155 |
+
|
| 156 |
+
return np.array(img)
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ββ Matplotlib charts βββββββββββββββββββββββββββββββββββββββββ
|
| 160 |
+
|
| 161 |
+
def render_spectral_signatures_chart(signatures: dict) -> np.ndarray:
|
| 162 |
+
"""Line chart of per-class mean Β± 1-sigma across the 7 bands."""
|
| 163 |
+
fig, ax = plt.subplots(figsize=(8, 3.8))
|
| 164 |
+
x = np.arange(len(BAND_NAMES))
|
| 165 |
+
|
| 166 |
+
for cls_idx, sig in signatures.items():
|
| 167 |
+
mean = sig["mean"]
|
| 168 |
+
std = sig["std"]
|
| 169 |
+
n = sig["n"]
|
| 170 |
+
color = CLASS_COLORS[cls_idx] / 255.0
|
| 171 |
+
label = f"{CLASS_NAMES[cls_idx]} (n={n})"
|
| 172 |
+
ax.plot(x, mean, "o-", color=color, label=label, linewidth=2, markersize=5)
|
| 173 |
+
ax.fill_between(x, mean - std, mean + std, alpha=0.18, color=color)
|
| 174 |
+
|
| 175 |
+
ax.set_xticks(x)
|
| 176 |
+
ax.set_xticklabels([d.replace(" (", "\n(") for d in BAND_DESCRIPTIONS], fontsize=8)
|
| 177 |
+
ax.set_ylabel("Normalised reflectance")
|
| 178 |
+
ax.set_title("Spectral Signatures by Land Cover Class")
|
| 179 |
+
ax.legend(loc="upper left", fontsize=8)
|
| 180 |
+
ax.grid(True, alpha=0.3)
|
| 181 |
+
fig.tight_layout()
|
| 182 |
+
return _fig_to_numpy(fig)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def render_index_map(
|
| 186 |
+
index_arr: np.ndarray,
|
| 187 |
+
name: str,
|
| 188 |
+
train_mask: np.ndarray,
|
| 189 |
+
val_mask: np.ndarray,
|
| 190 |
+
) -> np.ndarray:
|
| 191 |
+
"""NDVI or NDWI heatmap with class-coloured label markers."""
|
| 192 |
+
cmap = "RdYlGn" if name == "NDVI" else "RdYlBu"
|
| 193 |
+
fig, ax = plt.subplots(figsize=(10, 4.5))
|
| 194 |
+
im = ax.imshow(index_arr, cmap=cmap, vmin=-1, vmax=1, aspect="auto")
|
| 195 |
+
plt.colorbar(im, ax=ax, fraction=0.018, pad=0.02)
|
| 196 |
+
|
| 197 |
+
for cls_idx in range(NUM_CLASSES):
|
| 198 |
+
color = CLASS_COLORS[cls_idx] / 255.0
|
| 199 |
+
name_ = CLASS_NAMES[cls_idx]
|
| 200 |
+
ys, xs = np.where(train_mask == cls_idx)
|
| 201 |
+
ax.scatter(xs, ys, c=[color], s=18, marker="s", label=f"{name_} (train)", zorder=5)
|
| 202 |
+
ys, xs = np.where(val_mask == cls_idx)
|
| 203 |
+
ax.scatter(xs, ys, c=[color], s=18, marker="o",
|
| 204 |
+
edgecolors="white", linewidths=0.6, zorder=6)
|
| 205 |
+
|
| 206 |
+
ax.set_title(f"{name} β squares=training labels, circles=val labels")
|
| 207 |
+
ax.legend(loc="upper right", fontsize=7, markerscale=1.4)
|
| 208 |
+
fig.tight_layout()
|
| 209 |
+
return _fig_to_numpy(fig)
|