| import torch | |
| from daedalus_mobile import DaedalusMobile | |
| from tokenizer import DaedalusTokenizer | |
| from config import config | |
| def train(model, device, train_loader, optimizer): | |
| model.train() | |
| total_loss = 0 | |
| for batch in train_loader: | |
| input_ids, attention_mask, labels = batch | |
| input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) | |
| optimizer.zero_grad() | |
| loss = model.train_step((input_ids, attention_mask, labels)) | |
| loss.backward() | |
| optimizer.step() | |
| total_loss += loss.item() | |
| return total_loss / len(train_loader) | |
| def main(): | |
| device = torch.device(config.device) | |
| model = DaedalusMobile(config) | |
| model.to(device) | |
| tokenizer = DaedalusTokenizer(config) | |
| train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=config.batch_size, shuffle=True) | |
| optimizer = model.configure_optimizers() | |
| for epoch in range(config.epochs): | |
| loss = train(model, device, train_loader, optimizer) | |
| print(f'Epoch {epoch+1}, Loss: {loss:.4f}') | |
| if __name__ == '__main__': | |
| main() |