File size: 901 Bytes
c1596ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch


def train_one_epoch(

    encoder,

    decoder,

    loader,

    criterion,

    optimizer,

    device,

    scheduler=None

    ):

    encoder.train()
    decoder.train()

    total_loss = 0
    for images, captions in loader:
        images = images.to(device)
        captions = captions.to(device)

        feature = encoder(images, return_features=True)

        input_caption = captions[:, :-1]
        target_caption = captions[:, 1:]

        outputs = decoder(feature, input_caption)

        loss = criterion(
            outputs.reshape(-1, outputs.shape[-1]),
            target_caption.reshape(-1)
        )

        if scheduler is not None:
            scheduler.step()
            
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)