""" ResNet-based viewpoint predictor for camera pose estimation. Predicts camera pose (rotation + translation + angular offset) from RGB images using: - ResNet backbone (pretrained on ImageNet) - MLP heads for rotation (6D representation), translation (3D vector), and angular offset (2D) """ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models from .rotation_utils import rotation_6d_to_matrix class ViewpointPredictor(nn.Module): """ResNet-based viewpoint predictor. Architecture: ResNet backbone → Global Average Pool → Flatten ├─→ MLP (rotation): features → hidden → hidden → 6D rotation ├─→ MLP (translation): features → hidden → hidden → 3D translation ├─→ MLP (spherical_angular): features → hidden → hidden → 7D (sin_az, cos_az, sin_el, cos_el, norm_radius, norm_yaw, norm_pitch) └─→ MLP (relative_rotation): features → hidden → hidden → 6D relative rotation Args: resnet_variant: ResNet variant ('resnet18', 'resnet34', 'resnet50', 'resnet101') pretrained: Whether to use ImageNet pretrained weights rotation_hidden_dim: Hidden dimension for rotation MLP translation_hidden_dim: Hidden dimension for translation MLP spherical_angular_hidden_dim: Hidden dimension for spherical_angular MLP relative_rotation_hidden_dim: Hidden dimension for relative rotation MLP num_rotation_layers: Number of layers in rotation MLP num_translation_layers: Number of layers in translation MLP num_spherical_angular_layers: Number of layers in spherical_angular MLP num_relative_rotation_layers: Number of layers in relative rotation MLP dropout: Dropout probability (default: 0.1) yaw_pitch_scale: Scale factor for yaw/pitch normalization (default: 5.0) """ def __init__( self, resnet_variant: str = 'resnet18', pretrained: bool = True, rotation_hidden_dim: int = 256, translation_hidden_dim: int = 256, spherical_angular_hidden_dim: int = 256, relative_rotation_hidden_dim: int = 256, num_rotation_layers: int = 3, num_translation_layers: int = 3, num_spherical_angular_layers: int = 3, num_relative_rotation_layers: int = 3, dropout: float = 0.1, yaw_pitch_scale: float = 5.0, ): super().__init__() self.resnet_variant = resnet_variant self.pretrained = pretrained self.yaw_pitch_scale = yaw_pitch_scale # Load ResNet backbone self.backbone = self._build_backbone(resnet_variant, pretrained) # Get feature dimension self.feature_dim = self._get_feature_dim(resnet_variant) # Build rotation head (predicts 6D rotation) self.rotation_head = self._build_mlp( input_dim=self.feature_dim, hidden_dim=rotation_hidden_dim, output_dim=6, # 6D rotation representation num_layers=num_rotation_layers, dropout=dropout, ) # Build translation head (predicts 3D translation) self.translation_head = self._build_mlp( input_dim=self.feature_dim, hidden_dim=translation_hidden_dim, output_dim=3, # 3D translation vector num_layers=num_translation_layers, dropout=dropout, ) # Build spherical_angular head (predicts azimuth, elevation, radius, yaw, pitch) # Output: [sin(az), cos(az), sin(el), cos(el), norm_radius, norm_yaw, norm_pitch] self.spherical_angular_head = self._build_mlp( input_dim=self.feature_dim, hidden_dim=spherical_angular_hidden_dim, output_dim=7, # 2 + 2 + 1 + 1 + 1 = 7 values num_layers=num_spherical_angular_layers, dropout=dropout, ) # Build relative rotation head (predicts 6D relative rotation from canonical pose) self.relative_rotation_head = self._build_mlp( input_dim=self.feature_dim, hidden_dim=relative_rotation_hidden_dim, output_dim=6, # 6D rotation representation num_layers=num_relative_rotation_layers, dropout=dropout, ) def _build_backbone(self, variant: str, pretrained: bool): """Build ResNet backbone without classification head.""" # Map variant string to torchvision model resnet_models = { 'resnet18': models.resnet18, 'resnet34': models.resnet34, 'resnet50': models.resnet50, 'resnet101': models.resnet101, } if variant not in resnet_models: raise ValueError(f"Unknown ResNet variant: {variant}. Choose from {list(resnet_models.keys())}") # Load model if pretrained: # Use new weights API for PyTorch 1.13+ if variant == 'resnet18': weights = models.ResNet18_Weights.IMAGENET1K_V1 elif variant == 'resnet34': weights = models.ResNet34_Weights.IMAGENET1K_V1 elif variant == 'resnet50': weights = models.ResNet50_Weights.IMAGENET1K_V2 elif variant == 'resnet101': weights = models.ResNet101_Weights.IMAGENET1K_V2 else: weights = None resnet = resnet_models[variant](weights=weights) else: resnet = resnet_models[variant](weights=None) # Remove final classification layer # ResNet: conv layers + avgpool + fc # We keep everything except fc backbone = nn.Sequential(*list(resnet.children())[:-1]) # Remove fc layer return backbone def _get_feature_dim(self, variant: str) -> int: """Get feature dimension for ResNet variant.""" feature_dims = { 'resnet18': 512, 'resnet34': 512, 'resnet50': 2048, 'resnet101': 2048, } return feature_dims[variant] def _build_mlp( self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int, dropout: float, ) -> nn.Module: """Build MLP head. Args: input_dim: Input feature dimension hidden_dim: Hidden layer dimension output_dim: Output dimension num_layers: Number of layers (including output layer) dropout: Dropout probability Returns: MLP module """ layers = [] if num_layers < 2: layers.extend([nn.Linear(input_dim, output_dim)]) return nn.Sequential(*layers) # Input layer layers.extend([ nn.Linear(input_dim, hidden_dim), nn.ReLU(inplace=True), nn.Dropout(dropout), ]) # Hidden layers for _ in range(num_layers - 2): layers.extend([ nn.Linear(hidden_dim, hidden_dim), nn.ReLU(inplace=True), nn.Dropout(dropout), ]) # Output layer (no activation) layers.append(nn.Linear(hidden_dim, output_dim)) return nn.Sequential(*layers) def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]: """Forward pass. Args: images: Input images of shape (B, 3, H, W) Returns: Dictionary containing: rotation: Rotation matrices of shape (B, 3, 3) translation: Translation vectors of shape (B, 3) relative_rotation: Relative rotation from canonical pose of shape (B, 3, 3) azimuth_sincos: Azimuth as (sin, cos) pair, L2 normalized, shape (B, 2) elevation_sincos: Elevation as (sin, cos) pair, L2 normalized, shape (B, 2) radius: Normalized radius in range [-1, 1], shape (B,) yaw: Normalized yaw with scale factor, shape (B,) pitch: Normalized pitch with scale factor, shape (B,) """ # Extract features features = self.backbone(images) # (B, feature_dim, 1, 1) features = features.flatten(1) # (B, feature_dim) # Predict 6D rotation rotation_6d = self.rotation_head(features) # (B, 6) # Convert 6D to 3x3 rotation matrix rotation = rotation_6d_to_matrix(rotation_6d) # (B, 3, 3) # Predict translation translation = self.translation_head(features) # (B, 3) # Predict spherical and angular parameters # Output: [sin(az), cos(az), sin(el), cos(el), norm_radius, norm_yaw, norm_pitch] spherical_angular_raw = self.spherical_angular_head(features) # (B, 7) # Split into components azimuth_raw = spherical_angular_raw[:, 0:2] # (B, 2) elevation_raw = spherical_angular_raw[:, 2:4] # (B, 2) norm_radius = spherical_angular_raw[:, 4] # (B,) norm_yaw_raw = spherical_angular_raw[:, 5] # (B,) norm_pitch_raw = spherical_angular_raw[:, 6] # (B,) # Normalize azimuth and elevation to unit vectors azimuth_sincos = F.normalize(azimuth_raw, p=2, dim=-1) # (B, 2) elevation_sincos = F.normalize(elevation_raw, p=2, dim=-1) # (B, 2) # Apply tanh and scale to yaw and pitch (bounded to [-scale, scale]) norm_yaw = torch.tanh(norm_yaw_raw) * self.yaw_pitch_scale # (B,) norm_pitch = torch.tanh(norm_pitch_raw) * self.yaw_pitch_scale # (B,) # Predict relative rotation (6D representation) relative_rotation_6d = self.relative_rotation_head(features) # (B, 6) # Convert 6D to 3x3 rotation matrix relative_rotation = rotation_6d_to_matrix(relative_rotation_6d) # (B, 3, 3) return { 'rotation': rotation, 'translation': translation, 'relative_rotation': relative_rotation, 'azimuth_sincos': azimuth_sincos, 'elevation_sincos': elevation_sincos, 'radius': norm_radius, 'yaw': norm_yaw, 'pitch': norm_pitch, } def freeze_backbone(self): """Freeze ResNet backbone parameters.""" for param in self.backbone.parameters(): param.requires_grad = False def unfreeze_backbone(self): """Unfreeze ResNet backbone parameters.""" for param in self.backbone.parameters(): param.requires_grad = True def get_num_params(self) -> dict: """Get number of parameters.""" backbone_params = sum(p.numel() for p in self.backbone.parameters()) rotation_params = sum(p.numel() for p in self.rotation_head.parameters()) translation_params = sum(p.numel() for p in self.translation_head.parameters()) spherical_angular_params = sum(p.numel() for p in self.spherical_angular_head.parameters()) relative_rotation_params = sum(p.numel() for p in self.relative_rotation_head.parameters()) total_params = backbone_params + rotation_params + translation_params + spherical_angular_params + relative_rotation_params trainable_backbone = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad) trainable_rotation = sum(p.numel() for p in self.rotation_head.parameters() if p.requires_grad) trainable_translation = sum(p.numel() for p in self.translation_head.parameters() if p.requires_grad) trainable_spherical_angular = sum(p.numel() for p in self.spherical_angular_head.parameters() if p.requires_grad) trainable_relative_rotation = sum(p.numel() for p in self.relative_rotation_head.parameters() if p.requires_grad) trainable_total = trainable_backbone + trainable_rotation + trainable_translation + trainable_spherical_angular + trainable_relative_rotation return { 'total': total_params, 'backbone': backbone_params, 'rotation_head': rotation_params, 'translation_head': translation_params, 'spherical_angular_head': spherical_angular_params, 'relative_rotation_head': relative_rotation_params, 'trainable_total': trainable_total, 'trainable_backbone': trainable_backbone, 'trainable_rotation': trainable_rotation, 'trainable_translation': trainable_translation, 'trainable_spherical_angular': trainable_spherical_angular, 'trainable_relative_rotation': trainable_relative_rotation, } def test_model(): """Test model instantiation and forward pass.""" print("Testing ViewpointPredictor...") # Test with different ResNet variants for variant in ['resnet18', 'resnet34', 'resnet50']: print(f"\nTesting {variant}...") # Create model model = ViewpointPredictor( resnet_variant=variant, pretrained=False, # Don't download weights for test rotation_hidden_dim=256, translation_hidden_dim=256, ) # Print parameters params = model.get_num_params() print(f"Total parameters: {params['total']:,}") print(f" Backbone: {params['backbone']:,}") print(f" Rotation head: {params['rotation_head']:,}") print(f" Translation head: {params['translation_head']:,}") print(f" Offset head: {params['offset_head']:,}") # Test forward pass batch_size = 4 images = torch.randn(batch_size, 3, 224, 224) with torch.no_grad(): rotation, translation, offset, azimuth_sincos = model(images) print(f"Input shape: {images.shape}") print(f"Rotation shape: {rotation.shape} (expected: ({batch_size}, 3, 3))") print(f"Translation shape: {translation.shape} (expected: ({batch_size}, 3))") print(f"Offset shape: {offset.shape} (expected: ({batch_size}, 2))") print(f"Azimuth sincos shape: {azimuth_sincos.shape} (expected: ({batch_size}, 2))") # Check rotation properties det = torch.det(rotation) print(f"Rotation determinants: {det.tolist()} (should be ~1)") identity = torch.eye(3).unsqueeze(0).expand(batch_size, 3, 3) ortho_error = torch.norm(torch.bmm(rotation.transpose(-2, -1), rotation) - identity, dim=(1, 2)) print(f"Orthogonality errors: {ortho_error.tolist()} (should be small)") # Check azimuth normalization azimuth_norms = torch.norm(azimuth_sincos, p=2, dim=-1) print(f"Azimuth sincos norms: {azimuth_norms.tolist()} (should be ~1)") print("\nTest passed!") if __name__ == "__main__": test_model()