OSF-Base / osf /models /base_pretrain_model.py
ztshuaiUCLA's picture
Upload folder using huggingface_hub
8f8716a verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning import LightningModule
from osf.backbone.vit1d import vit_nano, vit_tiny, vit_small, vit_middle, vit_base
VIT_FACTORIES = {
"vit_nano": vit_nano,
"vit_tiny": vit_tiny,
"vit_small": vit_small,
"vit_middle": vit_middle,
"vit_base": vit_base,
}
class PSGModalityEncoder(nn.Module):
"""ViT encoder for PSG signals: backbone -> optional projection -> L2-norm"""
def __init__(self, *,
encoder_name: str,
proj_out: int = 256,
proj_hidden: int = 512,
freq: int = 64,
win_sec: int = 30,
channel: int = 11,
lead_wise=0,
patch_size=40,
patch_size_ch=4,
use_lead_embedding: bool = True,
is_proj_head=1):
super().__init__()
token_len = freq * win_sec
self.token_len = token_len
self.patch_size = patch_size
if encoder_name not in VIT_FACTORIES:
raise ValueError(f"Unknown encoder_name: {encoder_name}. Choose from {list(VIT_FACTORIES.keys())}")
self.backbone = VIT_FACTORIES[encoder_name](
num_leads=channel, seq_len=token_len, patch_size=patch_size,
lead_wise=lead_wise, patch_size_ch=patch_size_ch,
use_lead_embedding=use_lead_embedding,
)
d_model = self.backbone.width
if is_proj_head == 1:
self.proj_head = nn.Sequential(
nn.Linear(d_model, proj_hidden),
nn.LayerNorm(proj_hidden),
nn.ReLU(inplace=True),
nn.Linear(proj_hidden, proj_out),
nn.LayerNorm(proj_out),
)
else:
self.proj_head = None
def forward(self, x, normalize=True):
# x: [B, C, T]
h = self.backbone(x) # [B, D]
if self.proj_head is not None:
h = self.proj_head(h) # [B, proj_out]
if normalize:
return F.normalize(h, dim=-1)
return h
class BasePretrainModel(LightningModule):
def __init__(self,
psg_encoder_name: str = "vit_base",
text_encoder_name: str = "google/flan-t5-base",
fusion_decoder_name: str = 'cross-attn',
shared_emb_dim: int = 256,
lr: float = 2e-4,
weight_decay: float = 0.2,
training_steps_per_epoch: int = 7000,
max_epochs: int = 100,
*args, **kwargs):
super().__init__()
self.save_hyperparameters()
self.psg_encoder_name = psg_encoder_name
self.text_encoder_name = text_encoder_name
self.fusion_decoder_name = fusion_decoder_name
self.shared_emb_dim = shared_emb_dim
self.lr = lr
self.weight_decay = weight_decay
self.training_steps_per_epoch = training_steps_per_epoch
self.max_epochs = max_epochs
self.warmup_epochs = 0.1 * self.max_epochs
self.proj_out = shared_emb_dim
self.proj_hidden = 256
assert self.training_steps_per_epoch > 1
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.parameters(),
lr=self.lr,
weight_decay=self.weight_decay,
betas=(0.9, 0.95),
)
total_steps = int(self.training_steps_per_epoch * self.max_epochs)
warmup_steps = int(round(self.training_steps_per_epoch * self.warmup_epochs))
warmup_steps = max(0, warmup_steps)
decay_steps = max(1, total_steps - warmup_steps)
if warmup_steps > 0:
warmup = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=0.01, end_factor=1.0, total_iters=warmup_steps)
cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=decay_steps, eta_min=1e-8)
sched = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup, cosine], milestones=[warmup_steps])
else:
sched = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=decay_steps, eta_min=1e-8)
return [optimizer], [{"scheduler": sched, "interval": "step", "frequency": 1}]
def training_step(self, batch, batch_idx):
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
for k, v in loss_dict.items():
self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
for k, v in metrics_dict.items():
self.log(f"train/{k}", v, on_step=True, on_epoch=True, prog_bar=True, sync_dist=True)
return loss_dict['loss']
def validation_step(self, batch, batch_idx):
with torch.no_grad():
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
for k, v in loss_dict.items():
self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
for k, v in metrics_dict.items():
self.log(f"val/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return loss_dict
def test_step(self, batch, batch_idx):
loss_dict, metrics_dict = self.shared_step(batch, batch_idx)
for k, v in loss_dict.items():
self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
for k, v in metrics_dict.items():
self.log(f"test/{k}", v, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
return loss_dict