import torch import torch.nn as nn from .. import sonata from typing import Dict, Union, Optional from pathlib import Path class SonataFeatureExtractor(nn.Module): """ Feature extractor using Sonata backbone with MLP projection. Supports batch processing and gradient computation. """ def __init__( self, ckpt_path: Optional[str] = "", ): super().__init__() # Load Sonata model self.sonata = sonata.load_by_config( str(Path(__file__).parent.parent.parent / "config" / "sonata.json") ) # Store original dtype for later reference # self._original_dtype = next(self.parameters()).dtype # Define MLP projection head (same as in train-sonata.py) self.mlp = nn.Sequential( nn.Linear(1232, 512), nn.GELU(), nn.Linear(512, 512), nn.GELU(), nn.Linear(512, 512), ) # Define transform self.transform = sonata.transform.default() # Load checkpoint if provided if ckpt_path: self.load_checkpoint(ckpt_path) def load_checkpoint(self, checkpoint_path: str): """Load model weights from checkpoint.""" checkpoint = torch.load(checkpoint_path, map_location="cpu") # Extract state dict from Lightning checkpoint if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] # Remove 'model.' prefix if present from Lightning state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()} else: state_dict = checkpoint # Debug: Show all keys in checkpoint print("\n=== Checkpoint Keys ===") print(f"Total keys in checkpoint: {len(state_dict)}") print("\nSample keys:") for i, key in enumerate(list(state_dict.keys())[:10]): print(f" {key}") if len(state_dict) > 10: print(f" ... and {len(state_dict) - 10} more keys") # Load only the relevant weights sonata_dict = { k.replace("sonata.", ""): v for k, v in state_dict.items() if k.startswith("sonata.") } mlp_dict = { k.replace("mlp.", ""): v for k, v in state_dict.items() if k.startswith("mlp.") } print(f"\nFound {len(sonata_dict)} Sonata keys") print(f"Found {len(mlp_dict)} MLP keys") # Load Sonata weights and show missing/unexpected keys if sonata_dict: print("\n=== Loading Sonata Weights ===") result = self.sonata.load_state_dict(sonata_dict, strict=False) if result.missing_keys: print(f"\nMissing keys ({len(result.missing_keys)}):") for key in result.missing_keys[:20]: # Show first 20 print(f" - {key}") if len(result.missing_keys) > 20: print(f" ... and {len(result.missing_keys) - 20} more") else: print("No missing keys!") if result.unexpected_keys: print(f"\nUnexpected keys ({len(result.unexpected_keys)}):") for key in result.unexpected_keys[:20]: # Show first 20 print(f" - {key}") if len(result.unexpected_keys) > 20: print(f" ... and {len(result.unexpected_keys) - 20} more") else: print("No unexpected keys!") # Load MLP weights if mlp_dict: print("\n=== Loading MLP Weights ===") result = self.mlp.load_state_dict(mlp_dict, strict=False) if result.missing_keys: print(f"\nMissing keys: {result.missing_keys}") if result.unexpected_keys: print(f"Unexpected keys: {result.unexpected_keys}") print("MLP weights loaded successfully!") print(f"\n✓ Loaded checkpoint from {checkpoint_path}") def prepare_batch_data( self, points: torch.Tensor, normals: Optional[torch.Tensor] = None ) -> Dict: """ Prepare batch data for Sonata model. Args: points: [B, N, 3] or [N, 3] tensor of point coordinates normals: [B, N, 3] or [N, 3] tensor of normals (optional) Returns: Dictionary formatted for Sonata input """ # Handle single batch case if points.dim() == 2: points = points.unsqueeze(0) if normals is not None: normals = normals.unsqueeze(0) # print('Sonata points shape: ', points.shape) B, N, _ = points.shape # Prepare batch indices batch_idx = torch.arange(B).view(-1, 1).repeat(1, N).reshape(-1) # Flatten points for Sonata format coord = points.reshape(B * N, 3) if normals is not None: normal = normals.reshape(B * N, 3) else: # Generate dummy normals if not provided normal = torch.ones_like(coord) # Generate dummy colors color = torch.ones_like(coord) # Function to convert tensor to numpy array, handling BFloat16 def to_numpy(tensor): # First convert to CPU if needed if tensor.is_cuda: tensor = tensor.cpu() # Convert BFloat16 or other unsupported dtypes to float32 if tensor.dtype not in [ torch.float32, torch.float64, torch.int32, torch.int64, torch.uint8, torch.int8, torch.int16, ]: tensor = tensor.to(torch.float32) # Then convert to numpy return tensor.numpy() # Create data dict data_dict = { "coord": to_numpy(coord), "normal": to_numpy(normal), "color": to_numpy(color), "batch": to_numpy(batch_idx), } # Apply transform data_dict = self.transform(data_dict) return data_dict, B, N def forward( self, points: torch.Tensor, normals: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Extract features from point clouds. Args: points: [B, N, 3] or [N, 3] tensor of point coordinates normals: [B, N, 3] or [N, 3] tensor of normals (optional) Returns: features: [B, N, 512] or [N, 512] tensor of features """ # Store original shape original_shape = points.shape single_batch = points.dim() == 2 # Prepare data for Sonata data_dict, B, N = self.prepare_batch_data(points, normals) # Move to GPU if needed and convert to appropriate dtype device = points.device dtype = points.dtype # Make sure the entire model is in the correct dtype # if dtype != self._original_dtype: # self.to(dtype) # self._original_dtype = dtype for key in data_dict.keys(): if isinstance(data_dict[key], torch.Tensor): # Convert tensors to the right device and dtype if they're floating point if data_dict[key].is_floating_point(): data_dict[key] = data_dict[key].to(device=device, dtype=dtype) else: # For integer tensors, just move to device without changing dtype data_dict[key] = data_dict[key].to(device) # Extract Sonata features point = self.sonata(data_dict) # Handle pooling layers (same as in train-sonata.py) while "pooling_parent" in point.keys(): assert "pooling_inverse" in point.keys() parent = point.pop("pooling_parent") inverse = point.pop("pooling_inverse") parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1) point = parent # Get features and apply MLP feat = point.feat # [M, 1232] feat = self.mlp(feat) # [M, 512] # Map back to original points feat = feat[point.inverse] # [B*N, 512] # Reshape to batch format feat = feat.reshape(B, -1, feat.shape[-1]) # [B, N, 512] # Return in original format if single_batch: feat = feat.squeeze(0) # [N, 512] return feat def extract_features_batch( self, points_list: list, normals_list: Optional[list] = None, batch_size: int = 8, ) -> list: """ Extract features for multiple point clouds in batches. Args: points_list: List of [N_i, 3] tensors normals_list: List of [N_i, 3] tensors (optional) batch_size: Batch size for processing Returns: List of [N_i, 512] feature tensors """ features_list = [] # Process in batches for i in range(0, len(points_list), batch_size): batch_points = points_list[i : i + batch_size] batch_normals = normals_list[i : i + batch_size] if normals_list else None # Find max points in batch max_n = max(p.shape[0] for p in batch_points) # Pad to same size padded_points = [] masks = [] for points in batch_points: n = points.shape[0] if n < max_n: padding = torch.zeros(max_n - n, 3, device=points.device) points = torch.cat([points, padding], dim=0) padded_points.append(points) mask = torch.zeros(max_n, dtype=torch.bool, device=points.device) mask[:n] = True masks.append(mask) # Stack batch batch_tensor = torch.stack(padded_points) # [B, max_n, 3] # Handle normals similarly if provided if batch_normals: padded_normals = [] for j, normals in enumerate(batch_normals): n = normals.shape[0] if n < max_n: padding = torch.ones(max_n - n, 3, device=normals.device) normals = torch.cat([normals, padding], dim=0) padded_normals.append(normals) normals_tensor = torch.stack(padded_normals) else: normals_tensor = None # Extract features with torch.cuda.amp.autocast(enabled=True): batch_features = self.forward( batch_tensor, normals_tensor ) # [B, max_n, 512] # Unpad and add to results for j, (feat, mask) in enumerate(zip(batch_features, masks)): features_list.append(feat[mask]) return features_list