File size: 942 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch


def validation_one_epoch(

    encoder,

    decoder,

    loader,

    criterion,

    device,

):

    encoder.eval()
    decoder.eval()

    with torch.no_grad():
        total_loss = 0
        for images, captions, _, __ in loader:

            images = images.to(device) # B, 3, 224, 224
            captions = captions.to(device) # B, seq_len

            feature = encoder(images, return_features=True) # B, 49, 512
            
            input_caption = captions[:, :-1] # B, seq_len-1 
            target_caption = captions[:, 1:] # B, seq_len-1

            outputs = decoder(feature, input_caption) # B, seq_len-1, voca_size

            loss = criterion(
                outputs.reshape(-1, outputs.shape[-1]), # B*(seq_len-1), voca_size
                target_caption.reshape(-1) # B*seq_len-1
            )

            total_loss += loss.item()

    return total_loss / len(loader)