| | """ |
| | Biosignals-Text CoCa Model |
| | |
| | Adapted from the original CoCa model to work with biosignals (time series) data |
| | instead of images. This model is designed for biosignals-text contrastive learning. |
| | """ |
| |
|
| | from typing import Dict, List, Optional, Union, Tuple |
| | import torch |
| | from torch import nn |
| | from torch.nn import functional as F |
| | import numpy as np |
| | import math |
| | from dataclasses import dataclass, field |
| |
|
| | from .transformer import ( |
| | LayerNormFp32, |
| | LayerNorm, |
| | QuickGELU, |
| | MultimodalTransformer, |
| | ConcatMultimodalTransformer, |
| | ) |
| | from .model import CLIPTextCfg, _build_text_tower |
| | from .coca_model import MultimodalCfg, _build_text_decoder_tower, _token_to_tensor |
| |
|
| | try: |
| | from transformers.generation.beam_search import BeamSearchScorer |
| | from transformers.generation.logits_process import ( |
| | LogitsProcessorList, |
| | TopPLogitsWarper, |
| | TopKLogitsWarper, |
| | RepetitionPenaltyLogitsProcessor, |
| | MinLengthLogitsProcessor, |
| | ) |
| | from transformers.generation.stopping_criteria import ( |
| | MaxLengthCriteria, |
| | EosTokenCriteria, |
| | StoppingCriteriaList, |
| | ) |
| |
|
| | GENERATION_TYPES = { |
| | "top_k": TopKLogitsWarper, |
| | "top_p": TopPLogitsWarper, |
| | "beam_search": "beam_search" |
| | } |
| | _has_transformers = True |
| | except ImportError as e: |
| | GENERATION_TYPES = { |
| | "top_k": None, |
| | "top_p": None, |
| | "beam_search": "beam_search" |
| | } |
| | _has_transformers = False |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class RotaryEmbedding(nn.Module): |
| | """Rotary Position Embedding (RoPE)""" |
| | def __init__(self, dim: int, theta: float = 10000.0, learned_freq: bool = False): |
| | super().__init__() |
| | self.dim = dim |
| | self.theta = theta |
| | self.learned_freq = learned_freq |
| | |
| | if learned_freq: |
| | |
| | self.freqs = nn.Parameter(torch.randn(dim // 2) * 0.02) |
| | else: |
| | |
| | freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) |
| | self.register_buffer('freqs', freqs) |
| | |
| | def rotate_queries_or_keys(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None): |
| | """ |
| | Apply rotary embeddings to queries or keys |
| | |
| | Args: |
| | x: (batch_size, num_heads, seq_len, head_dim) |
| | position_ids: (seq_len,) or (batch_size, seq_len) - position indices |
| | Returns: |
| | Rotated tensor of same shape |
| | """ |
| | batch_size, num_heads, seq_len, head_dim = x.shape |
| | assert head_dim == self.dim, f"head_dim {head_dim} != self.dim {self.dim}" |
| | |
| | |
| | if position_ids is None: |
| | position_ids = torch.arange(seq_len, device=x.device, dtype=torch.float) |
| | elif position_ids.ndim == 2: |
| | |
| | position_ids = position_ids[0].float() |
| | else: |
| | position_ids = position_ids.float() |
| | |
| | |
| | |
| | |
| | angles = torch.einsum('s,d->sd', position_ids, self.freqs) |
| | |
| | |
| | |
| | cos = torch.cos(angles).repeat_interleave(2, dim=-1) |
| | sin = torch.sin(angles).repeat_interleave(2, dim=-1) |
| | |
| | |
| | cos = cos.unsqueeze(0).unsqueeze(0) |
| | sin = sin.unsqueeze(0).unsqueeze(0) |
| | |
| | |
| | |
| | x1 = x[..., 0::2] |
| | x2 = x[..., 1::2] |
| | |
| | |
| | x_rotated = torch.empty_like(x) |
| | x_rotated[..., 0::2] = x1 * cos[..., 0::2] - x2 * sin[..., 0::2] |
| | x_rotated[..., 1::2] = x1 * sin[..., 1::2] + x2 * cos[..., 1::2] |
| | |
| | return x_rotated |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | """Root Mean Square Layer Normalization""" |
| | def __init__(self, dim: int, eps: float = 1e-6): |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(dim)) |
| |
|
| | def _norm(self, x): |
| | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
| |
|
| | def forward(self, x): |
| | output = self._norm(x.float()).type_as(x) |
| | return output * self.weight |
| |
|
| |
|
| | class SwiGLU(nn.Module): |
| | """SwiGLU activation function: SiLU(x * W1) * (x * W2)""" |
| | def __init__(self, dim_in: int, dim_out: int, bias: bool = False): |
| | super().__init__() |
| | self.w1 = nn.Linear(dim_in, dim_out, bias=bias) |
| | self.w2 = nn.Linear(dim_in, dim_out, bias=bias) |
| | |
| | def forward(self, x): |
| | return F.silu(self.w1(x)) * self.w2(x) |
| |
|
| |
|
| | class MLP(nn.Module): |
| | """MLP with configurable activation and normalization""" |
| | def __init__(self, |
| | dim: int, |
| | hidden_dim: int, |
| | dropout: float = 0.0, |
| | activation: str = "swiglu", |
| | bias: bool = False): |
| | super().__init__() |
| | self.activation = activation |
| | |
| | if activation == "swiglu": |
| | |
| | self.gate_proj = SwiGLU(dim, hidden_dim, bias=bias) |
| | self.down_proj = nn.Linear(hidden_dim, dim, bias=bias) |
| | else: |
| | |
| | self.up_proj = nn.Linear(dim, hidden_dim, bias=bias) |
| | self.down_proj = nn.Linear(hidden_dim, dim, bias=bias) |
| | |
| | if activation == "gelu": |
| | self.act_fn = nn.GELU() |
| | elif activation == "relu": |
| | self.act_fn = nn.ReLU() |
| | else: |
| | raise ValueError(f"Unknown activation: {activation}") |
| | |
| | self.dropout = nn.Dropout(dropout) |
| | |
| | def forward(self, x): |
| | if self.activation == "swiglu": |
| | x = self.gate_proj(x) |
| | x = self.dropout(x) |
| | x = self.down_proj(x) |
| | else: |
| | x = self.up_proj(x) |
| | x = self.act_fn(x) |
| | x = self.dropout(x) |
| | x = self.down_proj(x) |
| | |
| | return self.dropout(x) |
| |
|
| |
|
| | class ChannelPatching(nn.Module): |
| | """Patching layer that operates independently on each channel""" |
| | def __init__(self, |
| | patch_size: int = 32, |
| | conv_embed_dim: int = 256, |
| | num_channels: int = 21): |
| | super().__init__() |
| | self.patch_size = patch_size |
| | self.conv_embed_dim = conv_embed_dim |
| | self.num_channels = num_channels |
| | |
| | |
| | self.conv_patching = nn.Conv1d( |
| | in_channels=1, |
| | out_channels=conv_embed_dim, |
| | kernel_size=patch_size, |
| | stride=patch_size, |
| | padding=0 |
| | ) |
| | |
| | def forward(self, x): |
| | """ |
| | Args: |
| | x: (batch_size, num_channels, signal_length) - multi-channel signal |
| | Returns: |
| | (batch_size, num_channels, num_patches, conv_embed_dim) - patched representations |
| | """ |
| | batch_size, num_channels, seq_len = x.shape |
| | |
| | |
| | x_reshaped = x.reshape(batch_size * num_channels, 1, seq_len) |
| | |
| | |
| | patched = self.conv_patching(x_reshaped) |
| | |
| | |
| | _, conv_embed_dim, num_patches = patched.shape |
| | patched = patched.reshape(batch_size, num_channels, conv_embed_dim, num_patches) |
| | |
| | |
| | patched = patched.transpose(2, 3) |
| | |
| | return patched |
| |
|
| |
|
| | class DualRoPEAttention(nn.Module): |
| | """Multi-head attention with separate RoPE for temporal and learnable RoPE for channels""" |
| | def __init__(self, |
| | embed_dim: int = 256, |
| | num_heads: int = 8, |
| | dropout: float = 0.1, |
| | attention_type: str = "temporal", |
| | num_channels: int = 21, |
| | shared_channel_rope: Optional[nn.Module] = None): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.num_heads = num_heads |
| | self.head_dim = embed_dim // num_heads |
| | self.attention_type = attention_type |
| | |
| | assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" |
| | |
| | |
| | self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
| | self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
| | self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) |
| | self.out_proj = nn.Linear(embed_dim, embed_dim) |
| | |
| | |
| | if attention_type == "temporal": |
| | |
| | self.rotary_emb = RotaryEmbedding( |
| | dim=self.head_dim, |
| | theta=10000, |
| | learned_freq=False |
| | ) |
| | elif attention_type == "channel": |
| | |
| | if shared_channel_rope is not None: |
| | self.rotary_emb = shared_channel_rope |
| | else: |
| | |
| | self.rotary_emb = RotaryEmbedding( |
| | dim=self.head_dim, |
| | theta=10000, |
| | learned_freq=True |
| | ) |
| | else: |
| | raise ValueError(f"Unknown attention_type: {attention_type}") |
| | |
| | self.dropout = nn.Dropout(dropout) |
| | self.scale = self.head_dim ** -0.5 |
| | |
| | def forward(self, x, position_ids=None): |
| | """ |
| | Args: |
| | x: (batch_size, seq_len, embed_dim) |
| | position_ids: (batch_size, seq_len) or (seq_len,) - custom position indices for RoPE |
| | Returns: |
| | (batch_size, seq_len, embed_dim) |
| | """ |
| | batch_size, seq_len, embed_dim = x.shape |
| | |
| | |
| | q = self.q_proj(x) |
| | k = self.k_proj(x) |
| | v = self.v_proj(x) |
| | |
| | |
| | q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| | |
| | |
| | q = self.rotary_emb.rotate_queries_or_keys(q, position_ids=position_ids) |
| | k = self.rotary_emb.rotate_queries_or_keys(k, position_ids=position_ids) |
| | |
| | |
| | attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| | attn_weights = F.softmax(attn_weights, dim=-1) |
| | attn_weights = self.dropout(attn_weights) |
| | |
| | |
| | attn_output = torch.matmul(attn_weights, v) |
| | |
| | |
| | attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) |
| | output = self.out_proj(attn_output) |
| | |
| | return output |
| |
|
| |
|
| | class DualTransformerBlock(nn.Module): |
| | """Biosignal transformer block with channel and temporal attention using dual RoPE""" |
| | def __init__(self, |
| | embed_dim: int = 256, |
| | num_heads: int = 8, |
| | num_temporal_layers: int = 2, |
| | dropout: float = 0.1, |
| | mlp_ratio: float = 4.0, |
| | num_channels: int = 21, |
| | activation: str = "swiglu", |
| | norm_type: str = "rmsnorm", |
| | mlp_bias: bool = False, |
| | shared_channel_rope: Optional[nn.Module] = None): |
| | super().__init__() |
| | self.embed_dim = embed_dim |
| | self.num_temporal_layers = num_temporal_layers |
| | |
| | |
| | def create_norm(dim): |
| | if norm_type == "rmsnorm": |
| | return RMSNorm(dim) |
| | elif norm_type == "layernorm": |
| | return nn.LayerNorm(dim) |
| | else: |
| | raise ValueError(f"Unknown norm_type: {norm_type}") |
| | |
| | |
| | self.channel_attention = DualRoPEAttention( |
| | embed_dim, num_heads, dropout, |
| | attention_type="channel", num_channels=num_channels, |
| | shared_channel_rope=shared_channel_rope |
| | ) |
| | self.channel_norm = create_norm(embed_dim) |
| | |
| | |
| | self.temporal_attention_layers = nn.ModuleList([ |
| | DualRoPEAttention(embed_dim, num_heads, dropout, attention_type="temporal") |
| | for _ in range(num_temporal_layers) |
| | ]) |
| | self.temporal_norms = nn.ModuleList([ |
| | create_norm(embed_dim) |
| | for _ in range(num_temporal_layers) |
| | ]) |
| | |
| | |
| | mlp_hidden_dim = int(embed_dim * mlp_ratio) |
| | self.channel_mlp = MLP( |
| | dim=embed_dim, |
| | hidden_dim=mlp_hidden_dim, |
| | dropout=dropout, |
| | activation=activation, |
| | bias=mlp_bias |
| | ) |
| | |
| | self.temporal_mlps = nn.ModuleList([ |
| | MLP( |
| | dim=embed_dim, |
| | hidden_dim=mlp_hidden_dim, |
| | dropout=dropout, |
| | activation=activation, |
| | bias=mlp_bias |
| | ) for _ in range(num_temporal_layers) |
| | ]) |
| | |
| | self.channel_mlp_norm = create_norm(embed_dim) |
| | self.temporal_mlp_norms = nn.ModuleList([ |
| | create_norm(embed_dim) |
| | for _ in range(num_temporal_layers) |
| | ]) |
| | |
| | def forward(self, x, temporal_position_ids=None): |
| | """ |
| | Args: |
| | x: (batch_size, num_channels, num_patches, embed_dim) |
| | temporal_position_ids: (batch_size, num_patches) or (num_patches,) - position indices for temporal RoPE |
| | Returns: |
| | (batch_size, num_channels, num_patches, embed_dim) |
| | """ |
| | batch_size, num_channels, num_patches, embed_dim = x.shape |
| | |
| | |
| | x_for_channel_attn = x.permute(0, 2, 1, 3).contiguous().reshape(batch_size * num_patches, num_channels, embed_dim) |
| | |
| | |
| | channel_attn_out = self.channel_attention(x_for_channel_attn) |
| | |
| | |
| | x_for_channel_attn = self.channel_norm(x_for_channel_attn + channel_attn_out) |
| | |
| | |
| | channel_mlp_out = self.channel_mlp(x_for_channel_attn) |
| | x_for_channel_attn = self.channel_mlp_norm(x_for_channel_attn + channel_mlp_out) |
| | |
| | |
| | x = x_for_channel_attn.reshape(batch_size, num_patches, num_channels, embed_dim).permute(0, 2, 1, 3) |
| | |
| | |
| | x_for_temporal_attn = x.reshape(batch_size * num_channels, num_patches, embed_dim) |
| | |
| | |
| | if temporal_position_ids is not None: |
| | if temporal_position_ids.ndim == 2: |
| | temporal_pos_ids_expanded = temporal_position_ids[0] |
| | else: |
| | temporal_pos_ids_expanded = temporal_position_ids |
| | else: |
| | temporal_pos_ids_expanded = None |
| | |
| | |
| | for i in range(self.num_temporal_layers): |
| | temporal_attn_out = self.temporal_attention_layers[i](x_for_temporal_attn, position_ids=temporal_pos_ids_expanded) |
| | x_for_temporal_attn = self.temporal_norms[i](x_for_temporal_attn + temporal_attn_out) |
| | |
| | temporal_mlp_out = self.temporal_mlps[i](x_for_temporal_attn) |
| | x_for_temporal_attn = self.temporal_mlp_norms[i](x_for_temporal_attn + temporal_mlp_out) |
| | |
| | |
| | x = x_for_temporal_attn.reshape(batch_size, num_channels, num_patches, embed_dim) |
| | |
| | return x |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def _build_signal_tower( |
| | embed_dim: int, |
| | signal_cfg, |
| | output_tokens: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | ): |
| | """Build a biosignals encoder tower |
| | |
| | Args: |
| | embed_dim: Output embedding dimension |
| | signal_cfg: BiosignalsCfg or dict with configuration |
| | output_tokens: Whether to output tokens for multimodal decoder |
| | cast_dtype: Optional dtype for casting |
| | |
| | Returns: |
| | Biosignals encoder (either BiosignalsEncoder or PureTransformerBiosignalsEncoder) |
| | """ |
| | if isinstance(signal_cfg, dict): |
| | signal_cfg = BiosignalsCfg(**signal_cfg) |
| | |
| | import logging |
| | architecture = getattr(signal_cfg, 'architecture', 'conv_transformer') |
| | logging.info(f"Building biosignals encoder with architecture: {architecture}") |
| | |
| | if architecture == "pure_transformer": |
| | signal_encoder = PureTransformerBiosignalsEncoder( |
| | biosignals_cfg=signal_cfg, |
| | embed_dim=embed_dim, |
| | output_tokens=output_tokens, |
| | cast_dtype=cast_dtype |
| | ) |
| | logging.info(f"Pure Transformer architecture:") |
| | logging.info(f" Patch size: {signal_cfg.patch_size}") |
| | logging.info(f" Conv embed dim: {signal_cfg.conv_embed_dim}") |
| | logging.info(f" Transformer blocks: {signal_cfg.transformer_layers}") |
| | logging.info(f" Temporal layers per block: {signal_cfg.num_temporal_layers}") |
| | logging.info(f" Activation: {signal_cfg.activation}") |
| | logging.info(f" Norm type: {signal_cfg.norm_type}") |
| | logging.info(f" Share channel RoPE: {signal_cfg.share_channel_rope}") |
| | elif architecture == "conv_transformer": |
| | signal_encoder = BiosignalsEncoder( |
| | biosignals_cfg=signal_cfg, |
| | embed_dim=embed_dim, |
| | output_tokens=output_tokens, |
| | cast_dtype=cast_dtype |
| | ) |
| | logging.info(f"Conv-Transformer architecture:") |
| | logging.info(f" Conv layers: {signal_cfg.conv_layers}") |
| | logging.info(f" Kernel sizes: {signal_cfg.kernel_sizes}") |
| | logging.info(f" Strides: {signal_cfg.strides}") |
| | logging.info(f" Transformer layers: {signal_cfg.transformer_layers}") |
| | else: |
| | raise ValueError(f"Unknown architecture: {architecture}. Must be 'conv_transformer' or 'pure_transformer'") |
| | |
| | return signal_encoder |
| |
|
| |
|
| | def _build_text_decoder_tower_v2( |
| | embed_dim, |
| | multimodal_cfg, |
| | quick_gelu: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | decoder_type: str = "cross_attention", |
| | prefix_len: int = 0, |
| | ): |
| | """Build text decoder tower with support for different decoder types. |
| | |
| | Args: |
| | embed_dim: Embedding dimension |
| | multimodal_cfg: MultimodalCfg config |
| | quick_gelu: Whether to use QuickGELU |
| | cast_dtype: Optional dtype for casting |
| | decoder_type: "cross_attention" or "concat" |
| | - "cross_attention": Uses separate cross-attention layers (default CoCa) |
| | - "concat": Concatenates image/biosignals and text tokens |
| | prefix_len: Number of prefix tokens (condition embeddings) prepended to text |
| | Used to pre-build prefix-causal attention mask |
| | """ |
| | multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| | act_layer = QuickGELU if quick_gelu else nn.GELU |
| | norm_layer = ( |
| | LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| | ) |
| |
|
| | if decoder_type == "cross_attention": |
| | decoder = MultimodalTransformer( |
| | context_length=multimodal_cfg.context_length, |
| | width=multimodal_cfg.width, |
| | heads=multimodal_cfg.heads, |
| | layers=multimodal_cfg.layers, |
| | mlp_ratio=multimodal_cfg.mlp_ratio, |
| | ls_init_value=multimodal_cfg.ls_init_value, |
| | output_dim=embed_dim, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | prefix_len=prefix_len, |
| | ) |
| | elif decoder_type == "concat": |
| | decoder = ConcatMultimodalTransformer( |
| | context_length=multimodal_cfg.context_length, |
| | width=multimodal_cfg.width, |
| | heads=multimodal_cfg.heads, |
| | layers=multimodal_cfg.layers, |
| | mlp_ratio=multimodal_cfg.mlp_ratio, |
| | ls_init_value=multimodal_cfg.ls_init_value, |
| | output_dim=embed_dim, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | prefix_len=prefix_len, |
| | ) |
| | else: |
| | raise ValueError(f"Unknown decoder_type: {decoder_type}. Must be 'cross_attention' or 'concat'") |
| |
|
| | return decoder |
| |
|
| |
|
| | @dataclass |
| | class BiosignalsCfg: |
| | """Configuration for biosignals encoder""" |
| | input_channels: int = 12 |
| | signal_length: int = 1000 |
| | sampling_rate: int = 500 |
| | |
| | |
| | architecture: str = "conv_transformer" |
| | |
| | |
| | conv_layers: List[int] = None |
| | kernel_sizes: List[int] = None |
| | strides: List[int] = None |
| | |
| | |
| | patch_size: int = 32 |
| | conv_embed_dim: int = 256 |
| | num_temporal_layers: int = 2 |
| | activation: str = "swiglu" |
| | norm_type: str = "rmsnorm" |
| | mlp_bias: bool = False |
| | share_channel_rope: bool = True |
| | decoder_tokens: int = 32 |
| | |
| | |
| | transformer_layers: int = 6 |
| | transformer_width: int = 768 |
| | transformer_heads: int = 12 |
| | mlp_ratio: float = 4.0 |
| | |
| | |
| | pool_type: str = 'attn' |
| | dropout: float = 0.1 |
| | |
| | def __post_init__(self): |
| | if self.architecture == "conv_transformer": |
| | if self.conv_layers is None: |
| | |
| | self.conv_layers = [64, 128, 256, 512] |
| | if self.kernel_sizes is None: |
| | |
| | self.kernel_sizes = [7, 5, 3, 3] |
| | if self.strides is None: |
| | |
| | self.strides = [2, 2, 2, 2] |
| |
|
| |
|
| | class BaseBiosignalsEncoder(nn.Module): |
| | """ |
| | Base class for biosignals encoders that handles common pooling and projection logic. |
| | Child classes should implement _encode() to return features before pooling. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | biosignals_cfg: BiosignalsCfg, |
| | embed_dim: int, |
| | output_tokens: bool, |
| | transformer_width: int, |
| | cast_dtype: Optional[torch.dtype] = None |
| | ): |
| | super().__init__() |
| | self.biosignals_cfg = biosignals_cfg |
| | self.embed_dim = embed_dim |
| | self.output_tokens = output_tokens |
| | self.transformer_width = transformer_width |
| | self.pool_type = biosignals_cfg.pool_type |
| | |
| | |
| | self.proj_to_embed = nn.Linear(transformer_width, embed_dim) |
| | |
| | |
| | if self.pool_type == 'attn': |
| | self.attn_pool = nn.MultiheadAttention( |
| | transformer_width, |
| | biosignals_cfg.transformer_heads, |
| | batch_first=True |
| | ) |
| | |
| | def _pool_features(self, x: torch.Tensor, has_cls_token: bool) -> torch.Tensor: |
| | """ |
| | Pool features using the configured pooling method. |
| | |
| | Args: |
| | x: Features of shape (batch_size, seq_len, width) |
| | has_cls_token: Whether the sequence includes a CLS token at the last position |
| | |
| | Returns: |
| | pooled: Pooled features of shape (batch_size, width) |
| | """ |
| | if self.pool_type == 'cls': |
| | |
| | pooled = x[:, -1] |
| | elif self.pool_type == 'avg': |
| | |
| | if has_cls_token: |
| | pooled = x[:, :-1].mean(dim=1) |
| | else: |
| | pooled = x.mean(dim=1) |
| | elif self.pool_type == 'max': |
| | |
| | if has_cls_token: |
| | pooled = x[:, :-1].max(dim=1)[0] |
| | else: |
| | pooled = x.max(dim=1)[0] |
| | elif self.pool_type == 'attn': |
| | |
| | query = x[:, -1:] |
| | |
| | pooled, _ = self.attn_pool(query, x[:, :-1], x[:, :-1]) |
| | pooled = pooled.squeeze(1) |
| | else: |
| | raise ValueError(f"Unknown pool_type: {self.pool_type}") |
| | |
| | return pooled |
| | |
| | def _encode(self, biosignals: torch.Tensor) -> Tuple[torch.Tensor, bool]: |
| | """ |
| | Encode biosignals to features. Must be implemented by child classes. |
| | |
| | Args: |
| | biosignals: Input biosignals tensor |
| | |
| | Returns: |
| | features: Encoded features of shape (batch_size, seq_len, transformer_width) |
| | has_cls_token: Whether the sequence includes a CLS token at the last position |
| | """ |
| | raise NotImplementedError("Child classes must implement _encode()") |
| | |
| | def forward(self, biosignals: torch.Tensor): |
| | """ |
| | Forward pass with encoding, pooling, and projection. |
| | |
| | Args: |
| | biosignals: Input biosignals tensor |
| | |
| | Returns: |
| | embedding: Global embedding (batch_size, embed_dim) |
| | tokens_for_decoder: Optional tokens for decoder (batch_size, seq_len, transformer_width) |
| | """ |
| | |
| | features, has_cls_token = self._encode(biosignals) |
| | |
| | |
| | pooled = self._pool_features(features, has_cls_token) |
| | |
| | |
| | embedding = self.proj_to_embed(pooled) |
| | |
| | if self.output_tokens: |
| | |
| | if has_cls_token: |
| | |
| | tokens_for_decoder = features[:, :-1] |
| | else: |
| | tokens_for_decoder = features |
| | return embedding, tokens_for_decoder |
| | else: |
| | return embedding |
| | |
| | def set_grad_checkpointing(self, enable=True): |
| | |
| | pass |
| |
|
| |
|
| | class Conv1dBlock(nn.Module): |
| | """1D Convolutional block with normalization and activation""" |
| | |
| | def __init__(self, in_channels, out_channels, kernel_size, stride=1, |
| | norm_layer=nn.BatchNorm1d, act_layer=nn.ReLU): |
| | super().__init__() |
| | self.conv = nn.Conv1d( |
| | in_channels, out_channels, kernel_size, |
| | stride=stride, padding=kernel_size//2 |
| | ) |
| | self.norm = norm_layer(out_channels) |
| | self.act = act_layer() |
| | self.dropout = nn.Dropout(0.1) |
| | |
| | def forward(self, x): |
| | x = self.conv(x) |
| | x = self.norm(x) |
| | x = self.act(x) |
| | x = self.dropout(x) |
| | return x |
| |
|
| |
|
| | class BiosignalsEncoder(BaseBiosignalsEncoder): |
| | """ |
| | Biosignals encoder that converts time series data to embeddings. |
| | Uses a combination of 1D convolutions and transformers. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | biosignals_cfg: BiosignalsCfg, |
| | embed_dim: int = 512, |
| | output_tokens: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None |
| | ): |
| | |
| | super().__init__( |
| | biosignals_cfg=biosignals_cfg, |
| | embed_dim=embed_dim, |
| | output_tokens=output_tokens, |
| | transformer_width=biosignals_cfg.transformer_width, |
| | cast_dtype=cast_dtype |
| | ) |
| | |
| | |
| | conv_layers = [] |
| | in_channels = biosignals_cfg.input_channels |
| | |
| | for i, (out_channels, kernel_size, stride) in enumerate( |
| | zip(biosignals_cfg.conv_layers, biosignals_cfg.kernel_sizes, biosignals_cfg.strides) |
| | ): |
| | conv_layers.append( |
| | Conv1dBlock(in_channels, out_channels, kernel_size, stride) |
| | ) |
| | in_channels = out_channels |
| | |
| | self.conv_layers = nn.Sequential(*conv_layers) |
| | |
| | |
| | |
| | with torch.no_grad(): |
| | dummy_input = torch.randn(1, biosignals_cfg.input_channels, biosignals_cfg.signal_length) |
| | dummy_output = self.conv_layers(dummy_input) |
| | conv_output_length = dummy_output.shape[2] |
| | |
| | self.conv_output_length = conv_output_length |
| | self.conv_output_dim = biosignals_cfg.conv_layers[-1] |
| | |
| | |
| | self.proj_conv_to_transformer = nn.Linear( |
| | self.conv_output_dim, biosignals_cfg.transformer_width |
| | ) |
| | |
| | |
| | |
| | self.pos_embed = nn.Parameter( |
| | torch.randn(1, conv_output_length, biosignals_cfg.transformer_width) |
| | ) |
| | |
| | |
| | self.cls_token = nn.Parameter( |
| | torch.randn(1, 1, biosignals_cfg.transformer_width) |
| | ) |
| | |
| | |
| | norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| | act_layer = QuickGELU |
| | |
| | self.transformer_layers = nn.ModuleList([ |
| | TransformerBlock( |
| | biosignals_cfg.transformer_width, |
| | biosignals_cfg.transformer_heads, |
| | biosignals_cfg.mlp_ratio, |
| | act_layer=act_layer, |
| | norm_layer=norm_layer, |
| | dropout=biosignals_cfg.dropout |
| | ) |
| | for _ in range(biosignals_cfg.transformer_layers) |
| | ]) |
| | |
| | |
| | self.ln_final = norm_layer(biosignals_cfg.transformer_width) |
| | |
| | def _encode(self, biosignals): |
| | """ |
| | Encode biosignals to features before pooling. |
| | |
| | Args: |
| | biosignals: Tensor of shape (batch_size, channels, signal_length) |
| | Returns: |
| | features: Encoded features of shape (batch_size, seq_len, transformer_width) |
| | has_cls_token: Whether the sequence includes a CLS token at the last position |
| | """ |
| | batch_size = biosignals.shape[0] |
| | |
| | |
| | x = self.conv_layers(biosignals) |
| | |
| | |
| | x = x.transpose(1, 2) |
| | |
| | |
| | x = self.proj_conv_to_transformer(x) |
| | |
| | |
| | x = x + self.pos_embed |
| | |
| | |
| | |
| | if self.pool_type in ['cls', 'attn']: |
| | cls_tokens = self.cls_token.expand(batch_size, -1, -1) |
| | x = torch.cat([x, cls_tokens], dim=1) |
| | has_cls_token = True |
| | else: |
| | has_cls_token = False |
| | |
| | |
| | for layer in self.transformer_layers: |
| | x = layer(x) |
| | |
| | |
| | x = self.ln_final(x) |
| | |
| | return x, has_cls_token |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | """Transformer block with self-attention and MLP""" |
| | |
| | def __init__( |
| | self, |
| | width: int, |
| | heads: int, |
| | mlp_ratio: float = 4.0, |
| | act_layer=QuickGELU, |
| | norm_layer=LayerNorm, |
| | dropout: float = 0.1 |
| | ): |
| | super().__init__() |
| | self.attention = nn.MultiheadAttention(width, heads, dropout=dropout, batch_first=True) |
| | self.ln_1 = norm_layer(width) |
| | self.mlp = nn.Sequential( |
| | nn.Linear(width, int(width * mlp_ratio)), |
| | act_layer(), |
| | nn.Dropout(dropout), |
| | nn.Linear(int(width * mlp_ratio), width), |
| | nn.Dropout(dropout) |
| | ) |
| | self.ln_2 = norm_layer(width) |
| | |
| | def forward(self, x): |
| | |
| | attn_out, _ = self.attention(x, x, x) |
| | x = x + attn_out |
| | x = self.ln_1(x) |
| | |
| | |
| | mlp_out = self.mlp(x) |
| | x = x + mlp_out |
| | x = self.ln_2(x) |
| | |
| | return x |
| |
|
| |
|
| | class AttnPooler(nn.Module): |
| | """ |
| | CoCa-style attentional pooler. |
| | A small multi-head attention layer with n_query learned queries (Q), |
| | and the encoder sequence as both K and V. This lets us: |
| | - n_query = 1 => global embedding for contrastive loss |
| | - n_query = N => compressed token set for decoder cross-attention |
| | Ref: CoCa uses task-specific attentional pooling with nquery=1 for contrastive |
| | and nquery=256 for generative objectives. [oai_citation:2‡Medium](https://medium.com/%40arithmancylabs/coca-contrastive-captioners-are-image-textfoundation-models-324022377630?utm_source=chatgpt.com) |
| | """ |
| | def __init__(self, dim: int, num_heads: int, n_query: int): |
| | super().__init__() |
| | self.n_query = n_query |
| | self.query_tokens = nn.Parameter(torch.randn(1, n_query, dim) * 0.02) |
| | self.attn = nn.MultiheadAttention( |
| | embed_dim=dim, |
| | num_heads=num_heads, |
| | batch_first=True |
| | ) |
| |
|
| | def forward(self, x_seq: torch.Tensor) -> torch.Tensor: |
| | """ |
| | x_seq: (B, L, D) |
| | returns: |
| | pooled: (B, n_query, D) |
| | """ |
| | B = x_seq.size(0) |
| | q = self.query_tokens.expand(B, -1, -1) |
| | pooled, _ = self.attn(q, x_seq, x_seq) |
| | return pooled |
| |
|
| |
|
| | class PureTransformerBiosignalsEncoder(BaseBiosignalsEncoder): |
| | """ |
| | Pure Transformer encoder for biosignals with channel+temporal attention. |
| | |
| | Updated to use CoCa-style task-specific attentional pooling: |
| | - contrastive_pooler (n_query=1) → 1 global token for contrastive / CLS |
| | - decoder_pooler (n_query=N_dec) → small set of summary tokens for text decoder |
| | |
| | We still: |
| | 1. Patch each channel independently |
| | 2. Alternate channel-attn and temporal-attn in DualTransformerBlocks (factorized attention) |
| | 3. Keep (B, C, T, D) internally (cheap attention along channel or time separately) |
| | 4. Flatten to (B, C*T, D) only at the end |
| | 5. Run two poolers: |
| | - 1-query pooler -> global token |
| | - multi-query pooler -> decoder tokens |
| | 6. Append the 1-query pooled token to the end of x_seq so BaseBiosignalsEncoder |
| | can keep using pool_type='cls' or 'attn' the same way. |
| | 7. Save the multi-query pooled tokens so, when output_tokens=True, we can hand |
| | them to the text decoder instead of the full ~C*T sequence. |
| | |
| | This mirrors CoCa's "task-specific attentional pooling," where the same encoder |
| | supports both contrastive global alignment and caption-style generation with |
| | minimal extra cost. [oai_citation:3‡Medium](https://medium.com/%40arithmancylabs/coca-contrastive-captioners-are-image-textfoundation-models-324022377630?utm_source=chatgpt.com) |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | biosignals_cfg: BiosignalsCfg, |
| | embed_dim: int = 512, |
| | output_tokens: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None |
| | ): |
| | super().__init__( |
| | biosignals_cfg=biosignals_cfg, |
| | embed_dim=embed_dim, |
| | output_tokens=output_tokens, |
| | transformer_width=biosignals_cfg.transformer_width, |
| | cast_dtype=cast_dtype |
| | ) |
| |
|
| | |
| | assert biosignals_cfg.transformer_width % biosignals_cfg.transformer_heads == 0, ( |
| | f"transformer_width ({biosignals_cfg.transformer_width}) must be divisible by " |
| | f"transformer_heads ({biosignals_cfg.transformer_heads})" |
| | ) |
| | head_dim = biosignals_cfg.transformer_width // biosignals_cfg.transformer_heads |
| | assert head_dim % 2 == 0, ( |
| | f"head_dim ({head_dim}) must be even for RoPE. " |
| | f"Got transformer_width={biosignals_cfg.transformer_width}, " |
| | f"transformer_heads={biosignals_cfg.transformer_heads}" |
| | ) |
| |
|
| | |
| | self.patching = ChannelPatching( |
| | patch_size=biosignals_cfg.patch_size, |
| | conv_embed_dim=biosignals_cfg.conv_embed_dim, |
| | num_channels=biosignals_cfg.input_channels |
| | ) |
| |
|
| | |
| | self.num_patches = biosignals_cfg.signal_length // biosignals_cfg.patch_size |
| |
|
| | |
| | self.embed_projection = nn.Linear( |
| | biosignals_cfg.conv_embed_dim, |
| | biosignals_cfg.transformer_width |
| | ) |
| |
|
| | |
| | self.channel_id_embed = nn.Embedding( |
| | num_embeddings=biosignals_cfg.input_channels, |
| | embedding_dim=biosignals_cfg.transformer_width, |
| | ) |
| |
|
| | |
| | if biosignals_cfg.share_channel_rope: |
| | shared_head_dim = biosignals_cfg.transformer_width // biosignals_cfg.transformer_heads |
| | self.shared_channel_rope = RotaryEmbedding( |
| | dim=shared_head_dim, |
| | theta=10000, |
| | learned_freq=True |
| | ) |
| | else: |
| | self.shared_channel_rope = None |
| |
|
| | |
| | self.transformer_blocks = nn.ModuleList([ |
| | DualTransformerBlock( |
| | embed_dim=biosignals_cfg.transformer_width, |
| | num_heads=biosignals_cfg.transformer_heads, |
| | num_temporal_layers=biosignals_cfg.num_temporal_layers, |
| | dropout=biosignals_cfg.dropout, |
| | mlp_ratio=biosignals_cfg.mlp_ratio, |
| | num_channels=biosignals_cfg.input_channels, |
| | activation=biosignals_cfg.activation, |
| | norm_type=biosignals_cfg.norm_type, |
| | mlp_bias=biosignals_cfg.mlp_bias, |
| | shared_channel_rope=self.shared_channel_rope if biosignals_cfg.share_channel_rope else None |
| | ) for _ in range(biosignals_cfg.transformer_layers) |
| | ]) |
| |
|
| | |
| | norm_layer = ( |
| | LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm |
| | ) |
| | if biosignals_cfg.norm_type == "rmsnorm": |
| | self.ln_final = RMSNorm(biosignals_cfg.transformer_width) |
| | else: |
| | self.ln_final = norm_layer(biosignals_cfg.transformer_width) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | n_decoder_tokens = getattr(biosignals_cfg, "decoder_tokens", 32) |
| |
|
| | self.contrastive_pooler = AttnPooler( |
| | dim=biosignals_cfg.transformer_width, |
| | num_heads=biosignals_cfg.transformer_heads, |
| | n_query=1 |
| | ) |
| |
|
| | self.decoder_pooler = AttnPooler( |
| | dim=biosignals_cfg.transformer_width, |
| | num_heads=biosignals_cfg.transformer_heads, |
| | n_query=n_decoder_tokens |
| | ) |
| |
|
| |
|
| | def _encode(self, biosignals: torch.Tensor): |
| | """ |
| | Returns: |
| | features: (B, N_dec + 1, D) |
| | first N_dec tokens = pooled decoder tokens |
| | last token = global pooled token (contrastive CLS) |
| | has_cls_token: True |
| | """ |
| | B = biosignals.shape[0] |
| | device = biosignals.device |
| |
|
| | |
| | x = self.patching(biosignals) |
| |
|
| | |
| | x = self.embed_projection(x) |
| |
|
| | |
| | _, C, T, D = x.shape |
| | channel_ids = torch.arange(C, device=device) |
| | channel_bias = self.channel_id_embed(channel_ids) |
| | channel_bias = channel_bias.view(1, C, 1, D).expand(B, C, T, D) |
| | x = x + channel_bias |
| |
|
| | |
| | pos_ids = torch.arange(self.num_patches, device=device) |
| |
|
| | |
| | for block in self.transformer_blocks: |
| | x = block(x, temporal_position_ids=pos_ids) |
| |
|
| | |
| | x = self.ln_final(x) |
| |
|
| | |
| | x_seq = x.reshape(B, C * T, D) |
| |
|
| | |
| | |
| | |
| | global_token = self.contrastive_pooler(x_seq) |
| | dec_tokens = self.decoder_pooler(x_seq) |
| |
|
| | |
| | |
| | |
| | |
| | features = torch.cat([dec_tokens, global_token], dim=1) |
| |
|
| | has_cls_token = True |
| | return features, has_cls_token |
| |
|
| |
|
| | class SignalReconstructionDecoder(nn.Module): |
| | """ |
| | Lightweight transformer decoder for signal reconstruction. |
| | Uses 2-3 transformer encoder layers + final MLP to reconstruct biosignals. |
| | Note: Uses TransformerEncoder (self-attention only) since we don't need cross-attention. |
| | """ |
| | |
| | def __init__( |
| | self, |
| | input_dim: int = 768, |
| | num_layers: int = 2, |
| | num_heads: int = 4, |
| | output_channels: int = 10, |
| | output_length: int = 1920, |
| | ): |
| | super().__init__() |
| | |
| | |
| | |
| | encoder_layer = nn.TransformerEncoderLayer( |
| | d_model=input_dim, |
| | nhead=num_heads, |
| | dim_feedforward=input_dim * 2, |
| | batch_first=True, |
| | norm_first=True, |
| | ) |
| | self.transformer = nn.TransformerEncoder(encoder_layer, num_layers) |
| | |
| | |
| | |
| | self.to_signal = nn.Sequential( |
| | nn.Linear(input_dim, input_dim // 2), |
| | nn.ReLU(), |
| | nn.Linear(input_dim // 2, output_channels * output_length), |
| | ) |
| | |
| | self.output_channels = output_channels |
| | self.output_length = output_length |
| | |
| | def forward(self, encoder_features): |
| | """ |
| | Args: |
| | encoder_features: (B, seq_len, input_dim) - unprojected encoder features |
| | Returns: |
| | reconstructed: (B, output_channels, output_length) |
| | """ |
| | B = encoder_features.shape[0] |
| | |
| | |
| | decoded = self.transformer(encoder_features) |
| | |
| | |
| | pooled = decoded.mean(dim=1) |
| | |
| | |
| | signal_flat = self.to_signal(pooled) |
| | |
| | |
| | signal = signal_flat.reshape(B, self.output_channels, self.output_length) |
| | |
| | return signal |
| |
|
| |
|
| | class BiosignalsCoCa(nn.Module): |
| | """ |
| | CoCa model adapted for biosignals-text contrastive learning. |
| | Replaces the vision tower with a biosignals encoder. |
| | |
| | Supports two decoder types: |
| | - "cross_attention": Separate cross-attention between text and biosignals (default CoCa) |
| | - "concat": Concatenate biosignals and text tokens with prefix-causal masking |
| | """ |
| | |
| | def __init__( |
| | self, |
| | embed_dim, |
| | multimodal_cfg: MultimodalCfg, |
| | text_cfg: CLIPTextCfg, |
| | biosignals_cfg: BiosignalsCfg, |
| | quick_gelu: bool = False, |
| | init_logit_scale: float = np.log(1 / 0.07), |
| | init_logit_bias: Optional[float] = None, |
| | nonscalar_logit_scale: bool = False, |
| | cast_dtype: Optional[torch.dtype] = None, |
| | pad_id: int = 0, |
| | decoder_type: str = "cross_attention", |
| | num_caption_channels: int = 12, |
| | prefix_len: int = 0, |
| | use_signal_decoder: bool = False, |
| | ): |
| | super().__init__() |
| | multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg |
| | text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg |
| | biosignals_cfg = BiosignalsCfg(**biosignals_cfg) if isinstance(biosignals_cfg, dict) else biosignals_cfg |
| |
|
| | self.decoder_type = decoder_type |
| | self.num_channels = num_caption_channels |
| | self.use_signal_decoder = use_signal_decoder |
| | |
| | |
| | import logging |
| | logging.info(f"BiosignalsCoCa initialized with num_caption_channels={num_caption_channels}, prefix_len={prefix_len}") |
| | if use_signal_decoder: |
| | logging.info(f"Signal reconstruction decoder enabled") |
| |
|
| | self.text = _build_text_tower( |
| | embed_dim=embed_dim, |
| | text_cfg=text_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | ) |
| |
|
| | vocab_size = ( |
| | self.text.vocab_size |
| | if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None |
| | else text_cfg.vocab_size |
| | ) |
| | |
| | |
| | self.biosignals = _build_signal_tower( |
| | embed_dim=embed_dim, |
| | signal_cfg=biosignals_cfg, |
| | output_tokens=True, |
| | cast_dtype=cast_dtype, |
| | ) |
| |
|
| | self.text_decoder = _build_text_decoder_tower_v2( |
| | vocab_size, |
| | multimodal_cfg=multimodal_cfg, |
| | quick_gelu=quick_gelu, |
| | cast_dtype=cast_dtype, |
| | decoder_type=decoder_type, |
| | prefix_len=prefix_len, |
| | ) |
| |
|
| | lshape = [1] if nonscalar_logit_scale else [] |
| | self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale) |
| | if init_logit_bias is not None: |
| | self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias) |
| | else: |
| | self.logit_bias = None |
| | self.pad_id = pad_id |
| |
|
| | self.context_length = multimodal_cfg.context_length |
| | |
| | |
| | |
| | |
| | self.channel_embeddings = nn.Parameter( |
| | torch.randn(num_caption_channels, multimodal_cfg.width) * 0.02 |
| | ) |
| | |
| | |
| | |
| | self.padding_embedding = nn.Parameter( |
| | torch.randn(multimodal_cfg.width) * 0.02 |
| | ) |
| | |
| | self.decoder_width = multimodal_cfg.width |
| | |
| | |
| | if use_signal_decoder: |
| | self.signal_decoder = SignalReconstructionDecoder( |
| | input_dim=biosignals_cfg.transformer_width, |
| | num_layers=2, |
| | num_heads=biosignals_cfg.transformer_heads, |
| | output_channels=biosignals_cfg.input_channels, |
| | output_length=biosignals_cfg.signal_length, |
| | ) |
| |
|
| | @torch.jit.ignore |
| | def set_grad_checkpointing(self, enable: bool = True): |
| | self.biosignals.set_grad_checkpointing(enable) |
| | self.text.set_grad_checkpointing(enable) |
| | self.text_decoder.set_grad_checkpointing(enable) |
| |
|
| | def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True): |
| | """Lock the text encoder, optionally leaving the last N layers unlocked. |
| | |
| | Args: |
| | unlocked_layers: Number of layers to leave unlocked (from the end) |
| | freeze_layer_norm: Whether to freeze LayerNorm parameters in locked layers |
| | """ |
| | if hasattr(self.text, 'lock'): |
| | |
| | self.text.lock(unlocked_layers, freeze_layer_norm) |
| | |
| | |
| | |
| | if hasattr(self.text, 'original_vocab_size'): |
| | import logging |
| | embedding_module = self.text.transformer.get_input_embeddings() |
| | original_size = self.text.original_vocab_size |
| | current_size = embedding_module.weight.shape[0] |
| | |
| | if current_size > original_size: |
| | |
| | embedding_module.weight.requires_grad = True |
| | |
| | |
| | self.text._new_token_start_idx = original_size |
| | |
| | |
| | actual_embedding_size = embedding_module.weight.shape[0] |
| | new_vocab_size = self.text.vocab_size |
| | |
| | |
| | |
| | def _zero_grad_frozen_tokens(grad): |
| | """Zero out gradients for old (frozen) tokens and padding, keep only new tokens.""" |
| | if grad is not None: |
| | |
| | grad[:original_size] = 0 |
| | |
| | if actual_embedding_size > new_vocab_size: |
| | grad[new_vocab_size:] = 0 |
| | return grad |
| | |
| | embedding_module.weight.register_hook(_zero_grad_frozen_tokens) |
| | |
| | num_new_tokens = new_vocab_size - original_size |
| | num_padding_tokens = actual_embedding_size - new_vocab_size |
| | logging.info(f"Embedding layer configuration:") |
| | logging.info(f" Trainable new tokens: {num_new_tokens} (indices {original_size}:{new_vocab_size})") |
| | logging.info(f" Frozen pretrained tokens: {original_size} (indices 0:{original_size})") |
| | if num_padding_tokens > 0: |
| | logging.info(f" Frozen padding tokens: {num_padding_tokens} (indices {new_vocab_size}:{actual_embedding_size})") |
| | logging.info(f" Total embedding size: {actual_embedding_size}") |
| | logging.info(f"Registered gradient masking hook before DDP wrapping") |
| | logging.info(f"NOTE: Optimizer uses weight_decay=0 for embedding layer") |
| | else: |
| | |
| | assert False, "BiosignalsCoCa does not support locking standard TextTransformer" |
| | from .transformer import lock_text_tower |
| | lock_text_tower(self, unlocked_layers) |
| |
|
| | def _encode_biosignals(self, biosignals, normalize: bool = True): |
| | biosignals_latent, tokens_embs = self.biosignals(biosignals) |
| | biosignals_latent = F.normalize(biosignals_latent, dim=-1) if normalize else biosignals_latent |
| | return biosignals_latent, tokens_embs |
| |
|
| | def _encode_text(self, text, normalize: bool = True): |
| | text_latent, token_emb = self.text(text) |
| | text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent |
| | return text_latent, token_emb |
| |
|
| | def encode_image(self, biosignals, normalize: bool = True): |
| | biosignals_latent, _ = self._encode_biosignals(biosignals, normalize=normalize) |
| | return biosignals_latent |
| |
|
| | def encode_text(self, text, normalize: bool = True): |
| | text_latent, _ = self._encode_text(text, normalize=normalize) |
| | return text_latent |
| |
|
| | def _get_channel_condition_embs(self, channel_indices: torch.Tensor) -> torch.Tensor: |
| | """Convert channel/modality indices to embeddings with learnable padding. |
| | |
| | Args: |
| | channel_indices: (batch_size, prefix_len) tensor of indices |
| | - Individual mode: indices into 23 channel embeddings (22 channels + 1 stage_event) |
| | - Modality mode: indices into 5 modality embeddings (4 modalities + 1 stage_event) |
| | - Padded with -1 for variable length (uses learnable padding_embedding for -1) |
| | |
| | Returns: |
| | condition_embs: (batch_size, prefix_len, decoder_width) |
| | Embeddings for all positions. -1 positions use learnable padding_embedding |
| | that learns to be neutral/ignored during training. |
| | """ |
| | batch_size, prefix_len = channel_indices.shape |
| | |
| | |
| | condition_embs = torch.zeros(batch_size, prefix_len, self.decoder_width, |
| | dtype=self.channel_embeddings.dtype, |
| | device=self.channel_embeddings.device) |
| | |
| | |
| | valid_mask = channel_indices >= 0 |
| | padding_mask = channel_indices == -1 |
| | |
| | |
| | |
| | indices_safe = channel_indices.clamp(min=0) |
| | |
| | |
| | expanded_embeddings = self.channel_embeddings.unsqueeze(0).expand(batch_size, -1, -1) |
| | |
| | |
| | indices_expanded = indices_safe.unsqueeze(-1).expand(-1, -1, self.decoder_width) |
| | gathered_embs = torch.gather(expanded_embeddings, 1, indices_expanded) |
| | |
| | |
| | condition_embs[valid_mask] = gathered_embs[valid_mask] |
| | |
| | |
| | if padding_mask.any(): |
| | |
| | condition_embs[padding_mask] = self.padding_embedding |
| | |
| | return condition_embs |
| | |
| | def forward( |
| | self, |
| | biosignals, |
| | text: Optional[torch.Tensor] = None, |
| | biosignals_latent: Optional[torch.Tensor] = None, |
| | biosignals_embs: Optional[torch.Tensor] = None, |
| | |
| | channel_indices: Optional[torch.Tensor] = None, |
| | output_labels: bool = True, |
| | ): |
| | """Forward pass for BiosignalsCoCa model. |
| | |
| | Args: |
| | biosignals: Input biosignals tensor |
| | text: Optional text token ids |
| | biosignals_latent: Optional pre-computed biosignals latent features |
| | biosignals_embs: Optional pre-computed biosignals token embeddings |
| | |
| | channel_indices: Optional (batch_size, num_selected_channels) tensor of channel indices |
| | Used to select channel-specific condition embeddings. If provided, overrides condition_embs. |
| | output_labels: Whether to output labels for loss computation |
| | """ |
| | if biosignals_latent is None or biosignals_embs is None: |
| | biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals) |
| |
|
| | if text is None: |
| | return {"image_features": biosignals_latent, "image_embs": biosignals_embs} |
| |
|
| | text_latent, token_embs = self._encode_text(text) |
| |
|
| | |
| | labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None |
| | if output_labels: |
| | |
| | token_embs = token_embs[:, :-1] |
| | |
| | |
| | if channel_indices is not None: |
| | condition_embs = self._get_channel_condition_embs(channel_indices) |
| | else: |
| | condition_embs = None |
| |
|
| | logits = self.text_decoder(biosignals_embs, token_embs, condition_embs=condition_embs) |
| | out_dict = { |
| | "image_features": biosignals_latent, |
| | "text_features": text_latent, |
| | "logits": logits, |
| | "logit_scale": self.logit_scale.exp() |
| | } |
| | if labels is not None: |
| | out_dict["labels"] = labels |
| | if self.logit_bias is not None: |
| | out_dict["logit_bias"] = self.logit_bias |
| | |
| | |
| | if self.use_signal_decoder: |
| | reconstructed_signal = self.signal_decoder(biosignals_embs) |
| | out_dict["reconstructed_signal"] = reconstructed_signal |
| | out_dict["original_signal"] = biosignals |
| | |
| | return out_dict |
| |
|
| | def generate( |
| | self, |
| | biosignals, |
| | text=None, |
| | seq_len=30, |
| | max_seq_len=256, |
| | temperature=1., |
| | generation_type="beam_search", |
| | top_p=0.1, |
| | top_k=1, |
| | pad_token_id=None, |
| | eos_token_id=None, |
| | sot_token_id=None, |
| | num_beams=6, |
| | num_beam_groups=3, |
| | min_seq_len=5, |
| | stopping_criteria=None, |
| | repetition_penalty=1.0, |
| | fixed_output_length=False, |
| | condition_embs=None, |
| | channel_indices=None, |
| | ): |
| | |
| | |
| | assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`." |
| | assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len" |
| | device = biosignals.device |
| | |
| | |
| | |
| |
|
| | with torch.no_grad(): |
| | sot_token_id = _token_to_tensor(sot_token_id, device=device) |
| | eos_token_id = _token_to_tensor(eos_token_id, device=device) |
| | pad_token_id = pad_token_id |
| | logit_processor = LogitsProcessorList( |
| | [ |
| | MinLengthLogitsProcessor(min_seq_len, eos_token_id), |
| | RepetitionPenaltyLogitsProcessor(repetition_penalty), |
| | ] |
| | ) |
| |
|
| | if stopping_criteria is None: |
| | stopping_criteria = [MaxLengthCriteria(max_length=seq_len)] |
| | stopping_criteria = StoppingCriteriaList(stopping_criteria) |
| |
|
| | if generation_type == "beam_search": |
| | output = self._generate_beamsearch( |
| | biosignals_inputs=biosignals, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | sot_token_id=sot_token_id, |
| | num_beams=num_beams, |
| | num_beam_groups=num_beam_groups, |
| | min_seq_len=min_seq_len, |
| | stopping_criteria=stopping_criteria, |
| | logit_processor=logit_processor, |
| | channel_indices=channel_indices, |
| | ) |
| | if fixed_output_length and output.shape[1] < seq_len: |
| | pad_len = seq_len - output.shape[1] |
| | return torch.cat(( |
| | output, |
| | torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id |
| | ), |
| | dim=1 |
| | ) |
| | return output |
| |
|
| | elif generation_type == "top_p": |
| | logit_warper = GENERATION_TYPES[generation_type](top_p) |
| | elif generation_type == "top_k": |
| | logit_warper = GENERATION_TYPES[generation_type](top_k) |
| | else: |
| | raise ValueError( |
| | f"generation_type has to be one of " |
| | f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}." |
| | ) |
| |
|
| | biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals) |
| |
|
| | if text is None: |
| | text = torch.ones((biosignals.shape[0], 1), device=device, dtype=torch.long) * sot_token_id |
| |
|
| | was_training = self.training |
| | num_dims = len(text.shape) |
| |
|
| | if num_dims == 1: |
| | text = text[None, :] |
| |
|
| | self.eval() |
| | out = text |
| |
|
| | while True: |
| | x = out[:, -max_seq_len:] |
| | cur_len = x.shape[1] |
| | logits = self( |
| | biosignals, |
| | x, |
| | biosignals_latent=biosignals_latent, |
| | biosignals_embs=biosignals_embs, |
| | channel_indices=channel_indices, |
| | output_labels=False, |
| | )["logits"][:, -1] |
| | mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id) |
| | sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id |
| |
|
| | if mask.all(): |
| | if not fixed_output_length: |
| | break |
| | else: |
| | logits = logits[~mask, :] |
| | filtered_logits = logit_processor(x[~mask, :], logits) |
| | filtered_logits = logit_warper(x[~mask, :], filtered_logits) |
| | probs = F.softmax(filtered_logits / temperature, dim=-1) |
| |
|
| | if (cur_len + 1 == seq_len): |
| | sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id |
| | else: |
| | sample[~mask, :] = torch.multinomial(probs, 1) |
| |
|
| | out = torch.cat((out, sample), dim=-1) |
| |
|
| | cur_len += 1 |
| |
|
| | if all(stopping_criteria(out, None)): |
| | break |
| |
|
| | if num_dims == 1: |
| | out = out.squeeze(0) |
| |
|
| | self.train(was_training) |
| | return out |
| |
|
| | def _generate_beamsearch( |
| | self, |
| | biosignals_inputs, |
| | pad_token_id=None, |
| | eos_token_id=None, |
| | sot_token_id=None, |
| | num_beams=6, |
| | num_beam_groups=3, |
| | min_seq_len=5, |
| | stopping_criteria=None, |
| | logit_processor=None, |
| | logit_warper=None, |
| | channel_indices=None, |
| | ): |
| | device = biosignals_inputs.device |
| | batch_size = biosignals_inputs.shape[0] |
| | biosignals_inputs = torch.repeat_interleave(biosignals_inputs, num_beams, dim=0) |
| | biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals_inputs) |
| | |
| | |
| | |
| | if channel_indices is not None: |
| | channel_indices = torch.repeat_interleave(channel_indices, num_beams, dim=0) |
| |
|
| | input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long) |
| | input_ids = input_ids * sot_token_id |
| | beam_scorer = BeamSearchScorer( |
| | batch_size=batch_size, |
| | num_beams=num_beams, |
| | device=device, |
| | num_beam_groups=num_beam_groups, |
| | ) |
| | |
| | logits_processor = ( |
| | LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)]) |
| | if logit_processor is None |
| | else logit_processor |
| | ) |
| |
|
| | num_beams = beam_scorer.num_beams |
| | num_beam_groups = beam_scorer.num_beam_groups |
| | num_sub_beams = num_beams // num_beam_groups |
| | batch_size = len(beam_scorer._beam_hyps) // num_beam_groups |
| | batch_beam_size, cur_len = input_ids.shape |
| | beam_indices = None |
| |
|
| | if num_beams * batch_size != batch_beam_size: |
| | raise ValueError( |
| | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." |
| | ) |
| |
|
| | beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device) |
| | |
| | |
| | beam_scores[:, ::num_sub_beams] = 0 |
| | beam_scores = beam_scores.view((batch_size * num_beams,)) |
| |
|
| | while True: |
| |
|
| | |
| | current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device) |
| |
|
| | |
| | reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device) |
| |
|
| | |
| | model_inputs = prepare_inputs_for_generation(input_ids=input_ids, biosignals_inputs=biosignals_inputs) |
| | outputs = self( |
| | model_inputs['biosignals'], |
| | model_inputs['text'], |
| | biosignals_latent=biosignals_latent, |
| | biosignals_embs=biosignals_embs, |
| | channel_indices=channel_indices, |
| | output_labels=False, |
| | ) |
| |
|
| | for beam_group_idx in range(num_beam_groups): |
| | group_start_idx = beam_group_idx * num_sub_beams |
| | group_end_idx = min(group_start_idx + num_sub_beams, num_beams) |
| | group_size = group_end_idx - group_start_idx |
| |
|
| | |
| | batch_group_indices = [] |
| |
|
| | for batch_idx in range(batch_size): |
| | batch_group_indices.extend( |
| | [batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)] |
| | ) |
| | group_input_ids = input_ids[batch_group_indices] |
| |
|
| | |
| | next_token_logits = outputs['logits'][batch_group_indices, -1, :] |
| | vocab_size = next_token_logits.shape[-1] |
| |
|
| | next_token_scores_processed = logits_processor( |
| | group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx |
| | ) |
| | next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1) |
| | next_token_scores = next_token_scores.expand_as(next_token_scores_processed) |
| |
|
| | |
| | next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size) |
| |
|
| | next_token_scores, next_tokens = torch.topk( |
| | next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True |
| | ) |
| |
|
| | next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor") |
| | next_tokens = next_tokens % vocab_size |
| |
|
| | |
| | process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| | beam_outputs = beam_scorer.process( |
| | group_input_ids, |
| | next_token_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | beam_indices=process_beam_indices, |
| | group_index=beam_group_idx, |
| | ) |
| | beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"] |
| | beam_next_tokens = beam_outputs["next_beam_tokens"] |
| | beam_idx = beam_outputs["next_beam_indices"] |
| |
|
| | input_ids[batch_group_indices] = group_input_ids[beam_idx] |
| | group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) |
| | current_tokens[batch_group_indices] = group_input_ids[:, -1] |
| |
|
| | |
| | |
| | reordering_indices[batch_group_indices] = ( |
| | num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size) |
| | ) |
| |
|
| | input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1) |
| |
|
| | |
| | cur_len = cur_len + 1 |
| | if beam_scorer.is_done or all(stopping_criteria(input_ids, None)): |
| | break |
| |
|
| | final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None |
| | sequence_outputs = beam_scorer.finalize( |
| | input_ids, |
| | beam_scores, |
| | next_tokens, |
| | next_indices, |
| | pad_token_id=pad_token_id, |
| | eos_token_id=eos_token_id, |
| | max_length=stopping_criteria.max_length, |
| | beam_indices=final_beam_indices, |
| | ) |
| | return sequence_outputs['sequences'] |
| |
|
| |
|
| | def prepare_inputs_for_generation(input_ids, biosignals_inputs, past=None, **kwargs): |
| | if past: |
| | input_ids = input_ids[:, -1].unsqueeze(-1) |
| |
|
| | attention_mask = kwargs.get("attention_mask", None) |
| | position_ids = kwargs.get("position_ids", None) |
| |
|
| | if attention_mask is not None and position_ids is None: |
| | |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| | else: |
| | position_ids = None |
| | return { |
| | "text": input_ids, |
| | "biosignals": biosignals_inputs, |
| | "past_key_values": past, |
| | "position_ids": position_ids, |
| | "attention_mask": attention_mask, |
| | } |
| |
|