Spaces:
Paused
Paused
| """NVIDIA C-RADIOv4-H unified vision backbone. | |
| Distills DINOv3-7B + SAM3 + SigLIP2 into a single 631M-param encoder. | |
| Uses the SigLIP2 adaptor head for zero-shot text-prompted roof segmentation. | |
| Memory optimization: the SigLIP2 text encoder (~7.5GB) is loaded once to | |
| pre-compute text embeddings for our fixed prompt set, then freed from RAM. | |
| Only the vision backbone + adaptor projection head are kept (~1.6GB). | |
| """ | |
| import gc | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from PIL import Image | |
| from einops import rearrange | |
| PATCH_SIZE = 16 | |
| # Zero-shot text prompts for roof segmentation | |
| ROOF_PROMPTS = [ | |
| "flat roof plane", | |
| "pitched roof plane", | |
| "hip roof plane", | |
| "gable roof plane", | |
| ] | |
| NON_ROOF_PROMPTS = [ | |
| "sky", | |
| "ground", | |
| "tree", | |
| "wall", | |
| "shadow", | |
| ] | |
| # Module-level caches | |
| _model = None | |
| _device = None | |
| _cached_text_embeddings = None # Pre-computed for ROOF_PROMPTS + NON_ROOF_PROMPTS | |
| def load_model(device: str = "cuda", vitdet_window_size: int = 8): | |
| """Load C-RADIOv4-H with siglip2-g adaptor. | |
| Pre-computes text embeddings for fixed prompts, then frees the | |
| SigLIP2 text encoder to reclaim ~7.5GB of RAM. | |
| Args: | |
| device: 'cuda' or 'cpu'. For ZeroGPU, load to 'cpu' at startup | |
| and move to 'cuda' per-request. | |
| vitdet_window_size: Window size for ViTDet attention (8 = input | |
| must be multiples of 128px). Set to None for | |
| global attention. | |
| Returns: | |
| The loaded model. | |
| """ | |
| global _model, _device, _cached_text_embeddings | |
| if _model is not None: | |
| return _model | |
| print("Loading C-RADIOv4-H (631M params)...") | |
| kwargs = { | |
| "version": "c-radio_v4-h", | |
| "adaptor_names": ["siglip2-g"], | |
| "progress": True, | |
| "skip_validation": True, | |
| } | |
| if vitdet_window_size is not None: | |
| kwargs["vitdet_window_size"] = vitdet_window_size | |
| _model = torch.hub.load("NVlabs/RADIO", "radio_model", **kwargs) | |
| _model.eval() | |
| _model.to(device) | |
| _device = device | |
| # --- Pre-compute text embeddings, then free the text encoder --- | |
| all_labels = ROOF_PROMPTS + NON_ROOF_PROMPTS | |
| print(f"Caching text embeddings for {len(all_labels)} prompts...") | |
| sig2_adaptor = _model.adaptors["siglip2-g"] | |
| text_input = sig2_adaptor.tokenizer(all_labels).to(device) | |
| with torch.no_grad(): | |
| _cached_text_embeddings = sig2_adaptor.encode_text( | |
| text_input, normalize=True | |
| ).cpu().clone() | |
| # Free the heavy SigLIP2 text encoder (~7.5GB) | |
| _free_text_encoder(sig2_adaptor) | |
| print(f"C-RADIOv4-H loaded on {device} (text encoder freed, embeddings cached)") | |
| return _model | |
| def _free_text_encoder(adaptor): | |
| """Delete the SigLIP2 text encoder from the adaptor to free ~7.5GB RAM. | |
| Only targets modules > 1GB — the text encoder is ~7.5GB while vision | |
| projection heads (feat_mlp, summary_mlp, etc.) are < 1GB. | |
| """ | |
| freed = 0 | |
| # Log all modules so we can see what's there | |
| print(" Adaptor modules:") | |
| for name, module in adaptor.named_children(): | |
| param_bytes = sum( | |
| p.numel() * p.element_size() for p in module.parameters() | |
| ) | |
| print(f" {name}: {param_bytes / 1e6:.0f} MB") | |
| # Only delete modules > 1GB (the text encoder is ~7.5GB, | |
| # vision projection heads like feat_mlp are < 1GB) | |
| for name in list(dict(adaptor.named_children()).keys()): | |
| module = getattr(adaptor, name) | |
| param_bytes = sum( | |
| p.numel() * p.element_size() for p in module.parameters() | |
| ) | |
| if param_bytes > 1_000_000_000: # > 1GB | |
| size_gb = param_bytes / 1e9 | |
| print(f" Freeing adaptor.{name} ({size_gb:.1f} GB)") | |
| try: | |
| delattr(adaptor, name) | |
| freed += param_bytes | |
| except Exception: | |
| pass | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if freed > 0: | |
| print(f" Total freed: {freed / 1e9:.1f} GB") | |
| else: | |
| print(" Warning: no module > 1GB found — text encoder may still be in RAM") | |
| def get_model(): | |
| """Get the cached model, loading if necessary.""" | |
| global _model | |
| if _model is None: | |
| _model = load_model() | |
| return _model | |
| def move_to(device: str): | |
| """Move the model to a different device (for ZeroGPU).""" | |
| global _model, _device, _cached_text_embeddings | |
| if _model is not None and _device != device: | |
| _model.to(device) | |
| _device = device | |
| # Text embeddings stay on CPU; moved to device in zero_shot_segment | |
| def prepare_image( | |
| image: np.ndarray | Image.Image, | |
| model=None, | |
| ) -> tuple[torch.Tensor, tuple[int, int], tuple[int, int]]: | |
| """Prepare image for C-RADIOv4-H inference. | |
| Converts to [0, 1] float tensor and snaps to nearest supported | |
| resolution (must be multiples of patch_size * window_size). | |
| Args: | |
| image: RGB image as numpy array (H, W, 3) or PIL Image. | |
| model: The RADIO model (for resolution snapping). | |
| Returns: | |
| (pixel_values: (1, 3, H', W'), original_size, snapped_size) | |
| """ | |
| if model is None: | |
| model = get_model() | |
| if isinstance(image, Image.Image): | |
| image = np.array(image.convert("RGB")) | |
| original_size = (image.shape[0], image.shape[1]) | |
| # Convert HWC uint8 -> CHW float [0, 1] | |
| x = torch.from_numpy(image).permute(2, 0, 1).float().div_(255.0) | |
| x = x.unsqueeze(0) # (1, 3, H, W) | |
| # Snap to nearest supported resolution | |
| snapped = model.get_nearest_supported_resolution(*x.shape[-2:]) | |
| if snapped != x.shape[-2:]: | |
| x = F.interpolate(x, snapped, mode="bilinear", align_corners=False) | |
| return x, original_size, snapped | |
| def zero_shot_segment( | |
| image: np.ndarray | Image.Image, | |
| roof_prompts: list[str] = ROOF_PROMPTS, | |
| non_roof_prompts: list[str] = NON_ROOF_PROMPTS, | |
| model=None, | |
| device: str = "cuda", | |
| ) -> tuple[np.ndarray, np.ndarray, list[str]]: | |
| """Zero-shot roof segmentation via RADSeg approach. | |
| Uses SigLIP2 adaptor to create dense language-aligned patch features, | |
| then computes cosine similarity against pre-computed text embeddings. | |
| Args: | |
| image: RGB image. | |
| roof_prompts: Text labels for roof types (must match startup prompts). | |
| non_roof_prompts: Text labels for non-roof classes. | |
| model: C-RADIOv4-H model. | |
| device: Compute device. | |
| Returns: | |
| (score_map: H x W x C float, seg_map: H x W int, all_labels: list[str]) | |
| where seg_map[y,x] is the index into all_labels. | |
| """ | |
| global _cached_text_embeddings | |
| if model is None: | |
| model = get_model() | |
| pixel_values, original_size, snapped_size = prepare_image(image, model) | |
| pixel_values = pixel_values.to(device) | |
| all_labels = roof_prompts + non_roof_prompts | |
| with torch.no_grad(), torch.autocast(device, dtype=torch.bfloat16): | |
| vis_output = model(pixel_values) | |
| # Get SigLIP2-aligned spatial features | |
| sig2_summary, sig2_features = vis_output["siglip2-g"] | |
| # Use pre-computed text embeddings (cached at startup) | |
| text_embeddings = _cached_text_embeddings.to(device) | |
| # Cosine similarity: (1, T, D) vs (C, D) -> (1, T, C) | |
| dense_features = F.normalize(sig2_features.float(), dim=-1) | |
| text_embeddings = text_embeddings.float() | |
| scores = torch.einsum("btd,cd->btc", dense_features, text_embeddings) | |
| # Reshape to spatial grid | |
| h_patches = snapped_size[0] // PATCH_SIZE | |
| w_patches = snapped_size[1] // PATCH_SIZE | |
| score_map = rearrange(scores, "b (h w) c -> b c h w", h=h_patches, w=w_patches) | |
| # Upsample to original image size | |
| score_map = F.interpolate( | |
| score_map, size=original_size, mode="bilinear", align_corners=False | |
| ) | |
| # (1, C, H, W) -> (H, W, C) | |
| score_map_np = score_map[0].permute(1, 2, 0).cpu().numpy() | |
| seg_map = score_map[0].argmax(dim=0).cpu().numpy().astype(np.int32) | |
| return score_map_np, seg_map, all_labels | |
| def get_roof_mask(seg_map: np.ndarray, num_roof_classes: int = 4) -> np.ndarray: | |
| """Extract binary roof mask from segmentation map. | |
| Assumes first num_roof_classes indices in the label list are roof types. | |
| """ | |
| return (seg_map < num_roof_classes).astype(np.uint8) | |