File size: 5,652 Bytes
8f8716a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 | import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from osf.backbone.vit1d import vit_nano, vit_tiny, vit_small, vit_middle, vit_base
VIT_FACTORIES = {
"vit_nano": vit_nano,
"vit_tiny": vit_tiny,
"vit_small": vit_small,
"vit_middle": vit_middle,
"vit_base": vit_base,
}
class PSGModalityEncoder(nn.Module):
"""ViT encoder for PSG signals: backbone -> optional projection -> L2-norm"""
def __init__(self, *,
encoder_name: str,
proj_out: int = 256,
proj_hidden: int = 512,
freq: int = 64,
win_sec: int = 30,
channel: int = 11,
lead_wise=0,
patch_size=40,
patch_size_ch=4,
use_lead_embedding: bool = True,
is_proj_head=1):
super().__init__()
token_len = freq * win_sec
self.token_len = token_len
self.patch_size = patch_size
if encoder_name not in VIT_FACTORIES:
raise ValueError(f"Unknown encoder_name: {encoder_name}. Choose from {list(VIT_FACTORIES.keys())}")
self.backbone = VIT_FACTORIES[encoder_name](
num_leads=channel, seq_len=token_len, patch_size=patch_size,
lead_wise=lead_wise, patch_size_ch=patch_size_ch,
use_lead_embedding=use_lead_embedding,
)
d_model = self.backbone.width
if is_proj_head == 1:
self.proj_head = nn.Sequential(
nn.Linear(d_model, proj_hidden),
nn.LayerNorm(proj_hidden),
nn.ReLU(inplace=True),
nn.Linear(proj_hidden, proj_out),
nn.LayerNorm(proj_out),
)
else:
self.proj_head = None
def forward(self, x, normalize=True):
# x: [B, C, T]
h = self.backbone(x) # [B, D]
if self.proj_head is not None:
h = self.proj_head(h) # [B, proj_out]
if normalize:
return F.normalize(h, dim=-1)
return h
class BasePretrainModel(LightningModule):
def __init__(self,
psg_encoder_name: str = "vit_base",
text_encoder_name: str = "google/flan-t5-base",
fusion_decoder_name: str = 'cross-attn',
shared_emb_dim: int = 256,
lr: float = 2e-4,
weight_decay: float = 0.2,
training_steps_per_epoch: int = 7000,
max_epochs: int = 100,
*args, **kwargs):
super().__init__()
self.save_hyperparameters()
self.psg_encoder_name = psg_encoder_name
self.text_encoder_name = text_encoder_name
self.fusion_decoder_name = fusion_decoder_name
self.shared_emb_dim = shared_emb_dim
self.lr = lr
self.weight_decay = weight_decay
self.training_steps_per_epoch = training_steps_per_epoch
self.max_epochs = max_epochs
self.warmup_epochs = 0.1 * self.max_epochs
self.proj_out = shared_emb_dim
self.proj_hidden = 256
assert self.training_steps_per_epoch > 1
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
betas=(0.9, 0.95),
)
total_steps = int(self.training_steps_per_epoch * self.max_epochs)
warmup_steps = int(round(self.training_steps_per_epoch * self.warmup_epochs))
warmup_steps = max(0, warmup_steps)
decay_steps = max(1, total_steps - warmup_steps)
if warmup_steps > 0:
warmup = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=decay_steps, eta_min=1e-8)
sched = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])
else:
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=decay_steps, eta_min=1e-8)
return [optimizer], [{"scheduler": sched, "interval": "step", "frequency": 1}]
def training_step(self, batch, batch_idx):
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
for k, v in loss_dict.items():
self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
for k, v in metrics_dict.items():
self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
return loss_dict['loss']
def validation_step(self, batch, batch_idx):
with torch.no_grad():
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
for k, v in loss_dict.items():
self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
for k, v in metrics_dict.items():
self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return loss_dict
def test_step(self, batch, batch_idx):
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
for k, v in loss_dict.items():
self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
for k, v in metrics_dict.items():
self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return loss_dict
|