| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from pytorch_lightning import LightningModule |
| |
|
| | from osf.backbone.vit1d import vit_nano, vit_tiny, vit_small, vit_middle, vit_base |
| |
|
| | VIT_FACTORIES = { |
| | "vit_nano": vit_nano, |
| | "vit_tiny": vit_tiny, |
| | "vit_small": vit_small, |
| | "vit_middle": vit_middle, |
| | "vit_base": vit_base, |
| | } |
| |
|
| |
|
| | class PSGModalityEncoder(nn.Module): |
| | """ViT encoder for PSG signals: backbone -> optional projection -> L2-norm""" |
| |
|
| | def __init__(self, *, |
| | encoder_name: str, |
| | proj_out: int = 256, |
| | proj_hidden: int = 512, |
| | freq: int = 64, |
| | win_sec: int = 30, |
| | channel: int = 11, |
| | lead_wise=0, |
| | patch_size=40, |
| | patch_size_ch=4, |
| | use_lead_embedding: bool = True, |
| | is_proj_head=1): |
| | super().__init__() |
| | token_len = freq * win_sec |
| | self.token_len = token_len |
| | self.patch_size = patch_size |
| |
|
| | if encoder_name not in VIT_FACTORIES: |
| | raise ValueError(f"Unknown encoder_name: {encoder_name}. Choose from {list(VIT_FACTORIES.keys())}") |
| |
|
| | self.backbone = VIT_FACTORIES[encoder_name]( |
| | num_leads=channel, seq_len=token_len, patch_size=patch_size, |
| | lead_wise=lead_wise, patch_size_ch=patch_size_ch, |
| | use_lead_embedding=use_lead_embedding, |
| | ) |
| |
|
| | d_model = self.backbone.width |
| | if is_proj_head == 1: |
| | self.proj_head = nn.Sequential( |
| | nn.Linear(d_model, proj_hidden), |
| | nn.LayerNorm(proj_hidden), |
| | nn.ReLU(inplace=True), |
| | nn.Linear(proj_hidden, proj_out), |
| | nn.LayerNorm(proj_out), |
| | ) |
| | else: |
| | self.proj_head = None |
| |
|
| | def forward(self, x, normalize=True): |
| | |
| | h = self.backbone(x) |
| | if self.proj_head is not None: |
| | h = self.proj_head(h) |
| | if normalize: |
| | return F.normalize(h, dim=-1) |
| | return h |
| |
|
| |
|
| | class BasePretrainModel(LightningModule): |
| | def __init__(self, |
| | psg_encoder_name: str = "vit_base", |
| | text_encoder_name: str = "google/flan-t5-base", |
| | fusion_decoder_name: str = 'cross-attn', |
| | shared_emb_dim: int = 256, |
| | lr: float = 2e-4, |
| | weight_decay: float = 0.2, |
| | training_steps_per_epoch: int = 7000, |
| | max_epochs: int = 100, |
| | *args, **kwargs): |
| | super().__init__() |
| | self.save_hyperparameters() |
| | self.psg_encoder_name = psg_encoder_name |
| | self.text_encoder_name = text_encoder_name |
| | self.fusion_decoder_name = fusion_decoder_name |
| | self.shared_emb_dim = shared_emb_dim |
| | self.lr = lr |
| | self.weight_decay = weight_decay |
| | self.training_steps_per_epoch = training_steps_per_epoch |
| | self.max_epochs = max_epochs |
| | self.warmup_epochs = 0.1 * self.max_epochs |
| | self.proj_out = shared_emb_dim |
| | self.proj_hidden = 256 |
| |
|
| | assert self.training_steps_per_epoch > 1 |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.AdamW( |
| | self.parameters(), |
| | lr=self.lr, |
| | weight_decay=self.weight_decay, |
| | betas=(0.9, 0.95), |
| | ) |
| |
|
| | total_steps = int(self.training_steps_per_epoch * self.max_epochs) |
| | warmup_steps = int(round(self.training_steps_per_epoch * self.warmup_epochs)) |
| | warmup_steps = max(0, warmup_steps) |
| | decay_steps = max(1, total_steps - warmup_steps) |
| |
|
| | if warmup_steps > 0: |
| | warmup = torch.optim.lr_scheduler.LinearLR( |
| | optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps) |
| | cosine = torch.optim.lr_scheduler.CosineAnnealingLR( |
| | optimizer, T_max=decay_steps, eta_min=1e-8) |
| | sched = torch.optim.lr_scheduler.SequentialLR( |
| | optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps]) |
| | else: |
| | sched = torch.optim.lr_scheduler.CosineAnnealingLR( |
| | optimizer, T_max=decay_steps, eta_min=1e-8) |
| |
|
| | return [optimizer], [{"scheduler": sched, "interval": "step", "frequency": 1}] |
| |
|
| | def training_step(self, batch, batch_idx): |
| | loss_dict, metrics_dict = self.shared_step(batch, batch_idx) |
| | for k, v in loss_dict.items(): |
| | self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) |
| | for k, v in metrics_dict.items(): |
| | self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True) |
| | return loss_dict['loss'] |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | with torch.no_grad(): |
| | loss_dict, metrics_dict = self.shared_step(batch, batch_idx) |
| | for k, v in loss_dict.items(): |
| | self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| | for k, v in metrics_dict.items(): |
| | self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| | return loss_dict |
| |
|
| | def test_step(self, batch, batch_idx): |
| | loss_dict, metrics_dict = self.shared_step(batch, batch_idx) |
| | for k, v in loss_dict.items(): |
| | self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| | for k, v in metrics_dict.items(): |
| | self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True) |
| | return loss_dict |
| |
|