Spaces:
Runtime error
Runtime error
| """Pipeline orchestration: address -> GeoJSON roof planes. | |
| Chains together all modules in the correct order: | |
| Google Solar API -> Building isolation -> RANSAC planes -> C-RADIOv4-H -> Fusion -> GeoJSON | |
| """ | |
| from dataclasses import dataclass, field | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from google_solar import geocode_address, fetch_geotiff, parse_geotiff, parse_building_mask, parse_dsm | |
| from building import isolate_primary_building, crop_to_building, recalculate_bounds | |
| from ransac_planes import preprocess_dsm, dsm_to_point_cloud, fit_planes, planes_to_label_map, build_plane_info | |
| from radio_backbone import zero_shot_segment, get_roof_mask, move_to | |
| from fusion import fuse_segmentations, split_disconnected_regions, merge_small_fragments | |
| from geo_export import labels_to_geojson | |
| class PipelineResult: | |
| """All outputs from a pipeline run.""" | |
| original_image: np.ndarray = None | |
| building_mask: np.ndarray = None | |
| dsm: np.ndarray = None | |
| ransac_labels: np.ndarray = None | |
| radio_seg_map: np.ndarray = None | |
| fused_labels: np.ndarray = None | |
| plane_info: list[dict] = field(default_factory=list) | |
| geojson: dict = field(default_factory=dict) | |
| overlay: np.ndarray = None | |
| bounds: tuple = None | |
| status: list[str] = field(default_factory=list) | |
| # Distinct colors for plane visualization | |
| PLANE_COLORS = [ | |
| (230, 25, 75), (60, 180, 75), (255, 225, 25), (0, 130, 200), | |
| (245, 130, 48), (145, 30, 180), (70, 240, 240), (240, 50, 230), | |
| (128, 128, 0), (0, 128, 128), (220, 190, 255), (170, 110, 40), | |
| (255, 250, 200), (128, 0, 0), (0, 0, 128), | |
| ] | |
| def run( | |
| address: str, | |
| api_key: str, | |
| radius_meters: int = 50, | |
| ransac_threshold: float = 0.15, | |
| max_planes: int = 8, | |
| min_area_sqft: float = 50.0, | |
| device: str = "cuda", | |
| ) -> PipelineResult: | |
| """Run the full segmentation pipeline. | |
| Args: | |
| address: Property address string. | |
| api_key: Google Cloud API key. | |
| radius_meters: Search radius for Solar API. | |
| ransac_threshold: RANSAC inlier distance (meters). | |
| max_planes: Max roof planes to detect. | |
| min_area_sqft: Minimum polygon area to include in GeoJSON. | |
| device: Compute device for C-RADIOv4-H. | |
| Returns: | |
| PipelineResult with all intermediate and final outputs. | |
| """ | |
| result = PipelineResult() | |
| # --- Stage 1: Data Acquisition --- | |
| result.status.append("Geocoding address...") | |
| lat, lng, formatted_address = geocode_address(address, api_key) | |
| result.status.append(f"Location: {formatted_address} ({lat:.6f}, {lng:.6f})") | |
| result.status.append("Fetching satellite imagery from Google Solar API...") | |
| rgb_bytes, mask_bytes, dsm_bytes, _ = fetch_geotiff(lat, lng, api_key, radius_meters) | |
| image, rgb_bounds = parse_geotiff(rgb_bytes) | |
| building_mask_full, _ = parse_building_mask(mask_bytes) | |
| dsm_full, _ = parse_dsm(dsm_bytes) | |
| if building_mask_full is None: | |
| raise ValueError("No building mask available for this location.") | |
| if dsm_full is None: | |
| raise ValueError("No DSM data available for this location.") | |
| # --- Stage 2: Building Isolation --- | |
| result.status.append("Isolating primary building...") | |
| primary_mask = isolate_primary_building(building_mask_full, lat, lng, rgb_bounds) | |
| if primary_mask is None: | |
| raise ValueError("No building found at this address.") | |
| img_array = np.array(image) | |
| cropped_img, crop_info, cropped_mask = crop_to_building(img_array, primary_mask) | |
| if crop_info is None: | |
| raise ValueError("Building crop failed — empty mask.") | |
| bounds = recalculate_bounds(rgb_bounds, crop_info) | |
| result.bounds = bounds | |
| result.original_image = cropped_img.copy() | |
| # Crop DSM to match | |
| rmin, rmax = crop_info["rmin"], crop_info["rmax"] | |
| cmin, cmax = crop_info["cmin"], crop_info["cmax"] | |
| # Resize DSM to match image dimensions before cropping | |
| if dsm_full.shape != img_array.shape[:2]: | |
| dsm_full = cv2.resize( | |
| dsm_full.astype(np.float32), | |
| (img_array.shape[1], img_array.shape[0]), | |
| interpolation=cv2.INTER_LINEAR, | |
| ) | |
| dsm_cropped = dsm_full[rmin:rmax, cmin:cmax] | |
| result.dsm = dsm_cropped | |
| result.building_mask = cropped_mask | |
| h, w = cropped_img.shape[:2] | |
| result.status.append(f"Building isolated: {w}x{h} px, {len(np.unique(cropped_mask))} mask values") | |
| # --- Stage 3: RANSAC Plane Fitting --- | |
| result.status.append("Fitting 3D planes to DSM via RANSAC...") | |
| dsm_smooth = preprocess_dsm(dsm_cropped, cropped_mask) | |
| points_3d, valid_flat_idx, valid_flat = dsm_to_point_cloud(dsm_smooth, cropped_mask) | |
| planes = fit_planes( | |
| points_3d, | |
| distance_threshold=ransac_threshold, | |
| min_points=max(100, int(len(points_3d) * 0.02)), # At least 2% of points | |
| max_planes=max_planes, | |
| ) | |
| ransac_labels = planes_to_label_map(planes, valid_flat_idx, (h, w), cropped_mask) | |
| plane_info = build_plane_info(planes) | |
| result.ransac_labels = ransac_labels | |
| result.plane_info = plane_info | |
| n_planes = len(planes) | |
| result.status.append(f"RANSAC found {n_planes} plane(s):") | |
| for p in plane_info: | |
| result.status.append( | |
| f" Plane {p['segment_id']}: {p['mean_slope']:.1f} deg pitch, " | |
| f"{p['mean_aspect']:.0f} deg azimuth, {p['area_sqm']:.1f} m2" | |
| ) | |
| # --- Stage 4: C-RADIOv4-H Zero-Shot Segmentation --- | |
| result.status.append("Running C-RADIOv4-H zero-shot segmentation...") | |
| move_to(device) | |
| score_map, seg_map, labels = zero_shot_segment(cropped_img, device=device) | |
| result.radio_seg_map = seg_map | |
| roof_mask = get_roof_mask(seg_map) | |
| roof_pct = roof_mask.sum() / max((cropped_mask > 0).sum(), 1) * 100 | |
| result.status.append(f"RADIO: {roof_pct:.0f}% of building classified as roof") | |
| # --- Stage 5: Fusion --- | |
| result.status.append("Fusing geometry + appearance...") | |
| fused = fuse_segmentations(ransac_labels, score_map, cropped_mask) | |
| fused = split_disconnected_regions(fused) | |
| fused = merge_small_fragments(fused, cropped_mask, min_fraction=0.05) | |
| result.fused_labels = fused | |
| n_final = len(set(np.unique(fused)) - {0}) | |
| result.status.append(f"Final: {n_final} roof plane(s) after fusion") | |
| # Rebuild plane_info for fused labels (carry forward RANSAC metrics | |
| # for labels that survived, create new entries for split labels) | |
| fused_plane_info = _rebuild_plane_info(fused, planes, dsm_smooth, cropped_mask) | |
| result.plane_info = fused_plane_info | |
| # --- Stage 6: GeoJSON Export --- | |
| result.status.append("Generating GeoJSON...") | |
| result.geojson = labels_to_geojson(fused, bounds, fused_plane_info, min_area_sqft) | |
| n_features = len(result.geojson.get("features", [])) | |
| result.status.append(f"GeoJSON: {n_features} polygon(s) exported") | |
| # --- Visualization --- | |
| result.overlay = build_overlay(cropped_img, fused, fused_plane_info) | |
| return result | |
| def _rebuild_plane_info( | |
| fused_labels: np.ndarray, | |
| original_planes: list[dict], | |
| dsm_smooth: np.ndarray, | |
| building_mask: np.ndarray, | |
| ) -> list[dict]: | |
| """Rebuild plane info for the fused label map. | |
| For labels that match original RANSAC planes, carry forward | |
| the plane equation and metrics. For new labels (from splits), | |
| compute metrics from DSM. | |
| """ | |
| # Map original segment_id -> plane info | |
| orig_lookup = {p["segment_id"]: p for p in original_planes} | |
| info = [] | |
| for label_id in sorted(set(np.unique(fused_labels)) - {0}): | |
| mask = fused_labels == label_id | |
| n_pixels = mask.sum() | |
| if label_id in orig_lookup: | |
| p = orig_lookup[label_id] | |
| info.append({ | |
| "segment_id": int(label_id), | |
| "mean_slope": p["pitch_deg"], | |
| "mean_aspect": p["azimuth_deg"], | |
| "mean_height": p["mean_height"], | |
| "area_pixels": int(n_pixels), | |
| "area_sqm": float(n_pixels * 0.1 ** 2), | |
| "plane_normal": p["normal"], | |
| "plane_d": p["plane_d"], | |
| }) | |
| else: | |
| # Compute from DSM | |
| heights = dsm_smooth[mask] | |
| dy, dx = np.gradient(dsm_smooth) | |
| slope_vals = np.degrees(np.arctan(np.sqrt(dx[mask]**2 + dy[mask]**2) / 0.1)) | |
| aspect_vals = np.degrees(np.arctan2(-dx[mask], dy[mask])) % 360 | |
| info.append({ | |
| "segment_id": int(label_id), | |
| "mean_slope": float(np.median(slope_vals)), | |
| "mean_aspect": float(np.median(aspect_vals)), | |
| "mean_height": float(heights.mean()), | |
| "area_pixels": int(n_pixels), | |
| "area_sqm": float(n_pixels * 0.1 ** 2), | |
| }) | |
| return info | |
| def build_overlay( | |
| image: np.ndarray, | |
| label_map: np.ndarray, | |
| plane_info: list[dict], | |
| alpha: float = 0.45, | |
| ) -> np.ndarray: | |
| """Build a visualization overlay: colored planes on the original image. | |
| Each plane gets a distinct color with alpha blending. | |
| White contour outlines and plane ID labels are drawn on top. | |
| """ | |
| overlay = image.copy() | |
| color_layer = np.zeros_like(image) | |
| for i, info in enumerate(plane_info): | |
| seg_id = info["segment_id"] | |
| color = PLANE_COLORS[i % len(PLANE_COLORS)] | |
| mask = label_map == seg_id | |
| color_layer[mask] = color | |
| # Alpha blend | |
| has_label = label_map > 0 | |
| overlay[has_label] = ( | |
| (1 - alpha) * image[has_label].astype(float) | |
| + alpha * color_layer[has_label].astype(float) | |
| ).astype(np.uint8) | |
| # Draw contours | |
| for info in plane_info: | |
| seg_id = info["segment_id"] | |
| mask = (label_map == seg_id).astype(np.uint8) | |
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(overlay, contours, -1, (255, 255, 255), 1, cv2.LINE_AA) | |
| # Label text | |
| if contours: | |
| M = cv2.moments(contours[0]) | |
| if M["m00"] > 0: | |
| cx = int(M["m10"] / M["m00"]) | |
| cy = int(M["m01"] / M["m00"]) | |
| pitch = info.get("mean_slope", 0) | |
| text = f"P{seg_id}: {pitch:.0f} deg" | |
| cv2.putText(overlay, text, (cx - 30, cy), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255, 255, 255), 1, cv2.LINE_AA) | |
| return overlay | |