File size: 7,928 Bytes
089078d
 
 
 
 
db426cb
089078d
 
 
 
db426cb
089078d
db426cb
 
 
 
 
 
 
 
 
 
089078d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db426cb
089078d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db426cb
089078d
 
 
 
db426cb
 
089078d
9cb9bec
 
 
 
 
db426cb
 
 
 
 
 
 
 
 
089078d
9cb9bec
 
 
 
db426cb
 
 
 
 
 
 
089078d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5fdcdd
 
089078d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d5fdcdd
089078d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
import io

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image, ImageDraw

from config import CLASS_COLORS, CLASS_NAMES, BAND_NAMES, BAND_DESCRIPTIONS, IGNORE_INDEX, NUM_CLASSES


# โ”€โ”€ Low-level helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> np.ndarray:
    x = x.astype(np.float32)
    lo = np.percentile(x, low)
    hi = np.percentile(x, high)
    if hi <= lo:
        hi = lo + 1e-6
    return np.clip((x - lo) / (hi - lo), 0, 1)


def _fig_to_numpy(fig) -> np.ndarray:
    buf = io.BytesIO()
    fig.savefig(buf, format="png", bbox_inches="tight", dpi=110)
    plt.close(fig)
    buf.seek(0)
    return np.array(Image.open(buf).convert("RGB"))


def _blank_rgb(h: int = 300, w: int = 400) -> np.ndarray:
    return np.full((h, w, 3), 220, dtype=np.uint8)


# โ”€โ”€ Composite rendering (full image or patch) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

def render_composite(img7: np.ndarray, r: int, g: int, b: int) -> np.ndarray:
    """img7: (7, H, W) -> (H, W, 3) uint8."""
    return (np.stack([
        percentile_stretch(img7[r]),
        percentile_stretch(img7[g]),
        percentile_stretch(img7[b]),
    ], axis=-1) * 255).astype(np.uint8)


def render_single_band(img7: np.ndarray, band_idx: int) -> np.ndarray:
    """Single band as grayscale RGB."""
    gray = (percentile_stretch(img7[band_idx]) * 255).astype(np.uint8)
    return np.stack([gray, gray, gray], axis=-1)


def multispectral_to_rgb(img7: np.ndarray) -> np.ndarray:
    """Natural colour composite: H_4/H_3/H_2 -> R/G/B."""
    return render_composite(img7, r=3, g=2, b=1)


# โ”€โ”€ Label markers on full-scene image โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

def add_labels_overlay(
    base_rgb:   np.ndarray,
    train_mask: np.ndarray,
    val_mask:   np.ndarray,
    radius:     int = 5,
) -> np.ndarray:
    """
    Draw class-coloured markers on base_rgb.
    Training labels -> filled squares; validation labels -> circles with white ring.
    """
    img  = Image.fromarray(base_rgb)
    draw = ImageDraw.Draw(img)
    H, W = base_rgb.shape[:2]

    for cls_idx in range(NUM_CLASSES):
        color = tuple(int(c) for c in CLASS_COLORS[cls_idx])

        ys, xs = np.where(train_mask == cls_idx)
        for y, x in zip(ys.tolist(), xs.tolist()):
            box = [x - radius, y - radius, x + radius, y + radius]
            box = [max(0, box[0]), max(0, box[1]), min(W-1, box[2]), min(H-1, box[3])]
            draw.rectangle(box, fill=color, outline=(255, 255, 255))

        ys, xs = np.where(val_mask == cls_idx)
        for y, x in zip(ys.tolist(), xs.tolist()):
            outer = [x-radius-2, y-radius-2, x+radius+2, y+radius+2]
            inner = [x-radius,   y-radius,   x+radius,   y+radius  ]
            outer = [max(0, v) for v in outer]
            draw.ellipse(outer, fill=(255, 255, 255))
            draw.ellipse(inner, fill=color)

    return np.array(img)


# โ”€โ”€ Mask colourisation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

def mask_to_color(mask: np.ndarray) -> np.ndarray:
    """Class indices -> RGB. IGNORE_INDEX pixels rendered as light gray."""
    out = np.full((*mask.shape, 3), 200, dtype=np.uint8)
    labeled = (mask != IGNORE_INDEX) & (mask >= 0)
    if labeled.any():
        out[labeled] = CLASS_COLORS[mask[labeled].astype(np.int64)]
    return out


