File size: 963 Bytes
0ccabc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from daedalus_mobile import DaedalusMobile
from tokenizer import DaedalusTokenizer
from config import config

def evaluate(model, device, eval_loader):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in eval_loader:
            input_ids, attention_mask, labels = batch
            input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
            loss = model.eval_step((input_ids, attention_mask, labels))
            total_loss += loss.item()
    return total_loss / len(eval_loader)

def main():
    device = torch.device(config.device)
    model = DaedalusMobile(config)
    model.to(device)
    tokenizer = DaedalusTokenizer(config)
    eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=config.batch_size, shuffle=False)
    loss = evaluate(model, device, eval_loader)
    print(f'Loss: {loss:.4f}')

if __name__ == '__main__':
    main()