RoofSegmentation2 / fusion.py
Deagin's picture
Tune segmentation: reduce over-splitting and clean polygon shapes
dda24a3
"""Fusion of geometry-based and appearance-based segmentation.
RANSAC gives shadow-invariant plane geometry from the DSM.
C-RADIOv4-H gives appearance-based segmentation from RGB.
This module merges them: geometry is primary (ground truth for planes),
appearance refines boundaries and detects visual sub-features.
"""
import numpy as np
from scipy.ndimage import label, distance_transform_edt
def fuse_segmentations(
ransac_labels: np.ndarray,
radio_scores: np.ndarray,
building_mask: np.ndarray,
num_roof_classes: int = 4,
) -> np.ndarray:
"""Merge RANSAC plane labels with RADIO appearance scores.
Strategy:
1. RANSAC labels are the primary source (geometry = ground truth).
2. Within each RANSAC plane region, check if RADIO shows strong
internal class boundaries (different roof types within one plane).
3. If a RANSAC plane contains pixels where RADIO strongly disagrees
on class, consider sub-dividing — but only if the sub-region is
large enough to be meaningful.
Args:
ransac_labels: (H, W) int32 from RANSAC. 0 = background.
radio_scores: (H, W, C) float score map from zero_shot_segment.
building_mask: (H, W) binary mask.
num_roof_classes: Number of roof-type classes in radio_scores.
Returns:
Fused label map (H, W) int32.
"""
fused = ransac_labels.copy()
# For each RANSAC plane, check RADIO's opinion
unique_planes = sorted(set(np.unique(ransac_labels)) - {0})
next_id = max(unique_planes) + 1 if unique_planes else 1
# Get RADIO's per-pixel roof class (only among roof classes)
roof_scores = radio_scores[:, :, :num_roof_classes]
radio_roof_class = np.argmax(roof_scores, axis=2)
for plane_id in unique_planes:
plane_mask = ransac_labels == plane_id
plane_pixels = plane_mask.sum()
if plane_pixels < 100:
continue
# What does RADIO think about this plane's pixels?
classes_in_plane = radio_roof_class[plane_mask]
unique_classes, counts = np.unique(classes_in_plane, return_counts=True)
if len(unique_classes) <= 1:
continue
# Check if there's a significant minority class
dominant_class = unique_classes[np.argmax(counts)]
for cls, count in zip(unique_classes, counts):
if cls == dominant_class:
continue
fraction = count / plane_pixels
if fraction < 0.30:
# Must represent >30% of the plane to justify a split.
# The RADIO roof-type prompts are similar ("flat/pitched/hip/gable
# roof plane"), so smaller fractions are noise, not real sub-features.
continue
# This class represents >30% of the plane — check spatial coherence
sub_mask = plane_mask & (radio_roof_class == cls)
labeled_sub, n_components = label(sub_mask)
for comp_id in range(1, n_components + 1):
comp_mask = labeled_sub == comp_id
comp_size = comp_mask.sum()
# Only split if the sub-region is spatially coherent and large enough
if comp_size > plane_pixels * 0.20 and comp_size > 500:
fused[comp_mask] = next_id
next_id += 1
return fused
def split_disconnected_regions(label_map: np.ndarray) -> np.ndarray:
"""Assign new IDs to spatially disconnected components of the same label.
Handles the case where two parallel roof faces have the same pitch
(RANSAC assigns them the same plane) but are physically separate.
"""
result = np.zeros_like(label_map)
next_id = 1
for old_id in sorted(set(np.unique(label_map)) - {0}):
mask = label_map == old_id
labeled, n = label(mask)
for comp in range(1, n + 1):
result[labeled == comp] = next_id
next_id += 1
return result
def merge_small_fragments(
label_map: np.ndarray,
building_mask: np.ndarray,
min_fraction: float = 0.05,
) -> np.ndarray:
"""Merge fragments smaller than min_fraction of building area.
Small fragments get absorbed into their nearest neighboring plane.
"""
building_area = (building_mask > 0).sum()
min_pixels = int(building_area * min_fraction)
merged = label_map.copy()
unique_labels = sorted(set(np.unique(merged)) - {0})
for lbl in unique_labels:
mask = merged == lbl
if mask.sum() >= min_pixels:
continue
# Find nearest neighbor label
dilated = mask.copy()
for _ in range(5):
kernel = np.ones((3, 3), dtype=np.uint8)
dilated = np.maximum(
dilated,
(np.pad(dilated, 1, mode="constant")[:-2, 1:-1] |
np.pad(dilated, 1, mode="constant")[2:, 1:-1] |
np.pad(dilated, 1, mode="constant")[1:-1, :-2] |
np.pad(dilated, 1, mode="constant")[1:-1, 2:])
)
# Find labels in the dilated region that aren't the current one
neighbor_labels = merged[dilated & ~mask & (merged > 0)]
if len(neighbor_labels) == 0:
continue
# Merge into the most common neighbor
vals, cnts = np.unique(neighbor_labels, return_counts=True)
best_neighbor = vals[np.argmax(cnts)]
merged[mask] = best_neighbor
# Relabel sequentially
result = np.zeros_like(merged)
for new_id, old_id in enumerate(sorted(set(np.unique(merged)) - {0}), start=1):
result[merged == old_id] = new_id
return result