Spaces:
Sleeping
Sleeping
File size: 942 Bytes
c1596ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import torch
def validation_one_epoch(
encoder,
decoder,
loader,
criterion,
device,
):
encoder.eval()
decoder.eval()
with torch.no_grad():
total_loss = 0
for images, captions, _, __ in loader:
images = images.to(device) # B, 3, 224, 224
captions = captions.to(device) # B, seq_len
feature = encoder(images, return_features=True) # B, 49, 512
input_caption = captions[:, :-1] # B, seq_len-1
target_caption = captions[:, 1:] # B, seq_len-1
outputs = decoder(feature, input_caption) # B, seq_len-1, voca_size
loss = criterion(
outputs.reshape(-1, outputs.shape[-1]), # B*(seq_len-1), voca_size
target_caption.reshape(-1) # B*seq_len-1
)
total_loss += loss.item()
return total_loss / len(loader) |