lsnu's picture
2026-03-25 runpod handoff update
e7d8e79 verified
from __future__ import annotations
from dataclasses import dataclass
import math
import os
from pathlib import Path
from typing import Sequence
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from pytorch3d.transforms import matrix_to_quaternion
@dataclass
class FrozenVLBackboneConfig:
model_name: str = "openai/clip-vit-base-patch32"
hidden_dim: int = 512
max_text_tokens: int = 32
freeze_backbone: bool = True
gradient_checkpointing: bool = True
use_dummy_backbone: bool = False
depth_patch_size: int = 16
geometry_feature_dim: int = 8
use_camera_geometry: bool = True
use_depth_tokens: bool = True
use_geometry_tokens: bool = True
use_camera_pose_tokens: bool = True
class DepthPatchAdapter(nn.Module):
def __init__(
self,
hidden_dim: int,
patch_size: int = 16,
geometry_feature_dim: int = 8,
) -> None:
super().__init__()
self.hidden_dim = hidden_dim
self.patch_size = patch_size
self.geometry_feature_dim = geometry_feature_dim
self.depth_proj = nn.Sequential(
nn.LayerNorm(2 + geometry_feature_dim),
nn.Linear(2 + geometry_feature_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
)
self.geometry_proj = nn.Sequential(
nn.LayerNorm(geometry_feature_dim),
nn.Linear(geometry_feature_dim, hidden_dim),
nn.GELU(),
)
self.camera_proj = nn.Sequential(
nn.LayerNorm(7),
nn.Linear(7, hidden_dim),
nn.GELU(),
)
def _patchify(self, tensor: Tensor) -> Tensor:
pooled = F.avg_pool2d(tensor, kernel_size=self.patch_size, stride=self.patch_size)
return pooled.flatten(2).transpose(1, 2)
def _geometry_features(
self,
depths: Tensor,
camera_intrinsics: Tensor | None = None,
camera_extrinsics: Tensor | None = None,
) -> tuple[Tensor, Tensor]:
batch_views, _, height, width = depths.shape
grid_h = max(1, height // self.patch_size)
grid_w = max(1, width // self.patch_size)
patch_center_y = torch.linspace(
self.patch_size * 0.5,
max(self.patch_size * 0.5, height - (self.patch_size * 0.5)),
steps=grid_h,
device=depths.device,
dtype=depths.dtype,
)
patch_center_x = torch.linspace(
self.patch_size * 0.5,
max(self.patch_size * 0.5, width - (self.patch_size * 0.5)),
steps=grid_w,
device=depths.device,
dtype=depths.dtype,
)
pixel_y, pixel_x = torch.meshgrid(patch_center_y, patch_center_x, indexing="ij")
norm_x = ((pixel_x / max(width - 1, 1)) * 2.0 - 1.0).reshape(1, grid_h * grid_w, 1)
norm_y = ((pixel_y / max(height - 1, 1)) * 2.0 - 1.0).reshape(1, grid_h * grid_w, 1)
coords = torch.cat([norm_x, norm_y], dim=-1).expand(batch_views, -1, -1)
if camera_intrinsics is not None:
fx = camera_intrinsics[:, 0, 0].unsqueeze(-1)
fy = camera_intrinsics[:, 1, 1].unsqueeze(-1)
cx = camera_intrinsics[:, 0, 2].unsqueeze(-1)
cy = camera_intrinsics[:, 1, 2].unsqueeze(-1)
patch_x = pixel_x.reshape(1, grid_h * grid_w).expand(batch_views, -1)
patch_y = pixel_y.reshape(1, grid_h * grid_w).expand(batch_views, -1)
ray_x = (patch_x - cx) / fx.clamp_min(1e-6)
ray_y = (patch_y - cy) / fy.clamp_min(1e-6)
else:
ray_x = coords[..., 0]
ray_y = coords[..., 1]
ray_camera = torch.stack([ray_x, ray_y, torch.ones_like(ray_x)], dim=-1)
ray_camera = F.normalize(ray_camera, dim=-1)
if camera_extrinsics is not None:
rotation = camera_extrinsics[:, :3, :3]
translation = camera_extrinsics[:, :3, 3].unsqueeze(1).expand(-1, grid_h * grid_w, -1)
ray_world = torch.matmul(rotation, ray_camera.transpose(1, 2)).transpose(1, 2)
quaternion = matrix_to_quaternion(rotation)
else:
rotation = None
translation = torch.zeros(batch_views, grid_h * grid_w, 3, device=depths.device, dtype=depths.dtype)
ray_world = ray_camera
quaternion = torch.zeros(batch_views, 4, device=depths.device, dtype=depths.dtype)
quaternion[:, 0] = 1.0
geometry = torch.cat([coords, ray_world, translation], dim=-1)
if geometry.shape[-1] < self.geometry_feature_dim:
pad = self.geometry_feature_dim - geometry.shape[-1]
geometry = F.pad(geometry, (0, pad))
elif geometry.shape[-1] > self.geometry_feature_dim:
geometry = geometry[..., : self.geometry_feature_dim]
if camera_extrinsics is not None:
translation_summary = camera_extrinsics[:, :3, 3]
else:
translation_summary = torch.zeros(batch_views, 3, device=depths.device, dtype=depths.dtype)
camera_summary = torch.cat([quaternion, translation_summary], dim=-1)
return geometry, camera_summary
def forward(
self,
depths: Tensor,
depth_valid: Tensor | None = None,
camera_intrinsics: Tensor | None = None,
camera_extrinsics: Tensor | None = None,
include_geometry_features: bool = True,
include_camera_pose: bool = True,
) -> dict[str, Tensor]:
if depths.ndim == 4:
depths = depths.unsqueeze(2)
if depth_valid is None:
depth_valid = torch.ones_like(depths)
if depth_valid.ndim == 4:
depth_valid = depth_valid.unsqueeze(2)
if depths.ndim != 5:
raise ValueError(f"Expected depths to have shape [B, V, H, W] or [B, V, 1, H, W], got {tuple(depths.shape)}")
if depths.shape[2] != 1:
depths = depths.mean(dim=2, keepdim=True)
if depth_valid.shape[2] != 1:
depth_valid = depth_valid.mean(dim=2, keepdim=True)
batch_size, num_views = depths.shape[:2]
flat_depths = depths.reshape(batch_size * num_views, 1, depths.shape[-2], depths.shape[-1]).float()
flat_valid = depth_valid.reshape(batch_size * num_views, 1, depth_valid.shape[-2], depth_valid.shape[-1]).float()
flat_intrinsics = None
flat_extrinsics = None
if camera_intrinsics is not None:
flat_intrinsics = camera_intrinsics.reshape(batch_size * num_views, *camera_intrinsics.shape[-2:]).float()
if camera_extrinsics is not None:
flat_extrinsics = camera_extrinsics.reshape(batch_size * num_views, *camera_extrinsics.shape[-2:]).float()
depth_patch = self._patchify(flat_depths)
valid_patch = self._patchify(flat_valid)
geometry_features, camera_summary = self._geometry_features(
flat_depths,
camera_intrinsics=flat_intrinsics,
camera_extrinsics=flat_extrinsics,
)
if not include_geometry_features:
geometry_features = torch.zeros_like(geometry_features)
if not include_camera_pose:
camera_summary = torch.zeros_like(camera_summary)
# Keep depth tokens depth-only so depth, geometry, and pose ablations are separable.
token_inputs = torch.cat([depth_patch, valid_patch, torch.zeros_like(geometry_features)], dim=-1)
depth_tokens = self.depth_proj(token_inputs)
geometry_tokens = self.geometry_proj(geometry_features)
camera_tokens = self.camera_proj(camera_summary).unsqueeze(1)
return {
"depth_tokens": depth_tokens.view(batch_size, num_views, depth_tokens.shape[1], depth_tokens.shape[2]),
"geometry_tokens": geometry_tokens.view(batch_size, num_views, geometry_tokens.shape[1], geometry_tokens.shape[2]),
"camera_tokens": camera_tokens.view(batch_size, num_views, 1, camera_tokens.shape[-1]),
}
class _DummyTextTokenizer:
def __init__(self, vocab_size: int = 8192, max_length: int = 32) -> None:
self.vocab_size = vocab_size
self.max_length = max_length
def __call__(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]:
token_ids = torch.zeros((len(texts), self.max_length), dtype=torch.long, device=device)
attention_mask = torch.zeros_like(token_ids)
for row, text in enumerate(texts):
encoded = [min(ord(char), self.vocab_size - 1) for char in text[: self.max_length]]
if encoded:
token_ids[row, : len(encoded)] = torch.tensor(encoded, dtype=torch.long, device=device)
attention_mask[row, : len(encoded)] = 1
return {"input_ids": token_ids, "attention_mask": attention_mask}
class FrozenVLBackbone(nn.Module):
def __init__(self, config: FrozenVLBackboneConfig) -> None:
super().__init__()
self.config = config
self.hidden_dim = config.hidden_dim
self.use_dummy_backbone = config.use_dummy_backbone
self.depth_adapter = DepthPatchAdapter(
hidden_dim=config.hidden_dim,
patch_size=config.depth_patch_size,
geometry_feature_dim=config.geometry_feature_dim,
)
if config.use_dummy_backbone:
self.image_patch_size = 16
self.tokenizer = _DummyTextTokenizer(max_length=config.max_text_tokens)
else:
from transformers import AutoTokenizer, CLIPModel
local_model_source: str | None = None
if config.model_name == "openai/clip-vit-base-patch32":
explicit_local_dir = Path("/workspace/models/openai_clip_vit_base_patch32")
if (explicit_local_dir / "config.json").exists():
local_model_source = str(explicit_local_dir)
cache_home = Path(os.environ.get("HF_HOME", "/workspace/.cache/huggingface"))
cache_root = cache_home / "hub" / "models--openai--clip-vit-base-patch32"
if local_model_source is None:
ref_path = cache_root / "refs" / "main"
if ref_path.exists():
snapshot_id = ref_path.read_text(encoding="utf-8").strip()
snapshot_dir = cache_root / "snapshots" / snapshot_id
if (snapshot_dir / "config.json").exists():
local_model_source = str(snapshot_dir)
if local_model_source is None:
snapshot_root = cache_root / "snapshots"
if snapshot_root.exists():
for snapshot_dir in sorted(snapshot_root.iterdir(), reverse=True):
if (snapshot_dir / "config.json").exists():
local_model_source = str(snapshot_dir)
break
clip_model = None
last_clip_error: Exception | None = None
model_sources: list[tuple[str, dict[str, object]]] = []
if local_model_source is not None:
model_sources.append((local_model_source, {"use_safetensors": True, "local_files_only": True}))
model_sources.append((local_model_source, {"local_files_only": True}))
model_sources.append((config.model_name, {"use_safetensors": True}))
model_sources.append((config.model_name, {}))
for source, kwargs in model_sources:
try:
clip_model = CLIPModel.from_pretrained(source, **kwargs)
break
except Exception as exc:
last_clip_error = exc
if clip_model is None:
assert last_clip_error is not None
raise last_clip_error
self.vision_model = clip_model.vision_model
self.text_model = clip_model.text_model
self.visual_projection = clip_model.visual_projection
self.text_projection = clip_model.text_projection
tokenizer = None
last_tokenizer_error: Exception | None = None
tokenizer_sources: list[tuple[str, dict[str, object]]] = []
if local_model_source is not None:
tokenizer_sources.append((local_model_source, {"local_files_only": True}))
tokenizer_sources.append((config.model_name, {}))
for source, kwargs in tokenizer_sources:
try:
tokenizer = AutoTokenizer.from_pretrained(source, **kwargs)
break
except Exception as exc:
last_tokenizer_error = exc
if tokenizer is None:
assert last_tokenizer_error is not None
raise last_tokenizer_error
self.tokenizer = tokenizer
self.hidden_dim = clip_model.config.projection_dim
if config.gradient_checkpointing:
if hasattr(self.vision_model, "gradient_checkpointing_enable"):
self.vision_model.gradient_checkpointing_enable()
if hasattr(self.text_model, "gradient_checkpointing_enable"):
self.text_model.gradient_checkpointing_enable()
if config.freeze_backbone and not config.use_dummy_backbone:
for module in (
getattr(self, "vision_model", None),
getattr(self, "text_model", None),
getattr(self, "visual_projection", None),
getattr(self, "text_projection", None),
):
if module is None:
continue
for parameter in module.parameters():
parameter.requires_grad = False
def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]:
if self.use_dummy_backbone:
return self.tokenizer(texts, device=device)
return self.tokenizer(
list(texts),
padding=True,
truncation=True,
max_length=self.config.max_text_tokens,
return_tensors="pt",
).to(device)
def _encode_rgb_tokens(self, images: Tensor) -> Tensor:
batch_size, num_views, channels, height, width = images.shape
flat_images = images.reshape(batch_size * num_views, channels, height, width)
if self.use_dummy_backbone:
pooled = F.avg_pool2d(flat_images.float(), kernel_size=self.image_patch_size, stride=self.image_patch_size)
patch_tokens = pooled.flatten(2).transpose(1, 2)
grid_h, grid_w = pooled.shape[-2], pooled.shape[-1]
y_coords = torch.linspace(-1.0, 1.0, steps=grid_h, device=images.device)
x_coords = torch.linspace(-1.0, 1.0, steps=grid_w, device=images.device)
grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij")
coords = torch.stack([grid_x, grid_y], dim=-1).reshape(1, grid_h * grid_w, 2)
coords = coords.expand(patch_tokens.shape[0], -1, -1)
intensity = patch_tokens.mean(dim=-1, keepdim=True)
base = torch.cat([patch_tokens, intensity, coords], dim=-1)
repeat_factor = math.ceil(self.hidden_dim / base.shape[-1])
tokens = base.repeat(1, 1, repeat_factor)[..., : self.hidden_dim]
else:
outputs = self.vision_model(pixel_values=flat_images)
tokens = self.visual_projection(outputs.last_hidden_state)
num_tokens = tokens.shape[1]
return tokens.reshape(batch_size, num_views, num_tokens, -1)
def encode_images(
self,
images: Tensor,
depths: Tensor | None = None,
depth_valid: Tensor | None = None,
camera_intrinsics: Tensor | None = None,
camera_extrinsics: Tensor | None = None,
return_aux: bool = False,
use_depth_tokens: bool | None = None,
use_geometry_tokens: bool | None = None,
use_camera_pose_tokens: bool | None = None,
) -> Tensor | dict[str, Tensor | None]:
rgb_tokens = self._encode_rgb_tokens(images)
wants_aux = return_aux or depths is not None or depth_valid is not None or camera_intrinsics is not None or camera_extrinsics is not None
if not wants_aux:
return rgb_tokens
depth_enabled = self.config.use_depth_tokens if use_depth_tokens is None else use_depth_tokens
geometry_enabled = self.config.use_geometry_tokens if use_geometry_tokens is None else use_geometry_tokens
camera_pose_enabled = self.config.use_camera_pose_tokens if use_camera_pose_tokens is None else use_camera_pose_tokens
geometry_enabled = bool(self.config.use_camera_geometry and geometry_enabled)
camera_pose_enabled = bool(self.config.use_camera_geometry and camera_pose_enabled)
depth_outputs: dict[str, Tensor | None] = {
"depth_tokens": None,
"geometry_tokens": None,
"camera_tokens": None,
}
if depths is not None:
depth_outputs = self.depth_adapter(
depths=depths,
depth_valid=depth_valid,
camera_intrinsics=camera_intrinsics,
camera_extrinsics=camera_extrinsics,
include_geometry_features=geometry_enabled,
include_camera_pose=camera_pose_enabled,
)
if not depth_enabled:
depth_outputs["depth_tokens"] = None
if not geometry_enabled:
depth_outputs["geometry_tokens"] = None
if not camera_pose_enabled:
depth_outputs["camera_tokens"] = None
return {
"rgb_tokens": rgb_tokens,
"depth_tokens": depth_outputs["depth_tokens"],
"geometry_tokens": depth_outputs["geometry_tokens"],
"camera_tokens": depth_outputs["camera_tokens"],
}
def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
if self.use_dummy_backbone:
vocab_scale = float(self.tokenizer.vocab_size - 1)
token_values = input_ids.float() / vocab_scale
frequencies = torch.linspace(
1.0,
4.0,
steps=max(1, self.hidden_dim // 2),
device=input_ids.device,
dtype=token_values.dtype,
)
phases = token_values.unsqueeze(-1) * frequencies.view(1, 1, -1) * (2.0 * math.pi)
embeddings = torch.cat([torch.sin(phases), torch.cos(phases)], dim=-1)[..., : self.hidden_dim]
return embeddings * attention_mask.unsqueeze(-1).float()
outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask)
return self.text_projection(outputs.last_hidden_state)