from __future__ import annotations from dataclasses import dataclass import math import os from pathlib import Path from typing import Sequence import torch import torch.nn.functional as F from torch import Tensor, nn from pytorch3d.transforms import matrix_to_quaternion @dataclass class FrozenVLBackboneConfig: model_name: str = "openai/clip-vit-base-patch32" hidden_dim: int = 512 max_text_tokens: int = 32 freeze_backbone: bool = True gradient_checkpointing: bool = True use_dummy_backbone: bool = False depth_patch_size: int = 16 geometry_feature_dim: int = 8 use_camera_geometry: bool = True use_depth_tokens: bool = True use_geometry_tokens: bool = True use_camera_pose_tokens: bool = True class DepthPatchAdapter(nn.Module): def __init__( self, hidden_dim: int, patch_size: int = 16, geometry_feature_dim: int = 8, ) -> None: super().__init__() self.hidden_dim = hidden_dim self.patch_size = patch_size self.geometry_feature_dim = geometry_feature_dim self.depth_proj = nn.Sequential( nn.LayerNorm(2 + geometry_feature_dim), nn.Linear(2 + geometry_feature_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim), ) self.geometry_proj = nn.Sequential( nn.LayerNorm(geometry_feature_dim), nn.Linear(geometry_feature_dim, hidden_dim), nn.GELU(), ) self.camera_proj = nn.Sequential( nn.LayerNorm(7), nn.Linear(7, hidden_dim), nn.GELU(), ) def _patchify(self, tensor: Tensor) -> Tensor: pooled = F.avg_pool2d(tensor, kernel_size=self.patch_size, stride=self.patch_size) return pooled.flatten(2).transpose(1, 2) def _geometry_features( self, depths: Tensor, camera_intrinsics: Tensor | None = None, camera_extrinsics: Tensor | None = None, ) -> tuple[Tensor, Tensor]: batch_views, _, height, width = depths.shape grid_h = max(1, height // self.patch_size) grid_w = max(1, width // self.patch_size) patch_center_y = torch.linspace( self.patch_size * 0.5, max(self.patch_size * 0.5, height - (self.patch_size * 0.5)), steps=grid_h, device=depths.device, dtype=depths.dtype, ) patch_center_x = torch.linspace( self.patch_size * 0.5, max(self.patch_size * 0.5, width - (self.patch_size * 0.5)), steps=grid_w, device=depths.device, dtype=depths.dtype, ) pixel_y, pixel_x = torch.meshgrid(patch_center_y, patch_center_x, indexing="ij") norm_x = ((pixel_x / max(width - 1, 1)) * 2.0 - 1.0).reshape(1, grid_h * grid_w, 1) norm_y = ((pixel_y / max(height - 1, 1)) * 2.0 - 1.0).reshape(1, grid_h * grid_w, 1) coords = torch.cat([norm_x, norm_y], dim=-1).expand(batch_views, -1, -1) if camera_intrinsics is not None: fx = camera_intrinsics[:, 0, 0].unsqueeze(-1) fy = camera_intrinsics[:, 1, 1].unsqueeze(-1) cx = camera_intrinsics[:, 0, 2].unsqueeze(-1) cy = camera_intrinsics[:, 1, 2].unsqueeze(-1) patch_x = pixel_x.reshape(1, grid_h * grid_w).expand(batch_views, -1) patch_y = pixel_y.reshape(1, grid_h * grid_w).expand(batch_views, -1) ray_x = (patch_x - cx) / fx.clamp_min(1e-6) ray_y = (patch_y - cy) / fy.clamp_min(1e-6) else: ray_x = coords[..., 0] ray_y = coords[..., 1] ray_camera = torch.stack([ray_x, ray_y, torch.ones_like(ray_x)], dim=-1) ray_camera = F.normalize(ray_camera, dim=-1) if camera_extrinsics is not None: rotation = camera_extrinsics[:, :3, :3] translation = camera_extrinsics[:, :3, 3].unsqueeze(1).expand(-1, grid_h * grid_w, -1) ray_world = torch.matmul(rotation, ray_camera.transpose(1, 2)).transpose(1, 2) quaternion = matrix_to_quaternion(rotation) else: rotation = None translation = torch.zeros(batch_views, grid_h * grid_w, 3, device=depths.device, dtype=depths.dtype) ray_world = ray_camera quaternion = torch.zeros(batch_views, 4, device=depths.device, dtype=depths.dtype) quaternion[:, 0] = 1.0 geometry = torch.cat([coords, ray_world, translation], dim=-1) if geometry.shape[-1] < self.geometry_feature_dim: pad = self.geometry_feature_dim - geometry.shape[-1] geometry = F.pad(geometry, (0, pad)) elif geometry.shape[-1] > self.geometry_feature_dim: geometry = geometry[..., : self.geometry_feature_dim] if camera_extrinsics is not None: translation_summary = camera_extrinsics[:, :3, 3] else: translation_summary = torch.zeros(batch_views, 3, device=depths.device, dtype=depths.dtype) camera_summary = torch.cat([quaternion, translation_summary], dim=-1) return geometry, camera_summary def forward( self, depths: Tensor, depth_valid: Tensor | None = None, camera_intrinsics: Tensor | None = None, camera_extrinsics: Tensor | None = None, include_geometry_features: bool = True, include_camera_pose: bool = True, ) -> dict[str, Tensor]: if depths.ndim == 4: depths = depths.unsqueeze(2) if depth_valid is None: depth_valid = torch.ones_like(depths) if depth_valid.ndim == 4: depth_valid = depth_valid.unsqueeze(2) if depths.ndim != 5: raise ValueError(f"Expected depths to have shape [B, V, H, W] or [B, V, 1, H, W], got {tuple(depths.shape)}") if depths.shape[2] != 1: depths = depths.mean(dim=2, keepdim=True) if depth_valid.shape[2] != 1: depth_valid = depth_valid.mean(dim=2, keepdim=True) batch_size, num_views = depths.shape[:2] flat_depths = depths.reshape(batch_size * num_views, 1, depths.shape[-2], depths.shape[-1]).float() flat_valid = depth_valid.reshape(batch_size * num_views, 1, depth_valid.shape[-2], depth_valid.shape[-1]).float() flat_intrinsics = None flat_extrinsics = None if camera_intrinsics is not None: flat_intrinsics = camera_intrinsics.reshape(batch_size * num_views, *camera_intrinsics.shape[-2:]).float() if camera_extrinsics is not None: flat_extrinsics = camera_extrinsics.reshape(batch_size * num_views, *camera_extrinsics.shape[-2:]).float() depth_patch = self._patchify(flat_depths) valid_patch = self._patchify(flat_valid) geometry_features, camera_summary = self._geometry_features( flat_depths, camera_intrinsics=flat_intrinsics, camera_extrinsics=flat_extrinsics, ) if not include_geometry_features: geometry_features = torch.zeros_like(geometry_features) if not include_camera_pose: camera_summary = torch.zeros_like(camera_summary) # Keep depth tokens depth-only so depth, geometry, and pose ablations are separable. token_inputs = torch.cat([depth_patch, valid_patch, torch.zeros_like(geometry_features)], dim=-1) depth_tokens = self.depth_proj(token_inputs) geometry_tokens = self.geometry_proj(geometry_features) camera_tokens = self.camera_proj(camera_summary).unsqueeze(1) return { "depth_tokens": depth_tokens.view(batch_size, num_views, depth_tokens.shape[1], depth_tokens.shape[2]), "geometry_tokens": geometry_tokens.view(batch_size, num_views, geometry_tokens.shape[1], geometry_tokens.shape[2]), "camera_tokens": camera_tokens.view(batch_size, num_views, 1, camera_tokens.shape[-1]), } class _DummyTextTokenizer: def __init__(self, vocab_size: int = 8192, max_length: int = 32) -> None: self.vocab_size = vocab_size self.max_length = max_length def __call__(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]: token_ids = torch.zeros((len(texts), self.max_length), dtype=torch.long, device=device) attention_mask = torch.zeros_like(token_ids) for row, text in enumerate(texts): encoded = [min(ord(char), self.vocab_size - 1) for char in text[: self.max_length]] if encoded: token_ids[row, : len(encoded)] = torch.tensor(encoded, dtype=torch.long, device=device) attention_mask[row, : len(encoded)] = 1 return {"input_ids": token_ids, "attention_mask": attention_mask} class FrozenVLBackbone(nn.Module): def __init__(self, config: FrozenVLBackboneConfig) -> None: super().__init__() self.config = config self.hidden_dim = config.hidden_dim self.use_dummy_backbone = config.use_dummy_backbone self.depth_adapter = DepthPatchAdapter( hidden_dim=config.hidden_dim, patch_size=config.depth_patch_size, geometry_feature_dim=config.geometry_feature_dim, ) if config.use_dummy_backbone: self.image_patch_size = 16 self.tokenizer = _DummyTextTokenizer(max_length=config.max_text_tokens) else: from transformers import AutoTokenizer, CLIPModel local_model_source: str | None = None if config.model_name == "openai/clip-vit-base-patch32": explicit_local_dir = Path("/workspace/models/openai_clip_vit_base_patch32") if (explicit_local_dir / "config.json").exists(): local_model_source = str(explicit_local_dir) cache_home = Path(os.environ.get("HF_HOME", "/workspace/.cache/huggingface")) cache_root = cache_home / "hub" / "models--openai--clip-vit-base-patch32" if local_model_source is None: ref_path = cache_root / "refs" / "main" if ref_path.exists(): snapshot_id = ref_path.read_text(encoding="utf-8").strip() snapshot_dir = cache_root / "snapshots" / snapshot_id if (snapshot_dir / "config.json").exists(): local_model_source = str(snapshot_dir) if local_model_source is None: snapshot_root = cache_root / "snapshots" if snapshot_root.exists(): for snapshot_dir in sorted(snapshot_root.iterdir(), reverse=True): if (snapshot_dir / "config.json").exists(): local_model_source = str(snapshot_dir) break clip_model = None last_clip_error: Exception | None = None model_sources: list[tuple[str, dict[str, object]]] = [] if local_model_source is not None: model_sources.append((local_model_source, {"use_safetensors": True, "local_files_only": True})) model_sources.append((local_model_source, {"local_files_only": True})) model_sources.append((config.model_name, {"use_safetensors": True})) model_sources.append((config.model_name, {})) for source, kwargs in model_sources: try: clip_model = CLIPModel.from_pretrained(source, **kwargs) break except Exception as exc: last_clip_error = exc if clip_model is None: assert last_clip_error is not None raise last_clip_error self.vision_model = clip_model.vision_model self.text_model = clip_model.text_model self.visual_projection = clip_model.visual_projection self.text_projection = clip_model.text_projection tokenizer = None last_tokenizer_error: Exception | None = None tokenizer_sources: list[tuple[str, dict[str, object]]] = [] if local_model_source is not None: tokenizer_sources.append((local_model_source, {"local_files_only": True})) tokenizer_sources.append((config.model_name, {})) for source, kwargs in tokenizer_sources: try: tokenizer = AutoTokenizer.from_pretrained(source, **kwargs) break except Exception as exc: last_tokenizer_error = exc if tokenizer is None: assert last_tokenizer_error is not None raise last_tokenizer_error self.tokenizer = tokenizer self.hidden_dim = clip_model.config.projection_dim if config.gradient_checkpointing: if hasattr(self.vision_model, "gradient_checkpointing_enable"): self.vision_model.gradient_checkpointing_enable() if hasattr(self.text_model, "gradient_checkpointing_enable"): self.text_model.gradient_checkpointing_enable() if config.freeze_backbone and not config.use_dummy_backbone: for module in ( getattr(self, "vision_model", None), getattr(self, "text_model", None), getattr(self, "visual_projection", None), getattr(self, "text_projection", None), ): if module is None: continue for parameter in module.parameters(): parameter.requires_grad = False def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]: if self.use_dummy_backbone: return self.tokenizer(texts, device=device) return self.tokenizer( list(texts), padding=True, truncation=True, max_length=self.config.max_text_tokens, return_tensors="pt", ).to(device) def _encode_rgb_tokens(self, images: Tensor) -> Tensor: batch_size, num_views, channels, height, width = images.shape flat_images = images.reshape(batch_size * num_views, channels, height, width) if self.use_dummy_backbone: pooled = F.avg_pool2d(flat_images.float(), kernel_size=self.image_patch_size, stride=self.image_patch_size) patch_tokens = pooled.flatten(2).transpose(1, 2) grid_h, grid_w = pooled.shape[-2], pooled.shape[-1] y_coords = torch.linspace(-1.0, 1.0, steps=grid_h, device=images.device) x_coords = torch.linspace(-1.0, 1.0, steps=grid_w, device=images.device) grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij") coords = torch.stack([grid_x, grid_y], dim=-1).reshape(1, grid_h * grid_w, 2) coords = coords.expand(patch_tokens.shape[0], -1, -1) intensity = patch_tokens.mean(dim=-1, keepdim=True) base = torch.cat([patch_tokens, intensity, coords], dim=-1) repeat_factor = math.ceil(self.hidden_dim / base.shape[-1]) tokens = base.repeat(1, 1, repeat_factor)[..., : self.hidden_dim] else: outputs = self.vision_model(pixel_values=flat_images) tokens = self.visual_projection(outputs.last_hidden_state) num_tokens = tokens.shape[1] return tokens.reshape(batch_size, num_views, num_tokens, -1) def encode_images( self, images: Tensor, depths: Tensor | None = None, depth_valid: Tensor | None = None, camera_intrinsics: Tensor | None = None, camera_extrinsics: Tensor | None = None, return_aux: bool = False, use_depth_tokens: bool | None = None, use_geometry_tokens: bool | None = None, use_camera_pose_tokens: bool | None = None, ) -> Tensor | dict[str, Tensor | None]: rgb_tokens = self._encode_rgb_tokens(images) wants_aux = return_aux or depths is not None or depth_valid is not None or camera_intrinsics is not None or camera_extrinsics is not None if not wants_aux: return rgb_tokens depth_enabled = self.config.use_depth_tokens if use_depth_tokens is None else use_depth_tokens geometry_enabled = self.config.use_geometry_tokens if use_geometry_tokens is None else use_geometry_tokens camera_pose_enabled = self.config.use_camera_pose_tokens if use_camera_pose_tokens is None else use_camera_pose_tokens geometry_enabled = bool(self.config.use_camera_geometry and geometry_enabled) camera_pose_enabled = bool(self.config.use_camera_geometry and camera_pose_enabled) depth_outputs: dict[str, Tensor | None] = { "depth_tokens": None, "geometry_tokens": None, "camera_tokens": None, } if depths is not None: depth_outputs = self.depth_adapter( depths=depths, depth_valid=depth_valid, camera_intrinsics=camera_intrinsics, camera_extrinsics=camera_extrinsics, include_geometry_features=geometry_enabled, include_camera_pose=camera_pose_enabled, ) if not depth_enabled: depth_outputs["depth_tokens"] = None if not geometry_enabled: depth_outputs["geometry_tokens"] = None if not camera_pose_enabled: depth_outputs["camera_tokens"] = None return { "rgb_tokens": rgb_tokens, "depth_tokens": depth_outputs["depth_tokens"], "geometry_tokens": depth_outputs["geometry_tokens"], "camera_tokens": depth_outputs["camera_tokens"], } def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: if self.use_dummy_backbone: vocab_scale = float(self.tokenizer.vocab_size - 1) token_values = input_ids.float() / vocab_scale frequencies = torch.linspace( 1.0, 4.0, steps=max(1, self.hidden_dim // 2), device=input_ids.device, dtype=token_values.dtype, ) phases = token_values.unsqueeze(-1) * frequencies.view(1, 1, -1) * (2.0 * math.pi) embeddings = torch.cat([torch.sin(phases), torch.cos(phases)], dim=-1)[..., : self.hidden_dim] return embeddings * attention_mask.unsqueeze(-1).float() outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask) return self.text_projection(outputs.last_hidden_state)