File size: 8,376 Bytes
caf6ee7
 
906fcb9
caf6ee7
906fcb9
 
 
 
 
1baebae
906fcb9
caf6ee7
906fcb9
 
 
 
 
1baebae
caf6ee7
 
 
 
 
 
1baebae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
906fcb9
 
 
 
1baebae
 
906fcb9
 
1baebae
 
906fcb9
1baebae
 
 
 
 
 
906fcb9
1baebae
906fcb9
 
1baebae
906fcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1baebae
 
 
906fcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1baebae
906fcb9
 
 
 
 
 
 
 
 
 
1baebae
906fcb9
 
 
 
 
 
 
 
 
 
 
caf6ee7
906fcb9
 
1baebae
906fcb9
 
 
 
 
 
 
 
 
 
 
 
 
 
caf6ee7
 
906fcb9
 
 
 
 
 
 
 
 
1baebae
 
906fcb9
 
 
1baebae
 
 
906fcb9
 
 
 
 
caf6ee7
 
906fcb9
caf6ee7
906fcb9
 
1baebae
906fcb9
 
 
 
caf6ee7
 
906fcb9
 
caf6ee7
 
 
906fcb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import argparse
import logging
import time

import numpy as np
import torch
import torch.nn as nn
from monai.metrics import Cumulative, CumulativeAverage
from sklearn.metrics import cohen_kappa_score


def get_lambda_att(epoch: int, max_lambda: float = 2.0, warmup_epochs: int = 10) -> float:
    if epoch < warmup_epochs:
        return (epoch / warmup_epochs) * max_lambda
    else:
        return max_lambda


