Spaces:
No application file
No application file
File size: 1,681 Bytes
4f2b2f4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from model.casual_transformer import CausalDiT
class AutoregressiveModule(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.learning_rate = config.training.learning_rate
# Initialize model (causal transformer)
self.model = CausalDiT(config)
def forward(self, x):
return self.model(x)
def training_loss(self, x1):
# next token prediction loss
input_ids = x1[:, :-1]
logits = self.model(input_ids)
target_ids = x1[:, 1:]
loss = F.cross_entropy(
logits.reshape(-1, logits.shape[-1]),
target_ids.reshape(-1),
ignore_index=self.config.interpolant.pad_token,
)
return loss
def training_step(self, batch, batch_idx):
# Extract input data
if isinstance(batch, dict):
batch = batch["input_ids"]
x1 = batch
loss = self.training_loss(x1)
self.log("train/total_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
if isinstance(batch, dict):
batch = batch["input_ids"]
x1 = batch
loss = self.training_loss(x1)
self.log("val_loss", loss, prog_bar=True)
return loss
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
def on_save_checkpoint(self, checkpoint):
checkpoint["config"] = self.config
def on_load_checkpoint(self, checkpoint):
self.config = checkpoint["config"]
|