OSF-Base / osf /backbone /vit1d.py
ztshuaiUCLA's picture
Upload folder using huggingface_hub
8f8716a verified
"""
1D Vision Transformer for time-series signals.
Patchify modes:
- lead_wise=0: 1D patchify (all channels in one patch), no lead embedding
- lead_wise=1: 2D patchify (channel groups), with lead embedding by default
"""
import torch
import torch.nn as nn
from einops import rearrange
class DropPath(nn.Module):
def __init__(self, drop_prob: float, scale_by_keep: bool = True):
super().__init__()
self.drop_prob = drop_prob
self.scale_by_keep = scale_by_keep
def forward(self, x):
if self.drop_prob <= 0. or not self.training:
return x
keep_prob = 1 - self.drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0 and self.scale_by_keep:
random_tensor.div_(keep_prob)
return x * random_tensor
class PreNorm(nn.Module):
def __init__(self, dim: int, fn: nn.Module):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)
class FeedForward(nn.Module):
def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, drop_out_rate=0.):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Dropout(drop_out_rate),
nn.Linear(hidden_dim, output_dim),
nn.Dropout(drop_out_rate)
)
def forward(self, x):
return self.net(x)
class Attention(nn.Module):
def __init__(self, input_dim: int, output_dim: int, heads: int = 8, dim_head: int = 64,
qkv_bias: bool = True, drop_out_rate: float = 0., attn_drop_out_rate: float = 0.):
super().__init__()
inner_dim = dim_head * heads
project_out = not (heads == 1 and dim_head == input_dim)
self.heads = heads
self.scale = dim_head ** -0.5
self.attend = nn.Softmax(dim=-1)
self.dropout = nn.Dropout(attn_drop_out_rate)
self.to_qkv = nn.Linear(input_dim, inner_dim * 3, bias=qkv_bias)
if project_out:
self.to_out = nn.Sequential(nn.Linear(inner_dim, output_dim), nn.Dropout(drop_out_rate))
else:
self.to_out = nn.Identity()
def forward(self, x):
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.heads), qkv)
dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
attn = self.attend(dots)
attn = self.dropout(attn)
out = torch.matmul(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class TransformerBlock(nn.Module):
def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, heads: int = 8,
dim_head: int = 32, qkv_bias: bool = True, drop_out_rate: float = 0.,
attn_drop_out_rate: float = 0., drop_path_rate: float = 0.):
super().__init__()
attn = Attention(input_dim, output_dim, heads, dim_head, qkv_bias, drop_out_rate, attn_drop_out_rate)
self.attn = PreNorm(input_dim, attn)
self.droppath1 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
ff = FeedForward(output_dim, output_dim, hidden_dim, drop_out_rate)
self.ff = PreNorm(output_dim, ff)
self.droppath2 = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def forward(self, x):
x = self.droppath1(self.attn(x)) + x
x = self.droppath2(self.ff(x)) + x
return x
class ViT(nn.Module):
def __init__(self,
num_leads: int,
seq_len: int,
patch_size: int,
lead_wise=0,
patch_size_ch=4,
use_lead_embedding: bool = True,
width: int = 768,
depth: int = 12,
mlp_dim: int = 3072,
heads: int = 12,
dim_head: int = 64,
qkv_bias: bool = True,
drop_out_rate: float = 0.,
attn_drop_out_rate: float = 0.,
drop_path_rate: float = 0.,
**kwargs):
super().__init__()
assert seq_len % patch_size == 0
num_patches = seq_len // patch_size
self.lead_wise = lead_wise
self.use_lead_embedding = use_lead_embedding
if lead_wise == 0:
self.to_patch_embedding = nn.Conv1d(num_leads, width, kernel_size=patch_size, stride=patch_size, bias=False)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, width))
else:
self.to_patch_embedding = nn.Conv2d(1, width, kernel_size=(patch_size_ch, patch_size),
stride=(patch_size_ch, patch_size), bias=False)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches * num_leads // patch_size_ch, width))
if use_lead_embedding:
self.lead_emb = nn.Embedding(num_leads // patch_size_ch, width)
else:
self.lead_emb = None
self.dropout = nn.Dropout(drop_out_rate)
self.depth = depth
self.width = width
drop_path_rate_list = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
for i in range(depth):
block = TransformerBlock(width, width, mlp_dim, heads, dim_head, qkv_bias,
drop_out_rate, attn_drop_out_rate, drop_path_rate_list[i])
self.add_module(f'block{i}', block)
self.norm = nn.LayerNorm(width)
self.head = nn.Identity()
def _patchify_and_embed(self, series: torch.Tensor) -> torch.Tensor:
"""Patchify input and add positional/lead embeddings. [B,C,T] -> [B,N,D]"""
if self.lead_wise == 0:
x = self.to_patch_embedding(series) # [B, D, N]
x = rearrange(x, 'b c n -> b n c') # [B, N, D]
x = x + self.pos_embedding[:, :x.size(1), :].to(x.device)
else:
x = self.to_patch_embedding(series.unsqueeze(1)) # [B, D, Lr, Nt]
Lr, Nt = x.shape[-2], x.shape[-1]
x = rearrange(x, 'b c lr nt -> b (lr nt) c') # [B, N, D]
x = x + self.pos_embedding[:, :x.size(1), :].to(x.device)
if self.use_lead_embedding and self.lead_emb is not None:
row_ids = torch.arange(Lr, device=x.device).repeat_interleave(Nt)
x = x + self.lead_emb(row_ids)[None, :, :]
return x
def forward_encoding(self, series: torch.Tensor) -> torch.Tensor:
"""Encode series. Returns [B,D] (mean pooled)."""
x = self._patchify_and_embed(series)
x = self.dropout(x)
for i in range(self.depth):
x = getattr(self, f'block{i}')(x)
x = x.mean(dim=1)
return self.norm(x)
def forward(self, series):
x = self.forward_encoding(series)
return self.head(x)
def reset_head(self, num_classes=1):
del self.head
self.head = nn.Linear(self.width, num_classes)
def vit_nano(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
width=128, depth=6, heads=4, mlp_dim=512, **kwargs)
def vit_tiny(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
width=192, depth=12, heads=3, mlp_dim=768, **kwargs)
def vit_small(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
width=384, depth=12, heads=6, mlp_dim=1536, **kwargs)
def vit_middle(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
width=512, depth=12, heads=8, mlp_dim=2048, **kwargs)
def vit_base(num_leads, num_classes=1, seq_len=5000, patch_size=50, **kwargs):
return ViT(num_leads=num_leads, num_classes=num_classes, seq_len=seq_len, patch_size=patch_size,
width=768, depth=12, heads=12, mlp_dim=3072, **kwargs)