File size: 2,411 Bytes
9eda8e3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from __future__ import annotations

import numpy as np
import cv2

# Colour palette for up to 10 classes (BGR → RGB)
DEFAULT_PALETTE = [
    [0, 0, 0],        # 0 — background (black)
    [255, 50, 50],     # 1 — building (red)
    [50, 255, 50],     # 2
    [50, 50, 255],     # 3
    [255, 255, 50],    # 4
    [255, 50, 255],    # 5
    [50, 255, 255],    # 6
    [200, 200, 200],   # 7
    [128, 0, 255],     # 8
    [0, 128, 255],     # 9
]

def overlay_mask(
    image: np.ndarray,
    mask: np.ndarray,
    alpha: float = 0.45,
    colour: tuple[int, int, int] = (255, 50, 50),
) -> np.ndarray:
    """Overlay a binary mask on an RGB image.

    Args:
        image:  HWC uint8 RGB image.
        mask:   HW boolean or int array (True / >0 = foreground).
        alpha:  Opacity of the overlay.
        colour: RGB colour for foreground pixels.

    Returns:
        Blended HWC uint8 RGB image.
    """
    overlay = image.copy()
    bool_mask = mask.astype(bool)
    overlay[bool_mask] = np.array(colour, dtype=np.uint8)
    output = image.copy()
    cv2.addWeighted(overlay, alpha, output, 1.0 - alpha, 0, output)
    return output

def colourize_labels(
    label: np.ndarray,
    palette: list[list[int]] | None = None,
) -> np.ndarray:
    """Convert an integer label map to an RGB colour image.

    Args:
        label:   HW int array of class IDs.
        palette: list of [R, G, B] per class index.

    Returns:
        HWC uint8 RGB image.
    """
    if palette is None:
        palette = DEFAULT_PALETTE

    h, w = label.shape
    rgb = np.zeros((h, w, 3), dtype=np.uint8)
    for cls_id, colour in enumerate(palette):
        rgb[label == cls_id] = colour
    return rgb

def plot_predictions(
    image: np.ndarray,
    gt_label: np.ndarray,
    pred_label: np.ndarray,
    save_path: str | None = None,
) -> np.ndarray:
    """Create a side-by-side comparison: image | GT overlay | prediction overlay.

    Returns:
        Concatenated HWC uint8 RGB image (width = 3 × original).
    """
    gt_overlay = overlay_mask(image, gt_label > 0, alpha=0.4, colour=(50, 255, 50))
    pred_overlay = overlay_mask(image, pred_label > 0, alpha=0.4, colour=(255, 50, 50))

    combined = np.concatenate([image, gt_overlay, pred_overlay], axis=1)

    if save_path is not None:
        bgr = cv2.cvtColor(combined, cv2.COLOR_RGB2BGR)
        cv2.imwrite(save_path, bgr)

    return combined