Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import Any, Dict, List, Optional, Tuple | |
| import numpy as np | |
| import re | |
| import roma | |
| import torch | |
| from torch.distributed import all_gather_object, barrier | |
| from lightning import LightningModule | |
| from lightning.pytorch.loggers.wandb import WandbLogger | |
| from torchmetrics import MaxMetric, MeanMetric, MinMetric, SumMetric, Metric | |
| from torchmetrics.aggregation import BaseAggregator | |
| from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR | |
| from stream3r.dust3r.model import FlashDUSt3R | |
| from stream3r.models.stream3r import STream3R | |
| from stream3r.utils import pylogger | |
| log = pylogger.RankedLogger(__name__, rank_zero_only=True) | |
| class AccumulatedSum(BaseAggregator): | |
| def __init__( | |
| self, | |
| **kwargs: Any, | |
| ) -> None: | |
| super().__init__( | |
| fn="sum", | |
| default_value=torch.tensor(0.0, dtype=torch.long), | |
| nan_strategy='warn', | |
| state_name="sum_value", | |
| **kwargs, | |
| ) | |
| def update(self, value: int) -> None: | |
| self.sum_value += value | |
| def compute(self) -> torch.LongTensor: | |
| return self.sum_value | |
| class MultiViewDUSt3RLitModule(LightningModule): | |
| def __init__( | |
| self, | |
| net: torch.nn.Module, | |
| train_criterion: torch.nn.Module, | |
| validation_criterion: torch.nn.Module, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler, | |
| compile: bool, | |
| pretrained: Optional[str] = None, | |
| resume_from_checkpoint: Optional[str] = None, | |
| eval_use_pts3d_from_local_head: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| self.save_hyperparameters(logger=False, ignore=['net', 'train_criterion', 'validation_criterion']) | |
| self.net = net | |
| self.train_criterion = train_criterion | |
| self.validation_criterion = validation_criterion | |
| self.pretrained = pretrained | |
| self.resume_from_checkpoint = resume_from_checkpoint | |
| self.eval_use_pts3d_from_local_head = eval_use_pts3d_from_local_head | |
| # use register_buffer to save these with checkpoints | |
| # so that when we resume training, these bookkeeping variables are preserved | |
| self.register_buffer("epoch_fraction", torch.tensor(0.0, dtype=torch.float32, device=self.device)) | |
| self.register_buffer("train_total_samples", torch.tensor(0, dtype=torch.long, device=self.device)) | |
| self.register_buffer("train_total_images", torch.tensor(0, dtype=torch.long, device=self.device)) | |
| self.train_total_samples_per_step = AccumulatedSum() # these need to be reduced across GPUs, so use Metric | |
| self.train_total_images_per_step = AccumulatedSum() # these need to be reduced across GPUs, so use Metric | |
| self.val_loss = MeanMetric() | |
| def load_for_inference(cls, net: STream3R): | |
| lit_module = cls(net=net, train_criterion=None, validation_criterion=None, optimizer=None, scheduler=None, compile=False) | |
| lit_module.eval() | |
| return lit_module | |
| def forward(self, views: List[Dict[str, torch.Tensor]]) -> Any: | |
| return self.net(views) | |
| def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: | |
| # Legacy: if the checkpoint does not contain the epoch_fraction, train_total_samples, and train_total_images | |
| # we manually add them to the checkpoint | |
| # if self.trainer.strategy.strategy_name != "deepseed": | |
| # if checkpoint["state_dict"].get("epoch_fraction") is None: | |
| # checkpoint["state_dict"]["epoch_fraction"] = self.epoch_fraction | |
| # if checkpoint["state_dict"].get("train_total_samples") is None: | |
| # checkpoint["state_dict"]["train_total_samples"] = self.train_total_samples | |
| # if checkpoint["state_dict"].get("train_total_images") is None: | |
| # checkpoint["state_dict"]["train_total_images"] = self.train_total_images | |
| pass | |
| def on_train_start(self) -> None: | |
| """Lightning hook that is called when training begins.""" | |
| # by default lightning executes validation step sanity checks before training starts, | |
| # so it's worth to make sure validation metrics don't store results from these checks | |
| self.val_loss.reset() | |
| # the wandb logger lives in self.loggers | |
| # find the wandb logger and watch the model and gradients | |
| for logger in self.loggers: | |
| if isinstance(logger, WandbLogger): | |
| self.wandb_logger = logger | |
| # log gradients, parameter histogram and model topology | |
| self.wandb_logger.watch(self.net, log="all", log_freq=500, log_graph=False) | |
| def on_train_epoch_start(self) -> None: | |
| # save initial checkpoint to check pretrained model | |
| # if self.trainer.global_step == 0: | |
| # checkpoint_path = os.path.join(self.trainer.checkpoint_callback.dirpath, "step_0.ckpt") | |
| # self.trainer.save_checkpoint(checkpoint_path) | |
| # our custom dataset and sampler has to have epoch set by calling set_epoch | |
| if hasattr(self.trainer.train_dataloader, "dataset") and hasattr(self.trainer.train_dataloader.dataset, "set_epoch"): | |
| self.trainer.train_dataloader.dataset.set_epoch(self.current_epoch) | |
| if hasattr(self.trainer.train_dataloader, "sampler") and hasattr(self.trainer.train_dataloader.sampler, "set_epoch"): | |
| self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch) | |
| def on_validation_epoch_start(self) -> None: | |
| # our custom dataset and sampler has to have epoch set by calling set_epoch | |
| for loader in self.trainer.val_dataloaders: | |
| if hasattr(loader, "dataset") and hasattr(loader.dataset, "set_epoch"): | |
| loader.dataset.set_epoch(0) | |
| if hasattr(loader, "sampler") and hasattr(loader.sampler, "set_epoch"): | |
| loader.sampler.set_epoch(0) | |
| def model_step( | |
| self, batch: List[Dict[str, torch.Tensor]], criterion: torch.nn.Module, | |
| ) -> Tuple[torch.Tensor, Dict]: | |
| device = self.device | |
| # Move data to device | |
| for view in batch: | |
| for name in "img pts3d valid_mask camera_pose camera_intrinsics F_matrix corres".split(): | |
| if name in view: | |
| view[name] = view[name].to(device, non_blocking=True) | |
| views = batch | |
| preds = self.forward(views) | |
| # Compute the loss in higher precision | |
| with torch.autocast(device_type=self.device.type, dtype=torch.float32): | |
| loss, loss_details = criterion(views, preds) if criterion is not None else None | |
| return views, preds, loss, loss_details | |
| def training_step( | |
| self, batch: List[Dict[str, torch.Tensor]], batch_idx: int | |
| ) -> torch.Tensor: | |
| views, preds, loss, loss_details = self.model_step(batch, self.train_criterion) | |
| if not isinstance(loss, (torch.Tensor, dict, type(None))): # this will cause a lightning.fabric.utilities.exceptions.MisconfigurationException | |
| # log loss and the batch information to help debugging | |
| # use print instead of log because the logger only logs on rank 0, but this could happen on any rank | |
| print(f"Loss is not a tensor or dict but {type(loss)}, value: {loss}") | |
| print(f"Loss details: {loss_details}") | |
| print(f"Batch: {batch}") | |
| print(f"Batch index: {batch_idx}") | |
| print(f"Views: {views}") | |
| print(f"Preds: {preds}") | |
| loss = None # set loss to None will still break the training loop in DDP, this is intended - we should fix the data to avoid nan loss in the first place | |
| return loss | |
| self.epoch_fraction = torch.tensor(self.trainer.current_epoch + batch_idx / self.trainer.num_training_batches, device=self.device) | |
| self.log("trainer/epoch", self.epoch_fraction, on_step=True, on_epoch=False, prog_bar=True) | |
| self.log("trainer/lr", self.trainer.lr_scheduler_configs[0].scheduler.get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True) | |
| self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True) | |
| # log the details of the loss | |
| if loss_details is not None: | |
| for key, value in loss_details.items(): | |
| self.log(f"train_detail_{key}", value, on_step=True, on_epoch=False, prog_bar=False) | |
| match = re.search(r'/(\d{1,2})$', key) | |
| if match: | |
| stripped_key = key[:match.start()] | |
| self.log(f"train/{stripped_key}", value, on_step=True, on_epoch=False, prog_bar=False) | |
| # Log the total number of samples seen so far | |
| batch_size = views[0]["img"].shape[0] | |
| self.train_total_samples_per_step(batch_size) # aggregate across all GPUs | |
| self.train_total_samples += self.train_total_samples_per_step.compute() # accumulate across all steps | |
| self.train_total_samples_per_step.reset() | |
| self.log("trainer/total_samples", self.train_total_samples, on_step=True, on_epoch=False, prog_bar=False) | |
| # Log the total number of images seen so far | |
| num_views = len(views) | |
| n_image_cur_step = batch_size * num_views | |
| self.train_total_images_per_step(n_image_cur_step) # aggregate across all GPUs | |
| self.train_total_images += self.train_total_images_per_step.compute() # accumulate across all steps | |
| self.train_total_images_per_step.reset() | |
| self.log("trainer/total_images", self.train_total_images, on_step=True, on_epoch=False, prog_bar=False) | |
| return loss | |
| def validation_step( | |
| self, batch: List[Dict[str, torch.Tensor]], batch_idx: int, dataloader_idx: int = 0, | |
| ) -> torch.Tensor: | |
| views, preds, loss, loss_details = self.model_step(batch, self.validation_criterion) | |
| # Extract the dataset name and batch size | |
| dataset_name = views[0]['dataset'][0] # all views should have the same dataset name because we use "sequential" mode of CombinedLoader | |
| batch_size = views[0]["img"].shape[0] | |
| self.val_loss(loss) | |
| for key, value in loss_details.items(): | |
| self.log( | |
| f"val_detail_{dataset_name}_{key}", | |
| value, | |
| on_step=False, | |
| on_epoch=True, | |
| prog_bar=False, | |
| reduce_fx="mean", | |
| sync_dist=True, | |
| add_dataloader_idx=False, | |
| batch_size=batch_size, | |
| ) | |
| match = re.search(r'/(\d{1,2})$', key) | |
| if match: | |
| stripped_key = key[:match.start()] | |
| self.log(f"val/{dataset_name}_{stripped_key}", value, on_step=False, on_epoch=True, prog_bar=False, reduce_fx="mean", sync_dist=True, add_dataloader_idx=False, batch_size=batch_size) | |
| loss_value = loss.detach().cpu().item() | |
| del loss, loss_details | |
| torch.cuda.empty_cache() | |
| del views, preds | |
| torch.cuda.empty_cache() | |
| return loss_value | |
| def on_validation_epoch_end(self) -> None: | |
| self.log("val/loss", self.val_loss, prog_bar=True) | |
| # if we dont do these, wandb for some reason cannot display the validation loss with them as the x-axis | |
| self.log("trainer/epoch", self.epoch_fraction, sync_dist=True) | |
| self.log("trainer/total_samples", self.train_total_samples.cpu().item(), sync_dist=True) | |
| self.log("trainer/total_images", self.train_total_images.cpu().item(), sync_dist=True) | |
| # def test_step( | |
| # self, batch: List[Dict[str, torch.Tensor]], batch_idx: int | |
| # ) -> None: | |
| # pass | |
| def configure_optimizers(self) -> Dict[str, Any]: | |
| optimizer = self.hparams.optimizer(params=self.trainer.model.parameters()) | |
| if self.hparams.scheduler is not None: | |
| scheduler_config = self.hparams.scheduler | |
| # HACK: if the class is pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR, | |
| # both warmup_epochs and max_epochs should be scaled. | |
| # more specifically, max_epochs should be scaled to total number of steps that we will have during training, | |
| # and warmup_epochs should be scaled up proportionally. | |
| if scheduler_config.func is LinearWarmupCosineAnnealingLR: | |
| # Extract the keyword arguments from the partial object | |
| scheduler_kwargs = {k: v for k, v in scheduler_config.keywords.items()} | |
| original_warmup_epochs = scheduler_kwargs['warmup_epochs'] | |
| original_max_epochs = scheduler_kwargs['max_epochs'] | |
| total_steps = self.trainer.estimated_stepping_batches # total number of total steps in all training epochs | |
| # Scale warmup_epochs and max_epochs | |
| scaled_warmup_epochs = int(original_warmup_epochs * total_steps / original_max_epochs) | |
| scaled_max_epochs = total_steps | |
| # Update the kwargs with scaled values | |
| scheduler_kwargs.update({ | |
| 'warmup_epochs': scaled_warmup_epochs, | |
| 'max_epochs': scaled_max_epochs | |
| }) | |
| # Re-initialize the scheduler with updated parameters | |
| scheduler = LinearWarmupCosineAnnealingLR( | |
| optimizer=optimizer, | |
| **scheduler_kwargs | |
| ) | |
| else: | |
| scheduler = scheduler_config(optimizer=optimizer) | |
| return { | |
| 'optimizer': optimizer, | |
| 'lr_scheduler': { | |
| 'name': 'train/lr', # put lr inside train group in loggers | |
| 'scheduler': scheduler, | |
| 'interval': 'step' if scheduler_config.func is LinearWarmupCosineAnnealingLR else 'epoch', | |
| 'frequency': 1, | |
| } | |
| } | |
| return {"optimizer": optimizer} | |
| def setup(self, stage: str) -> None: | |
| if self.hparams.compile and stage == "fit": | |
| self.net = torch.compile(self.net) | |
| # Load pretrained weights if available and not resuming | |
| # note that if resume_from_checkpoint is set, the Trainer is responsible for actually loading the checkpoint | |
| # so we are only using resume_from_checkpoint as a check of whether we should load the pretrained weights | |
| if self.pretrained and not self.resume_from_checkpoint: | |
| self._load_pretrained_weights() | |
| def _load_pretrained_weights(self) -> None: | |
| log.info(f"Loading pretrained: {self.pretrained}") | |
| if isinstance(self.net, FlashDUSt3R): # if the model is FlashDUSt3R, use the weights of the first head only | |
| ckpt = torch.load(self.pretrained) | |
| ckpt = self._update_ckpt_keys(ckpt, new_head_name='downstream_head', head_to_keep='downstream_head1', head_to_discard='downstream_head2') | |
| self.net.load_state_dict(ckpt["model"], strict=False) | |
| del ckpt # in case it occupies memory | |
| elif isinstance(self.net, STream3R): | |
| # if the checkpoint is also STream3R, load all weights | |
| log.info(f"Loading pretrained weights from {self.pretrained}") | |
| checkpoint = torch.load(self.pretrained) | |
| missing_keys, unexpected_keys = self.net.load_state_dict(checkpoint, strict=False) | |
| log.info(f"Missing keys: {missing_keys}") | |
| log.info(f"Unexpected keys: {unexpected_keys}") | |
| def _update_ckpt_keys(ckpt, new_head_name='downstream_head', head_to_keep='downstream_head1', head_to_discard='downstream_head2'): | |
| """Helper function to use the weights of a model with multiple heads in a model with a single head. | |
| specifically, keep only the weights of the first head and delete the weights of the second head. | |
| """ | |
| new_ckpt = {'model': {}} | |
| for key, value in ckpt['model'].items(): | |
| if key.startswith(head_to_keep): | |
| new_key = key.replace(head_to_keep, new_head_name) | |
| new_ckpt['model'][new_key] = value | |
| elif key.startswith(head_to_discard): | |
| continue | |
| else: | |
| new_ckpt['model'][key] = value | |
| return new_ckpt |