import numpy as np from config import CLASS_COLORS, IGNORE_INDEX def percentile_stretch(x: np.ndarray, low: float = 2.0, high: float = 98.0) -> np.ndarray: x = x.astype(np.float32) lo = np.percentile(x, low) hi = np.percentile(x, high) if hi <= lo: hi = lo + 1e-6 return np.clip((x - lo) / (hi - lo), 0, 1) def multispectral_to_rgb(img7: np.ndarray) -> np.ndarray: """img7: (7, H, W) — uses H_3/H_2/H_1 for natural colour-like composite.""" r = percentile_stretch(img7[2]) g = percentile_stretch(img7[1]) b = percentile_stretch(img7[0]) return (np.stack([r, g, b], axis=-1) * 255).astype(np.uint8) def mask_to_color(mask: np.ndarray) -> np.ndarray: """Class indices → RGB. IGNORE_INDEX pixels rendered as light gray.""" out = np.full((*mask.shape, 3), 200, dtype=np.uint8) labeled = (mask != IGNORE_INDEX) & (mask >= 0) if labeled.any(): out[labeled] = CLASS_COLORS[mask[labeled].astype(np.int64)] return out def overlay_mask(rgb: np.ndarray, mask: np.ndarray, alpha: float = 0.45) -> np.ndarray: color_mask = mask_to_color(mask) out = ((1 - alpha) * rgb.astype(np.float32) + alpha * color_mask.astype(np.float32)).clip(0, 255) return out.astype(np.uint8) def correctness_map(pred: np.ndarray, gt: np.ndarray) -> np.ndarray: """Green = correct, red = wrong, gray = unlabeled (IGNORE_INDEX).""" out = np.full((*pred.shape, 3), 180, dtype=np.uint8) labeled = gt != IGNORE_INDEX out[labeled & (pred == gt)] = [0, 220, 0] out[labeled & (pred != gt)] = [220, 0, 0] return out def correctness_overlay(rgb: np.ndarray, pred: np.ndarray, gt: np.ndarray, alpha: float = 0.38) -> np.ndarray: cm = correctness_map(pred, gt) out = ((1 - alpha) * rgb.astype(np.float32) + alpha * cm.astype(np.float32)).clip(0, 255) return out.astype(np.uint8)