"""Fusion of geometry-based and appearance-based segmentation. RANSAC gives shadow-invariant plane geometry from the DSM. C-RADIOv4-H gives appearance-based segmentation from RGB. This module merges them: geometry is primary (ground truth for planes), appearance refines boundaries and detects visual sub-features. """ import numpy as np from scipy.ndimage import label, distance_transform_edt def fuse_segmentations( ransac_labels: np.ndarray, radio_scores: np.ndarray, building_mask: np.ndarray, num_roof_classes: int = 4, ) -> np.ndarray: """Merge RANSAC plane labels with RADIO appearance scores. Strategy: 1. RANSAC labels are the primary source (geometry = ground truth). 2. Within each RANSAC plane region, check if RADIO shows strong internal class boundaries (different roof types within one plane). 3. If a RANSAC plane contains pixels where RADIO strongly disagrees on class, consider sub-dividing — but only if the sub-region is large enough to be meaningful. Args: ransac_labels: (H, W) int32 from RANSAC. 0 = background. radio_scores: (H, W, C) float score map from zero_shot_segment. building_mask: (H, W) binary mask. num_roof_classes: Number of roof-type classes in radio_scores. Returns: Fused label map (H, W) int32. """ fused = ransac_labels.copy() # For each RANSAC plane, check RADIO's opinion unique_planes = sorted(set(np.unique(ransac_labels)) - {0}) next_id = max(unique_planes) + 1 if unique_planes else 1 # Get RADIO's per-pixel roof class (only among roof classes) roof_scores = radio_scores[:, :, :num_roof_classes] radio_roof_class = np.argmax(roof_scores, axis=2) for plane_id in unique_planes: plane_mask = ransac_labels == plane_id plane_pixels = plane_mask.sum() if plane_pixels < 100: continue # What does RADIO think about this plane's pixels? classes_in_plane = radio_roof_class[plane_mask] unique_classes, counts = np.unique(classes_in_plane, return_counts=True) if len(unique_classes) <= 1: continue # Check if there's a significant minority class dominant_class = unique_classes[np.argmax(counts)] for cls, count in zip(unique_classes, counts): if cls == dominant_class: continue fraction = count / plane_pixels if fraction < 0.30: # Must represent >30% of the plane to justify a split. # The RADIO roof-type prompts are similar ("flat/pitched/hip/gable # roof plane"), so smaller fractions are noise, not real sub-features. continue # This class represents >30% of the plane — check spatial coherence sub_mask = plane_mask & (radio_roof_class == cls) labeled_sub, n_components = label(sub_mask) for comp_id in range(1, n_components + 1): comp_mask = labeled_sub == comp_id comp_size = comp_mask.sum() # Only split if the sub-region is spatially coherent and large enough if comp_size > plane_pixels * 0.20 and comp_size > 500: fused[comp_mask] = next_id next_id += 1 return fused def split_disconnected_regions(label_map: np.ndarray) -> np.ndarray: """Assign new IDs to spatially disconnected components of the same label. Handles the case where two parallel roof faces have the same pitch (RANSAC assigns them the same plane) but are physically separate. """ result = np.zeros_like(label_map) next_id = 1 for old_id in sorted(set(np.unique(label_map)) - {0}): mask = label_map == old_id labeled, n = label(mask) for comp in range(1, n + 1): result[labeled == comp] = next_id next_id += 1 return result def merge_small_fragments( label_map: np.ndarray, building_mask: np.ndarray, min_fraction: float = 0.05, ) -> np.ndarray: """Merge fragments smaller than min_fraction of building area. Small fragments get absorbed into their nearest neighboring plane. """ building_area = (building_mask > 0).sum() min_pixels = int(building_area * min_fraction) merged = label_map.copy() unique_labels = sorted(set(np.unique(merged)) - {0}) for lbl in unique_labels: mask = merged == lbl if mask.sum() >= min_pixels: continue # Find nearest neighbor label dilated = mask.copy() for _ in range(5): kernel = np.ones((3, 3), dtype=np.uint8) dilated = np.maximum( dilated, (np.pad(dilated, 1, mode="constant")[:-2, 1:-1] | np.pad(dilated, 1, mode="constant")[2:, 1:-1] | np.pad(dilated, 1, mode="constant")[1:-1, :-2] | np.pad(dilated, 1, mode="constant")[1:-1, 2:]) ) # Find labels in the dilated region that aren't the current one neighbor_labels = merged[dilated & ~mask & (merged > 0)] if len(neighbor_labels) == 0: continue # Merge into the most common neighbor vals, cnts = np.unique(neighbor_labels, return_counts=True) best_neighbor = vals[np.argmax(cnts)] merged[mask] = best_neighbor # Relabel sequentially result = np.zeros_like(merged) for new_id, old_id in enumerate(sorted(set(np.unique(merged)) - {0}), start=1): result[merged == old_id] = new_id return result