Spaces:
Sleeping
Sleeping
File size: 901 Bytes
c1596ac | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 | import torch
def train_one_epoch(
encoder,
decoder,
loader,
criterion,
optimizer,
device,
scheduler=None
):
encoder.train()
decoder.train()
total_loss = 0
for images, captions in loader:
images = images.to(device)
captions = captions.to(device)
feature = encoder(images, return_features=True)
input_caption = captions[:, :-1]
target_caption = captions[:, 1:]
outputs = decoder(feature, input_caption)
loss = criterion(
outputs.reshape(-1, outputs.shape[-1]),
target_caption.reshape(-1)
)
if scheduler is not None:
scheduler.step()
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss / len(loader) |