Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import yaml | |
| import sys | |
| sys.path.append(".") | |
| sys.path.append("submodules") | |
| sys.path.append("submodules/mast3r") | |
| from mast3r.model import AsymmetricMASt3R | |
| from src.ptv3 import PTV3 | |
| from src.gaussian_head import GaussianHead | |
| from src.utils.points_process import merge_points | |
| from src.losses import GaussianLoss | |
| from src.lseg import LSegFeatureExtractor | |
| import argparse | |
| class LSM_MASt3R(nn.Module): | |
| def __init__(self, | |
| mast3r_config, | |
| point_transformer_config, | |
| gaussian_head_config, | |
| lseg_config, | |
| ): | |
| super().__init__() | |
| # self.config | |
| self.config = { | |
| 'mast3r_config': mast3r_config, | |
| 'point_transformer_config': point_transformer_config, | |
| 'gaussian_head_config': gaussian_head_config, | |
| 'lseg_config': lseg_config | |
| } | |
| # Initialize AsymmetricMASt3R | |
| self.mast3r = AsymmetricMASt3R.from_pretrained(**mast3r_config) | |
| # Freeze MASt3R parameters | |
| for param in self.mast3r.parameters(): | |
| param.requires_grad = False | |
| self.mast3r.eval() | |
| # Initialize PointTransformerV3 | |
| self.point_transformer = PTV3(**point_transformer_config) | |
| # Initialize the gaussian head | |
| self.gaussian_head = GaussianHead(**gaussian_head_config) | |
| # Initialize the lseg feature extractor | |
| self.lseg_feature_extractor = LSegFeatureExtractor.from_pretrained(**lseg_config) | |
| for param in self.lseg_feature_extractor.parameters(): | |
| param.requires_grad = False | |
| self.lseg_feature_extractor.eval() | |
| # Define two linear layers | |
| d_gs_feats = gaussian_head_config.get('d_gs_feats', 32) | |
| self.feature_reduction = nn.Sequential( | |
| nn.Conv2d(512, d_gs_feats, kernel_size=1), | |
| nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) | |
| ) # (b, 512, h//2, w//2) -> (b, d_features, h, w) | |
| self.feature_expansion = nn.Sequential( | |
| nn.Conv2d(d_gs_feats, 512, kernel_size=1), | |
| nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) | |
| ) # (b, d_features, h, w) -> (b, 512, h//2, w//2) | |
| def forward(self, view1, view2): | |
| # AsymmetricMASt3R forward pass | |
| mast3r_output = self.mast3r(view1, view2) | |
| # merge points from two views | |
| data_dict = merge_points(mast3r_output, view1, view2) | |
| # PointTransformerV3 forward pass | |
| point_transformer_output = self.point_transformer(data_dict) | |
| # extract lseg features | |
| lseg_features = self.extract_lseg_features(view1, view2) | |
| # Gaussian head forward pass | |
| final_output = self.gaussian_head(point_transformer_output, lseg_features) | |
| return final_output | |
| def extract_lseg_features(self, view1, view2): | |
| # concat view1 and view2 | |
| img = torch.cat([view1['img'], view2['img']], dim=0) # (v*b, 3, h, w) | |
| # extract features | |
| lseg_features = self.lseg_feature_extractor.extract_features(img) # (v*b, 512, h//2, w//2) | |
| # reduce dimensions | |
| lseg_features = self.feature_reduction(lseg_features) # (v*b, d_features, h, w) | |
| return lseg_features | |
| def from_pretrained(checkpoint_path, device='cuda'): | |
| # Load the checkpoint | |
| ckpt = torch.load(checkpoint_path, map_location='cpu') | |
| # Extract the configuration from the checkpoint | |
| config = ckpt['args'] | |
| # Create a new instance of LSM_MASt3R | |
| model = eval(config.model) | |
| # Load the state dict | |
| model.load_state_dict(ckpt['model']) | |
| # Move the model to the specified device | |
| model = model.to(device) | |
| return model | |
| def state_dict(self, destination=None, prefix='', keep_vars=False): | |
| # 获取所有参数的state_dict | |
| full_state_dict = super().state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) | |
| # 只保留需要训练的参数 | |
| trainable_state_dict = { | |
| k: v for k, v in full_state_dict.items() | |
| if not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.')) | |
| } | |
| return trainable_state_dict | |
| def load_state_dict(self, state_dict, strict=True): | |
| # 获取当前模型的完整state_dict | |
| model_state = super().state_dict() | |
| # 只更新需要训练的参数 | |
| for k in list(state_dict.keys()): | |
| if k in model_state and not (k.startswith('mast3r.') or k.startswith('lseg_feature_extractor.')): | |
| model_state[k] = state_dict[k] | |
| # 使用更新后的state_dict加载模型 | |
| super().load_state_dict(model_state, strict=False) | |
| if __name__ == "__main__": | |
| from torch.utils.data import DataLoader | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--checkpoint', type=str) | |
| args = parser.parse_args() | |
| # Load config | |
| with open("configs/model_config.yaml", "r") as f: | |
| config = yaml.safe_load(f) | |
| # Initialize model | |
| if args.checkpoint is not None: | |
| model = LSM_MASt3R.from_pretrained(args.checkpoint, device='cuda') | |
| else: | |
| model = LSM_MASt3R(**config).to('cuda') | |
| model.eval() | |
| # Print model | |
| print(model) | |
| # Load dataset | |
| from src.datasets.scannet import Scannet | |
| dataset = Scannet(split='train', ROOT="data/scannet_processed", resolution=[(512, 384)]) | |
| # Print dataset | |
| print(dataset) | |
| # Test model | |
| data_loader = DataLoader(dataset, batch_size=3, shuffle=True) | |
| data = next(iter(data_loader)) | |
| # move data to cuda | |
| for view in data: | |
| view['img'] = view['img'].to('cuda') | |
| view['depthmap'] = view['depthmap'].to('cuda') | |
| view['camera_pose'] = view['camera_pose'].to('cuda') | |
| view['camera_intrinsics'] = view['camera_intrinsics'].to('cuda') | |
| # Forward pass | |
| output = model(*data[:2]) | |
| # Loss | |
| loss = GaussianLoss() | |
| loss_value = loss(*data, *output, model) | |
| print(loss_value) | |