File size: 2,497 Bytes
8b98de9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
from torch import nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
import numpy as np
from tqdm import tqdm
from typing import Dict, Tuple


from utils import barrier, reduce_mean, update_loss_info


def train(
    model: nn.Module,
    data_loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: Optimizer,
    grad_scaler: GradScaler,
    device: torch.device,
    rank: int,
    nprocs: int,
) -> Tuple[nn.Module, Optimizer, GradScaler, Dict[str, float]]:
    model.train()
    info = None
    data_iter = tqdm(data_loader) if rank == 0 else data_loader
    ddp = nprocs > 1
    regression = (model.module.bins is None) if ddp else (model.bins is None)

    for image, target_points, target_density in data_iter:
        image = image.to(device)
        target_points = [p.to(device) for p in target_points]
        target_density = target_density.to(device)
        with torch.set_grad_enabled(True):

            if grad_scaler is not None:
                with autocast(enabled=grad_scaler.is_enabled()):
                    if not regression:
                        pred_class, pred_density = model(image)
                        loss, loss_info = loss_fn(pred_class, pred_density, target_density, target_points)
                    else:
                        pred_density = model(image)
                        loss, loss_info = loss_fn(pred_density, target_density, target_points)

            else:
                if not regression:
                    pred_class, pred_density = model(image)
                    loss, loss_info = loss_fn(pred_class, pred_density, target_density, target_points)
                else:
                    pred_density = model(image)
                    loss, loss_info = loss_fn(pred_density, target_density, target_points)

        optimizer.zero_grad()
        if grad_scaler is not None:
            grad_scaler.scale(loss).backward()
            grad_scaler.step(optimizer)
            grad_scaler.update()
        else:
            loss.backward()
            optimizer.step()

        loss_info = {k: reduce_mean(v.detach(), nprocs).item() if ddp else v.detach().item() for k, v in loss_info.items()}
        # if rank == 0:
            # loss_info = {k: v.item() for k, v in loss_info.items()}
        info = update_loss_info(info, loss_info)

        barrier(ddp)

    return model, optimizer, grad_scaler, {k: np.mean(v) for k, v in info.items()}