def overlay_mask(rgb: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.ndarray:
    color_mask = mask_to_color(mask)
    out = ((1 - alpha) * rgb.astype(np.float32) + alpha * color_mask.astype(np.float32)).clip(0, 255)
    return out.astype(np.uint8)


def correctness_map(pred: np.ndarray, gt: np.ndarray) -> np.ndarray:
    """Green = correct, red = wrong, gray = unlabeled."""
    out = np.full((*pred.shape, 3), 180, dtype=np.uint8)
    labeled = gt != IGNORE_INDEX
    out[labeled & (pred == gt)] = [0, 220, 0]
    out[labeled & (pred != gt)] = [220, 0, 0]
    return out


def correctness_overlay(rgb: np.ndarray, pred: np.ndarray, gt: np.ndarray, alpha: float = 0.38) -> np.ndarray:
    cm = correctness_map(pred, gt)
    out = ((1 - alpha) * rgb.astype(np.float32) + alpha * cm.astype(np.float32)).clip(0, 255)
    return out.astype(np.uint8)


# โ”€โ”€ Full-scene prediction rendering โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

def render_full_prediction_overlay(
    full_image: np.ndarray,
    full_pred:  np.ndarray,
    val_mask:   np.ndarray,
    alpha:      float = 0.40,
    dot_radius: int   = 6,
) -> np.ndarray:
    """
    Blend predicted class colours over natural-colour composite, then draw
    correctness markers at every validation label location.
    """
    rgb  = render_composite(full_image, r=3, g=2, b=1)
    base = overlay_mask(rgb, full_pred, alpha=alpha)

    img  = Image.fromarray(base)
    draw = ImageDraw.Draw(img)
    H, W = base.shape[:2]

    for cls_idx in range(NUM_CLASSES):
        ys, xs = np.where(val_mask == cls_idx)
        for y, x in zip(ys.tolist(), xs.tolist()):
            correct = (full_pred[y, x] == cls_idx)
            ring    = (0, 200, 0) if correct else (220, 0, 0)
            r = dot_radius
            outer = [max(0, x-r-2), max(0, y-r-2), min(W-1, x+r+2), min(H-1, y+r+2)]
            inner = [max(0, x-r),   max(0, y-r),   min(W-1, x+r),   min(H-1, y+r)  ]
            draw.ellipse(outer, fill=(255, 255, 255))
            draw.ellipse(inner, fill=ring)

    return np.array(img)


# โ”€โ”€ Matplotlib charts โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

def render_spectral_signatures_chart(signatures: dict) -> np.ndarray:
    """Line chart of per-class mean ยฑ 1-sigma across the 7 bands."""
    fig, ax = plt.subplots(figsize=(8, 3.8))
    x = np.arange(len(BAND_NAMES))

    for cls_idx, sig in signatures.items():
        mean  = sig["mean"]
        std   = sig["std"]
        n     = sig["n"]
        color = CLASS_COLORS[cls_idx] / 255.0
        label = f"{CLASS_NAMES[cls_idx]} (n={n})"
        ax.plot(x, mean, "o-", color=color, label=label, linewidth=2, markersize=5)
        ax.fill_between(x, mean - std, mean + std, alpha=0.18, color=color)

    ax.set_xticks(x)
    ax.set_xticklabels([d.replace(" (", "\n(") for d in BAND_DESCRIPTIONS], fontsize=8)
    ax.set_ylabel("Rรฉflectance normalisรฉe")
    ax.set_title("Signatures spectrales par classe d'occupation du sol")
    ax.legend(loc="upper left", fontsize=8)
    ax.grid(True, alpha=0.3)
    fig.tight_layout()
    return _fig_to_numpy(fig)


def render_index_map(
    index_arr:  np.ndarray,
    name:       str,
    train_mask: np.ndarray,
    val_mask:   np.ndarray,
) -> np.ndarray:
    """NDVI or NDWI heatmap with class-coloured label markers."""
    cmap = "RdYlGn" if name == "NDVI" else "RdYlBu"
    fig, ax = plt.subplots(figsize=(10, 4.5))
    im = ax.imshow(index_arr, cmap=cmap, vmin=-1, vmax=1, aspect="auto")
    plt.colorbar(im, ax=ax, fraction=0.018, pad=0.02)

    for cls_idx in range(NUM_CLASSES):
        color = CLASS_COLORS[cls_idx] / 255.0
        name_ = CLASS_NAMES[cls_idx]
        ys, xs = np.where(train_mask == cls_idx)
        ax.scatter(xs, ys, c=[color], s=18, marker="s", label=f"{name_} (train)", zorder=5)
        ys, xs = np.where(val_mask == cls_idx)
        ax.scatter(xs, ys, c=[color], s=18, marker="o",
                   edgecolors="white", linewidths=0.6, zorder=6)

    ax.set_title(f"{name} โ€” carrรฉs=รฉtiquettes d'entraรฎnement, cercles=รฉtiquettes de validation")
    ax.legend(loc="upper right", fontsize=7, markerscale=1.4)
    fig.tight_layout()
    return _fig_to_numpy(fig)