| import torch | |
| from daedalus_mobile import DaedalusMobile | |
| from tokenizer import DaedalusTokenizer | |
| from config import config | |
| def evaluate(model, device, eval_loader): | |
| model.eval() | |
| total_loss = 0 | |
| with torch.no_grad(): | |
| for batch in eval_loader: | |
| input_ids, attention_mask, labels = batch | |
| input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device) | |
| loss = model.eval_step((input_ids, attention_mask, labels)) | |
| total_loss += loss.item() | |
| return total_loss / len(eval_loader) | |
| def main(): | |
| device = torch.device(config.device) | |
| model = DaedalusMobile(config) | |
| model.to(device) | |
| tokenizer = DaedalusTokenizer(config) | |
| eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=config.batch_size, shuffle=False) | |
| loss = evaluate(model, device, eval_loader) | |
| print(f'Loss: {loss:.4f}') | |
| if __name__ == '__main__': | |
| main() |