RoofSegmentation2 / pipeline.py
Deagin's picture
Fix: Free SigLIP2 text encoder after caching embeddings (~7.5GB saved)
8266ce5
"""Pipeline orchestration: address -> GeoJSON roof planes.
Chains together all modules in the correct order:
Google Solar API -> Building isolation -> RANSAC planes -> C-RADIOv4-H -> Fusion -> GeoJSON
"""
from dataclasses import dataclass, field
import numpy as np
import cv2
from PIL import Image
from google_solar import geocode_address, fetch_geotiff, parse_geotiff, parse_building_mask, parse_dsm
from building import isolate_primary_building, crop_to_building, recalculate_bounds
from ransac_planes import preprocess_dsm, dsm_to_point_cloud, fit_planes, planes_to_label_map, build_plane_info
from radio_backbone import zero_shot_segment, get_roof_mask, move_to
from fusion import fuse_segmentations, split_disconnected_regions, merge_small_fragments
from geo_export import labels_to_geojson
@dataclass
class PipelineResult:
"""All outputs from a pipeline run."""
original_image: np.ndarray = None
building_mask: np.ndarray = None
dsm: np.ndarray = None
ransac_labels: np.ndarray = None
radio_seg_map: np.ndarray = None
fused_labels: np.ndarray = None
plane_info: list[dict] = field(default_factory=list)
geojson: dict = field(default_factory=dict)
overlay: np.ndarray = None
bounds: tuple = None
status: list[str] = field(default_factory=list)
# Distinct colors for plane visualization
PLANE_COLORS = [
(230, 25, 75), (60, 180, 75), (255, 225, 25), (0, 130, 200),
(245, 130, 48), (145, 30, 180), (70, 240, 240), (240, 50, 230),
(128, 128, 0), (0, 128, 128), (220, 190, 255), (170, 110, 40),
(255, 250, 200), (128, 0, 0), (0, 0, 128),
]
def run(
address: str,
api_key: str,
radius_meters: int = 50,
ransac_threshold: float = 0.15,
max_planes: int = 8,
min_area_sqft: float = 50.0,
device: str = "cuda",
) -> PipelineResult:
"""Run the full segmentation pipeline.
Args:
address: Property address string.
api_key: Google Cloud API key.
radius_meters: Search radius for Solar API.
ransac_threshold: RANSAC inlier distance (meters).
max_planes: Max roof planes to detect.
min_area_sqft: Minimum polygon area to include in GeoJSON.
device: Compute device for C-RADIOv4-H.
Returns:
PipelineResult with all intermediate and final outputs.
"""
result = PipelineResult()
# --- Stage 1: Data Acquisition ---
result.status.append("Geocoding address...")
lat, lng, formatted_address = geocode_address(address, api_key)
result.status.append(f"Location: {formatted_address} ({lat:.6f}, {lng:.6f})")
result.status.append("Fetching satellite imagery from Google Solar API...")
rgb_bytes, mask_bytes, dsm_bytes, _ = fetch_geotiff(lat, lng, api_key, radius_meters)
image, rgb_bounds = parse_geotiff(rgb_bytes)
building_mask_full, _ = parse_building_mask(mask_bytes)
dsm_full, _ = parse_dsm(dsm_bytes)
if building_mask_full is None:
raise ValueError("No building mask available for this location.")
if dsm_full is None:
raise ValueError("No DSM data available for this location.")
# --- Stage 2: Building Isolation ---
result.status.append("Isolating primary building...")
primary_mask = isolate_primary_building(building_mask_full, lat, lng, rgb_bounds)
if primary_mask is None:
raise ValueError("No building found at this address.")
img_array = np.array(image)
cropped_img, crop_info, cropped_mask = crop_to_building(img_array, primary_mask)
if crop_info is None:
raise ValueError("Building crop failed — empty mask.")
bounds = recalculate_bounds(rgb_bounds, crop_info)
result.bounds = bounds
result.original_image = cropped_img.copy()
# Crop DSM to match
rmin, rmax = crop_info["rmin"], crop_info["rmax"]
cmin, cmax = crop_info["cmin"], crop_info["cmax"]
# Resize DSM to match image dimensions before cropping
if dsm_full.shape != img_array.shape[:2]:
dsm_full = cv2.resize(
dsm_full.astype(np.float32),
(img_array.shape[1], img_array.shape[0]),
interpolation=cv2.INTER_LINEAR,
)
dsm_cropped = dsm_full[rmin:rmax, cmin:cmax]
result.dsm = dsm_cropped
result.building_mask = cropped_mask
h, w = cropped_img.shape[:2]
result.status.append(f"Building isolated: {w}x{h} px, {len(np.unique(cropped_mask))} mask values")
# --- Stage 3: RANSAC Plane Fitting ---
result.status.append("Fitting 3D planes to DSM via RANSAC...")
dsm_smooth = preprocess_dsm(dsm_cropped, cropped_mask)
points_3d, valid_flat_idx, valid_flat = dsm_to_point_cloud(dsm_smooth, cropped_mask)
planes = fit_planes(
points_3d,
distance_threshold=ransac_threshold,
min_points=max(100, int(len(points_3d) * 0.02)), # At least 2% of points
max_planes=max_planes,
)
ransac_labels = planes_to_label_map(planes, valid_flat_idx, (h, w), cropped_mask)
plane_info = build_plane_info(planes)
result.ransac_labels = ransac_labels
result.plane_info = plane_info
n_planes = len(planes)
result.status.append(f"RANSAC found {n_planes} plane(s):")
for p in plane_info:
result.status.append(
f" Plane {p['segment_id']}: {p['mean_slope']:.1f} deg pitch, "
f"{p['mean_aspect']:.0f} deg azimuth, {p['area_sqm']:.1f} m2"
)
# --- Stage 4: C-RADIOv4-H Zero-Shot Segmentation ---
result.status.append("Running C-RADIOv4-H zero-shot segmentation...")
move_to(device)
score_map, seg_map, labels = zero_shot_segment(cropped_img, device=device)
result.radio_seg_map = seg_map
roof_mask = get_roof_mask(seg_map)
roof_pct = roof_mask.sum() / max((cropped_mask > 0).sum(), 1) * 100
result.status.append(f"RADIO: {roof_pct:.0f}% of building classified as roof")
# --- Stage 5: Fusion ---
result.status.append("Fusing geometry + appearance...")
fused = fuse_segmentations(ransac_labels, score_map, cropped_mask)
fused = split_disconnected_regions(fused)
fused = merge_small_fragments(fused, cropped_mask, min_fraction=0.05)
result.fused_labels = fused
n_final = len(set(np.unique(fused)) - {0})
result.status.append(f"Final: {n_final} roof plane(s) after fusion")
# Rebuild plane_info for fused labels (carry forward RANSAC metrics
# for labels that survived, create new entries for split labels)
fused_plane_info = _rebuild_plane_info(fused, planes, dsm_smooth, cropped_mask)
result.plane_info = fused_plane_info
# --- Stage 6: GeoJSON Export ---
result.status.append("Generating GeoJSON...")
result.geojson = labels_to_geojson(fused, bounds, fused_plane_info, min_area_sqft)
n_features = len(result.geojson.get("features", []))
result.status.append(f"GeoJSON: {n_features} polygon(s) exported")
# --- Visualization ---
result.overlay = build_overlay(cropped_img, fused, fused_plane_info)
return result
def _rebuild_plane_info(
fused_labels: np.ndarray,
original_planes: list[dict],
dsm_smooth: np.ndarray,
building_mask: np.ndarray,
) -> list[dict]:
"""Rebuild plane info for the fused label map.
For labels that match original RANSAC planes, carry forward
the plane equation and metrics. For new labels (from splits),
compute metrics from DSM.
"""
# Map original segment_id -> plane info
orig_lookup = {p["segment_id"]: p for p in original_planes}
info = []
for label_id in sorted(set(np.unique(fused_labels)) - {0}):
mask = fused_labels == label_id
n_pixels = mask.sum()
if label_id in orig_lookup:
p = orig_lookup[label_id]
info.append({
"segment_id": int(label_id),
"mean_slope": p["pitch_deg"],
"mean_aspect": p["azimuth_deg"],
"mean_height": p["mean_height"],
"area_pixels": int(n_pixels),
"area_sqm": float(n_pixels * 0.1 ** 2),
"plane_normal": p["normal"],
"plane_d": p["plane_d"],
})
else:
# Compute from DSM
heights = dsm_smooth[mask]
dy, dx = np.gradient(dsm_smooth)
slope_vals = np.degrees(np.arctan(np.sqrt(dx[mask]**2 + dy[mask]**2) / 0.1))
aspect_vals = np.degrees(np.arctan2(-dx[mask], dy[mask])) % 360
info.append({
"segment_id": int(label_id),
"mean_slope": float(np.median(slope_vals)),
"mean_aspect": float(np.median(aspect_vals)),
"mean_height": float(heights.mean()),
"area_pixels": int(n_pixels),
"area_sqm": float(n_pixels * 0.1 ** 2),
})
return info
def build_overlay(
image: np.ndarray,
label_map: np.ndarray,
plane_info: list[dict],
alpha: float = 0.45,
) -> np.ndarray:
"""Build a visualization overlay: colored planes on the original image.
Each plane gets a distinct color with alpha blending.
White contour outlines and plane ID labels are drawn on top.
"""
overlay = image.copy()
color_layer = np.zeros_like(image)
for i, info in enumerate(plane_info):
seg_id = info["segment_id"]
color = PLANE_COLORS[i % len(PLANE_COLORS)]
mask = label_map == seg_id
color_layer[mask] = color
# Alpha blend
has_label = label_map > 0
overlay[has_label] = (
(1 - alpha) * image[has_label].astype(float)
+ alpha * color_layer[has_label].astype(float)
).astype(np.uint8)
# Draw contours
for info in plane_info:
seg_id = info["segment_id"]
mask = (label_map == seg_id).astype(np.uint8)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay, contours, -1, (255, 255, 255), 1, cv2.LINE_AA)
# Label text
if contours:
M = cv2.moments(contours[0])
if M["m00"] > 0:
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
pitch = info.get("mean_slope", 0)
text = f"P{seg_id}: {pitch:.0f} deg"
cv2.putText(overlay, text, (cx - 30, cy),
cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA)
return overlay