"""Transformer class.""" import logging from collections import OrderedDict from pathlib import Path from typing import Literal import torch # from torch_geometric.nn import GATv2Conv import yaml from torch import nn import sys, os sys.path.append(os.path.join(os.getcwd(), "External_Repos", "trackastra")) # NoPositionalEncoding, from ..utils import blockwise_causal_norm from .model_parts import ( FeedForward, PositionalEncoding, RelativePositionalAttention, ) # from memory_profiler import profile logger = logging.getLogger(__name__) class EncoderLayer(nn.Module): def __init__( self, coord_dim: int = 2, d_model=256, num_heads=4, dropout=0.1, cutoff_spatial: int = 256, window: int = 16, positional_bias: Literal["bias", "rope", "none"] = "bias", positional_bias_n_spatial: int = 32, attn_dist_mode: str = "v0", ): super().__init__() self.positional_bias = positional_bias self.attn = RelativePositionalAttention( coord_dim, d_model, num_heads, cutoff_spatial=cutoff_spatial, n_spatial=positional_bias_n_spatial, cutoff_temporal=window, n_temporal=window, dropout=dropout, mode=positional_bias, attn_dist_mode=attn_dist_mode, ) self.mlp = FeedForward(d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) def forward( self, x: torch.Tensor, coords: torch.Tensor, padding_mask: torch.Tensor = None, ): x = self.norm1(x) # setting coords to None disables positional bias a = self.attn( x, x, x, coords=coords if self.positional_bias else None, padding_mask=padding_mask, ) x = x + a x = x + self.mlp(self.norm2(x)) return x class DecoderLayer(nn.Module): def __init__( self, coord_dim: int = 2, d_model=256, num_heads=4, dropout=0.1, window: int = 16, cutoff_spatial: int = 256, positional_bias: Literal["bias", "rope", "none"] = "bias", positional_bias_n_spatial: int = 32, attn_dist_mode: str = "v0", ): super().__init__() self.positional_bias = positional_bias self.attn = RelativePositionalAttention( coord_dim, d_model, num_heads, cutoff_spatial=cutoff_spatial, n_spatial=positional_bias_n_spatial, cutoff_temporal=window, n_temporal=window, dropout=dropout, mode=positional_bias, attn_dist_mode=attn_dist_mode, ) self.mlp = FeedForward(d_model) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) def forward( self, x: torch.Tensor, y: torch.Tensor, coords: torch.Tensor, padding_mask: torch.Tensor = None, ): x = self.norm1(x) y = self.norm2(y) # cross attention # setting coords to None disables positional bias a = self.attn( x, y, y, coords=coords if self.positional_bias else None, padding_mask=padding_mask, ) x = x + a x = x + self.mlp(self.norm3(x)) return x # class BidirectionalRelativePositionalAttention(RelativePositionalAttention): # def forward( # self, # query1: torch.Tensor, # query2: torch.Tensor, # coords: torch.Tensor, # padding_mask: torch.Tensor = None, # ): # B, N, D = query1.size() # q1 = self.q_pro(query1) # (B, N, D) # q2 = self.q_pro(query2) # (B, N, D) # v1 = self.v_pro(query1) # (B, N, D) # v2 = self.v_pro(query2) # (B, N, D) # # (B, nh, N, hs) # q1 = q1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) # v1 = v1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) # q2 = q2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) # v2 = v2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) # attn_mask = torch.zeros( # (B, self.n_head, N, N), device=query1.device, dtype=q1.dtype # ) # # add negative value but not too large to keep mixed precision loss from becoming nan # attn_ignore_val = -1e3 # # spatial cutoff # yx = coords[..., 1:] # spatial_dist = torch.cdist(yx, yx) # spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) # attn_mask.masked_fill_(spatial_mask, attn_ignore_val) # # dont add positional bias to self-attention if coords is None # if coords is not None: # if self._mode == "bias": # attn_mask = attn_mask + self.pos_bias(coords) # elif self._mode == "rope": # q1, q2 = self.rot_pos_enc(q1, q2, coords) # else: # pass # dist = torch.cdist(coords, coords, p=2) # attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) # # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens) # if padding_mask is not None: # ignore_mask = torch.logical_or( # padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) # ).unsqueeze(1) # attn_mask.masked_fill_(ignore_mask, attn_ignore_val) # self.attn_mask = attn_mask.clone() # y1 = nn.functional.scaled_dot_product_attention( # q1, # q2, # v1, # attn_mask=attn_mask, # dropout_p=self.dropout if self.training else 0, # ) # y2 = nn.functional.scaled_dot_product_attention( # q2, # q1, # v2, # attn_mask=attn_mask, # dropout_p=self.dropout if self.training else 0, # ) # y1 = y1.transpose(1, 2).contiguous().view(B, N, D) # y1 = self.proj(y1) # y2 = y2.transpose(1, 2).contiguous().view(B, N, D) # y2 = self.proj(y2) # return y1, y2 # class BidirectionalCrossAttention(nn.Module): # def __init__( # self, # coord_dim: int = 2, # d_model=256, # num_heads=4, # dropout=0.1, # window: int = 16, # cutoff_spatial: int = 256, # positional_bias: Literal["bias", "rope", "none"] = "bias", # positional_bias_n_spatial: int = 32, # ): # super().__init__() # self.positional_bias = positional_bias # self.attn = BidirectionalRelativePositionalAttention( # coord_dim, # d_model, # num_heads, # cutoff_spatial=cutoff_spatial, # n_spatial=positional_bias_n_spatial, # cutoff_temporal=window, # n_temporal=window, # dropout=dropout, # mode=positional_bias, # ) # self.mlp = FeedForward(d_model) # self.norm1 = nn.LayerNorm(d_model) # self.norm2 = nn.LayerNorm(d_model) # def forward( # self, # x: torch.Tensor, # y: torch.Tensor, # coords: torch.Tensor, # padding_mask: torch.Tensor = None, # ): # x = self.norm1(x) # y = self.norm1(y) # # cross attention # # setting coords to None disables positional bias # x2, y2 = self.attn( # x, # y, # coords=coords if self.positional_bias else None, # padding_mask=padding_mask, # ) # # print(torch.norm(x2).item()/torch.norm(x).item()) # x = x + x2 # x = x + self.mlp(self.norm2(x)) # y = y + y2 # y = y + self.mlp(self.norm2(y)) # return x, y class TrackingTransformer(torch.nn.Module): def __init__( self, coord_dim: int = 3, feat_dim: int = 0, d_model: int = 128, nhead: int = 4, num_encoder_layers: int = 4, num_decoder_layers: int = 4, dropout: float = 0.1, pos_embed_per_dim: int = 32, feat_embed_per_dim: int = 1, window: int = 6, spatial_pos_cutoff: int = 256, attn_positional_bias: Literal["bias", "rope", "none"] = "rope", attn_positional_bias_n_spatial: int = 16, causal_norm: Literal[ "none", "linear", "softmax", "quiet_softmax" ] = "quiet_softmax", attn_dist_mode: str = "v0", ): super().__init__() self.config = dict( coord_dim=coord_dim, feat_dim=feat_dim, pos_embed_per_dim=pos_embed_per_dim, d_model=d_model, nhead=nhead, num_encoder_layers=num_encoder_layers, num_decoder_layers=num_decoder_layers, window=window, dropout=dropout, attn_positional_bias=attn_positional_bias, attn_positional_bias_n_spatial=attn_positional_bias_n_spatial, spatial_pos_cutoff=spatial_pos_cutoff, feat_embed_per_dim=feat_embed_per_dim, causal_norm=causal_norm, attn_dist_mode=attn_dist_mode, ) # TODO remove, alredy present in self.config # self.window = window # self.feat_dim = feat_dim # self.coord_dim = coord_dim self.proj = nn.Linear( (1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model ) self.norm = nn.LayerNorm(d_model) self.encoder = nn.ModuleList([ EncoderLayer( coord_dim, d_model, nhead, dropout, window=window, cutoff_spatial=spatial_pos_cutoff, positional_bias=attn_positional_bias, positional_bias_n_spatial=attn_positional_bias_n_spatial, attn_dist_mode=attn_dist_mode, ) for _ in range(num_encoder_layers) ]) self.decoder = nn.ModuleList([ DecoderLayer( coord_dim, d_model, nhead, dropout, window=window, cutoff_spatial=spatial_pos_cutoff, positional_bias=attn_positional_bias, positional_bias_n_spatial=attn_positional_bias_n_spatial, attn_dist_mode=attn_dist_mode, ) for _ in range(num_decoder_layers) ]) self.head_x = FeedForward(d_model) self.head_y = FeedForward(d_model) if feat_embed_per_dim > 1: self.feat_embed = PositionalEncoding( cutoffs=(1000,) * feat_dim, n_pos=(feat_embed_per_dim,) * feat_dim, cutoffs_start=(0.01,) * feat_dim, ) else: self.feat_embed = nn.Identity() self.pos_embed = PositionalEncoding( cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim, n_pos=(pos_embed_per_dim,) * (1 + coord_dim), ) # self.pos_embed = NoPositionalEncoding(d=pos_embed_per_dim * (1 + coord_dim)) # @profile def forward(self, coords, features=None, padding_mask=None, attn_feat=None): assert coords.ndim == 3 and coords.shape[-1] in (3, 4) _B, _N, _D = coords.shape # disable padded coords (such that it doesnt affect minimum) if padding_mask is not None: coords = coords.clone() coords[padding_mask] = coords.max() # remove temporal offset min_time = coords[:, :, :1].min(dim=1, keepdims=True).values coords = coords - min_time pos = self.pos_embed(coords) if features is None or features.numel() == 0: features = pos else: features = self.feat_embed(features) features = torch.cat((pos, features), axis=-1) features = self.proj(features) if attn_feat is not None: # add attention embedding features = features + attn_feat features = self.norm(features) x = features # encoder for enc in self.encoder: x = enc(x, coords=coords, padding_mask=padding_mask) y = features # decoder w cross attention for dec in self.decoder: y = dec(y, x, coords=coords, padding_mask=padding_mask) # y = dec(y, y, coords=coords, padding_mask=padding_mask) x = self.head_x(x) y = self.head_y(y) # outer product is the association matrix (logits) A = torch.einsum("bnd,bmd->bnm", x, y) return A def normalize_output( self, A: torch.FloatTensor, timepoints: torch.LongTensor, coords: torch.FloatTensor, ) -> torch.FloatTensor: """Apply (parental) softmax, or elementwise sigmoid. Args: A: Tensor of shape B, N, N timepoints: Tensor of shape B, N coords: Tensor of shape B, N, (time + n_spatial) """ assert A.ndim == 3 assert timepoints.ndim == 2 assert coords.ndim == 3 assert coords.shape[2] == 1 + self.config["coord_dim"] # spatial distances dist = torch.cdist(coords[:, :, 1:], coords[:, :, 1:]) invalid = dist > self.config["spatial_pos_cutoff"] if self.config["causal_norm"] == "none": # Spatially distant entries are set to zero A = torch.sigmoid(A) A[invalid] = 0 else: return torch.stack([ blockwise_causal_norm( _A, _t, mode=self.config["causal_norm"], mask_invalid=_m ) for _A, _t, _m in zip(A, timepoints, invalid) ]) return A def save(self, folder): folder = Path(folder) folder.mkdir(parents=True, exist_ok=True) yaml.safe_dump(self.config, open(folder / "config.yaml", "w")) torch.save(self.state_dict(), folder / "model.pt") @classmethod def from_folder( cls, folder, map_location=None, args=None, checkpoint_path: str = "model.pt" ): folder = Path(folder) config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader) if args: args = vars(args) for k, v in config.items(): errors = [] if k in args: if config[k] != args[k]: errors.append( f"Loaded model config {k}={config[k]}, but current argument" f" {k}={args[k]}." ) if errors: raise ValueError("\n".join(errors)) model = cls(**config) # try: # # Try to load from lightning checkpoint first # v_folder = sorted((folder / "tb").glob("version_*"))[version] # checkpoint = sorted((v_folder / "checkpoints").glob("*epoch*.ckpt"))[0] # pl_state_dict = torch.load(checkpoint, map_location=map_location)[ # "state_dict" # ] # state_dict = OrderedDict() # # Hack # for k, v in pl_state_dict.items(): # if k.startswith("model."): # state_dict[k[6:]] = v # else: # raise ValueError(f"Unexpected key {k} in state_dict") # model.load_state_dict(state_dict) # logger.info(f"Loaded model from {checkpoint}") # except: # # Default: Load manually saved model (legacy) fpath = folder / checkpoint_path logger.info(f"Loading model state from {fpath}") state = torch.load(fpath, map_location=map_location, weights_only=True) # if state is a checkpoint, we have to extract state_dict if "state_dict" in state: state = state["state_dict"] state = OrderedDict( (k[6:], v) for k, v in state.items() if k.startswith("model.") ) model.load_state_dict(state) return model @classmethod def from_cfg( cls, cfg_path, args=None ): cfg_path = Path(cfg_path) config = yaml.load(open(cfg_path), Loader=yaml.FullLoader) if args: args = vars(args) for k, v in config.items(): errors = [] if k in args: if config[k] != args[k]: errors.append( f"Loaded model config {k}={config[k]}, but current argument" f" {k}={args[k]}." ) if errors: raise ValueError("\n".join(errors)) model = cls(**config) return model