Spaces:
Runtime error
Runtime error
File size: 8,329 Bytes
5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 0462568 8266ce5 0462568 8266ce5 0462568 8266ce5 0462568 8266ce5 0462568 8266ce5 0462568 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 8266ce5 5c52fb9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 | """NVIDIA C-RADIOv4-H unified vision backbone.
Distills DINOv3-7B + SAM3 + SigLIP2 into a single 631M-param encoder.
Uses the SigLIP2 adaptor head for zero-shot text-prompted roof segmentation.
Memory optimization: the SigLIP2 text encoder (~7.5GB) is loaded once to
pre-compute text embeddings for our fixed prompt set, then freed from RAM.
Only the vision backbone + adaptor projection head are kept (~1.6GB).
"""
import gc
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from einops import rearrange
PATCH_SIZE = 16
# Zero-shot text prompts for roof segmentation
ROOF_PROMPTS = [
"flat roof plane",
"pitched roof plane",
"hip roof plane",
"gable roof plane",
]
NON_ROOF_PROMPTS = [
"sky",
"ground",
"tree",
"wall",
"shadow",
]
# Module-level caches
_model = None
_device = None
_cached_text_embeddings = None # Pre-computed for ROOF_PROMPTS + NON_ROOF_PROMPTS
def load_model(device: str = "cuda", vitdet_window_size: int = 8):
"""Load C-RADIOv4-H with siglip2-g adaptor.
Pre-computes text embeddings for fixed prompts, then frees the
SigLIP2 text encoder to reclaim ~7.5GB of RAM.
Args:
device: 'cuda' or 'cpu'. For ZeroGPU, load to 'cpu' at startup
and move to 'cuda' per-request.
vitdet_window_size: Window size for ViTDet attention (8 = input
must be multiples of 128px). Set to None for
global attention.
Returns:
The loaded model.
"""
global _model, _device, _cached_text_embeddings
if _model is not None:
return _model
print("Loading C-RADIOv4-H (631M params)...")
kwargs = {
"version": "c-radio_v4-h",
"adaptor_names": ["siglip2-g"],
"progress": True,
"skip_validation": True,
}
if vitdet_window_size is not None:
kwargs["vitdet_window_size"] = vitdet_window_size
_model = torch.hub.load("NVlabs/RADIO", "radio_model", **kwargs)
_model.eval()
_model.to(device)
_device = device
# --- Pre-compute text embeddings, then free the text encoder ---
all_labels = ROOF_PROMPTS + NON_ROOF_PROMPTS
print(f"Caching text embeddings for {len(all_labels)} prompts...")
sig2_adaptor = _model.adaptors["siglip2-g"]
text_input = sig2_adaptor.tokenizer(all_labels).to(device)
with torch.no_grad():
_cached_text_embeddings = sig2_adaptor.encode_text(
text_input, normalize=True
).cpu().clone()
# Free the heavy SigLIP2 text encoder (~7.5GB)
_free_text_encoder(sig2_adaptor)
print(f"C-RADIOv4-H loaded on {device} (text encoder freed, embeddings cached)")
return _model
def _free_text_encoder(adaptor):
"""Delete the SigLIP2 text encoder from the adaptor to free ~7.5GB RAM.
Only targets modules > 1GB — the text encoder is ~7.5GB while vision
projection heads (feat_mlp, summary_mlp, etc.) are < 1GB.
"""
freed = 0
# Log all modules so we can see what's there
print(" Adaptor modules:")
for name, module in adaptor.named_children():
param_bytes = sum(
p.numel() * p.element_size() for p in module.parameters()
)
print(f" {name}: {param_bytes / 1e6:.0f} MB")
# Only delete modules > 1GB (the text encoder is ~7.5GB,
# vision projection heads like feat_mlp are < 1GB)
for name in list(dict(adaptor.named_children()).keys()):
module = getattr(adaptor, name)
param_bytes = sum(
p.numel() * p.element_size() for p in module.parameters()
)
if param_bytes > 1_000_000_000: # > 1GB
size_gb = param_bytes / 1e9
print(f" Freeing adaptor.{name} ({size_gb:.1f} GB)")
try:
delattr(adaptor, name)
freed += param_bytes
except Exception:
pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if freed > 0:
print(f" Total freed: {freed / 1e9:.1f} GB")
else:
print(" Warning: no module > 1GB found — text encoder may still be in RAM")
def get_model():
"""Get the cached model, loading if necessary."""
global _model
if _model is None:
_model = load_model()
return _model
def move_to(device: str):
"""Move the model to a different device (for ZeroGPU)."""
global _model, _device, _cached_text_embeddings
if _model is not None and _device != device:
_model.to(device)
_device = device
# Text embeddings stay on CPU; moved to device in zero_shot_segment
def prepare_image(
image: np.ndarray | Image.Image,
model=None,
) -> tuple[torch.Tensor, tuple[int, int], tuple[int, int]]:
"""Prepare image for C-RADIOv4-H inference.
Converts to [0, 1] float tensor and snaps to nearest supported
resolution (must be multiples of patch_size * window_size).
Args:
image: RGB image as numpy array (H, W, 3) or PIL Image.
model: The RADIO model (for resolution snapping).
Returns:
(pixel_values: (1, 3, H', W'), original_size, snapped_size)
"""
if model is None:
model = get_model()
if isinstance(image, Image.Image):
image = np.array(image.convert("RGB"))
original_size = (image.shape[0], image.shape[1])
# Convert HWC uint8 -> CHW float [0, 1]
x = torch.from_numpy(image).permute(2, 0, 1).float().div_(255.0)
x = x.unsqueeze(0) # (1, 3, H, W)
# Snap to nearest supported resolution
snapped = model.get_nearest_supported_resolution(*x.shape[-2:])
if snapped != x.shape[-2:]:
x = F.interpolate(x, snapped, mode="bilinear", align_corners=False)
return x, original_size, snapped
def zero_shot_segment(
image: np.ndarray | Image.Image,
roof_prompts: list[str] = ROOF_PROMPTS,
non_roof_prompts: list[str] = NON_ROOF_PROMPTS,
model=None,
device: str = "cuda",
) -> tuple[np.ndarray, np.ndarray, list[str]]:
"""Zero-shot roof segmentation via RADSeg approach.
Uses SigLIP2 adaptor to create dense language-aligned patch features,
then computes cosine similarity against pre-computed text embeddings.
Args:
image: RGB image.
roof_prompts: Text labels for roof types (must match startup prompts).
non_roof_prompts: Text labels for non-roof classes.
model: C-RADIOv4-H model.
device: Compute device.
Returns:
(score_map: H x W x C float, seg_map: H x W int, all_labels: list[str])
where seg_map[y,x] is the index into all_labels.
"""
global _cached_text_embeddings
if model is None:
model = get_model()
pixel_values, original_size, snapped_size = prepare_image(image, model)
pixel_values = pixel_values.to(device)
all_labels = roof_prompts + non_roof_prompts
with torch.no_grad(), torch.autocast(device, dtype=torch.bfloat16):
vis_output = model(pixel_values)
# Get SigLIP2-aligned spatial features
sig2_summary, sig2_features = vis_output["siglip2-g"]
# Use pre-computed text embeddings (cached at startup)
text_embeddings = _cached_text_embeddings.to(device)
# Cosine similarity: (1, T, D) vs (C, D) -> (1, T, C)
dense_features = F.normalize(sig2_features.float(), dim=-1)
text_embeddings = text_embeddings.float()
scores = torch.einsum("btd,cd->btc", dense_features, text_embeddings)
# Reshape to spatial grid
h_patches = snapped_size[0] // PATCH_SIZE
w_patches = snapped_size[1] // PATCH_SIZE
score_map = rearrange(scores, "b (h w) c -> b c h w", h=h_patches, w=w_patches)
# Upsample to original image size
score_map = F.interpolate(
score_map, size=original_size, mode="bilinear", align_corners=False
)
# (1, C, H, W) -> (H, W, C)
score_map_np = score_map[0].permute(1, 2, 0).cpu().numpy()
seg_map = score_map[0].argmax(dim=0).cpu().numpy().astype(np.int32)
return score_map_np, seg_map, all_labels
def get_roof_mask(seg_map: np.ndarray, num_roof_classes: int = 4) -> np.ndarray:
"""Extract binary roof mask from segmentation map.
Assumes first num_roof_classes indices in the label list are roof types.
"""
return (seg_map < num_roof_classes).astype(np.uint8)
|