daedalus_mobile / evaluate.py
BathSalt-1's picture
Create evaluate.py
0ccabc3 verified
raw
history blame contribute delete
963 Bytes
import torch
from daedalus_mobile import DaedalusMobile
from tokenizer import DaedalusTokenizer
from config import config
def evaluate(model, device, eval_loader):
model.eval()
total_loss = 0
with torch.no_grad():
for batch in eval_loader:
input_ids, attention_mask, labels = batch
input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device)
loss = model.eval_step((input_ids, attention_mask, labels))
total_loss += loss.item()
return total_loss / len(eval_loader)
def main():
device = torch.device(config.device)
model = DaedalusMobile(config)
model.to(device)
tokenizer = DaedalusTokenizer(config)
eval_loader = torch.utils.data.DataLoader(dataset=eval_dataset, batch_size=config.batch_size, shuffle=False)
loss = evaluate(model, device, eval_loader)
print(f'Loss: {loss:.4f}')
if __name__ == '__main__':
main()