root
add our app
7b75adb
import torch
import torch.nn as nn
from .. import sonata
from typing import Dict, Union, Optional
from pathlib import Path
class SonataFeatureExtractor(nn.Module):
"""
Feature extractor using Sonata backbone with MLP projection.
Supports batch processing and gradient computation.
"""
def __init__(
self,
ckpt_path: Optional[str] = "",
):
super().__init__()
# Load Sonata model
self.sonata = sonata.load_by_config(
str(Path(__file__).parent.parent.parent / "config" / "sonata.json")
)
# Store original dtype for later reference
# self._original_dtype = next(self.parameters()).dtype
# Define MLP projection head (same as in train-sonata.py)
self.mlp = nn.Sequential(
nn.Linear(1232, 512),
nn.GELU(),
nn.Linear(512, 512),
nn.GELU(),
nn.Linear(512, 512),
)
# Define transform
self.transform = sonata.transform.default()
# Load checkpoint if provided
if ckpt_path:
self.load_checkpoint(ckpt_path)
def load_checkpoint(self, checkpoint_path: str):
"""Load model weights from checkpoint."""
checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Extract state dict from Lightning checkpoint
if "state_dict" in checkpoint:
state_dict = checkpoint["state_dict"]
# Remove 'model.' prefix if present from Lightning
state_dict = {k.replace("model.", ""): v for k, v in state_dict.items()}
else:
state_dict = checkpoint
# Debug: Show all keys in checkpoint
print("\n=== Checkpoint Keys ===")
print(f"Total keys in checkpoint: {len(state_dict)}")
print("\nSample keys:")
for i, key in enumerate(list(state_dict.keys())[:10]):
print(f" {key}")
if len(state_dict) > 10:
print(f" ... and {len(state_dict) - 10} more keys")
# Load only the relevant weights
sonata_dict = {
k.replace("sonata.", ""): v
for k, v in state_dict.items()
if k.startswith("sonata.")
}
mlp_dict = {
k.replace("mlp.", ""): v
for k, v in state_dict.items()
if k.startswith("mlp.")
}
print(f"\nFound {len(sonata_dict)} Sonata keys")
print(f"Found {len(mlp_dict)} MLP keys")
# Load Sonata weights and show missing/unexpected keys
if sonata_dict:
print("\n=== Loading Sonata Weights ===")
result = self.sonata.load_state_dict(sonata_dict, strict=False)
if result.missing_keys:
print(f"\nMissing keys ({len(result.missing_keys)}):")
for key in result.missing_keys[:20]: # Show first 20
print(f" - {key}")
if len(result.missing_keys) > 20:
print(f" ... and {len(result.missing_keys) - 20} more")
else:
print("No missing keys!")
if result.unexpected_keys:
print(f"\nUnexpected keys ({len(result.unexpected_keys)}):")
for key in result.unexpected_keys[:20]: # Show first 20
print(f" - {key}")
if len(result.unexpected_keys) > 20:
print(f" ... and {len(result.unexpected_keys) - 20} more")
else:
print("No unexpected keys!")
# Load MLP weights
if mlp_dict:
print("\n=== Loading MLP Weights ===")
result = self.mlp.load_state_dict(mlp_dict, strict=False)
if result.missing_keys:
print(f"\nMissing keys: {result.missing_keys}")
if result.unexpected_keys:
print(f"Unexpected keys: {result.unexpected_keys}")
print("MLP weights loaded successfully!")
print(f"\n✓ Loaded checkpoint from {checkpoint_path}")
def prepare_batch_data(
self, points: torch.Tensor, normals: Optional[torch.Tensor] = None
) -> Dict:
"""
Prepare batch data for Sonata model.
Args:
points: [B, N, 3] or [N, 3] tensor of point coordinates
normals: [B, N, 3] or [N, 3] tensor of normals (optional)
Returns:
Dictionary formatted for Sonata input
"""
# Handle single batch case
if points.dim() == 2:
points = points.unsqueeze(0)
if normals is not None:
normals = normals.unsqueeze(0)
# print('Sonata points shape: ', points.shape)
B, N, _ = points.shape
# Prepare batch indices
batch_idx = torch.arange(B).view(-1, 1).repeat(1, N).reshape(-1)
# Flatten points for Sonata format
coord = points.reshape(B * N, 3)
if normals is not None:
normal = normals.reshape(B * N, 3)
else:
# Generate dummy normals if not provided
normal = torch.ones_like(coord)
# Generate dummy colors
color = torch.ones_like(coord)
# Function to convert tensor to numpy array, handling BFloat16
def to_numpy(tensor):
# First convert to CPU if needed
if tensor.is_cuda:
tensor = tensor.cpu()
# Convert BFloat16 or other unsupported dtypes to float32
if tensor.dtype not in [
torch.float32,
torch.float64,
torch.int32,
torch.int64,
torch.uint8,
torch.int8,
torch.int16,
]:
tensor = tensor.to(torch.float32)
# Then convert to numpy
return tensor.numpy()
# Create data dict
data_dict = {
"coord": to_numpy(coord),
"normal": to_numpy(normal),
"color": to_numpy(color),
"batch": to_numpy(batch_idx),
}
# Apply transform
data_dict = self.transform(data_dict)
return data_dict, B, N
def forward(
self, points: torch.Tensor, normals: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Extract features from point clouds.
Args:
points: [B, N, 3] or [N, 3] tensor of point coordinates
normals: [B, N, 3] or [N, 3] tensor of normals (optional)
Returns:
features: [B, N, 512] or [N, 512] tensor of features
"""
# Store original shape
original_shape = points.shape
single_batch = points.dim() == 2
# Prepare data for Sonata
data_dict, B, N = self.prepare_batch_data(points, normals)
# Move to GPU if needed and convert to appropriate dtype
device = points.device
dtype = points.dtype
# Make sure the entire model is in the correct dtype
# if dtype != self._original_dtype:
# self.to(dtype)
# self._original_dtype = dtype
for key in data_dict.keys():
if isinstance(data_dict[key], torch.Tensor):
# Convert tensors to the right device and dtype if they're floating point
if data_dict[key].is_floating_point():
data_dict[key] = data_dict[key].to(device=device, dtype=dtype)
else:
# For integer tensors, just move to device without changing dtype
data_dict[key] = data_dict[key].to(device)
# Extract Sonata features
point = self.sonata(data_dict)
# Handle pooling layers (same as in train-sonata.py)
while "pooling_parent" in point.keys():
assert "pooling_inverse" in point.keys()
parent = point.pop("pooling_parent")
inverse = point.pop("pooling_inverse")
parent.feat = torch.cat([parent.feat, point.feat[inverse]], dim=-1)
point = parent
# Get features and apply MLP
feat = point.feat # [M, 1232]
feat = self.mlp(feat) # [M, 512]
# Map back to original points
feat = feat[point.inverse] # [B*N, 512]
# Reshape to batch format
feat = feat.reshape(B, -1, feat.shape[-1]) # [B, N, 512]
# Return in original format
if single_batch:
feat = feat.squeeze(0) # [N, 512]
return feat
def extract_features_batch(
self,
points_list: list,
normals_list: Optional[list] = None,
batch_size: int = 8,
) -> list:
"""
Extract features for multiple point clouds in batches.
Args:
points_list: List of [N_i, 3] tensors
normals_list: List of [N_i, 3] tensors (optional)
batch_size: Batch size for processing
Returns:
List of [N_i, 512] feature tensors
"""
features_list = []
# Process in batches
for i in range(0, len(points_list), batch_size):
batch_points = points_list[i : i + batch_size]
batch_normals = normals_list[i : i + batch_size] if normals_list else None
# Find max points in batch
max_n = max(p.shape[0] for p in batch_points)
# Pad to same size
padded_points = []
masks = []
for points in batch_points:
n = points.shape[0]
if n < max_n:
padding = torch.zeros(max_n - n, 3, device=points.device)
points = torch.cat([points, padding], dim=0)
padded_points.append(points)
mask = torch.zeros(max_n, dtype=torch.bool, device=points.device)
mask[:n] = True
masks.append(mask)
# Stack batch
batch_tensor = torch.stack(padded_points) # [B, max_n, 3]
# Handle normals similarly if provided
if batch_normals:
padded_normals = []
for j, normals in enumerate(batch_normals):
n = normals.shape[0]
if n < max_n:
padding = torch.ones(max_n - n, 3, device=normals.device)
normals = torch.cat([normals, padding], dim=0)
padded_normals.append(normals)
normals_tensor = torch.stack(padded_normals)
else:
normals_tensor = None
# Extract features
with torch.cuda.amp.autocast(enabled=True):
batch_features = self.forward(
batch_tensor, normals_tensor
) # [B, max_n, 512]
# Unpad and add to results
for j, (feat, mask) in enumerate(zip(batch_features, masks)):
features_list.append(feat[mask])
return features_list