FloodDiffusion-MEI / utils /lightning_module.py
H-Liu1997's picture
Upload utils/lightning_module.py with huggingface_hub
4a1fbdb verified
import os
import time
import torch
from lightning import LightningModule
from lightning.pytorch.utilities import rank_zero_info
from torch_ema import ExponentialMovingAverage
from utils.initialize import (
compare_statedict_and_parameters,
instantiate,
print_model_size,
)
# Set tokenizers parallelism to false to avoid warnings in multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class BasicLightningModule(LightningModule):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.model = instantiate(
target=cfg.model.target, cfg=None, hfstyle=False, **cfg.model.params
)
# NOTE: ligntning init stage the device is cpu, so no need to move to device
self.ema = ExponentialMovingAverage(
self.model.parameters(), decay=cfg.model.ema_decay
)
print_model_size(self.model)
# logging
self.last_batch_end_time, self.batch_ready_time = None, None
self.validation_step_outputs = []
# metric
self.initialize_metrics()
def configure_optimizers(self):
optim_target = self.cfg.optimizer.target
if len(optim_target.split(".")) == 1:
optim_target = "torch.optim." + optim_target
optimizer = instantiate(
target=optim_target,
cfg=None,
hfstyle=False,
params=self.model.parameters(),
**self.cfg.optimizer.params,
)
scheduler_target = self.cfg.lr_scheduler.target
if len(scheduler_target.split(".")) == 1:
scheduler_target = "torch.optim.lr_scheduler." + scheduler_target
lr_scheduler = instantiate(
target=scheduler_target,
cfg=None,
hfstyle=False,
optimizer=optimizer,
**self.cfg.lr_scheduler.params,
)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": lr_scheduler,
"interval": "step",
"frequency": 1,
},
}
def load_state_dict(self, state_dict, strict=True):
pass
def on_load_checkpoint(self, checkpoint):
self.model.load_state_dict(checkpoint["state_dict"], strict=True)
if "ema_state" in checkpoint:
self.ema.load_state_dict(checkpoint["ema_state"])
rank_zero_info("init ema from ckpt")
else:
self.ema = ExponentialMovingAverage(
self.model.parameters(), decay=self.cfg.model.ema_decay
)
rank_zero_info("init ema from current model weights")
# Compare state_dict and parameters
compare_statedict_and_parameters(
state_dict=self.model.state_dict(),
named_parameters=self.model.named_parameters(),
named_buffers=self.model.named_buffers(),
)
def on_save_checkpoint(self, checkpoint):
checkpoint["ema_state"] = self.ema.state_dict()
checkpoint["state_dict"] = self.model.state_dict()
def _step(self, batch, is_training=True):
out = self.model(batch)
return out
def on_train_batch_start(self, batch, batch_idx):
self.batch_ready_time = time.time()
def training_step(self, batch, batch_idx):
net_start_time = time.time()
# forward
loss_dict = self._step(batch, is_training=True)
# logging
net_end_time = time.time()
data_time = (
self.batch_ready_time - self.last_batch_end_time
if self.last_batch_end_time is not None
else 0.0
)
net_time = net_end_time - net_start_time
batch_size = self.cfg.data.train_bs
self.log(
"lr",
self.trainer.optimizers[0].param_groups[0]["lr"],
on_step=True,
prog_bar=True,
batch_size=batch_size,
)
self.log(
"data_time", data_time, on_step=True, prog_bar=True, batch_size=batch_size
)
self.log(
"net_time", net_time, on_step=True, prog_bar=True, batch_size=batch_size
)
for key, value in loss_dict.items():
self.log(
f"train_loss/{key}",
value.item(),
on_step=True,
on_epoch=True,
prog_bar=True,
sync_dist=True,
batch_size=batch_size,
)
return loss_dict["total"]
def on_train_batch_end(self, outputs, batch, batch_idx):
self.last_batch_end_time = time.time()
self.ema.to(self.device)
self.ema.update()
# Calculate average difference using vectorized operations
if self.global_step % 100 == 0:
self.log("ema_decay", self.ema.decay, sync_dist=False)
with torch.no_grad():
model_params = torch.cat(
[p.flatten() for p in self.model.parameters() if p.requires_grad]
)
ema_params = torch.cat(
[
self.ema.shadow_params[i].flatten()
for i, (name, p) in enumerate(self.model.named_parameters())
if p.requires_grad
]
)
avg_diff = torch.abs(model_params - ema_params).mean()
self.log("ema_diff/avg", avg_diff, sync_dist=True)
# NOTE: lightning handles with torch.no_grad() and model.eval() automatically
def validation_step(self, batch, batch_idx, dataloader_idx=0):
if dataloader_idx == 1:
if self.global_step % self.cfg.validation.test_steps == 0:
self.test_step(batch, batch_idx)
else:
with self.ema.average_parameters(self.model.parameters()):
loss_dict = self._step(batch, is_training=False)
# logging
batch_size = self.cfg.data.val_bs
for key, value in loss_dict.items():
self.log(
f"val_loss/{key}",
value.item(),
on_step=False,
on_epoch=True,
sync_dist=True,
batch_size=batch_size,
)
# metrics
self.update_metrics(batch)
return
def on_validation_epoch_end(self):
if self.global_step % self.cfg.validation.test_steps == 0:
self.on_test_epoch_end()
# metrics
self.compute_metrics()
# NOTE: lightning handles with torch.no_grad() and model.eval() automatically
def test_step(self, batch, batch_idx):
self.update_test(batch)
return
def on_test_epoch_end(self):
# Only rank 0 does rendering and wandb logging
if self.trainer.global_rank == 0:
self.process_test_results()
def initialize_metrics(self):
pass
def update_metrics(self, batch):
pass
def compute_metrics(self):
pass
def update_test(self, batch):
pass
def process_test_results(self):
pass