SleepLM-Base / src /open_clip /biosignals_coca_model.py
zongzhex's picture
Add source code
06acd95 verified
"""
Biosignals-Text CoCa Model
Adapted from the original CoCa model to work with biosignals (time series) data
instead of images. This model is designed for biosignals-text contrastive learning.
"""
from typing import Dict, List, Optional, Union, Tuple
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import math
from dataclasses import dataclass, field
from .transformer import (
LayerNormFp32,
LayerNorm,
QuickGELU,
MultimodalTransformer,
ConcatMultimodalTransformer,
)
from .model import CLIPTextCfg, _build_text_tower
from .coca_model import MultimodalCfg, _build_text_decoder_tower, _token_to_tensor
try:
from transformers.generation.beam_search import BeamSearchScorer
from transformers.generation.logits_process import (
LogitsProcessorList,
TopPLogitsWarper,
TopKLogitsWarper,
RepetitionPenaltyLogitsProcessor,
MinLengthLogitsProcessor,
)
from transformers.generation.stopping_criteria import (
MaxLengthCriteria,
EosTokenCriteria,
StoppingCriteriaList,
)
GENERATION_TYPES = {
"top_k": TopKLogitsWarper,
"top_p": TopPLogitsWarper,
"beam_search": "beam_search"
}
_has_transformers = True
except ImportError as e:
GENERATION_TYPES = {
"top_k": None,
"top_p": None,
"beam_search": "beam_search"
}
_has_transformers = False
# ============================================================================
# Pure Transformer Architecture Components (from PureTransformerMAE)
# ============================================================================
class RotaryEmbedding(nn.Module):
"""Rotary Position Embedding (RoPE)"""
def __init__(self, dim: int, theta: float = 10000.0, learned_freq: bool = False):
super().__init__()
self.dim = dim
self.theta = theta
self.learned_freq = learned_freq
if learned_freq:
# Learnable frequencies for channel attention
self.freqs = nn.Parameter(torch.randn(dim // 2) * 0.02)
else:
# Fixed frequencies for temporal attention
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('freqs', freqs)
def rotate_queries_or_keys(self, x: torch.Tensor, position_ids: Optional[torch.Tensor] = None):
"""
Apply rotary embeddings to queries or keys
Args:
x: (batch_size, num_heads, seq_len, head_dim)
position_ids: (seq_len,) or (batch_size, seq_len) - position indices
Returns:
Rotated tensor of same shape
"""
batch_size, num_heads, seq_len, head_dim = x.shape
assert head_dim == self.dim, f"head_dim {head_dim} != self.dim {self.dim}"
# Generate position indices if not provided
if position_ids is None:
position_ids = torch.arange(seq_len, device=x.device, dtype=torch.float)
elif position_ids.ndim == 2:
# If 2D, take the first batch (assuming all batches have same pattern)
position_ids = position_ids[0].float()
else:
position_ids = position_ids.float()
# Compute angles: position_ids * freqs
# position_ids: (seq_len,), freqs: (dim // 2,)
# angles: (seq_len, dim // 2)
angles = torch.einsum('s,d->sd', position_ids, self.freqs)
# Duplicate for cos and sin
# cos/sin: (seq_len, dim)
cos = torch.cos(angles).repeat_interleave(2, dim=-1)
sin = torch.sin(angles).repeat_interleave(2, dim=-1)
# Reshape for broadcasting: (1, 1, seq_len, dim)
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
# Apply rotation
# Split x into even and odd dimensions
x1 = x[..., 0::2] # Even dimensions
x2 = x[..., 1::2] # Odd dimensions
# Apply rotation: [x1, x2] @ [[cos, -sin], [sin, cos]]
x_rotated = torch.empty_like(x)
x_rotated[..., 0::2] = x1 * cos[..., 0::2] - x2 * sin[..., 0::2]
x_rotated[..., 1::2] = x1 * sin[..., 1::2] + x2 * cos[..., 1::2]
return x_rotated
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
class SwiGLU(nn.Module):
"""SwiGLU activation function: SiLU(x * W1) * (x * W2)"""
def __init__(self, dim_in: int, dim_out: int, bias: bool = False):
super().__init__()
self.w1 = nn.Linear(dim_in, dim_out, bias=bias)
self.w2 = nn.Linear(dim_in, dim_out, bias=bias)
def forward(self, x):
return F.silu(self.w1(x)) * self.w2(x)
class MLP(nn.Module):
"""MLP with configurable activation and normalization"""
def __init__(self,
dim: int,
hidden_dim: int,
dropout: float = 0.0,
activation: str = "swiglu", # "swiglu", "gelu", "relu"
bias: bool = False):
super().__init__()
self.activation = activation
if activation == "swiglu":
# SwiGLU requires different structure: two parallel linear layers
self.gate_proj = SwiGLU(dim, hidden_dim, bias=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
else:
# Standard MLP structure
self.up_proj = nn.Linear(dim, hidden_dim, bias=bias)
self.down_proj = nn.Linear(hidden_dim, dim, bias=bias)
if activation == "gelu":
self.act_fn = nn.GELU()
elif activation == "relu":
self.act_fn = nn.ReLU()
else:
raise ValueError(f"Unknown activation: {activation}")
self.dropout = nn.Dropout(dropout)
def forward(self, x):
if self.activation == "swiglu":
x = self.gate_proj(x)
x = self.dropout(x)
x = self.down_proj(x)
else:
x = self.up_proj(x)
x = self.act_fn(x)
x = self.dropout(x)
x = self.down_proj(x)
return self.dropout(x)
class ChannelPatching(nn.Module):
"""Patching layer that operates independently on each channel"""
def __init__(self,
patch_size: int = 32,
conv_embed_dim: int = 256,
num_channels: int = 21):
super().__init__()
self.patch_size = patch_size
self.conv_embed_dim = conv_embed_dim
self.num_channels = num_channels
# Single conv layer applied to all channels (kernel_size=patch_size, stride=patch_size)
self.conv_patching = nn.Conv1d(
in_channels=1,
out_channels=conv_embed_dim,
kernel_size=patch_size,
stride=patch_size,
padding=0 # No padding for clean non-overlapping patches
)
def forward(self, x):
"""
Args:
x: (batch_size, num_channels, signal_length) - multi-channel signal
Returns:
(batch_size, num_channels, num_patches, conv_embed_dim) - patched representations
"""
batch_size, num_channels, seq_len = x.shape
# Reshape to process all channels independently: (batch_size * num_channels, 1, seq_len)
x_reshaped = x.reshape(batch_size * num_channels, 1, seq_len)
# Apply conv patching to all channels
patched = self.conv_patching(x_reshaped) # (batch_size * num_channels, conv_embed_dim, num_patches)
# Reshape back to separate batch and channel dimensions
_, conv_embed_dim, num_patches = patched.shape
patched = patched.reshape(batch_size, num_channels, conv_embed_dim, num_patches)
# Transpose to get (batch_size, num_channels, num_patches, conv_embed_dim)
patched = patched.transpose(2, 3)
return patched
class DualRoPEAttention(nn.Module):
"""Multi-head attention with separate RoPE for temporal and learnable RoPE for channels"""
def __init__(self,
embed_dim: int = 256,
num_heads: int = 8,
dropout: float = 0.1,
attention_type: str = "temporal", # "temporal" or "channel"
num_channels: int = 21,
shared_channel_rope: Optional[nn.Module] = None):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.attention_type = attention_type
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
# Linear projections
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim)
# RoPE embeddings - different for temporal vs channel
if attention_type == "temporal":
# Standard RoPE for temporal attention
self.rotary_emb = RotaryEmbedding(
dim=self.head_dim,
theta=10000,
learned_freq=False
)
elif attention_type == "channel":
# Use shared learnable RoPE for channel attention if provided
if shared_channel_rope is not None:
self.rotary_emb = shared_channel_rope
else:
# Fallback to creating own RoPE
self.rotary_emb = RotaryEmbedding(
dim=self.head_dim,
theta=10000,
learned_freq=True # Learnable frequencies for channels
)
else:
raise ValueError(f"Unknown attention_type: {attention_type}")
self.dropout = nn.Dropout(dropout)
self.scale = self.head_dim ** -0.5
def forward(self, x, position_ids=None):
"""
Args:
x: (batch_size, seq_len, embed_dim)
position_ids: (batch_size, seq_len) or (seq_len,) - custom position indices for RoPE
Returns:
(batch_size, seq_len, embed_dim)
"""
batch_size, seq_len, embed_dim = x.shape
# Linear projections
q = self.q_proj(x)
k = self.k_proj(x)
v = self.v_proj(x)
# Reshape for multi-head attention
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Apply RoPE
q = self.rotary_emb.rotate_queries_or_keys(q, position_ids=position_ids)
k = self.rotary_emb.rotate_queries_or_keys(k, position_ids=position_ids)
# Scaled dot-product attention
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = self.dropout(attn_weights)
# Apply attention to values
attn_output = torch.matmul(attn_weights, v)
# Reshape and project output
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
output = self.out_proj(attn_output)
return output
class DualTransformerBlock(nn.Module):
"""Biosignal transformer block with channel and temporal attention using dual RoPE"""
def __init__(self,
embed_dim: int = 256,
num_heads: int = 8,
num_temporal_layers: int = 2,
dropout: float = 0.1,
mlp_ratio: float = 4.0,
num_channels: int = 21,
activation: str = "swiglu",
norm_type: str = "rmsnorm",
mlp_bias: bool = False,
shared_channel_rope: Optional[nn.Module] = None):
super().__init__()
self.embed_dim = embed_dim
self.num_temporal_layers = num_temporal_layers
# Helper function to create normalization layer
def create_norm(dim):
if norm_type == "rmsnorm":
return RMSNorm(dim)
elif norm_type == "layernorm":
return nn.LayerNorm(dim)
else:
raise ValueError(f"Unknown norm_type: {norm_type}")
# Channel-wise attention with shared learnable RoPE
self.channel_attention = DualRoPEAttention(
embed_dim, num_heads, dropout,
attention_type="channel", num_channels=num_channels,
shared_channel_rope=shared_channel_rope
)
self.channel_norm = create_norm(embed_dim)
# Temporal attention layers with standard RoPE
self.temporal_attention_layers = nn.ModuleList([
DualRoPEAttention(embed_dim, num_heads, dropout, attention_type="temporal")
for _ in range(num_temporal_layers)
])
self.temporal_norms = nn.ModuleList([
create_norm(embed_dim)
for _ in range(num_temporal_layers)
])
# MLP layers
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.channel_mlp = MLP(
dim=embed_dim,
hidden_dim=mlp_hidden_dim,
dropout=dropout,
activation=activation,
bias=mlp_bias
)
self.temporal_mlps = nn.ModuleList([
MLP(
dim=embed_dim,
hidden_dim=mlp_hidden_dim,
dropout=dropout,
activation=activation,
bias=mlp_bias
) for _ in range(num_temporal_layers)
])
self.channel_mlp_norm = create_norm(embed_dim)
self.temporal_mlp_norms = nn.ModuleList([
create_norm(embed_dim)
for _ in range(num_temporal_layers)
])
def forward(self, x, temporal_position_ids=None):
"""
Args:
x: (batch_size, num_channels, num_patches, embed_dim)
temporal_position_ids: (batch_size, num_patches) or (num_patches,) - position indices for temporal RoPE
Returns:
(batch_size, num_channels, num_patches, embed_dim)
"""
batch_size, num_channels, num_patches, embed_dim = x.shape
# 1. Channel-wise attention on each patch independently
x_for_channel_attn = x.permute(0, 2, 1, 3).contiguous().reshape(batch_size * num_patches, num_channels, embed_dim)
# Apply channel attention with learnable RoPE
channel_attn_out = self.channel_attention(x_for_channel_attn)
# Residual connection and layer norm
x_for_channel_attn = self.channel_norm(x_for_channel_attn + channel_attn_out)
# MLP
channel_mlp_out = self.channel_mlp(x_for_channel_attn)
x_for_channel_attn = self.channel_mlp_norm(x_for_channel_attn + channel_mlp_out)
# Reshape back
x = x_for_channel_attn.reshape(batch_size, num_patches, num_channels, embed_dim).permute(0, 2, 1, 3)
# 2. Temporal attention on patches for each channel
x_for_temporal_attn = x.reshape(batch_size * num_channels, num_patches, embed_dim)
# Prepare temporal position IDs
if temporal_position_ids is not None:
if temporal_position_ids.ndim == 2:
temporal_pos_ids_expanded = temporal_position_ids[0]
else:
temporal_pos_ids_expanded = temporal_position_ids
else:
temporal_pos_ids_expanded = None
# Apply multiple temporal attention layers
for i in range(self.num_temporal_layers):
temporal_attn_out = self.temporal_attention_layers[i](x_for_temporal_attn, position_ids=temporal_pos_ids_expanded)
x_for_temporal_attn = self.temporal_norms[i](x_for_temporal_attn + temporal_attn_out)
temporal_mlp_out = self.temporal_mlps[i](x_for_temporal_attn)
x_for_temporal_attn = self.temporal_mlp_norms[i](x_for_temporal_attn + temporal_mlp_out)
# Reshape back
x = x_for_temporal_attn.reshape(batch_size, num_channels, num_patches, embed_dim)
return x
# ============================================================================
# End of Pure Transformer Architecture Components
# ============================================================================
def _build_signal_tower(
embed_dim: int,
signal_cfg,
output_tokens: bool = False,
cast_dtype: Optional[torch.dtype] = None,
):
"""Build a biosignals encoder tower
Args:
embed_dim: Output embedding dimension
signal_cfg: BiosignalsCfg or dict with configuration
output_tokens: Whether to output tokens for multimodal decoder
cast_dtype: Optional dtype for casting
Returns:
Biosignals encoder (either BiosignalsEncoder or PureTransformerBiosignalsEncoder)
"""
if isinstance(signal_cfg, dict):
signal_cfg = BiosignalsCfg(**signal_cfg)
import logging
architecture = getattr(signal_cfg, 'architecture', 'conv_transformer')
logging.info(f"Building biosignals encoder with architecture: {architecture}")
if architecture == "pure_transformer":
signal_encoder = PureTransformerBiosignalsEncoder(
biosignals_cfg=signal_cfg,
embed_dim=embed_dim,
output_tokens=output_tokens,
cast_dtype=cast_dtype
)
logging.info(f"Pure Transformer architecture:")
logging.info(f" Patch size: {signal_cfg.patch_size}")
logging.info(f" Conv embed dim: {signal_cfg.conv_embed_dim}")
logging.info(f" Transformer blocks: {signal_cfg.transformer_layers}")
logging.info(f" Temporal layers per block: {signal_cfg.num_temporal_layers}")
logging.info(f" Activation: {signal_cfg.activation}")
logging.info(f" Norm type: {signal_cfg.norm_type}")
logging.info(f" Share channel RoPE: {signal_cfg.share_channel_rope}")
elif architecture == "conv_transformer":
signal_encoder = BiosignalsEncoder(
biosignals_cfg=signal_cfg,
embed_dim=embed_dim,
output_tokens=output_tokens,
cast_dtype=cast_dtype
)
logging.info(f"Conv-Transformer architecture:")
logging.info(f" Conv layers: {signal_cfg.conv_layers}")
logging.info(f" Kernel sizes: {signal_cfg.kernel_sizes}")
logging.info(f" Strides: {signal_cfg.strides}")
logging.info(f" Transformer layers: {signal_cfg.transformer_layers}")
else:
raise ValueError(f"Unknown architecture: {architecture}. Must be 'conv_transformer' or 'pure_transformer'")
return signal_encoder
def _build_text_decoder_tower_v2(
embed_dim,
multimodal_cfg,
quick_gelu: bool = False,
cast_dtype: Optional[torch.dtype] = None,
decoder_type: str = "cross_attention",
prefix_len: int = 0,
):
"""Build text decoder tower with support for different decoder types.
Args:
embed_dim: Embedding dimension
multimodal_cfg: MultimodalCfg config
quick_gelu: Whether to use QuickGELU
cast_dtype: Optional dtype for casting
decoder_type: "cross_attention" or "concat"
- "cross_attention": Uses separate cross-attention layers (default CoCa)
- "concat": Concatenates image/biosignals and text tokens
prefix_len: Number of prefix tokens (condition embeddings) prepended to text
Used to pre-build prefix-causal attention mask
"""
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
act_layer = QuickGELU if quick_gelu else nn.GELU
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
if decoder_type == "cross_attention":
decoder = MultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
mlp_ratio=multimodal_cfg.mlp_ratio,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
prefix_len=prefix_len,
)
elif decoder_type == "concat":
decoder = ConcatMultimodalTransformer(
context_length=multimodal_cfg.context_length,
width=multimodal_cfg.width,
heads=multimodal_cfg.heads,
layers=multimodal_cfg.layers,
mlp_ratio=multimodal_cfg.mlp_ratio,
ls_init_value=multimodal_cfg.ls_init_value,
output_dim=embed_dim,
act_layer=act_layer,
norm_layer=norm_layer,
prefix_len=prefix_len,
)
else:
raise ValueError(f"Unknown decoder_type: {decoder_type}. Must be 'cross_attention' or 'concat'")
return decoder
@dataclass
class BiosignalsCfg:
"""Configuration for biosignals encoder"""
input_channels: int = 12 # Number of input channels (e.g., 12-lead ECG)
signal_length: int = 1000 # Length of input time series
sampling_rate: int = 500 # Sampling rate in Hz
# Architecture selection
architecture: str = "conv_transformer" # "conv_transformer" or "pure_transformer"
# Architecture parameters for conv_transformer
conv_layers: List[int] = None # Conv layer dimensions
kernel_sizes: List[int] = None # Kernel sizes for conv layers
strides: List[int] = None # Strides for conv layers
# Architecture parameters for pure_transformer
patch_size: int = 32 # Patch size for pure_transformer
conv_embed_dim: int = 256 # Conv embedding dimension for pure_transformer
num_temporal_layers: int = 2 # Number of temporal attention layers per block
activation: str = "swiglu" # "swiglu", "gelu", "relu" (for pure_transformer)
norm_type: str = "rmsnorm" # "rmsnorm", "layernorm" (for pure_transformer)
mlp_bias: bool = False # Whether to use bias in MLP layers (for pure_transformer)
share_channel_rope: bool = True # Share channel RoPE across blocks (for pure_transformer)
decoder_tokens: int = 32 # Number of decoder tokens for dual-axis transformer (pure_transformer)
# Transformer parameters (shared)
transformer_layers: int = 6 # Number of transformer layers/blocks
transformer_width: int = 768 # Transformer width
transformer_heads: int = 12 # Number of attention heads
mlp_ratio: float = 4.0 # MLP expansion ratio
# Pooling and output
pool_type: str = 'attn' # 'avg', 'max', 'cls', 'attn'
dropout: float = 0.1
def __post_init__(self):
if self.architecture == "conv_transformer":
if self.conv_layers is None:
# Default conv layers for processing time series
self.conv_layers = [64, 128, 256, 512]
if self.kernel_sizes is None:
# Default kernel sizes
self.kernel_sizes = [7, 5, 3, 3]
if self.strides is None:
# Default strides
self.strides = [2, 2, 2, 2]
class BaseBiosignalsEncoder(nn.Module):
"""
Base class for biosignals encoders that handles common pooling and projection logic.
Child classes should implement _encode() to return features before pooling.
"""
def __init__(
self,
biosignals_cfg: BiosignalsCfg,
embed_dim: int,
output_tokens: bool,
transformer_width: int,
cast_dtype: Optional[torch.dtype] = None
):
super().__init__()
self.biosignals_cfg = biosignals_cfg
self.embed_dim = embed_dim
self.output_tokens = output_tokens
self.transformer_width = transformer_width
self.pool_type = biosignals_cfg.pool_type
# Projection to output embedding dimension
self.proj_to_embed = nn.Linear(transformer_width, embed_dim)
# Attention pooling if needed
if self.pool_type == 'attn':
self.attn_pool = nn.MultiheadAttention(
transformer_width,
biosignals_cfg.transformer_heads,
batch_first=True
)
def _pool_features(self, x: torch.Tensor, has_cls_token: bool) -> torch.Tensor:
"""
Pool features using the configured pooling method.
Args:
x: Features of shape (batch_size, seq_len, width)
has_cls_token: Whether the sequence includes a CLS token at the last position
Returns:
pooled: Pooled features of shape (batch_size, width)
"""
if self.pool_type == 'cls':
# Use class token (last position)
pooled = x[:, -1]
elif self.pool_type == 'avg':
# Average pooling over sequence
if has_cls_token:
pooled = x[:, :-1].mean(dim=1)
else:
pooled = x.mean(dim=1)
elif self.pool_type == 'max':
# Max pooling over sequence
if has_cls_token:
pooled = x[:, :-1].max(dim=1)[0]
else:
pooled = x.max(dim=1)[0]
elif self.pool_type == 'attn':
# Attention pooling using cls token as query
query = x[:, -1:] # CLS token as query
# CLS attends to content tokens
pooled, _ = self.attn_pool(query, x[:, :-1], x[:, :-1])
pooled = pooled.squeeze(1)
else:
raise ValueError(f"Unknown pool_type: {self.pool_type}")
return pooled
def _encode(self, biosignals: torch.Tensor) -> Tuple[torch.Tensor, bool]:
"""
Encode biosignals to features. Must be implemented by child classes.
Args:
biosignals: Input biosignals tensor
Returns:
features: Encoded features of shape (batch_size, seq_len, transformer_width)
has_cls_token: Whether the sequence includes a CLS token at the last position
"""
raise NotImplementedError("Child classes must implement _encode()")
def forward(self, biosignals: torch.Tensor):
"""
Forward pass with encoding, pooling, and projection.
Args:
biosignals: Input biosignals tensor
Returns:
embedding: Global embedding (batch_size, embed_dim)
tokens_for_decoder: Optional tokens for decoder (batch_size, seq_len, transformer_width)
"""
# Encode to features
features, has_cls_token = self._encode(biosignals)
# Pool features
pooled = self._pool_features(features, has_cls_token)
# Project to final embedding dimension
embedding = self.proj_to_embed(pooled)
if self.output_tokens:
# Return tokens for multimodal decoder
if has_cls_token:
# Exclude CLS token from tokens for decoder
tokens_for_decoder = features[:, :-1]
else:
tokens_for_decoder = features
return embedding, tokens_for_decoder
else:
return embedding
def set_grad_checkpointing(self, enable=True):
# For compatibility with other models
pass
class Conv1dBlock(nn.Module):
"""1D Convolutional block with normalization and activation"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
norm_layer=nn.BatchNorm1d, act_layer=nn.ReLU):
super().__init__()
self.conv = nn.Conv1d(
in_channels, out_channels, kernel_size,
stride=stride, padding=kernel_size//2
)
self.norm = norm_layer(out_channels)
self.act = act_layer()
self.dropout = nn.Dropout(0.1)
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
x = self.dropout(x)
return x
class BiosignalsEncoder(BaseBiosignalsEncoder):
"""
Biosignals encoder that converts time series data to embeddings.
Uses a combination of 1D convolutions and transformers.
"""
def __init__(
self,
biosignals_cfg: BiosignalsCfg,
embed_dim: int = 512,
output_tokens: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
# Initialize base class with common pooling/projection logic
super().__init__(
biosignals_cfg=biosignals_cfg,
embed_dim=embed_dim,
output_tokens=output_tokens,
transformer_width=biosignals_cfg.transformer_width,
cast_dtype=cast_dtype
)
# Convolutional feature extraction
conv_layers = []
in_channels = biosignals_cfg.input_channels
for i, (out_channels, kernel_size, stride) in enumerate(
zip(biosignals_cfg.conv_layers, biosignals_cfg.kernel_sizes, biosignals_cfg.strides)
):
conv_layers.append(
Conv1dBlock(in_channels, out_channels, kernel_size, stride)
)
in_channels = out_channels
self.conv_layers = nn.Sequential(*conv_layers)
# Calculate the length after convolutions with padding - we'll use a dummy forward pass
# to get the exact dimensions
with torch.no_grad():
dummy_input = torch.randn(1, biosignals_cfg.input_channels, biosignals_cfg.signal_length)
dummy_output = self.conv_layers(dummy_input)
conv_output_length = dummy_output.shape[2]
self.conv_output_length = conv_output_length
self.conv_output_dim = biosignals_cfg.conv_layers[-1]
# Projection to transformer dimension
self.proj_conv_to_transformer = nn.Linear(
self.conv_output_dim, biosignals_cfg.transformer_width
)
# Positional embeddings for sequence positions (excluding CLS token)
# CLS token gets no positional embedding as it represents global context
self.pos_embed = nn.Parameter(
torch.randn(1, conv_output_length, biosignals_cfg.transformer_width)
)
# Add a class token for global representation (only used for 'cls' and 'attn' pooling)
self.cls_token = nn.Parameter(
torch.randn(1, 1, biosignals_cfg.transformer_width)
)
# Transformer layers
norm_layer = LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
act_layer = QuickGELU
self.transformer_layers = nn.ModuleList([
TransformerBlock(
biosignals_cfg.transformer_width,
biosignals_cfg.transformer_heads,
biosignals_cfg.mlp_ratio,
act_layer=act_layer,
norm_layer=norm_layer,
dropout=biosignals_cfg.dropout
)
for _ in range(biosignals_cfg.transformer_layers)
])
# Final layer norm
self.ln_final = norm_layer(biosignals_cfg.transformer_width)
def _encode(self, biosignals):
"""
Encode biosignals to features before pooling.
Args:
biosignals: Tensor of shape (batch_size, channels, signal_length)
Returns:
features: Encoded features of shape (batch_size, seq_len, transformer_width)
has_cls_token: Whether the sequence includes a CLS token at the last position
"""
batch_size = biosignals.shape[0]
# Apply convolutional layers
x = self.conv_layers(biosignals) # (batch_size, conv_dim, conv_length)
# Transpose to (batch_size, conv_length, conv_dim)
x = x.transpose(1, 2)
# Project to transformer dimension
x = self.proj_conv_to_transformer(x) # (batch_size, conv_length, transformer_width)
# Add positional embeddings
x = x + self.pos_embed
# Add class token only if needed for pooling
# For consistency with causal text encoder, append CLS token (not prepend)
if self.pool_type in ['cls', 'attn']:
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([x, cls_tokens], dim=1) # (batch_size, conv_length + 1, transformer_width)
has_cls_token = True
else:
has_cls_token = False
# Apply transformer layers
for layer in self.transformer_layers:
x = layer(x)
# Apply final layer norm
x = self.ln_final(x)
return x, has_cls_token
class TransformerBlock(nn.Module):
"""Transformer block with self-attention and MLP"""
def __init__(
self,
width: int,
heads: int,
mlp_ratio: float = 4.0,
act_layer=QuickGELU,
norm_layer=LayerNorm,
dropout: float = 0.1
):
super().__init__()
self.attention = nn.MultiheadAttention(width, heads, dropout=dropout, batch_first=True)
self.ln_1 = norm_layer(width)
self.mlp = nn.Sequential(
nn.Linear(width, int(width * mlp_ratio)),
act_layer(),
nn.Dropout(dropout),
nn.Linear(int(width * mlp_ratio), width),
nn.Dropout(dropout)
)
self.ln_2 = norm_layer(width)
def forward(self, x):
# Self-attention
attn_out, _ = self.attention(x, x, x)
x = x + attn_out
x = self.ln_1(x)
# MLP
mlp_out = self.mlp(x)
x = x + mlp_out
x = self.ln_2(x)
return x
class AttnPooler(nn.Module):
"""
CoCa-style attentional pooler.
A small multi-head attention layer with n_query learned queries (Q),
and the encoder sequence as both K and V. This lets us:
- n_query = 1 => global embedding for contrastive loss
- n_query = N => compressed token set for decoder cross-attention
Ref: CoCa uses task-specific attentional pooling with nquery=1 for contrastive
and nquery=256 for generative objectives. [oai_citation:2‡Medium](https://medium.com/%40arithmancylabs/coca-contrastive-captioners-are-image-textfoundation-models-324022377630?utm_source=chatgpt.com)
"""
def __init__(self, dim: int, num_heads: int, n_query: int):
super().__init__()
self.n_query = n_query
self.query_tokens = nn.Parameter(torch.randn(1, n_query, dim) * 0.02)
self.attn = nn.MultiheadAttention(
embed_dim=dim,
num_heads=num_heads,
batch_first=True
)
def forward(self, x_seq: torch.Tensor) -> torch.Tensor:
"""
x_seq: (B, L, D)
returns:
pooled: (B, n_query, D)
"""
B = x_seq.size(0)
q = self.query_tokens.expand(B, -1, -1) # (B, n_query, D)
pooled, _ = self.attn(q, x_seq, x_seq) # pooled attends over all tokens
return pooled # (B, n_query, D)
class PureTransformerBiosignalsEncoder(BaseBiosignalsEncoder):
"""
Pure Transformer encoder for biosignals with channel+temporal attention.
Updated to use CoCa-style task-specific attentional pooling:
- contrastive_pooler (n_query=1) → 1 global token for contrastive / CLS
- decoder_pooler (n_query=N_dec) → small set of summary tokens for text decoder
We still:
1. Patch each channel independently
2. Alternate channel-attn and temporal-attn in DualTransformerBlocks (factorized attention)
3. Keep (B, C, T, D) internally (cheap attention along channel or time separately)
4. Flatten to (B, C*T, D) only at the end
5. Run two poolers:
- 1-query pooler -> global token
- multi-query pooler -> decoder tokens
6. Append the 1-query pooled token to the end of x_seq so BaseBiosignalsEncoder
can keep using pool_type='cls' or 'attn' the same way.
7. Save the multi-query pooled tokens so, when output_tokens=True, we can hand
them to the text decoder instead of the full ~C*T sequence.
This mirrors CoCa's "task-specific attentional pooling," where the same encoder
supports both contrastive global alignment and caption-style generation with
minimal extra cost. [oai_citation:3‡Medium](https://medium.com/%40arithmancylabs/coca-contrastive-captioners-are-image-textfoundation-models-324022377630?utm_source=chatgpt.com)
"""
def __init__(
self,
biosignals_cfg: BiosignalsCfg,
embed_dim: int = 512,
output_tokens: bool = False,
cast_dtype: Optional[torch.dtype] = None
):
super().__init__(
biosignals_cfg=biosignals_cfg,
embed_dim=embed_dim,
output_tokens=output_tokens,
transformer_width=biosignals_cfg.transformer_width,
cast_dtype=cast_dtype
)
# --- Sanity checks for RoPE dimensions ---
assert biosignals_cfg.transformer_width % biosignals_cfg.transformer_heads == 0, (
f"transformer_width ({biosignals_cfg.transformer_width}) must be divisible by "
f"transformer_heads ({biosignals_cfg.transformer_heads})"
)
head_dim = biosignals_cfg.transformer_width // biosignals_cfg.transformer_heads
assert head_dim % 2 == 0, (
f"head_dim ({head_dim}) must be even for RoPE. "
f"Got transformer_width={biosignals_cfg.transformer_width}, "
f"transformer_heads={biosignals_cfg.transformer_heads}"
)
# 1. Channel patching (Conv1d tokenizer per channel)
self.patching = ChannelPatching(
patch_size=biosignals_cfg.patch_size,
conv_embed_dim=biosignals_cfg.conv_embed_dim,
num_channels=biosignals_cfg.input_channels
)
# number of temporal patches per channel
self.num_patches = biosignals_cfg.signal_length // biosignals_cfg.patch_size
# 2. Project patch embeddings to transformer_width
self.embed_projection = nn.Linear(
biosignals_cfg.conv_embed_dim,
biosignals_cfg.transformer_width
)
# 2a. Channel ID embedding (categorical channel identity)
self.channel_id_embed = nn.Embedding(
num_embeddings=biosignals_cfg.input_channels,
embedding_dim=biosignals_cfg.transformer_width,
)
# 3. Shared learnable RoPE for channel attention (optional)
if biosignals_cfg.share_channel_rope:
shared_head_dim = biosignals_cfg.transformer_width // biosignals_cfg.transformer_heads
self.shared_channel_rope = RotaryEmbedding(
dim=shared_head_dim,
theta=10000,
learned_freq=True # learnable for channel axis
)
else:
self.shared_channel_rope = None
# 4. Dual-axis Transformer blocks (channel attention + temporal attention)
self.transformer_blocks = nn.ModuleList([
DualTransformerBlock(
embed_dim=biosignals_cfg.transformer_width,
num_heads=biosignals_cfg.transformer_heads,
num_temporal_layers=biosignals_cfg.num_temporal_layers,
dropout=biosignals_cfg.dropout,
mlp_ratio=biosignals_cfg.mlp_ratio,
num_channels=biosignals_cfg.input_channels,
activation=biosignals_cfg.activation,
norm_type=biosignals_cfg.norm_type,
mlp_bias=biosignals_cfg.mlp_bias,
shared_channel_rope=self.shared_channel_rope if biosignals_cfg.share_channel_rope else None
) for _ in range(biosignals_cfg.transformer_layers)
])
# 5. Final norm
norm_layer = (
LayerNormFp32 if cast_dtype in (torch.float16, torch.bfloat16) else LayerNorm
)
if biosignals_cfg.norm_type == "rmsnorm":
self.ln_final = RMSNorm(biosignals_cfg.transformer_width)
else:
self.ln_final = norm_layer(biosignals_cfg.transformer_width)
# 6. CoCa-style attentional poolers
# - contrastive_pooler: n_query = 1 for global CLS token (contrastive head)
# - decoder_pooler: n_query = decoder_tokens (e.g. 32) for compressed memory
#
# We'll add a new config field on BiosignalsCfg: decoder_tokens (int, default 32).
n_decoder_tokens = getattr(biosignals_cfg, "decoder_tokens", 32)
self.contrastive_pooler = AttnPooler(
dim=biosignals_cfg.transformer_width,
num_heads=biosignals_cfg.transformer_heads,
n_query=1
)
self.decoder_pooler = AttnPooler(
dim=biosignals_cfg.transformer_width,
num_heads=biosignals_cfg.transformer_heads,
n_query=n_decoder_tokens
)
def _encode(self, biosignals: torch.Tensor):
"""
Returns:
features: (B, N_dec + 1, D)
first N_dec tokens = pooled decoder tokens
last token = global pooled token (contrastive CLS)
has_cls_token: True
"""
B = biosignals.shape[0]
device = biosignals.device
# 1. Patch per channel -> (B, C, T, conv_dim)
x = self.patching(biosignals)
# 2. Project to model dim -> (B, C, T, D)
x = self.embed_projection(x)
# 2a. Add channel ID embedding
_, C, T, D = x.shape
channel_ids = torch.arange(C, device=device) # (C,)
channel_bias = self.channel_id_embed(channel_ids) # (C, D)
channel_bias = channel_bias.view(1, C, 1, D).expand(B, C, T, D)
x = x + channel_bias
# 3. Temporal RoPE positions
pos_ids = torch.arange(self.num_patches, device=device) # (T,)
# 4. Dual-axis transformer blocks (channel-attn + temporal-attn)
for block in self.transformer_blocks:
x = block(x, temporal_position_ids=pos_ids) # stays (B, C, T, D)
# 5. Final norm
x = self.ln_final(x) # (B, C, T, D)
# 6. Flatten channels×time to a sequence for pooling (not for decoder!)
x_seq = x.reshape(B, C * T, D) # (B, L, D) with L = C*T
# 7. Task-specific attentional pooling (CoCa-style)
# contrastive_pooler: n_query=1 -> global_token (B,1,D)
# decoder_pooler: n_query=Nd -> dec_tokens (B,Nd,D)
global_token = self.contrastive_pooler(x_seq) # (B, 1, D)
dec_tokens = self.decoder_pooler(x_seq) # (B, N_dec, D)
# 8. Build final feature sequence:
# [decoder tokens..., global token] so that:
# - features[:, :-1] = dec_tokens (for decoder cross-attn)
# - features[:, -1] = global_token (for contrastive / CLS pooling)
features = torch.cat([dec_tokens, global_token], dim=1) # (B, N_dec+1, D)
has_cls_token = True
return features, has_cls_token
class SignalReconstructionDecoder(nn.Module):
"""
Lightweight transformer decoder for signal reconstruction.
Uses 2-3 transformer encoder layers + final MLP to reconstruct biosignals.
Note: Uses TransformerEncoder (self-attention only) since we don't need cross-attention.
"""
def __init__(
self,
input_dim: int = 768,
num_layers: int = 2,
num_heads: int = 4, # Reduced from 8 for efficiency
output_channels: int = 10,
output_length: int = 1920,
):
super().__init__()
# Transformer encoder layers (self-attention + FFN)
# Using 2x feedforward (instead of 4x) for lighter decoder
encoder_layer = nn.TransformerEncoderLayer(
d_model=input_dim,
nhead=num_heads,
dim_feedforward=input_dim * 2, # 1536 for input_dim=768
batch_first=True,
norm_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
# Final MLP to project to signal space
# Reduced intermediate dimension for efficiency
self.to_signal = nn.Sequential(
nn.Linear(input_dim, input_dim // 2),
nn.ReLU(),
nn.Linear(input_dim // 2, output_channels * output_length),
)
self.output_channels = output_channels
self.output_length = output_length
def forward(self, encoder_features):
"""
Args:
encoder_features: (B, seq_len, input_dim) - unprojected encoder features
Returns:
reconstructed: (B, output_channels, output_length)
"""
B = encoder_features.shape[0]
# Self-attention on encoder features
decoded = self.transformer(encoder_features) # (B, seq_len, dim)
# Global average pooling
pooled = decoded.mean(dim=1) # (B, dim)
# Project to signal space
signal_flat = self.to_signal(pooled) # (B, output_channels * output_length)
# Reshape to signal format
signal = signal_flat.reshape(B, self.output_channels, self.output_length)
return signal
class BiosignalsCoCa(nn.Module):
"""
CoCa model adapted for biosignals-text contrastive learning.
Replaces the vision tower with a biosignals encoder.
Supports two decoder types:
- "cross_attention": Separate cross-attention between text and biosignals (default CoCa)
- "concat": Concatenate biosignals and text tokens with prefix-causal masking
"""
def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
text_cfg: CLIPTextCfg,
biosignals_cfg: BiosignalsCfg,
quick_gelu: bool = False,
init_logit_scale: float = np.log(1 / 0.07),
init_logit_bias: Optional[float] = None,
nonscalar_logit_scale: bool = False,
cast_dtype: Optional[torch.dtype] = None,
pad_id: int = 0,
decoder_type: str = "cross_attention",
num_caption_channels: int = 12, # Number of channel/modality embeddings (22 for channels, 4 for modalities)
prefix_len: int = 0,
use_signal_decoder: bool = False, # NEW: Enable signal reconstruction
):
super().__init__()
multimodal_cfg = MultimodalCfg(**multimodal_cfg) if isinstance(multimodal_cfg, dict) else multimodal_cfg
text_cfg = CLIPTextCfg(**text_cfg) if isinstance(text_cfg, dict) else text_cfg
biosignals_cfg = BiosignalsCfg(**biosignals_cfg) if isinstance(biosignals_cfg, dict) else biosignals_cfg
self.decoder_type = decoder_type
self.num_channels = num_caption_channels
self.use_signal_decoder = use_signal_decoder
# Debug logging for channel configuration
import logging
logging.info(f"BiosignalsCoCa initialized with num_caption_channels={num_caption_channels}, prefix_len={prefix_len}")
if use_signal_decoder:
logging.info(f"Signal reconstruction decoder enabled")
self.text = _build_text_tower(
embed_dim=embed_dim,
text_cfg=text_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
)
vocab_size = (
self.text.vocab_size # for hf models
if hasattr(text_cfg, "hf_model_name") and text_cfg.hf_model_name is not None
else text_cfg.vocab_size
)
# Replace visual tower with biosignals tower
self.biosignals = _build_signal_tower(
embed_dim=embed_dim,
signal_cfg=biosignals_cfg,
output_tokens=True, # Need tokens for multimodal decoder
cast_dtype=cast_dtype,
)
self.text_decoder = _build_text_decoder_tower_v2(
vocab_size,
multimodal_cfg=multimodal_cfg,
quick_gelu=quick_gelu,
cast_dtype=cast_dtype,
decoder_type=decoder_type,
prefix_len=prefix_len,
)
lshape = [1] if nonscalar_logit_scale else []
self.logit_scale = nn.Parameter(torch.ones(lshape) * init_logit_scale)
if init_logit_bias is not None:
self.logit_bias = nn.Parameter(torch.ones(lshape) * init_logit_bias)
else:
self.logit_bias = None
self.pad_id = pad_id
self.context_length = multimodal_cfg.context_length
# Learnable channel/modality embeddings
# num_caption_channels will be 23 for individual channel mode or 5 for modality mode
# Dimension should match the decoder width (multimodal_cfg.width for text decoder input)
self.channel_embeddings = nn.Parameter(
torch.randn(num_caption_channels, multimodal_cfg.width) * 0.02
)
# Learnable padding embedding for -1 positions
# This learns to be "neutral" or ignored during training (similar to [PAD] tokens)
self.padding_embedding = nn.Parameter(
torch.randn(multimodal_cfg.width) * 0.02
)
self.decoder_width = multimodal_cfg.width
# Optional signal reconstruction decoder
if use_signal_decoder:
self.signal_decoder = SignalReconstructionDecoder(
input_dim=biosignals_cfg.transformer_width,
num_layers=2, # Lightweight: 2 transformer layers
num_heads=biosignals_cfg.transformer_heads,
output_channels=biosignals_cfg.input_channels,
output_length=biosignals_cfg.signal_length,
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable: bool = True):
self.biosignals.set_grad_checkpointing(enable)
self.text.set_grad_checkpointing(enable)
self.text_decoder.set_grad_checkpointing(enable)
def lock_text_tower(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
"""Lock the text encoder, optionally leaving the last N layers unlocked.
Args:
unlocked_layers: Number of layers to leave unlocked (from the end)
freeze_layer_norm: Whether to freeze LayerNorm parameters in locked layers
"""
if hasattr(self.text, 'lock'):
# For HFTextEncoder (Pythia, etc.)
self.text.lock(unlocked_layers, freeze_layer_norm)
# IMPORTANT: Unfreeze newly added token embeddings (e.g., <pad>, <coca_cls>)
# These were randomly initialized and need to be trained
if hasattr(self.text, 'original_vocab_size'):
import logging
embedding_module = self.text.transformer.get_input_embeddings()
original_size = self.text.original_vocab_size
current_size = embedding_module.weight.shape[0]
if current_size > original_size:
# Enable gradients for the embedding layer
embedding_module.weight.requires_grad = True
# Store metadata for optimizer configuration (zero weight decay)
self.text._new_token_start_idx = original_size
# Get actual embedding size (may be padded for Tensor Cores)
actual_embedding_size = embedding_module.weight.shape[0]
new_vocab_size = self.text.vocab_size # Actual number of tokens (not padded)
# Register parameter-level hook to mask frozen token gradients
# IMPORTANT: This is registered BEFORE DDP wrapping to ensure it persists
def _zero_grad_frozen_tokens(grad):
"""Zero out gradients for old (frozen) tokens and padding, keep only new tokens."""
if grad is not None:
# Zero out pretrained tokens [0:original_size]
grad[:original_size] = 0
# Zero out padding tokens [new_vocab_size:actual_embedding_size]
if actual_embedding_size > new_vocab_size:
grad[new_vocab_size:] = 0
return grad
embedding_module.weight.register_hook(_zero_grad_frozen_tokens)
num_new_tokens = new_vocab_size - original_size
num_padding_tokens = actual_embedding_size - new_vocab_size
logging.info(f"Embedding layer configuration:")
logging.info(f" Trainable new tokens: {num_new_tokens} (indices {original_size}:{new_vocab_size})")
logging.info(f" Frozen pretrained tokens: {original_size} (indices 0:{original_size})")
if num_padding_tokens > 0:
logging.info(f" Frozen padding tokens: {num_padding_tokens} (indices {new_vocab_size}:{actual_embedding_size})")
logging.info(f" Total embedding size: {actual_embedding_size}")
logging.info(f"Registered gradient masking hook before DDP wrapping")
logging.info(f"NOTE: Optimizer uses weight_decay=0 for embedding layer")
else:
# For standard TextTransformer
assert False, "BiosignalsCoCa does not support locking standard TextTransformer"
from .transformer import lock_text_tower
lock_text_tower(self, unlocked_layers)
def _encode_biosignals(self, biosignals, normalize: bool = True):
biosignals_latent, tokens_embs = self.biosignals(biosignals)
biosignals_latent = F.normalize(biosignals_latent, dim=-1) if normalize else biosignals_latent
return biosignals_latent, tokens_embs
def _encode_text(self, text, normalize: bool = True):
text_latent, token_emb = self.text(text)
text_latent = F.normalize(text_latent, dim=-1) if normalize else text_latent
return text_latent, token_emb
def encode_image(self, biosignals, normalize: bool = True):
biosignals_latent, _ = self._encode_biosignals(biosignals, normalize=normalize)
return biosignals_latent
def encode_text(self, text, normalize: bool = True):
text_latent, _ = self._encode_text(text, normalize=normalize)
return text_latent
def _get_channel_condition_embs(self, channel_indices: torch.Tensor) -> torch.Tensor:
"""Convert channel/modality indices to embeddings with learnable padding.
Args:
channel_indices: (batch_size, prefix_len) tensor of indices
- Individual mode: indices into 23 channel embeddings (22 channels + 1 stage_event)
- Modality mode: indices into 5 modality embeddings (4 modalities + 1 stage_event)
- Padded with -1 for variable length (uses learnable padding_embedding for -1)
Returns:
condition_embs: (batch_size, prefix_len, decoder_width)
Embeddings for all positions. -1 positions use learnable padding_embedding
that learns to be neutral/ignored during training.
"""
batch_size, prefix_len = channel_indices.shape
# Create output tensor
condition_embs = torch.zeros(batch_size, prefix_len, self.decoder_width,
dtype=self.channel_embeddings.dtype,
device=self.channel_embeddings.device)
# Create mask for valid (non-padding) indices
valid_mask = channel_indices >= 0 # (batch_size, prefix_len)
padding_mask = channel_indices == -1 # (batch_size, prefix_len)
# Gather channel embeddings for valid indices
# Clamp to 0 for safe indexing (will be overwritten by padding where needed)
indices_safe = channel_indices.clamp(min=0)
# Expand embeddings for batching
expanded_embeddings = self.channel_embeddings.unsqueeze(0).expand(batch_size, -1, -1)
# Gather embeddings
indices_expanded = indices_safe.unsqueeze(-1).expand(-1, -1, self.decoder_width)
gathered_embs = torch.gather(expanded_embeddings, 1, indices_expanded)
# Fill in valid positions with gathered embeddings
condition_embs[valid_mask] = gathered_embs[valid_mask]
# Fill in padding positions with learnable padding embedding
if padding_mask.any():
# Broadcast padding_embedding to all padding positions
condition_embs[padding_mask] = self.padding_embedding
return condition_embs
def forward(
self,
biosignals,
text: Optional[torch.Tensor] = None,
biosignals_latent: Optional[torch.Tensor] = None,
biosignals_embs: Optional[torch.Tensor] = None,
channel_indices: Optional[torch.Tensor] = None,
output_labels: bool = True,
):
"""Forward pass for BiosignalsCoCa model.
Args:
biosignals: Input biosignals tensor
text: Optional text token ids
biosignals_latent: Optional pre-computed biosignals latent features
biosignals_embs: Optional pre-computed biosignals token embeddings
channel_indices: Optional (batch_size, num_selected_channels) tensor of channel indices
Used to select channel-specific condition embeddings. If provided, overrides condition_embs.
output_labels: Whether to output labels for loss computation
"""
if biosignals_latent is None or biosignals_embs is None:
biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals)
if text is None:
return {"image_features": biosignals_latent, "image_embs": biosignals_embs}
text_latent, token_embs = self._encode_text(text)
# FIXME this isn't an ideal solution, would like to improve -RW
labels: Optional[torch.Tensor] = text[:, 1:] if output_labels else None
if output_labels:
# align text_embs and thus logits with labels for teacher-forcing caption loss
token_embs = token_embs[:, :-1]
# Convert channel indices to condition embeddings if provided
if channel_indices is not None:
condition_embs = self._get_channel_condition_embs(channel_indices)
else:
condition_embs = None
logits = self.text_decoder(biosignals_embs, token_embs, condition_embs=condition_embs)
out_dict = {
"image_features": biosignals_latent,
"text_features": text_latent,
"logits": logits,
"logit_scale": self.logit_scale.exp()
}
if labels is not None:
out_dict["labels"] = labels
if self.logit_bias is not None:
out_dict["logit_bias"] = self.logit_bias
# Optional signal reconstruction
if self.use_signal_decoder:
reconstructed_signal = self.signal_decoder(biosignals_embs)
out_dict["reconstructed_signal"] = reconstructed_signal
out_dict["original_signal"] = biosignals
return out_dict
def generate(
self,
biosignals,
text=None,
seq_len=30,
max_seq_len=256,
temperature=1.,
generation_type="beam_search",
top_p=0.1,
top_k=1,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
repetition_penalty=1.0,
fixed_output_length=False,
condition_embs=None,
channel_indices=None,
):
# taking many ideas and components from HuggingFace GenerationMixin
# https://huggingface.co/docs/transformers/main/en/main_classes/text_generation
assert _has_transformers, "Please install transformers for generate functionality. `pip install transformers`."
assert seq_len > min_seq_len, "seq_len must be larger than min_seq_len"
device = biosignals.device
# Note: condition_embs parameter is for backward compatibility
# We pass channel_indices directly to forward(), which handles the conversion internally
with torch.no_grad():
sot_token_id = _token_to_tensor(sot_token_id, device=device)
eos_token_id = _token_to_tensor(eos_token_id, device=device)
pad_token_id = pad_token_id
logit_processor = LogitsProcessorList(
[
MinLengthLogitsProcessor(min_seq_len, eos_token_id),
RepetitionPenaltyLogitsProcessor(repetition_penalty),
]
)
if stopping_criteria is None:
stopping_criteria = [MaxLengthCriteria(max_length=seq_len)]
stopping_criteria = StoppingCriteriaList(stopping_criteria)
if generation_type == "beam_search":
output = self._generate_beamsearch(
biosignals_inputs=biosignals,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
sot_token_id=sot_token_id,
num_beams=num_beams,
num_beam_groups=num_beam_groups,
min_seq_len=min_seq_len,
stopping_criteria=stopping_criteria,
logit_processor=logit_processor,
channel_indices=channel_indices,
)
if fixed_output_length and output.shape[1] < seq_len:
pad_len = seq_len - output.shape[1]
return torch.cat((
output,
torch.ones(output.shape[0], pad_len, device=device, dtype=output.dtype) * pad_token_id
),
dim=1
)
return output
elif generation_type == "top_p":
logit_warper = GENERATION_TYPES[generation_type](top_p)
elif generation_type == "top_k":
logit_warper = GENERATION_TYPES[generation_type](top_k)
else:
raise ValueError(
f"generation_type has to be one of "
f"{'| ' + ' | '.join(list(GENERATION_TYPES.keys())) + ' |'}."
)
biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals)
if text is None:
text = torch.ones((biosignals.shape[0], 1), device=device, dtype=torch.long) * sot_token_id
was_training = self.training
num_dims = len(text.shape)
if num_dims == 1:
text = text[None, :]
self.eval()
out = text
while True:
x = out[:, -max_seq_len:]
cur_len = x.shape[1]
logits = self(
biosignals,
x,
biosignals_latent=biosignals_latent,
biosignals_embs=biosignals_embs,
channel_indices=channel_indices,
output_labels=False,
)["logits"][:, -1]
mask = (out[:, -1] == eos_token_id) | (out[:, -1] == pad_token_id)
sample = torch.ones((out.shape[0], 1), device=device, dtype=torch.long) * pad_token_id
if mask.all():
if not fixed_output_length:
break
else:
logits = logits[~mask, :]
filtered_logits = logit_processor(x[~mask, :], logits)
filtered_logits = logit_warper(x[~mask, :], filtered_logits)
probs = F.softmax(filtered_logits / temperature, dim=-1)
if (cur_len + 1 == seq_len):
sample[~mask, :] = torch.ones((sum(~mask), 1), device=device, dtype=torch.long) * eos_token_id
else:
sample[~mask, :] = torch.multinomial(probs, 1)
out = torch.cat((out, sample), dim=-1)
cur_len += 1
if all(stopping_criteria(out, None)):
break
if num_dims == 1:
out = out.squeeze(0)
self.train(was_training)
return out
def _generate_beamsearch(
self,
biosignals_inputs,
pad_token_id=None,
eos_token_id=None,
sot_token_id=None,
num_beams=6,
num_beam_groups=3,
min_seq_len=5,
stopping_criteria=None,
logit_processor=None,
logit_warper=None,
channel_indices=None,
):
device = biosignals_inputs.device
batch_size = biosignals_inputs.shape[0]
biosignals_inputs = torch.repeat_interleave(biosignals_inputs, num_beams, dim=0)
biosignals_latent, biosignals_embs = self._encode_biosignals(biosignals_inputs)
# Repeat channel indices for beam search if provided
# forward() will convert them to condition embeddings internally
if channel_indices is not None:
channel_indices = torch.repeat_interleave(channel_indices, num_beams, dim=0)
input_ids = torch.ones((batch_size * num_beams, 1), device=device, dtype=torch.long)
input_ids = input_ids * sot_token_id
beam_scorer = BeamSearchScorer(
batch_size=batch_size,
num_beams=num_beams,
device=device,
num_beam_groups=num_beam_groups,
)
# instantiate logits processors
logits_processor = (
LogitsProcessorList([MinLengthLogitsProcessor(min_seq_len, eos_token_id=eos_token_id)])
if logit_processor is None
else logit_processor
)
num_beams = beam_scorer.num_beams
num_beam_groups = beam_scorer.num_beam_groups
num_sub_beams = num_beams // num_beam_groups
batch_size = len(beam_scorer._beam_hyps) // num_beam_groups
batch_beam_size, cur_len = input_ids.shape
beam_indices = None
if num_beams * batch_size != batch_beam_size:
raise ValueError(
f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
)
beam_scores = torch.full((batch_size, num_beams), -1e9, dtype=torch.float, device=device)
# initialise score of first beam of each group with 0 and the rest with 1e-9. This ensures that the beams in
# the same group don't produce same tokens everytime.
beam_scores[:, ::num_sub_beams] = 0
beam_scores = beam_scores.view((batch_size * num_beams,))
while True:
# predicted tokens in cur_len step
current_tokens = torch.zeros(batch_size * num_beams, dtype=input_ids.dtype, device=device)
# indices which will form the beams in the next time step
reordering_indices = torch.zeros(batch_size * num_beams, dtype=torch.long, device=device)
# do one decoder step on all beams of all sentences in batch
model_inputs = prepare_inputs_for_generation(input_ids=input_ids, biosignals_inputs=biosignals_inputs)
outputs = self(
model_inputs['biosignals'],
model_inputs['text'],
biosignals_latent=biosignals_latent,
biosignals_embs=biosignals_embs,
channel_indices=channel_indices,
output_labels=False,
)
for beam_group_idx in range(num_beam_groups):
group_start_idx = beam_group_idx * num_sub_beams
group_end_idx = min(group_start_idx + num_sub_beams, num_beams)
group_size = group_end_idx - group_start_idx
# indices of beams of current group among all sentences in batch
batch_group_indices = []
for batch_idx in range(batch_size):
batch_group_indices.extend(
[batch_idx * num_beams + idx for idx in range(group_start_idx, group_end_idx)]
)
group_input_ids = input_ids[batch_group_indices]
# select outputs of beams of currentg group only
next_token_logits = outputs['logits'][batch_group_indices, -1, :]
vocab_size = next_token_logits.shape[-1]
next_token_scores_processed = logits_processor(
group_input_ids, next_token_logits, current_tokens=current_tokens, beam_group_idx=beam_group_idx
)
next_token_scores = next_token_scores_processed + beam_scores[batch_group_indices].unsqueeze(-1)
next_token_scores = next_token_scores.expand_as(next_token_scores_processed)
# reshape for beam search
next_token_scores = next_token_scores.view(batch_size, group_size * vocab_size)
next_token_scores, next_tokens = torch.topk(
next_token_scores, 2 * group_size, dim=1, largest=True, sorted=True
)
next_indices = torch.div(next_tokens, vocab_size, rounding_mode="floor")
next_tokens = next_tokens % vocab_size
# stateless
process_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
beam_outputs = beam_scorer.process(
group_input_ids,
next_token_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
beam_indices=process_beam_indices,
group_index=beam_group_idx,
)
beam_scores[batch_group_indices] = beam_outputs["next_beam_scores"]
beam_next_tokens = beam_outputs["next_beam_tokens"]
beam_idx = beam_outputs["next_beam_indices"]
input_ids[batch_group_indices] = group_input_ids[beam_idx]
group_input_ids = torch.cat([group_input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
current_tokens[batch_group_indices] = group_input_ids[:, -1]
# (beam_idx // group_size) -> batch_idx
# (beam_idx % group_size) -> offset of idx inside the group
reordering_indices[batch_group_indices] = (
num_beams * torch.div(beam_idx, group_size, rounding_mode="floor") + group_start_idx + (beam_idx % group_size)
)
input_ids = torch.cat([input_ids, current_tokens.unsqueeze(-1)], dim=-1)
# increase cur_len
cur_len = cur_len + 1
if beam_scorer.is_done or all(stopping_criteria(input_ids, None)):
break
final_beam_indices = sum(beam_indices, ()) if beam_indices is not None else None
sequence_outputs = beam_scorer.finalize(
input_ids,
beam_scores,
next_tokens,
next_indices,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
max_length=stopping_criteria.max_length,
beam_indices=final_beam_indices,
)
return sequence_outputs['sequences']
def prepare_inputs_for_generation(input_ids, biosignals_inputs, past=None, **kwargs):
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
attention_mask = kwargs.get("attention_mask", None)
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
else:
position_ids = None
return {
"text": input_ids,
"biosignals": biosignals_inputs,
"past_key_values": past,
"position_ids": position_ids,
"attention_mask": attention_mask,
}