Spaces:
Sleeping
Sleeping
| from lightning.pytorch.callbacks import BaseFinetuning | |
| import torch | |
| import torch.nn as nn | |
| class FreezeClustering(BaseFinetuning): | |
| def __init__( | |
| self, | |
| ): | |
| super().__init__() | |
| # self._unfreeze_at_epoch = unfreeze_at_epoch | |
| def freeze_before_training(self, pl_module): | |
| print("freezing the following module:", pl_module) | |
| # freeze any module you want | |
| # Here, we are freezing `feature_extractor` | |
| self.freeze(pl_module.batch_norm) | |
| # self.freeze(pl_module.Dense_1) | |
| self.freeze(pl_module.gatr) | |
| # self.freeze(pl_module.postgn_dense) | |
| # self.freeze(pl_module.ScaledGooeyBatchNorm2_2) | |
| self.freeze(pl_module.clustering) | |
| self.freeze(pl_module.beta) | |
| print("CLUSTERING HAS BEEN FROOOZEN") | |
| def finetune_function(self, pl_module, current_epoch, optimizer): | |
| print("Not finetunning") | |
| # # When `current_epoch` is 10, feature_extractor will start training. | |
| # if current_epoch == self._unfreeze_at_epoch: | |
| # self.unfreeze_and_add_param_group( | |
| # modules=pl_module.feature_extractor, | |
| # optimizer=optimizer, | |
| # train_bn=True, | |
| # ) | |