Bailan-Alex's picture
Upload folder using huggingface_hub
4f2b2f4 verified
import torch
import pytorch_lightning as pl
from omegaconf import DictConfig
import torch.nn.functional as F
from model.transformer import AnyOrderMaskInsertionFlow
from interpolant import AnyOrderMaskInsertionInterpolant, ModelPrediction
from bregman import jump_kernel_elbo, mse
from schedule import get_schedule_from_config
import re
from typing import Dict, Any
def strip_orig_mod_keys(state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""
Returns a new state_dict where any key containing '._orig_mod.' is replaced
by removing the '_orig_mod' segment, e.g.
'model._orig_mod.vocab_embed.embedding'
becomes
'model.vocab_embed.embedding'
"""
new_state_dict: Dict[str, Any] = {}
for key, value in state_dict.items():
# remove all occurrences of '._orig_mod.'
clean_key = re.sub(r"\._orig_mod\.", ".", key)
new_state_dict[clean_key] = value
return new_state_dict
class AnyOrderInsertionFlowModule(pl.LightningModule):
def __init__(self, config: DictConfig):
super().__init__()
self.config = config
self.model_type = config.interpolant.type
self.learning_rate = config.training.learning_rate
self.unmask_loss_fn = config.training.loss_fn.unmask
self.insert_loss_fn = config.training.loss_fn.insert
# Initialize model based on type
self.model = AnyOrderMaskInsertionFlow(config)
self.model = torch.compile(self.model)
insert_schedule = get_schedule_from_config(config.interpolant.insert_schedule)
unmask_schedule = get_schedule_from_config(config.interpolant.unmask_schedule)
# Initialize interpolant
self.interpolant = AnyOrderMaskInsertionInterpolant(
insertion_schedule=insert_schedule,
unmask_schedule=unmask_schedule,
vocab_size=config.interpolant.tokens,
mask_token=config.interpolant.mask_token,
pad_token=config.interpolant.pad_token,
max_length=config.interpolant.max_length,
)
# Save hyperparameters
self.save_hyperparameters()
self.ema_decay = config.training.ema_decay or 0.0
self.use_ema = self.ema_decay > 0
self._orig_params = {}
def forward(self, x, t) -> ModelPrediction:
if self.config.training.only_embed_insert:
return self.model(x, self.interpolant.insertion_schedule.at(t))
else:
return self.model(x, t)
def training_loss(self, x1, t):
interpolant_sample = self.interpolant.sample_interpolant(t, x1)
unmask_weight, insert_weight = self.interpolant.elbo_weight(t, x1)
prediction: ModelPrediction = self(interpolant_sample.xt, t)
scale_factor = x1.shape[0] * self.config.interpolant.max_length
match self.unmask_loss_fn:
case "elbo":
mask_indices = interpolant_sample.mask_indices
unmask_loss = unmask_weight[mask_indices] * F.cross_entropy(
prediction.token_logits[mask_indices],
interpolant_sample.unmasked[mask_indices],
reduction="none",
)
unmask_loss = unmask_loss.sum() / scale_factor
case _:
raise ValueError(f"Invalid unmask loss type: {self.unmask_loss_fn}")
match self.insert_loss_fn:
case "expectation":
gaps, gaps_mask = interpolant_sample.gaps_and_mask
insertion_loss = insert_weight[gaps_mask] * jump_kernel_elbo(
gaps[gaps_mask], prediction.expected_gaps[gaps_mask]
)
insertion_loss = insertion_loss.sum() / scale_factor
case "distribution":
gaps, gaps_mask = interpolant_sample.gaps_and_mask
insertion_loss = insert_weight[gaps_mask] * F.cross_entropy(
prediction.length_posterior[gaps_mask], gaps[gaps_mask]
)
insertion_loss = insertion_loss.sum() / scale_factor
total_loss = unmask_loss + insertion_loss
return unmask_loss, insertion_loss, total_loss
def sample_time(self, batch_size: int, device: torch.device) -> torch.Tensor:
eps = 1e-6
interval = 1.0 - eps
interval_size = interval / batch_size
u = torch.rand(batch_size, device=device)
return (torch.arange(batch_size, device=device, dtype=u.dtype) + u) * interval_size
def training_step(self, batch, batch_idx):
# Extract input data
if isinstance(batch, dict):
batch = batch["input_ids"]
x1 = batch
t = self.sample_time(x1.shape[0], x1.device)
# Calculate the combined loss normally
unmask_loss, len_loss, loss = self.training_loss(x1, t)
# Log component losses
self.log("train/unmask_loss", unmask_loss, prog_bar=True)
self.log("train/len_loss", len_loss, prog_bar=True)
self.log("train/total_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
if isinstance(batch, dict):
batch = batch["input_ids"]
x1 = batch
t = self.sample_time(x1.shape[0], x1.device)
unmask_loss, len_loss, loss = self.training_loss(x1, t)
self.log("val/unmask_loss", unmask_loss, prog_bar=True, sync_dist=True)
self.log("val/len_loss", len_loss, prog_bar=True, sync_dist=True)
self.log("val_loss", loss, prog_bar=True, sync_dist=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.learning_rate,
weight_decay=self.config.training.weight_decay,
)
warmup_steps = self.config.training.warmup_steps
max_steps = self.config.training.max_steps
# Always create a fresh schedule starting from step 0
# This allows extending training beyond original max_steps
linear_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=1e-6,
end_factor=1.0,
total_iters=warmup_steps,
last_epoch=-1,
)
post_warmup = max_steps - warmup_steps
cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=post_warmup,
eta_min=0.0,
last_epoch=-1,
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[linear_scheduler, cosine_scheduler],
milestones=[warmup_steps],
last_epoch=-1,
)
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
def optimizer_step(
self,
epoch: int,
batch_idx: int,
optimizer,
optimizer_closure=None,
):
super().optimizer_step(
epoch, batch_idx, optimizer, optimizer_closure=optimizer_closure
)
# log learning rate and gradient norm
lr = optimizer.param_groups[0]["lr"]
self.log("train/lr", lr, on_step=True, prog_bar=True)
grad_norm = torch.sqrt(
sum(p.grad.norm(2) ** 2 for p in self.parameters() if p.grad is not None)
)
self.log("train/grad_norm", grad_norm, on_step=True, prog_bar=True)
# update EMA
if self.use_ema:
for n, p in self.named_parameters():
self.ema_params[n].mul_(self.ema_decay).add_(
p.data.clone().detach(), alpha=1 - self.ema_decay
)
def on_save_checkpoint(self, checkpoint):
checkpoint["config"] = self.config
# save EMA state
if self.use_ema:
checkpoint["ema_params"] = {
n: v.clone() for n, v in self.ema_params.items()
}
def on_load_checkpoint(self, checkpoint):
self.config = checkpoint["config"]
insert_schedule = get_schedule_from_config(
self.config.interpolant.insert_schedule
)
unmask_schedule = get_schedule_from_config(
self.config.interpolant.unmask_schedule
)
self.interpolant = AnyOrderMaskInsertionInterpolant(
insertion_schedule=insert_schedule,
unmask_schedule=unmask_schedule,
vocab_size=self.config.interpolant.tokens,
mask_token=self.config.interpolant.mask_token,
pad_token=self.config.interpolant.pad_token,
max_length=self.config.interpolant.max_length,
)
self.ema_params = checkpoint["ema_params"] if self.use_ema else {}
def swap_to_ema(self):
for name, p in self.named_parameters():
self._orig_params[name] = p.data.clone()
p.data.copy_(self.ema_params[name].to(p.device))
def restore_original(self):
for name, p in self.named_parameters():
p.data.copy_(self._orig_params[name])
self._orig_params.clear()
def on_train_start(self):
# initialize and move EMA buffers once model is on correct device
if self.use_ema:
self.ema_params = {
name: param.clone().detach().to(self.device)
for name, param in self.named_parameters()
}
for buf in self.ema_params.values():
buf.requires_grad = False