viewtoken-harmon-demo / src /models /viewpoint_predictor.py
XinxuanLu's picture
Initial demo
becf13a verified
"""
ResNet-based viewpoint predictor for camera pose estimation.
Predicts camera pose (rotation + translation + angular offset) from RGB images using:
- ResNet backbone (pretrained on ImageNet)
- MLP heads for rotation (6D representation), translation (3D vector), and angular offset (2D)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from .rotation_utils import rotation_6d_to_matrix
class ViewpointPredictor(nn.Module):
"""ResNet-based viewpoint predictor.
Architecture:
ResNet backbone → Global Average Pool → Flatten
├─→ MLP (rotation): features → hidden → hidden → 6D rotation
├─→ MLP (translation): features → hidden → hidden → 3D translation
├─→ MLP (spherical_angular): features → hidden → hidden → 7D (sin_az, cos_az, sin_el, cos_el, norm_radius, norm_yaw, norm_pitch)
└─→ MLP (relative_rotation): features → hidden → hidden → 6D relative rotation
Args:
resnet_variant: ResNet variant ('resnet18', 'resnet34', 'resnet50', 'resnet101')
pretrained: Whether to use ImageNet pretrained weights
rotation_hidden_dim: Hidden dimension for rotation MLP
translation_hidden_dim: Hidden dimension for translation MLP
spherical_angular_hidden_dim: Hidden dimension for spherical_angular MLP
relative_rotation_hidden_dim: Hidden dimension for relative rotation MLP
num_rotation_layers: Number of layers in rotation MLP
num_translation_layers: Number of layers in translation MLP
num_spherical_angular_layers: Number of layers in spherical_angular MLP
num_relative_rotation_layers: Number of layers in relative rotation MLP
dropout: Dropout probability (default: 0.1)
yaw_pitch_scale: Scale factor for yaw/pitch normalization (default: 5.0)
"""
def __init__(
self,
resnet_variant: str = 'resnet18',
pretrained: bool = True,
rotation_hidden_dim: int = 256,
translation_hidden_dim: int = 256,
spherical_angular_hidden_dim: int = 256,
relative_rotation_hidden_dim: int = 256,
num_rotation_layers: int = 3,
num_translation_layers: int = 3,
num_spherical_angular_layers: int = 3,
num_relative_rotation_layers: int = 3,
dropout: float = 0.1,
yaw_pitch_scale: float = 5.0,
):
super().__init__()
self.resnet_variant = resnet_variant
self.pretrained = pretrained
self.yaw_pitch_scale = yaw_pitch_scale
# Load ResNet backbone
self.backbone = self._build_backbone(resnet_variant, pretrained)
# Get feature dimension
self.feature_dim = self._get_feature_dim(resnet_variant)
# Build rotation head (predicts 6D rotation)
self.rotation_head = self._build_mlp(
input_dim=self.feature_dim,
hidden_dim=rotation_hidden_dim,
output_dim=6, # 6D rotation representation
num_layers=num_rotation_layers,
dropout=dropout,
)
# Build translation head (predicts 3D translation)
self.translation_head = self._build_mlp(
input_dim=self.feature_dim,
hidden_dim=translation_hidden_dim,
output_dim=3, # 3D translation vector
num_layers=num_translation_layers,
dropout=dropout,
)
# Build spherical_angular head (predicts azimuth, elevation, radius, yaw, pitch)
# Output: [sin(az), cos(az), sin(el), cos(el), norm_radius, norm_yaw, norm_pitch]
self.spherical_angular_head = self._build_mlp(
input_dim=self.feature_dim,
hidden_dim=spherical_angular_hidden_dim,
output_dim=7, # 2 + 2 + 1 + 1 + 1 = 7 values
num_layers=num_spherical_angular_layers,
dropout=dropout,
)
# Build relative rotation head (predicts 6D relative rotation from canonical pose)
self.relative_rotation_head = self._build_mlp(
input_dim=self.feature_dim,
hidden_dim=relative_rotation_hidden_dim,
output_dim=6, # 6D rotation representation
num_layers=num_relative_rotation_layers,
dropout=dropout,
)
def _build_backbone(self, variant: str, pretrained: bool):
"""Build ResNet backbone without classification head."""
# Map variant string to torchvision model
resnet_models = {
'resnet18': models.resnet18,
'resnet34': models.resnet34,
'resnet50': models.resnet50,
'resnet101': models.resnet101,
}
if variant not in resnet_models:
raise ValueError(f"Unknown ResNet variant: {variant}. Choose from {list(resnet_models.keys())}")
# Load model
if pretrained:
# Use new weights API for PyTorch 1.13+
if variant == 'resnet18':
weights = models.ResNet18_Weights.IMAGENET1K_V1
elif variant == 'resnet34':
weights = models.ResNet34_Weights.IMAGENET1K_V1
elif variant == 'resnet50':
weights = models.ResNet50_Weights.IMAGENET1K_V2
elif variant == 'resnet101':
weights = models.ResNet101_Weights.IMAGENET1K_V2
else:
weights = None
resnet = resnet_models[variant](weights=weights)
else:
resnet = resnet_models[variant](weights=None)
# Remove final classification layer
# ResNet: conv layers + avgpool + fc
# We keep everything except fc
backbone = nn.Sequential(*list(resnet.children())[:-1]) # Remove fc layer
return backbone
def _get_feature_dim(self, variant: str) -> int:
"""Get feature dimension for ResNet variant."""
feature_dims = {
'resnet18': 512,
'resnet34': 512,
'resnet50': 2048,
'resnet101': 2048,
}
return feature_dims[variant]
def _build_mlp(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
dropout: float,
) -> nn.Module:
"""Build MLP head.
Args:
input_dim: Input feature dimension
hidden_dim: Hidden layer dimension
output_dim: Output dimension
num_layers: Number of layers (including output layer)
dropout: Dropout probability
Returns:
MLP module
"""
layers = []
if num_layers < 2:
layers.extend([nn.Linear(input_dim, output_dim)])
return nn.Sequential(*layers)
# Input layer
layers.extend([
nn.Linear(input_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
])
# Hidden layers
for _ in range(num_layers - 2):
layers.extend([
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
])
# Output layer (no activation)
layers.append(nn.Linear(hidden_dim, output_dim))
return nn.Sequential(*layers)
def forward(self, images: torch.Tensor) -> dict[str, torch.Tensor]:
"""Forward pass.
Args:
images: Input images of shape (B, 3, H, W)
Returns:
Dictionary containing:
rotation: Rotation matrices of shape (B, 3, 3)
translation: Translation vectors of shape (B, 3)
relative_rotation: Relative rotation from canonical pose of shape (B, 3, 3)
azimuth_sincos: Azimuth as (sin, cos) pair, L2 normalized, shape (B, 2)
elevation_sincos: Elevation as (sin, cos) pair, L2 normalized, shape (B, 2)
radius: Normalized radius in range [-1, 1], shape (B,)
yaw: Normalized yaw with scale factor, shape (B,)
pitch: Normalized pitch with scale factor, shape (B,)
"""
# Extract features
features = self.backbone(images) # (B, feature_dim, 1, 1)
features = features.flatten(1) # (B, feature_dim)
# Predict 6D rotation
rotation_6d = self.rotation_head(features) # (B, 6)
# Convert 6D to 3x3 rotation matrix
rotation = rotation_6d_to_matrix(rotation_6d) # (B, 3, 3)
# Predict translation
translation = self.translation_head(features) # (B, 3)
# Predict spherical and angular parameters
# Output: [sin(az), cos(az), sin(el), cos(el), norm_radius, norm_yaw, norm_pitch]
spherical_angular_raw = self.spherical_angular_head(features) # (B, 7)
# Split into components
azimuth_raw = spherical_angular_raw[:, 0:2] # (B, 2)
elevation_raw = spherical_angular_raw[:, 2:4] # (B, 2)
norm_radius = spherical_angular_raw[:, 4] # (B,)
norm_yaw_raw = spherical_angular_raw[:, 5] # (B,)
norm_pitch_raw = spherical_angular_raw[:, 6] # (B,)
# Normalize azimuth and elevation to unit vectors
azimuth_sincos = F.normalize(azimuth_raw, p=2, dim=-1) # (B, 2)
elevation_sincos = F.normalize(elevation_raw, p=2, dim=-1) # (B, 2)
# Apply tanh and scale to yaw and pitch (bounded to [-scale, scale])
norm_yaw = torch.tanh(norm_yaw_raw) * self.yaw_pitch_scale # (B,)
norm_pitch = torch.tanh(norm_pitch_raw) * self.yaw_pitch_scale # (B,)
# Predict relative rotation (6D representation)
relative_rotation_6d = self.relative_rotation_head(features) # (B, 6)
# Convert 6D to 3x3 rotation matrix
relative_rotation = rotation_6d_to_matrix(relative_rotation_6d) # (B, 3, 3)
return {
'rotation': rotation,
'translation': translation,
'relative_rotation': relative_rotation,
'azimuth_sincos': azimuth_sincos,
'elevation_sincos': elevation_sincos,
'radius': norm_radius,
'yaw': norm_yaw,
'pitch': norm_pitch,
}
def freeze_backbone(self):
"""Freeze ResNet backbone parameters."""
for param in self.backbone.parameters():
param.requires_grad = False
def unfreeze_backbone(self):
"""Unfreeze ResNet backbone parameters."""
for param in self.backbone.parameters():
param.requires_grad = True
def get_num_params(self) -> dict:
"""Get number of parameters."""
backbone_params = sum(p.numel() for p in self.backbone.parameters())
rotation_params = sum(p.numel() for p in self.rotation_head.parameters())
translation_params = sum(p.numel() for p in self.translation_head.parameters())
spherical_angular_params = sum(p.numel() for p in self.spherical_angular_head.parameters())
relative_rotation_params = sum(p.numel() for p in self.relative_rotation_head.parameters())
total_params = backbone_params + rotation_params + translation_params + spherical_angular_params + relative_rotation_params
trainable_backbone = sum(p.numel() for p in self.backbone.parameters() if p.requires_grad)
trainable_rotation = sum(p.numel() for p in self.rotation_head.parameters() if p.requires_grad)
trainable_translation = sum(p.numel() for p in self.translation_head.parameters() if p.requires_grad)
trainable_spherical_angular = sum(p.numel() for p in self.spherical_angular_head.parameters() if p.requires_grad)
trainable_relative_rotation = sum(p.numel() for p in self.relative_rotation_head.parameters() if p.requires_grad)
trainable_total = trainable_backbone + trainable_rotation + trainable_translation + trainable_spherical_angular + trainable_relative_rotation
return {
'total': total_params,
'backbone': backbone_params,
'rotation_head': rotation_params,
'translation_head': translation_params,
'spherical_angular_head': spherical_angular_params,
'relative_rotation_head': relative_rotation_params,
'trainable_total': trainable_total,
'trainable_backbone': trainable_backbone,
'trainable_rotation': trainable_rotation,
'trainable_translation': trainable_translation,
'trainable_spherical_angular': trainable_spherical_angular,
'trainable_relative_rotation': trainable_relative_rotation,
}
def test_model():
"""Test model instantiation and forward pass."""
print("Testing ViewpointPredictor...")
# Test with different ResNet variants
for variant in ['resnet18', 'resnet34', 'resnet50']:
print(f"\nTesting {variant}...")
# Create model
model = ViewpointPredictor(
resnet_variant=variant,
pretrained=False, # Don't download weights for test
rotation_hidden_dim=256,
translation_hidden_dim=256,
)
# Print parameters
params = model.get_num_params()
print(f"Total parameters: {params['total']:,}")
print(f" Backbone: {params['backbone']:,}")
print(f" Rotation head: {params['rotation_head']:,}")
print(f" Translation head: {params['translation_head']:,}")
print(f" Offset head: {params['offset_head']:,}")
# Test forward pass
batch_size = 4
images = torch.randn(batch_size, 3, 224, 224)
with torch.no_grad():
rotation, translation, offset, azimuth_sincos = model(images)
print(f"Input shape: {images.shape}")
print(f"Rotation shape: {rotation.shape} (expected: ({batch_size}, 3, 3))")
print(f"Translation shape: {translation.shape} (expected: ({batch_size}, 3))")
print(f"Offset shape: {offset.shape} (expected: ({batch_size}, 2))")
print(f"Azimuth sincos shape: {azimuth_sincos.shape} (expected: ({batch_size}, 2))")
# Check rotation properties
det = torch.det(rotation)
print(f"Rotation determinants: {det.tolist()} (should be ~1)")
identity = torch.eye(3).unsqueeze(0).expand(batch_size, 3, 3)
ortho_error = torch.norm(torch.bmm(rotation.transpose(-2, -1), rotation) - identity, dim=(1, 2))
print(f"Orthogonality errors: {ortho_error.tolist()} (should be small)")
# Check azimuth normalization
azimuth_norms = torch.norm(azimuth_sincos, p=2, dim=-1)
print(f"Azimuth sincos norms: {azimuth_norms.tolist()} (should be ~1)")
print("\nTest passed!")
if __name__ == "__main__":
test_model()