| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| import math |
| import os |
| from pathlib import Path |
| from typing import Sequence |
|
|
| import torch |
| import torch.nn.functional as F |
| from torch import Tensor, nn |
|
|
| from pytorch3d.transforms import matrix_to_quaternion |
|
|
|
|
| @dataclass |
| class FrozenVLBackboneConfig: |
| model_name: str = "openai/clip-vit-base-patch32" |
| hidden_dim: int = 512 |
| max_text_tokens: int = 32 |
| freeze_backbone: bool = True |
| gradient_checkpointing: bool = True |
| use_dummy_backbone: bool = False |
| depth_patch_size: int = 16 |
| geometry_feature_dim: int = 8 |
| use_camera_geometry: bool = True |
| use_depth_tokens: bool = True |
| use_geometry_tokens: bool = True |
| use_camera_pose_tokens: bool = True |
|
|
|
|
| class DepthPatchAdapter(nn.Module): |
| def __init__( |
| self, |
| hidden_dim: int, |
| patch_size: int = 16, |
| geometry_feature_dim: int = 8, |
| ) -> None: |
| super().__init__() |
| self.hidden_dim = hidden_dim |
| self.patch_size = patch_size |
| self.geometry_feature_dim = geometry_feature_dim |
| self.depth_proj = nn.Sequential( |
| nn.LayerNorm(2 + geometry_feature_dim), |
| nn.Linear(2 + geometry_feature_dim, hidden_dim), |
| nn.GELU(), |
| nn.Linear(hidden_dim, hidden_dim), |
| ) |
| self.geometry_proj = nn.Sequential( |
| nn.LayerNorm(geometry_feature_dim), |
| nn.Linear(geometry_feature_dim, hidden_dim), |
| nn.GELU(), |
| ) |
| self.camera_proj = nn.Sequential( |
| nn.LayerNorm(7), |
| nn.Linear(7, hidden_dim), |
| nn.GELU(), |
| ) |
|
|
| def _patchify(self, tensor: Tensor) -> Tensor: |
| pooled = F.avg_pool2d(tensor, kernel_size=self.patch_size, stride=self.patch_size) |
| return pooled.flatten(2).transpose(1, 2) |
|
|
| def _geometry_features( |
| self, |
| depths: Tensor, |
| camera_intrinsics: Tensor | None = None, |
| camera_extrinsics: Tensor | None = None, |
| ) -> tuple[Tensor, Tensor]: |
| batch_views, _, height, width = depths.shape |
| grid_h = max(1, height // self.patch_size) |
| grid_w = max(1, width // self.patch_size) |
| patch_center_y = torch.linspace( |
| self.patch_size * 0.5, |
| max(self.patch_size * 0.5, height - (self.patch_size * 0.5)), |
| steps=grid_h, |
| device=depths.device, |
| dtype=depths.dtype, |
| ) |
| patch_center_x = torch.linspace( |
| self.patch_size * 0.5, |
| max(self.patch_size * 0.5, width - (self.patch_size * 0.5)), |
| steps=grid_w, |
| device=depths.device, |
| dtype=depths.dtype, |
| ) |
| pixel_y, pixel_x = torch.meshgrid(patch_center_y, patch_center_x, indexing="ij") |
| norm_x = ((pixel_x / max(width - 1, 1)) * 2.0 - 1.0).reshape(1, grid_h * grid_w, 1) |
| norm_y = ((pixel_y / max(height - 1, 1)) * 2.0 - 1.0).reshape(1, grid_h * grid_w, 1) |
| coords = torch.cat([norm_x, norm_y], dim=-1).expand(batch_views, -1, -1) |
|
|
| if camera_intrinsics is not None: |
| fx = camera_intrinsics[:, 0, 0].unsqueeze(-1) |
| fy = camera_intrinsics[:, 1, 1].unsqueeze(-1) |
| cx = camera_intrinsics[:, 0, 2].unsqueeze(-1) |
| cy = camera_intrinsics[:, 1, 2].unsqueeze(-1) |
| patch_x = pixel_x.reshape(1, grid_h * grid_w).expand(batch_views, -1) |
| patch_y = pixel_y.reshape(1, grid_h * grid_w).expand(batch_views, -1) |
| ray_x = (patch_x - cx) / fx.clamp_min(1e-6) |
| ray_y = (patch_y - cy) / fy.clamp_min(1e-6) |
| else: |
| ray_x = coords[..., 0] |
| ray_y = coords[..., 1] |
| ray_camera = torch.stack([ray_x, ray_y, torch.ones_like(ray_x)], dim=-1) |
| ray_camera = F.normalize(ray_camera, dim=-1) |
|
|
| if camera_extrinsics is not None: |
| rotation = camera_extrinsics[:, :3, :3] |
| translation = camera_extrinsics[:, :3, 3].unsqueeze(1).expand(-1, grid_h * grid_w, -1) |
| ray_world = torch.matmul(rotation, ray_camera.transpose(1, 2)).transpose(1, 2) |
| quaternion = matrix_to_quaternion(rotation) |
| else: |
| rotation = None |
| translation = torch.zeros(batch_views, grid_h * grid_w, 3, device=depths.device, dtype=depths.dtype) |
| ray_world = ray_camera |
| quaternion = torch.zeros(batch_views, 4, device=depths.device, dtype=depths.dtype) |
| quaternion[:, 0] = 1.0 |
|
|
| geometry = torch.cat([coords, ray_world, translation], dim=-1) |
| if geometry.shape[-1] < self.geometry_feature_dim: |
| pad = self.geometry_feature_dim - geometry.shape[-1] |
| geometry = F.pad(geometry, (0, pad)) |
| elif geometry.shape[-1] > self.geometry_feature_dim: |
| geometry = geometry[..., : self.geometry_feature_dim] |
|
|
| if camera_extrinsics is not None: |
| translation_summary = camera_extrinsics[:, :3, 3] |
| else: |
| translation_summary = torch.zeros(batch_views, 3, device=depths.device, dtype=depths.dtype) |
| camera_summary = torch.cat([quaternion, translation_summary], dim=-1) |
| return geometry, camera_summary |
|
|
| def forward( |
| self, |
| depths: Tensor, |
| depth_valid: Tensor | None = None, |
| camera_intrinsics: Tensor | None = None, |
| camera_extrinsics: Tensor | None = None, |
| include_geometry_features: bool = True, |
| include_camera_pose: bool = True, |
| ) -> dict[str, Tensor]: |
| if depths.ndim == 4: |
| depths = depths.unsqueeze(2) |
| if depth_valid is None: |
| depth_valid = torch.ones_like(depths) |
| if depth_valid.ndim == 4: |
| depth_valid = depth_valid.unsqueeze(2) |
| if depths.ndim != 5: |
| raise ValueError(f"Expected depths to have shape [B, V, H, W] or [B, V, 1, H, W], got {tuple(depths.shape)}") |
| if depths.shape[2] != 1: |
| depths = depths.mean(dim=2, keepdim=True) |
| if depth_valid.shape[2] != 1: |
| depth_valid = depth_valid.mean(dim=2, keepdim=True) |
|
|
| batch_size, num_views = depths.shape[:2] |
| flat_depths = depths.reshape(batch_size * num_views, 1, depths.shape[-2], depths.shape[-1]).float() |
| flat_valid = depth_valid.reshape(batch_size * num_views, 1, depth_valid.shape[-2], depth_valid.shape[-1]).float() |
| flat_intrinsics = None |
| flat_extrinsics = None |
| if camera_intrinsics is not None: |
| flat_intrinsics = camera_intrinsics.reshape(batch_size * num_views, *camera_intrinsics.shape[-2:]).float() |
| if camera_extrinsics is not None: |
| flat_extrinsics = camera_extrinsics.reshape(batch_size * num_views, *camera_extrinsics.shape[-2:]).float() |
|
|
| depth_patch = self._patchify(flat_depths) |
| valid_patch = self._patchify(flat_valid) |
| geometry_features, camera_summary = self._geometry_features( |
| flat_depths, |
| camera_intrinsics=flat_intrinsics, |
| camera_extrinsics=flat_extrinsics, |
| ) |
| if not include_geometry_features: |
| geometry_features = torch.zeros_like(geometry_features) |
| if not include_camera_pose: |
| camera_summary = torch.zeros_like(camera_summary) |
| |
| token_inputs = torch.cat([depth_patch, valid_patch, torch.zeros_like(geometry_features)], dim=-1) |
| depth_tokens = self.depth_proj(token_inputs) |
| geometry_tokens = self.geometry_proj(geometry_features) |
| camera_tokens = self.camera_proj(camera_summary).unsqueeze(1) |
| return { |
| "depth_tokens": depth_tokens.view(batch_size, num_views, depth_tokens.shape[1], depth_tokens.shape[2]), |
| "geometry_tokens": geometry_tokens.view(batch_size, num_views, geometry_tokens.shape[1], geometry_tokens.shape[2]), |
| "camera_tokens": camera_tokens.view(batch_size, num_views, 1, camera_tokens.shape[-1]), |
| } |
|
|
|
|
| class _DummyTextTokenizer: |
| def __init__(self, vocab_size: int = 8192, max_length: int = 32) -> None: |
| self.vocab_size = vocab_size |
| self.max_length = max_length |
|
|
| def __call__(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]: |
| token_ids = torch.zeros((len(texts), self.max_length), dtype=torch.long, device=device) |
| attention_mask = torch.zeros_like(token_ids) |
| for row, text in enumerate(texts): |
| encoded = [min(ord(char), self.vocab_size - 1) for char in text[: self.max_length]] |
| if encoded: |
| token_ids[row, : len(encoded)] = torch.tensor(encoded, dtype=torch.long, device=device) |
| attention_mask[row, : len(encoded)] = 1 |
| return {"input_ids": token_ids, "attention_mask": attention_mask} |
|
|
|
|
| class FrozenVLBackbone(nn.Module): |
| def __init__(self, config: FrozenVLBackboneConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.hidden_dim = config.hidden_dim |
| self.use_dummy_backbone = config.use_dummy_backbone |
| self.depth_adapter = DepthPatchAdapter( |
| hidden_dim=config.hidden_dim, |
| patch_size=config.depth_patch_size, |
| geometry_feature_dim=config.geometry_feature_dim, |
| ) |
|
|
| if config.use_dummy_backbone: |
| self.image_patch_size = 16 |
| self.tokenizer = _DummyTextTokenizer(max_length=config.max_text_tokens) |
| else: |
| from transformers import AutoTokenizer, CLIPModel |
|
|
| local_model_source: str | None = None |
| if config.model_name == "openai/clip-vit-base-patch32": |
| explicit_local_dir = Path("/workspace/models/openai_clip_vit_base_patch32") |
| if (explicit_local_dir / "config.json").exists(): |
| local_model_source = str(explicit_local_dir) |
| cache_home = Path(os.environ.get("HF_HOME", "/workspace/.cache/huggingface")) |
| cache_root = cache_home / "hub" / "models--openai--clip-vit-base-patch32" |
| if local_model_source is None: |
| ref_path = cache_root / "refs" / "main" |
| if ref_path.exists(): |
| snapshot_id = ref_path.read_text(encoding="utf-8").strip() |
| snapshot_dir = cache_root / "snapshots" / snapshot_id |
| if (snapshot_dir / "config.json").exists(): |
| local_model_source = str(snapshot_dir) |
| if local_model_source is None: |
| snapshot_root = cache_root / "snapshots" |
| if snapshot_root.exists(): |
| for snapshot_dir in sorted(snapshot_root.iterdir(), reverse=True): |
| if (snapshot_dir / "config.json").exists(): |
| local_model_source = str(snapshot_dir) |
| break |
| clip_model = None |
| last_clip_error: Exception | None = None |
| model_sources: list[tuple[str, dict[str, object]]] = [] |
| if local_model_source is not None: |
| model_sources.append((local_model_source, {"use_safetensors": True, "local_files_only": True})) |
| model_sources.append((local_model_source, {"local_files_only": True})) |
| model_sources.append((config.model_name, {"use_safetensors": True})) |
| model_sources.append((config.model_name, {})) |
| for source, kwargs in model_sources: |
| try: |
| clip_model = CLIPModel.from_pretrained(source, **kwargs) |
| break |
| except Exception as exc: |
| last_clip_error = exc |
| if clip_model is None: |
| assert last_clip_error is not None |
| raise last_clip_error |
| self.vision_model = clip_model.vision_model |
| self.text_model = clip_model.text_model |
| self.visual_projection = clip_model.visual_projection |
| self.text_projection = clip_model.text_projection |
| tokenizer = None |
| last_tokenizer_error: Exception | None = None |
| tokenizer_sources: list[tuple[str, dict[str, object]]] = [] |
| if local_model_source is not None: |
| tokenizer_sources.append((local_model_source, {"local_files_only": True})) |
| tokenizer_sources.append((config.model_name, {})) |
| for source, kwargs in tokenizer_sources: |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(source, **kwargs) |
| break |
| except Exception as exc: |
| last_tokenizer_error = exc |
| if tokenizer is None: |
| assert last_tokenizer_error is not None |
| raise last_tokenizer_error |
| self.tokenizer = tokenizer |
| self.hidden_dim = clip_model.config.projection_dim |
| if config.gradient_checkpointing: |
| if hasattr(self.vision_model, "gradient_checkpointing_enable"): |
| self.vision_model.gradient_checkpointing_enable() |
| if hasattr(self.text_model, "gradient_checkpointing_enable"): |
| self.text_model.gradient_checkpointing_enable() |
|
|
| if config.freeze_backbone and not config.use_dummy_backbone: |
| for module in ( |
| getattr(self, "vision_model", None), |
| getattr(self, "text_model", None), |
| getattr(self, "visual_projection", None), |
| getattr(self, "text_projection", None), |
| ): |
| if module is None: |
| continue |
| for parameter in module.parameters(): |
| parameter.requires_grad = False |
|
|
| def tokenize_text(self, texts: Sequence[str], device: torch.device) -> dict[str, Tensor]: |
| if self.use_dummy_backbone: |
| return self.tokenizer(texts, device=device) |
| return self.tokenizer( |
| list(texts), |
| padding=True, |
| truncation=True, |
| max_length=self.config.max_text_tokens, |
| return_tensors="pt", |
| ).to(device) |
|
|
| def _encode_rgb_tokens(self, images: Tensor) -> Tensor: |
| batch_size, num_views, channels, height, width = images.shape |
| flat_images = images.reshape(batch_size * num_views, channels, height, width) |
| if self.use_dummy_backbone: |
| pooled = F.avg_pool2d(flat_images.float(), kernel_size=self.image_patch_size, stride=self.image_patch_size) |
| patch_tokens = pooled.flatten(2).transpose(1, 2) |
| grid_h, grid_w = pooled.shape[-2], pooled.shape[-1] |
| y_coords = torch.linspace(-1.0, 1.0, steps=grid_h, device=images.device) |
| x_coords = torch.linspace(-1.0, 1.0, steps=grid_w, device=images.device) |
| grid_y, grid_x = torch.meshgrid(y_coords, x_coords, indexing="ij") |
| coords = torch.stack([grid_x, grid_y], dim=-1).reshape(1, grid_h * grid_w, 2) |
| coords = coords.expand(patch_tokens.shape[0], -1, -1) |
| intensity = patch_tokens.mean(dim=-1, keepdim=True) |
| base = torch.cat([patch_tokens, intensity, coords], dim=-1) |
| repeat_factor = math.ceil(self.hidden_dim / base.shape[-1]) |
| tokens = base.repeat(1, 1, repeat_factor)[..., : self.hidden_dim] |
| else: |
| outputs = self.vision_model(pixel_values=flat_images) |
| tokens = self.visual_projection(outputs.last_hidden_state) |
| num_tokens = tokens.shape[1] |
| return tokens.reshape(batch_size, num_views, num_tokens, -1) |
|
|
| def encode_images( |
| self, |
| images: Tensor, |
| depths: Tensor | None = None, |
| depth_valid: Tensor | None = None, |
| camera_intrinsics: Tensor | None = None, |
| camera_extrinsics: Tensor | None = None, |
| return_aux: bool = False, |
| use_depth_tokens: bool | None = None, |
| use_geometry_tokens: bool | None = None, |
| use_camera_pose_tokens: bool | None = None, |
| ) -> Tensor | dict[str, Tensor | None]: |
| rgb_tokens = self._encode_rgb_tokens(images) |
| wants_aux = return_aux or depths is not None or depth_valid is not None or camera_intrinsics is not None or camera_extrinsics is not None |
| if not wants_aux: |
| return rgb_tokens |
|
|
| depth_enabled = self.config.use_depth_tokens if use_depth_tokens is None else use_depth_tokens |
| geometry_enabled = self.config.use_geometry_tokens if use_geometry_tokens is None else use_geometry_tokens |
| camera_pose_enabled = self.config.use_camera_pose_tokens if use_camera_pose_tokens is None else use_camera_pose_tokens |
| geometry_enabled = bool(self.config.use_camera_geometry and geometry_enabled) |
| camera_pose_enabled = bool(self.config.use_camera_geometry and camera_pose_enabled) |
|
|
| depth_outputs: dict[str, Tensor | None] = { |
| "depth_tokens": None, |
| "geometry_tokens": None, |
| "camera_tokens": None, |
| } |
| if depths is not None: |
| depth_outputs = self.depth_adapter( |
| depths=depths, |
| depth_valid=depth_valid, |
| camera_intrinsics=camera_intrinsics, |
| camera_extrinsics=camera_extrinsics, |
| include_geometry_features=geometry_enabled, |
| include_camera_pose=camera_pose_enabled, |
| ) |
| if not depth_enabled: |
| depth_outputs["depth_tokens"] = None |
| if not geometry_enabled: |
| depth_outputs["geometry_tokens"] = None |
| if not camera_pose_enabled: |
| depth_outputs["camera_tokens"] = None |
|
|
| return { |
| "rgb_tokens": rgb_tokens, |
| "depth_tokens": depth_outputs["depth_tokens"], |
| "geometry_tokens": depth_outputs["geometry_tokens"], |
| "camera_tokens": depth_outputs["camera_tokens"], |
| } |
|
|
| def encode_text(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor: |
| if self.use_dummy_backbone: |
| vocab_scale = float(self.tokenizer.vocab_size - 1) |
| token_values = input_ids.float() / vocab_scale |
| frequencies = torch.linspace( |
| 1.0, |
| 4.0, |
| steps=max(1, self.hidden_dim // 2), |
| device=input_ids.device, |
| dtype=token_values.dtype, |
| ) |
| phases = token_values.unsqueeze(-1) * frequencies.view(1, 1, -1) * (2.0 * math.pi) |
| embeddings = torch.cat([torch.sin(phases), torch.cos(phases)], dim=-1)[..., : self.hidden_dim] |
| return embeddings * attention_mask.unsqueeze(-1).float() |
| outputs = self.text_model(input_ids=input_ids, attention_mask=attention_mask) |
| return self.text_projection(outputs.last_hidden_state) |
|
|