| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| class DaedalusMobile(nn.Module): | |
| def __init__(self, config): | |
| super(DaedalusMobile, self).__init__() | |
| self.config = config | |
| self.encoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small') | |
| self.decoder = AutoModelForSeq2SeqLM.from_pretrained('t5-small') | |
| self.dropout = nn.Dropout(config.dropout) | |
| def forward(self, input_ids, attention_mask): | |
| encoder_output = self.encoder(input_ids, attention_mask) | |
| decoder_output = self.decoder(encoder_output.last_hidden_state, attention_mask) | |
| output = self.dropout(decoder_output.last_hidden_state) | |
| return output | |
| def configure_optimizers(self): | |
| optimizer = optim.Adam(self.parameters(), lr=self.config.lr) | |
| return optimizer | |
| def train_step(self, batch): | |
| input_ids, attention_mask, labels = batch | |
| output = self(input_ids, attention_mask) | |
| loss = nn.CrossEntropyLoss()(output, labels) | |
| return loss | |
| def eval_step(self, batch): | |
| input_ids, attention_mask, labels = batch | |
| output = self(input_ids, attention_mask) | |
| loss = nn.CrossEntropyLoss()(output, labels) | |
| return loss |