Spaces:
Running on Zero
Running on Zero
| """ | |
| ResNet-based viewpoint predictor for camera pose estimation. | |
| Predicts camera pose (rotation + translation + angular offset) from RGB images using: | |
| - ResNet backbone (pretrained on ImageNet) | |
| - MLP heads for rotation (6D representation), translation (3D vector), and angular offset (2D) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.models as models | |
| from .rotation_utils import rotation_6d_to_matrix | |
| class ViewpointPredictor(nn.Module): | |
| """ResNet-based viewpoint predictor. | |
| Architecture: | |
| ResNet backbone → Global Average Pool → Flatten | |
| ├─→ MLP (rotation): features → hidden → hidden → 6D rotation | |
| ├─→ MLP (translation): features → hidden → hidden → 3D translation | |
| ├─→ MLP (spherical_angular): features → hidden → hidden → 7D (sin_az, cos_az, sin_el, cos_el, norm_radius, norm_yaw, norm_pitch) | |
| └─→ MLP (relative_rotation): features → hidden → hidden → 6D relative rotation | |
| Args: | |
| resnet_variant: ResNet variant ('resnet18', 'resnet34', 'resnet50', 'resnet101') | |
| pretrained: Whether to use ImageNet pretrained weights | |
| rotation_hidden_dim: Hidden dimension for rotation MLP | |
| translation_hidden_dim: Hidden dimension for translation MLP | |
| spherical_angular_hidden_dim: Hidden dimension for spherical_angular MLP | |
| relative_rotation_hidden_dim: Hidden dimension for relative rotation MLP | |
| num_rotation_layers: Number of layers in rotation MLP | |
| num_translation_layers: Number of layers in translation MLP | |
| num_spherical_angular_layers: Number of layers in spherical_angular MLP | |
| num_relative_rotation_layers: Number of layers in relative rotation MLP | |
| dropout: Dropout probability (default: 0.1) | |
| yaw_pitch_scale: Scale factor for yaw/pitch normalization (default: 5.0) | |
| """ | |
| def __init__( | |
| self, | |
| resnet_variant: str = 'resnet18', | |
| pretrained: bool = True, | |
| rotation_hidden_dim: int = 256, | |
| translation_hidden_dim: int = 256, | |
| spherical_angular_hidden_dim: int = 256, | |
| relative_rotation_hidden_dim: int = 256, | |
| num_rotation_layers: int = 3, | |
| num_translation_layers: int = 3, | |
| num_spherical_angular_layers: int = 3, | |
| num_relative_rotation_layers: int = 3, | |
| dropout: float = 0.1, | |
| yaw_pitch_scale: float = 5.0, | |
| ): | |
| super().__init__() | |
| self.resnet_variant = resnet_variant | |
| self.pretrained = pretrained | |
| self.yaw_pitch_scale = yaw_pitch_scale | |
| # Load ResNet backbone | |
| self.backbone = self._build_backbone(resnet_variant, pretrained) | |
| # Get feature dimension | |
| self.feature_dim = self._get_feature_dim(resnet_variant) | |
| # Build rotation head (predicts 6D rotation) | |
| self.rotation_head = self._build_mlp( | |
| input_dim=self.feature_dim, | |
| hidden_dim=rotation_hidden_dim, | |
| output_dim=6, # 6D rotation representation | |
| num_layers=num_rotation_layers, | |
| dropout=dropout, | |
| ) | |
| # Build translation head (predicts 3D translation) | |
| self.translation_head = self._build_mlp( | |
| input_dim=self.feature_dim, | |
| hidden_dim=translation_hidden_dim, | |
| output_dim=3, # 3D translation vector | |
| num_layers=num_translation_layers, | |
| dropout=dropout, | |
| ) | |
| # Build spherical_angular head (predicts azimuth, elevation, radius, yaw, pitch) | |
| # Output: [sin(az), cos(az), sin(el), cos(el), norm_radius, norm_yaw, norm_pitch] | |
| self.spherical_angular_head = self._build_mlp( | |
| input_dim=self.feature_dim, | |
| hidden_dim=spherical_angular_hidden_dim, | |
| output_dim=7, # 2 + 2 + 1 + 1 + 1 = 7 values | |
| num_layers=num_spherical_angular_layers, | |
| dropout=dropout, | |
| ) | |
| # Build relative rotation head (predicts 6D relative rotation from canonical pose) | |
| self.relative_rotation_head = self._build_mlp( | |
| input_dim=self.feature_dim, | |
| hidden_dim=relative_rotation_hidden_dim, | |
| output_dim=6, # 6D rotation representation | |
| num_layers=num_relative_rotation_layers, | |
| dropout=dropout, | |
| ) | |
| def _build_backbone(self, variant: str, pretrained: bool): | |
| """Build ResNet backbone without classification head.""" | |
| # Map variant string to torchvision model | |
| resnet_models = { | |
| 'resnet18': models.resnet18, | |
| 'resnet34': models.resnet34, | |
| 'resnet50': models.resnet50, | |
| 'resnet101': models.resnet101, | |
| } | |
| if variant not in resnet_models: | |
| raise ValueError(f"Unknown ResNet variant: {variant}. Choose from {list(resnet_models.keys())}") | |
| # Load model | |
| if pretrained: | |
| # Use new weights API for PyTorch 1.13+ | |
| if variant == 'resnet18': | |
| weights = models.ResNet18_Weights.IMAGENET1K_V1 | |
| elif variant == 'resnet34': | |
| weights = models.ResNet34_Weights.IMAGENET1K_V1 | |
| elif variant == 'resnet50': | |
| weights = models.ResNet50_Weights.IMAGENET1K_V2 | |
| elif variant == 'resnet101': | |
| weights = models.ResNet101_Weights.IMAGENET1K_V2 | |
| else: | |
| weights = None | |
| resnet = resnet_models[variant](weights=weights) | |
| else: | |
| resnet = resnet_models[variant](weights=None) | |
| # Remove final classification layer | |
| # ResNet: conv layers + avgpool + fc | |
| # We keep everything except fc | |
| backbone = nn.Sequential(*list(resnet.children())[:-1]) # Remove fc layer | |
| return backbone | |
| def _get_feature_dim(self, variant: str) -> int: | |
| """Get feature dimension for ResNet variant.""" | |
| feature_dims = { | |
| 'resnet18': 512, | |
| 'resnet34': 512, | |
| 'resnet50': 2048, | |
| 'resnet101': 2048, | |
| } | |
| return feature_dims[variant] | |
| def _build_mlp( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| output_dim: int, | |
| num_layers: int, | |
| dropout: float, | |
| ) -> nn.Module: | |
| """Build MLP head. | |
| Args: | |
| input_dim: Input feature dimension | |
| hidden_dim: Hidden layer dimension | |
| output_dim: Output dimension | |
| num_layers: Number of layers (including output layer) | |
| dropout: Dropout probability | |
| Returns: | |
| MLP module | |
| """ | |
| layers = [] | |
| if num_layers < 2: | |
| layers.extend([nn.Linear(input_dim, output_dim)]) | |
| return nn.Sequential(*layers) | |
| # Input layer | |
| layers.extend([ | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| ]) | |
| # Hidden layers | |
| for _ in range(num_layers - 2): | |
| layers.extend([ | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(dropout), | |
| ]) | |
| # Output layer (no activation) | |
| layers.append(nn.Linear(hidden_dim, output_dim)) | |
| return nn.Sequential(*layers) | |
| def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: | |
| """Forward pass. | |
| Args: | |
| images: Input images of shape (B, 3, H, W) | |
| Returns: | |
| Dictionary containing: | |
| rotation: Rotation matrices of shape (B, 3, 3) | |
| translation: Translation vectors of shape (B, 3) | |
| relative_rotation: Relative rotation from canonical pose of shape (B, 3, 3) | |
| azimuth_sincos: Azimuth as (sin, cos) pair, L2 normalized, shape (B, 2) | |
| elevation_sincos: Elevation as (sin, cos) pair, L2 normalized, shape (B, 2) | |
| radius: Normalized radius in range [-1, 1], shape (B,) | |
| yaw: Normalized yaw with scale factor, shape (B,) | |
| pitch: Normalized pitch with scale factor, shape (B,) | |
| """ | |
| # Extract features | |
| features = self.backbone(images) # (B, feature_dim, 1, 1) | |
| features = features.flatten(1) # (B, feature_dim) | |
| # Predict 6D rotation | |
| rotation_6d = self.rotation_head(features) # (B, 6) | |
| # Convert 6D to 3x3 rotation matrix | |
| rotation = rotation_6d_to_matrix(rotation_6d) # (B, 3, 3) | |
| # Predict translation | |
| translation = self.translation_head(features) # (B, 3) | |
| # Predict spherical and angular parameters | |
| # Output: [sin(az), cos(az), sin(el), cos(el), norm_radius, norm_yaw, norm_pitch] | |
| spherical_angular_raw = self.spherical_angular_head(features) # (B, 7) | |
| # Split into components | |
| azimuth_raw = spherical_angular_raw[:, 0:2] # (B, 2) | |
| elevation_raw = spherical_angular_raw[:, 2:4] # (B, 2) | |
| norm_radius = spherical_angular_raw[:, 4] # (B,) | |
| norm_yaw_raw = spherical_angular_raw[:, 5] # (B,) | |
| norm_pitch_raw = spherical_angular_raw[:, 6] # (B,) | |
| # Normalize azimuth and elevation to unit vectors | |
| azimuth_sincos = F.normalize(azimuth_raw, p=2, dim=-1) # (B, 2) | |
| elevation_sincos = F.normalize(elevation_raw, p=2, dim=-1) # (B, 2) | |
| # Apply tanh and scale to yaw and pitch (bounded to [-scale, scale]) | |
| norm_yaw = torch.tanh(norm_yaw_raw) * self.yaw_pitch_scale # (B,) | |
| norm_pitch = torch.tanh(norm_pitch_raw) * self.yaw_pitch_scale # (B,) | |
| # Predict relative rotation (6D representation) | |
| relative_rotation_6d = self.relative_rotation_head(features) # (B, 6) | |
| # Convert 6D to 3x3 rotation matrix | |
| relative_rotation = rotation_6d_to_matrix(relative_rotation_6d) # (B, 3, 3) | |
| return { | |
| 'rotation': rotation, | |
| 'translation': translation, | |
| 'relative_rotation': relative_rotation, | |
| 'azimuth_sincos': azimuth_sincos, | |
| 'elevation_sincos': elevation_sincos, | |
| 'radius': norm_radius, | |
| 'yaw': norm_yaw, | |
| 'pitch': norm_pitch, | |
| } | |
| def freeze_backbone(self): | |
| """Freeze ResNet backbone parameters.""" | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = False | |
| def unfreeze_backbone(self): | |
| """Unfreeze ResNet backbone parameters.""" | |
| for param in self.backbone.parameters(): | |
| param.requires_grad = True | |
| def get_num_params(self) -> dict: | |
| """Get number of parameters.""" | |
| backbone_params = sum(p.numel() for p in self.backbone.parameters()) | |
| rotation_params = sum(p.numel() for p in self.rotation_head.parameters()) | |
| translation_params = sum(p.numel() for p in self.translation_head.parameters()) | |
| spherical_angular_params = sum(p.numel() for p in self.spherical_angular_head.parameters()) | |
| relative_rotation_params = sum(p.numel() for p in self.relative_rotation_head.parameters()) | |
| total_params = backbone_params + rotation_params + translation_params + spherical_angular_params + relative_rotation_params | |
| trainable_backbone = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad) | |
| trainable_rotation = sum(p.numel() for p in self.rotation_head.parameters() if p.requires_grad) | |
| trainable_translation = sum(p.numel() for p in self.translation_head.parameters() if p.requires_grad) | |
| trainable_spherical_angular = sum(p.numel() for p in self.spherical_angular_head.parameters() if p.requires_grad) | |
| trainable_relative_rotation = sum(p.numel() for p in self.relative_rotation_head.parameters() if p.requires_grad) | |
| trainable_total = trainable_backbone + trainable_rotation + trainable_translation + trainable_spherical_angular + trainable_relative_rotation | |
| return { | |
| 'total': total_params, | |
| 'backbone': backbone_params, | |
| 'rotation_head': rotation_params, | |
| 'translation_head': translation_params, | |
| 'spherical_angular_head': spherical_angular_params, | |
| 'relative_rotation_head': relative_rotation_params, | |
| 'trainable_total': trainable_total, | |
| 'trainable_backbone': trainable_backbone, | |
| 'trainable_rotation': trainable_rotation, | |
| 'trainable_translation': trainable_translation, | |
| 'trainable_spherical_angular': trainable_spherical_angular, | |
| 'trainable_relative_rotation': trainable_relative_rotation, | |
| } | |
| def test_model(): | |
| """Test model instantiation and forward pass.""" | |
| print("Testing ViewpointPredictor...") | |
| # Test with different ResNet variants | |
| for variant in ['resnet18', 'resnet34', 'resnet50']: | |
| print(f"\nTesting {variant}...") | |
| # Create model | |
| model = ViewpointPredictor( | |
| resnet_variant=variant, | |
| pretrained=False, # Don't download weights for test | |
| rotation_hidden_dim=256, | |
| translation_hidden_dim=256, | |
| ) | |
| # Print parameters | |
| params = model.get_num_params() | |
| print(f"Total parameters: {params['total']:,}") | |
| print(f" Backbone: {params['backbone']:,}") | |
| print(f" Rotation head: {params['rotation_head']:,}") | |
| print(f" Translation head: {params['translation_head']:,}") | |
| print(f" Offset head: {params['offset_head']:,}") | |
| # Test forward pass | |
| batch_size = 4 | |
| images = torch.randn(batch_size, 3, 224, 224) | |
| with torch.no_grad(): | |
| rotation, translation, offset, azimuth_sincos = model(images) | |
| print(f"Input shape: {images.shape}") | |
| print(f"Rotation shape: {rotation.shape} (expected: ({batch_size}, 3, 3))") | |
| print(f"Translation shape: {translation.shape} (expected: ({batch_size}, 3))") | |
| print(f"Offset shape: {offset.shape} (expected: ({batch_size}, 2))") | |
| print(f"Azimuth sincos shape: {azimuth_sincos.shape} (expected: ({batch_size}, 2))") | |
| # Check rotation properties | |
| det = torch.det(rotation) | |
| print(f"Rotation determinants: {det.tolist()} (should be ~1)") | |
| identity = torch.eye(3).unsqueeze(0).expand(batch_size, 3, 3) | |
| ortho_error = torch.norm(torch.bmm(rotation.transpose(-2, -1), rotation) - identity, dim=(1, 2)) | |
| print(f"Orthogonality errors: {ortho_error.tolist()} (should be small)") | |
| # Check azimuth normalization | |
| azimuth_norms = torch.norm(azimuth_sincos, p=2, dim=-1) | |
| print(f"Azimuth sincos norms: {azimuth_norms.tolist()} (should be ~1)") | |
| print("\nTest passed!") | |
| if __name__ == "__main__": | |
| test_model() | |