"""NVIDIA C-RADIOv4-H unified vision backbone. Distills DINOv3-7B + SAM3 + SigLIP2 into a single 631M-param encoder. Uses the SigLIP2 adaptor head for zero-shot text-prompted roof segmentation. Memory optimization: the SigLIP2 text encoder (~7.5GB) is loaded once to pre-compute text embeddings for our fixed prompt set, then freed from RAM. Only the vision backbone + adaptor projection head are kept (~1.6GB). """ import gc import numpy as np import torch import torch.nn.functional as F from PIL import Image from einops import rearrange PATCH_SIZE = 16 # Zero-shot text prompts for roof segmentation ROOF_PROMPTS = [ "flat roof plane", "pitched roof plane", "hip roof plane", "gable roof plane", ] NON_ROOF_PROMPTS = [ "sky", "ground", "tree", "wall", "shadow", ] # Module-level caches _model = None _device = None _cached_text_embeddings = None # Pre-computed for ROOF_PROMPTS + NON_ROOF_PROMPTS def load_model(device: str = "cuda", vitdet_window_size: int = 8): """Load C-RADIOv4-H with siglip2-g adaptor. Pre-computes text embeddings for fixed prompts, then frees the SigLIP2 text encoder to reclaim ~7.5GB of RAM. Args: device: 'cuda' or 'cpu'. For ZeroGPU, load to 'cpu' at startup and move to 'cuda' per-request. vitdet_window_size: Window size for ViTDet attention (8 = input must be multiples of 128px). Set to None for global attention. Returns: The loaded model. """ global _model, _device, _cached_text_embeddings if _model is not None: return _model print("Loading C-RADIOv4-H (631M params)...") kwargs = { "version": "c-radio_v4-h", "adaptor_names": ["siglip2-g"], "progress": True, "skip_validation": True, } if vitdet_window_size is not None: kwargs["vitdet_window_size"] = vitdet_window_size _model = torch.hub.load("NVlabs/RADIO", "radio_model", **kwargs) _model.eval() _model.to(device) _device = device # --- Pre-compute text embeddings, then free the text encoder --- all_labels = ROOF_PROMPTS + NON_ROOF_PROMPTS print(f"Caching text embeddings for {len(all_labels)} prompts...") sig2_adaptor = _model.adaptors["siglip2-g"] text_input = sig2_adaptor.tokenizer(all_labels).to(device) with torch.no_grad(): _cached_text_embeddings = sig2_adaptor.encode_text( text_input, normalize=True ).cpu().clone() # Free the heavy SigLIP2 text encoder (~7.5GB) _free_text_encoder(sig2_adaptor) print(f"C-RADIOv4-H loaded on {device} (text encoder freed, embeddings cached)") return _model def _free_text_encoder(adaptor): """Delete the SigLIP2 text encoder from the adaptor to free ~7.5GB RAM. Only targets modules > 1GB — the text encoder is ~7.5GB while vision projection heads (feat_mlp, summary_mlp, etc.) are < 1GB. """ freed = 0 # Log all modules so we can see what's there print(" Adaptor modules:") for name, module in adaptor.named_children(): param_bytes = sum( p.numel() * p.element_size() for p in module.parameters() ) print(f" {name}: {param_bytes / 1e6:.0f} MB") # Only delete modules > 1GB (the text encoder is ~7.5GB, # vision projection heads like feat_mlp are < 1GB) for name in list(dict(adaptor.named_children()).keys()): module = getattr(adaptor, name) param_bytes = sum( p.numel() * p.element_size() for p in module.parameters() ) if param_bytes > 1_000_000_000: # > 1GB size_gb = param_bytes / 1e9 print(f" Freeing adaptor.{name} ({size_gb:.1f} GB)") try: delattr(adaptor, name) freed += param_bytes except Exception: pass gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() if freed > 0: print(f" Total freed: {freed / 1e9:.1f} GB") else: print(" Warning: no module > 1GB found — text encoder may still be in RAM") def get_model(): """Get the cached model, loading if necessary.""" global _model if _model is None: _model = load_model() return _model def move_to(device: str): """Move the model to a different device (for ZeroGPU).""" global _model, _device, _cached_text_embeddings if _model is not None and _device != device: _model.to(device) _device = device # Text embeddings stay on CPU; moved to device in zero_shot_segment def prepare_image( image: np.ndarray | Image.Image, model=None, ) -> tuple[torch.Tensor, tuple[int, int], tuple[int, int]]: """Prepare image for C-RADIOv4-H inference. Converts to [0, 1] float tensor and snaps to nearest supported resolution (must be multiples of patch_size * window_size). Args: image: RGB image as numpy array (H, W, 3) or PIL Image. model: The RADIO model (for resolution snapping). Returns: (pixel_values: (1, 3, H', W'), original_size, snapped_size) """ if model is None: model = get_model() if isinstance(image, Image.Image): image = np.array(image.convert("RGB")) original_size = (image.shape[0], image.shape[1]) # Convert HWC uint8 -> CHW float [0, 1] x = torch.from_numpy(image).permute(2, 0, 1).float().div_(255.0) x = x.unsqueeze(0) # (1, 3, H, W) # Snap to nearest supported resolution snapped = model.get_nearest_supported_resolution(*x.shape[-2:]) if snapped != x.shape[-2:]: x = F.interpolate(x, snapped, mode="bilinear", align_corners=False) return x, original_size, snapped def zero_shot_segment( image: np.ndarray | Image.Image, roof_prompts: list[str] = ROOF_PROMPTS, non_roof_prompts: list[str] = NON_ROOF_PROMPTS, model=None, device: str = "cuda", ) -> tuple[np.ndarray, np.ndarray, list[str]]: """Zero-shot roof segmentation via RADSeg approach. Uses SigLIP2 adaptor to create dense language-aligned patch features, then computes cosine similarity against pre-computed text embeddings. Args: image: RGB image. roof_prompts: Text labels for roof types (must match startup prompts). non_roof_prompts: Text labels for non-roof classes. model: C-RADIOv4-H model. device: Compute device. Returns: (score_map: H x W x C float, seg_map: H x W int, all_labels: list[str]) where seg_map[y,x] is the index into all_labels. """ global _cached_text_embeddings if model is None: model = get_model() pixel_values, original_size, snapped_size = prepare_image(image, model) pixel_values = pixel_values.to(device) all_labels = roof_prompts + non_roof_prompts with torch.no_grad(), torch.autocast(device, dtype=torch.bfloat16): vis_output = model(pixel_values) # Get SigLIP2-aligned spatial features sig2_summary, sig2_features = vis_output["siglip2-g"] # Use pre-computed text embeddings (cached at startup) text_embeddings = _cached_text_embeddings.to(device) # Cosine similarity: (1, T, D) vs (C, D) -> (1, T, C) dense_features = F.normalize(sig2_features.float(), dim=-1) text_embeddings = text_embeddings.float() scores = torch.einsum("btd,cd->btc", dense_features, text_embeddings) # Reshape to spatial grid h_patches = snapped_size[0] // PATCH_SIZE w_patches = snapped_size[1] // PATCH_SIZE score_map = rearrange(scores, "b (h w) c -> b c h w", h=h_patches, w=w_patches) # Upsample to original image size score_map = F.interpolate( score_map, size=original_size, mode="bilinear", align_corners=False ) # (1, C, H, W) -> (H, W, C) score_map_np = score_map[0].permute(1, 2, 0).cpu().numpy() seg_map = score_map[0].argmax(dim=0).cpu().numpy().astype(np.int32) return score_map_np, seg_map, all_labels def get_roof_mask(seg_map: np.ndarray, num_roof_classes: int = 4) -> np.ndarray: """Extract binary roof mask from segmentation map. Assumes first num_roof_classes indices in the label list are roof types. """ return (seg_map < num_roof_classes).astype(np.uint8)