def get_attention_scores(
    data: torch.Tensor,
    target: torch.Tensor,
    heatmap: torch.Tensor,
    args: argparse.Namespace,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Compute attention scores from heatmaps and shuffle data accordingly.
    This function generates attention scores based on spatial heatmaps, applies
    sharpening, and creates shuffled versions of the input data and attention
    labels. For PI-RADS 2 (target < 1), uniform attention scores are assigned.
    Args:
        data (torch.Tensor): Input data tensor of shape (batch_size, num_patches, ...).
        target (torch.Tensor): Target labels tensor of shape (batch_size,).
        heatmap (torch.Tensor): Attention heatmap tensor corresponding to input patches.
        args: Arguments object containing device specification.
    Returns:
        tuple: A tuple containing:
            - att_labels (torch.Tensor): Sharpened and normalized attention scores
              of shape (batch_size, num_patches), moved to args.device.
            - shuffled_images (torch.Tensor): Randomly permuted data samples
              of shape (batch_size, num_patches, ...), moved to args.device.
    Note:
        - Attention scores are computed by summing heatmap values across spatial dimensions.
        - Data and attention labels are shuffled with the same permutation per sample.
        - PI-RADS 2 samples receive uniform attention distribution.
        - Attention scores are squared for sharpening and then normalized.
    """

    attention_score = torch.zeros((data.shape[0], data.shape[1]))
    for i in range(data.shape[0]):
        sample = heatmap[i]
        heatmap_patches = sample.squeeze(1)
        raw_scores = heatmap_patches.view(len(heatmap_patches), -1).sum(dim=1)
        attention_score[i] = raw_scores / raw_scores.sum()
    shuffled_images = torch.empty_like(data).to(args.device)
    att_labels = torch.empty_like(attention_score).to(args.device)
    for i in range(data.shape[0]):
        perm = torch.randperm(data.shape[1])
        shuffled_images[i] = data[i, perm]
        att_labels[i] = attention_score[i, perm]

    att_labels[torch.argwhere(target < 1)] = torch.ones_like(att_labels[0]) / len(
        att_labels[0]
    )  # For PI-RADS 2, uniform scores across patches
    att_labels = att_labels**2  # Sharpening
    att_labels = att_labels / att_labels.sum(dim=1, keepdim=True)

    return att_labels, shuffled_images


def train_epoch(model, loader, optimizer, scaler, epoch, args):
    """One train epoch over the dataset"""
    lambda_att = get_lambda_att(epoch, warmup_epochs=25)

    model.train()
    criterion = nn.CrossEntropyLoss()
    att_criterion = nn.CosineSimilarity(dim=1, eps=1e-6)

    run_att_loss = CumulativeAverage()
    run_loss = CumulativeAverage()
    run_acc = CumulativeAverage()
    batch_norm = CumulativeAverage()

    start_time = time.time()
    loss, acc = 0.0, 0.0

    for idx, batch_data in enumerate(loader):
        eps = 1e-8
        data = batch_data["image"].as_subclass(torch.Tensor)
        target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
        target = target.long()
        if args.use_heatmap:
            att_labels, shuffled_images = get_attention_scores(
                data, target, batch_data["final_heatmap"], args
            )
            att_labels = att_labels + eps
        else:
            shuffled_images = data.to(args.device)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type=str(args.device), enabled=args.amp):
            # Classification Loss
            logits_attn = model(shuffled_images, no_head=True)
            x = logits_attn.to(torch.float32)
            x = x.permute(1, 0, 2)
            x = model.transformer(x)
            x = x.permute(1, 0, 2)
            a = model.attention(x)
            a = torch.softmax(a, dim=1)
            x = torch.sum(x * a, dim=1)
            logits = model.myfc(x)
            class_loss = criterion(logits, target)
            # Attention Loss
            if args.use_heatmap:
                y = logits_attn.to(torch.float32)
                y = y.permute(1, 0, 2)
                y = model.transformer(y)
                y_detach = y.permute(1, 0, 2).detach()
                b = model.attention(y_detach)
                b = b.squeeze(-1)
                b = b + eps
                att_preds = torch.softmax(b, dim=1)
                attn_loss = 1 - att_criterion(att_preds, att_labels).mean()
                loss = class_loss + (lambda_att * attn_loss)
            else:
                loss = class_loss
                attn_loss = torch.tensor(0.0)

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=float("inf"))
        if not torch.isfinite(total_norm):
            logging.warning("Non-finite gradient norm detected, skipping batch.")
            optimizer.zero_grad()
            scaler.update()
        else:
            scaler.step(optimizer)
            scaler.update()
            batch_norm.append(total_norm)
            pred = torch.argmax(logits, dim=1)
            acc = (pred == target).sum() / len(pred)

            run_att_loss.append(attn_loss.detach().cpu())
            run_loss.append(loss.detach().cpu())
            run_acc.append(acc.detach().cpu())
            logging.info(
                f"Epoch {epoch}/{args.epochs} {idx}/{len(loader)} loss: {loss.item():.4f} attention loss: {attn_loss.item():.4f} acc: {acc:.4f} grad norm: {total_norm:.4f} time {time.time() - start_time:.2f}s"
            )
            start_time = time.time()

    del data, target, shuffled_images, logits, logits_attn
    torch.cuda.empty_cache()
    batch_norm_epoch = batch_norm.aggregate()
    attn_loss_epoch = run_att_loss.aggregate()
    loss_epoch = run_loss.aggregate()
    acc_epoch = run_acc.aggregate()
    return loss_epoch, acc_epoch, attn_loss_epoch, batch_norm_epoch


def val_epoch(model, loader, epoch, args):
    criterion = nn.CrossEntropyLoss()

    run_loss = CumulativeAverage()
    run_acc = CumulativeAverage()
    preds_cumulative = Cumulative()
    targets_cumulative = Cumulative()

    start_time = time.time()
    loss, acc = 0.0, 0.0
    model.eval()
    with torch.no_grad():
        for idx, batch_data in enumerate(loader):
            data = batch_data["image"].as_subclass(torch.Tensor).to(args.device)
            target = batch_data["label"].as_subclass(torch.Tensor).to(args.device)
            target = target.long()

            with torch.amp.autocast(device_type=str(args.device), enabled=args.amp):
                logits = model(data)
                loss = criterion(logits, target)

            data = data.to("cpu")
            target = target.to("cpu")
            logits = logits.to("cpu")
            pred = torch.argmax(logits, dim=1)
            acc = (pred == target).sum() / len(target)

            run_loss.append(loss.detach().cpu())
            run_acc.append(acc.detach().cpu())
            preds_cumulative.extend(pred.detach().cpu())
            targets_cumulative.extend(target.detach().cpu())
            logging.info(
                f"Val epoch {epoch}/{args.epochs} {idx}/{len(loader)} loss: {loss:.4f} acc: {acc:.4f} time {time.time() - start_time:.2f}s"
            )
            start_time = time.time()

            del data, target, logits
            torch.cuda.empty_cache()

        # Calculate QWK metric (Quadratic Weigted Kappa) https://en.wikipedia.org/wiki/Cohen%27s_kappa
        preds_cumulative = preds_cumulative.get_buffer().cpu().numpy()
        targets_cumulative = targets_cumulative.get_buffer().cpu().numpy()
        loss_epoch = run_loss.aggregate()
        acc_epoch = run_acc.aggregate()
        qwk = cohen_kappa_score(
            targets_cumulative.astype(np.float64), preds_cumulative.astype(np.float64)
        )
    return loss_epoch, acc_epoch, qwk