File size: 5,652 Bytes
8f8716a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule

from osf.backbone.vit1d import vit_nano, vit_tiny, vit_small, vit_middle, vit_base

VIT_FACTORIES = {
    "vit_nano": vit_nano,
    "vit_tiny": vit_tiny,
    "vit_small": vit_small,
    "vit_middle": vit_middle,
    "vit_base": vit_base,
}


class PSGModalityEncoder(nn.Module):
    """ViT encoder for PSG signals: backbone -> optional projection -> L2-norm"""

    def __init__(self, *,
                 encoder_name: str,
                 proj_out: int = 256,
                 proj_hidden: int = 512,
                 freq: int = 64,
                 win_sec: int = 30,
                 channel: int = 11,
                 lead_wise=0,
                 patch_size=40,
                 patch_size_ch=4,
                 use_lead_embedding: bool = True,
                 is_proj_head=1):
        super().__init__()
        token_len = freq * win_sec
        self.token_len = token_len
        self.patch_size = patch_size

        if encoder_name not in VIT_FACTORIES:
            raise ValueError(f"Unknown encoder_name: {encoder_name}. Choose from {list(VIT_FACTORIES.keys())}")

        self.backbone = VIT_FACTORIES[encoder_name](
            num_leads=channel, seq_len=token_len, patch_size=patch_size,
            lead_wise=lead_wise, patch_size_ch=patch_size_ch,
            use_lead_embedding=use_lead_embedding,
        )

        d_model = self.backbone.width
        if is_proj_head == 1:
            self.proj_head = nn.Sequential(
                nn.Linear(d_model, proj_hidden),
                nn.LayerNorm(proj_hidden),
                nn.ReLU(inplace=True),
                nn.Linear(proj_hidden, proj_out),
                nn.LayerNorm(proj_out),
            )
        else:
            self.proj_head = None

    def forward(self, x, normalize=True):
        # x: [B, C, T]
        h = self.backbone(x)  # [B, D]
        if self.proj_head is not None:
            h = self.proj_head(h)  # [B, proj_out]
        if normalize:
            return F.normalize(h, dim=-1)
        return h


class BasePretrainModel(LightningModule):
    def __init__(self,
                 psg_encoder_name: str = "vit_base",
                 text_encoder_name: str = "google/flan-t5-base",
                 fusion_decoder_name: str = 'cross-attn',
                 shared_emb_dim: int = 256,
                 lr: float = 2e-4,
                 weight_decay: float = 0.2,
                 training_steps_per_epoch: int = 7000,
                 max_epochs: int = 100,
                 *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()
        self.psg_encoder_name = psg_encoder_name
        self.text_encoder_name = text_encoder_name
        self.fusion_decoder_name = fusion_decoder_name
        self.shared_emb_dim = shared_emb_dim
        self.lr = lr
        self.weight_decay = weight_decay
        self.training_steps_per_epoch = training_steps_per_epoch
        self.max_epochs = max_epochs
        self.warmup_epochs = 0.1 * self.max_epochs
        self.proj_out = shared_emb_dim
        self.proj_hidden = 256

        assert self.training_steps_per_epoch > 1

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.lr,
            weight_decay=self.weight_decay,
            betas=(0.9, 0.95),
        )

        total_steps = int(self.training_steps_per_epoch * self.max_epochs)
        warmup_steps = int(round(self.training_steps_per_epoch * self.warmup_epochs))
        warmup_steps = max(0, warmup_steps)
        decay_steps = max(1, total_steps - warmup_steps)

        if warmup_steps > 0:
            warmup = torch.optim.lr_scheduler.LinearLR(
                optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
            cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=decay_steps, eta_min=1e-8)
            sched = torch.optim.lr_scheduler.SequentialLR(
                optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])
        else:
            sched = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=decay_steps, eta_min=1e-8)

        return [optimizer], [{"scheduler": sched, "interval": "step", "frequency": 1}]

    def training_step(self, batch, batch_idx):
        loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
        for k, v in loss_dict.items():
            self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        for k, v in metrics_dict.items():
            self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss_dict['loss']

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
            loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
        for k, v in loss_dict.items():
            self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        for k, v in metrics_dict.items():
            self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss_dict

    def test_step(self, batch, batch_idx):
        loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
        for k, v in loss_dict.items():
            self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        for k, v in metrics_dict.items():
            self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss_dict