Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from model.MDM_transformer import DDiTNoLengthModel
from interpolant import MDMInterpolant # replaced relative import
from schedule import get_schedule_from_config
class MaskedDiffusionModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.learning_rate = config.training.learning_rate
# Initialize model (no length head)
self.model = DDiTNoLengthModel(config)
self.model = torch.compile(self.model)
unmask_schedule = get_schedule_from_config(config.interpolant.unmask_schedule)
# Initialize interpolant
self.interpolant = MDMInterpolant(
unmask_schedule=unmask_schedule,
vocab_size=config.interpolant.tokens,
mask_token=config.interpolant.mask_token,
pad_token=config.interpolant.pad_token,
max_length=config.interpolant.max_length,
)
# Save hyperparameters
self.save_hyperparameters()
self.ema_decay = config.training.ema_decay or 0.0
self.use_ema = self.ema_decay > 0
self._orig_params = {}
def forward(self, x, t) -> torch.Tensor:
return self.model(x, t)
def training_loss(self, x1, t):
# sample interpolant and elbo weight
interpolant_result = self.interpolant.sample_interpolant(t, x1)
unmask_weight = self.interpolant.elbo_weight(t, x1)
# model prediction
predicted_logits = self(interpolant_result.xt, t)
mask_indices = interpolant_result.mask_indices
# compute unmask loss
loss = unmask_weight[mask_indices] * F.cross_entropy(
predicted_logits[mask_indices],
interpolant_result.unmasked[mask_indices],
reduction="none",
)
loss = loss.sum() / (x1.shape[0] * self.config.interpolant.max_length)
return loss
def training_step(self, batch, batch_idx):
# Extract input data
if isinstance(batch, dict):
batch = batch["input_ids"]
x1 = batch
batch_size = x1.shape[0]
t = torch.rand(batch_size, device=x1.device)
loss = self.training_loss(x1, t)
self.log("train/total_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
if isinstance(batch, dict):
batch = batch["input_ids"]
x1 = batch
batch_size = x1.shape[0]
t = torch.rand(batch_size, device=x1.device)
loss = self.training_loss(x1, t)
self.log("val_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.learning_rate,
weight_decay=self.config.training.weight_decay,
)
warmup_steps = self.config.training.warmup_steps
max_steps = self.config.training.max_steps
linear_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1e-6,
end_factor=1.0,
total_iters=warmup_steps,
)
post_warmup = max_steps - warmup_steps
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0=post_warmup // 10,
T_mult=1,
eta_min=0.0,
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[linear_scheduler, cosine_scheduler],
milestones=[warmup_steps],
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def optimizer_step(
self,
epoch: int,
batch_idx: int,
optimizer,
optimizer_closure=None,
):
super().optimizer_step(
epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure
)
# log learning rate and gradient norm
lr = optimizer.param_groups[0]["lr"]
self.log("train/lr", lr, on_step=True, prog_bar=True)
grad_norm = torch.sqrt(
sum(p.grad.norm(2) ** 2 for p in self.parameters() if p.grad is not None)
)
self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=True)
# update EMA
if self.use_ema:
for n, p in self.named_parameters():
self.ema_params[n].mul_(self.ema_decay).add_(
p.data.clone().detach(), alpha=1 - self.ema_decay
)
def on_save_checkpoint(self, checkpoint):
checkpoint["config"] = self.config
# save EMA state
if self.use_ema:
checkpoint["ema_params"] = {n: v.cpu() for n, v in self.ema_params.items()}
def on_load_checkpoint(self, checkpoint):
self.config = checkpoint["config"]
unmask_schedule = get_schedule_from_config(
self.config.interpolant.unmask_schedule
)
self.interpolant = MDMInterpolant(
unmask_schedule=unmask_schedule,
vocab_size=self.config.interpolant.tokens,
mask_token=self.config.interpolant.mask_token,
pad_token=self.config.interpolant.pad_token,
max_length=self.config.interpolant.max_length,
)
self.ema_params = checkpoint["ema_params"] if self.use_ema else {}
def swap_to_ema(self):
for name, p in self.named_parameters():
self._orig_params[name] = p.data.clone()
p.data.copy_(self.ema_params[name].to(p.device))
def restore_original(self):
for name, p in self.named_parameters():
p.data.copy_(self._orig_params[name])
self._orig_params.clear()
def on_train_start(self):
# initialize and move EMA buffers once model is on correct device
if self.use_ema:
self.ema_params = {
name: param.clone().detach().to(self.device)
for name, param in self.named_parameters()
}
for buf in self.ema_params.values():
buf.requires_grad = False