RoofSegmentation2 / radio_backbone.py
Deagin's picture
Fix: Only free modules > 1GB (text encoder), keep feat_mlp
0462568
"""NVIDIA C-RADIOv4-H unified vision backbone.
Distills DINOv3-7B + SAM3 + SigLIP2 into a single 631M-param encoder.
Uses the SigLIP2 adaptor head for zero-shot text-prompted roof segmentation.
Memory optimization: the SigLIP2 text encoder (~7.5GB) is loaded once to
pre-compute text embeddings for our fixed prompt set, then freed from RAM.
Only the vision backbone + adaptor projection head are kept (~1.6GB).
"""
import gc
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from einops import rearrange
PATCH_SIZE = 16
# Zero-shot text prompts for roof segmentation
ROOF_PROMPTS = [
"flat roof plane",
"pitched roof plane",
"hip roof plane",
"gable roof plane",
]
NON_ROOF_PROMPTS = [
"sky",
"ground",
"tree",
"wall",
"shadow",
]
# Module-level caches
_model = None
_device = None
_cached_text_embeddings = None # Pre-computed for ROOF_PROMPTS + NON_ROOF_PROMPTS
def load_model(device: str = "cuda", vitdet_window_size: int = 8):
"""Load C-RADIOv4-H with siglip2-g adaptor.
Pre-computes text embeddings for fixed prompts, then frees the
SigLIP2 text encoder to reclaim ~7.5GB of RAM.
Args:
device: 'cuda' or 'cpu'. For ZeroGPU, load to 'cpu' at startup
and move to 'cuda' per-request.
vitdet_window_size: Window size for ViTDet attention (8 = input
must be multiples of 128px). Set to None for
global attention.
Returns:
The loaded model.
"""
global _model, _device, _cached_text_embeddings
if _model is not None:
return _model
print("Loading C-RADIOv4-H (631M params)...")
kwargs = {
"version": "c-radio_v4-h",
"adaptor_names": ["siglip2-g"],
"progress": True,
"skip_validation": True,
}
if vitdet_window_size is not None:
kwargs["vitdet_window_size"] = vitdet_window_size
_model = torch.hub.load("NVlabs/RADIO", "radio_model", **kwargs)
_model.eval()
_model.to(device)
_device = device
# --- Pre-compute text embeddings, then free the text encoder ---
all_labels = ROOF_PROMPTS + NON_ROOF_PROMPTS
print(f"Caching text embeddings for {len(all_labels)} prompts...")
sig2_adaptor = _model.adaptors["siglip2-g"]
text_input = sig2_adaptor.tokenizer(all_labels).to(device)
with torch.no_grad():
_cached_text_embeddings = sig2_adaptor.encode_text(
text_input, normalize=True
).cpu().clone()
# Free the heavy SigLIP2 text encoder (~7.5GB)
_free_text_encoder(sig2_adaptor)
print(f"C-RADIOv4-H loaded on {device} (text encoder freed, embeddings cached)")
return _model
def _free_text_encoder(adaptor):
"""Delete the SigLIP2 text encoder from the adaptor to free ~7.5GB RAM.
Only targets modules > 1GB — the text encoder is ~7.5GB while vision
projection heads (feat_mlp, summary_mlp, etc.) are < 1GB.
"""
freed = 0
# Log all modules so we can see what's there
print(" Adaptor modules:")
for name, module in adaptor.named_children():
param_bytes = sum(
p.numel() * p.element_size() for p in module.parameters()
)
print(f" {name}: {param_bytes / 1e6:.0f} MB")
# Only delete modules > 1GB (the text encoder is ~7.5GB,
# vision projection heads like feat_mlp are < 1GB)
for name in list(dict(adaptor.named_children()).keys()):
module = getattr(adaptor, name)
param_bytes = sum(
p.numel() * p.element_size() for p in module.parameters()
)
if param_bytes > 1_000_000_000: # > 1GB
size_gb = param_bytes / 1e9
print(f" Freeing adaptor.{name} ({size_gb:.1f} GB)")
try:
delattr(adaptor, name)
freed += param_bytes
except Exception:
pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if freed > 0:
print(f" Total freed: {freed / 1e9:.1f} GB")
else:
print(" Warning: no module > 1GB found — text encoder may still be in RAM")
def get_model():
"""Get the cached model, loading if necessary."""
global _model
if _model is None:
_model = load_model()
return _model
def move_to(device: str):
"""Move the model to a different device (for ZeroGPU)."""
global _model, _device, _cached_text_embeddings
if _model is not None and _device != device:
_model.to(device)
_device = device
# Text embeddings stay on CPU; moved to device in zero_shot_segment
def prepare_image(
image: np.ndarray | Image.Image,
model=None,
) -> tuple[torch.Tensor, tuple[int, int], tuple[int, int]]:
"""Prepare image for C-RADIOv4-H inference.
Converts to [0, 1] float tensor and snaps to nearest supported
resolution (must be multiples of patch_size * window_size).
Args:
image: RGB image as numpy array (H, W, 3) or PIL Image.
model: The RADIO model (for resolution snapping).
Returns:
(pixel_values: (1, 3, H', W'), original_size, snapped_size)
"""
if model is None:
model = get_model()
if isinstance(image, Image.Image):
image = np.array(image.convert("RGB"))
original_size = (image.shape[0], image.shape[1])
# Convert HWC uint8 -> CHW float [0, 1]
x = torch.from_numpy(image).permute(2, 0, 1).float().div_(255.0)
x = x.unsqueeze(0) # (1, 3, H, W)
# Snap to nearest supported resolution
snapped = model.get_nearest_supported_resolution(*x.shape[-2:])
if snapped != x.shape[-2:]:
x = F.interpolate(x, snapped, mode="bilinear", align_corners=False)
return x, original_size, snapped
def zero_shot_segment(
image: np.ndarray | Image.Image,
roof_prompts: list[str] = ROOF_PROMPTS,
non_roof_prompts: list[str] = NON_ROOF_PROMPTS,
model=None,
device: str = "cuda",
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""Zero-shot roof segmentation via RADSeg approach.
Uses SigLIP2 adaptor to create dense language-aligned patch features,
then computes cosine similarity against pre-computed text embeddings.
Args:
image: RGB image.
roof_prompts: Text labels for roof types (must match startup prompts).
non_roof_prompts: Text labels for non-roof classes.
model: C-RADIOv4-H model.
device: Compute device.
Returns:
(score_map: H x W x C float, seg_map: H x W int, all_labels: list[str])
where seg_map[y,x] is the index into all_labels.
"""
global _cached_text_embeddings
if model is None:
model = get_model()
pixel_values, original_size, snapped_size = prepare_image(image, model)
pixel_values = pixel_values.to(device)
all_labels = roof_prompts + non_roof_prompts
with torch.no_grad(), torch.autocast(device, dtype=torch.bfloat16):
vis_output = model(pixel_values)
# Get SigLIP2-aligned spatial features
sig2_summary, sig2_features = vis_output["siglip2-g"]
# Use pre-computed text embeddings (cached at startup)
text_embeddings = _cached_text_embeddings.to(device)
# Cosine similarity: (1, T, D) vs (C, D) -> (1, T, C)
dense_features = F.normalize(sig2_features.float(), dim=-1)
text_embeddings = text_embeddings.float()
scores = torch.einsum("btd,cd->btc", dense_features, text_embeddings)
# Reshape to spatial grid
h_patches = snapped_size[0] // PATCH_SIZE
w_patches = snapped_size[1] // PATCH_SIZE
score_map = rearrange(scores, "b (h w) c -> b c h w", h=h_patches, w=w_patches)
# Upsample to original image size
score_map = F.interpolate(
score_map, size=original_size, mode="bilinear", align_corners=False
)
# (1, C, H, W) -> (H, W, C)
score_map_np = score_map[0].permute(1, 2, 0).cpu().numpy()
seg_map = score_map[0].argmax(dim=0).cpu().numpy().astype(np.int32)
return score_map_np, seg_map, all_labels
def get_roof_mask(seg_map: np.ndarray, num_roof_classes: int = 4) -> np.ndarray:
"""Extract binary roof mask from segmentation map.
Assumes first num_roof_classes indices in the label list are roof types.
"""
return (seg_map < num_roof_classes).astype(np.uint8)