VisionLanguageGroup's picture
clean up
86072ea
raw
history blame
20.2 kB
"""Transformer class."""
import logging
import math
from collections import OrderedDict
from pathlib import Path
from typing import Literal, Tuple
import torch
import torch.nn.functional as F
import yaml
from torch import nn
import sys, os
from .utils import blockwise_causal_norm
logger = logging.getLogger(__name__)
def _pos_embed_fourier1d_init(
cutoff: float = 256, n: int = 32, cutoff_start: float = 1
):
return (
torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n))
.unsqueeze(0)
.unsqueeze(0)
)
def _rope_pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32):
# Maximum initial frequency is 1
return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0)
def _rotate_half(x: torch.Tensor) -> torch.Tensor:
"""Rotate pairs of scalars as 2d vectors by pi/2."""
x = x.unflatten(-1, (-1, 2))
x1, x2 = x.unbind(dim=-1)
return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2)
class RotaryPositionalEncoding(nn.Module):
def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)):
super().__init__()
assert len(cutoffs) == len(n_pos)
if not all(n % 2 == 0 for n in n_pos):
raise ValueError("n_pos must be even")
self._n_dim = len(cutoffs)
self.freqs = nn.ParameterList([
nn.Parameter(_rope_pos_embed_fourier1d_init(cutoff, n // 2))
for cutoff, n in zip(cutoffs, n_pos)
])
def get_co_si(self, coords: torch.Tensor):
_B, _N, D = coords.shape
assert D == len(self.freqs)
co = torch.cat(
tuple(
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
),
axis=-1,
)
si = torch.cat(
tuple(
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq))
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
),
axis=-1,
)
return co, si
def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor):
_B, _N, D = coords.shape
_B, _H, _N, _C = q.shape
if D != self._n_dim:
raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}")
co, si = self.get_co_si(coords)
co = co.unsqueeze(1).repeat_interleave(2, dim=-1)
si = si.unsqueeze(1).repeat_interleave(2, dim=-1)
q2 = q * co + _rotate_half(q) * si
k2 = k * co + _rotate_half(k) * si
return q2, k2
class FeedForward(nn.Module):
def __init__(self, d_model, expand: float = 2, bias: bool = True):
super().__init__()
self.fc1 = nn.Linear(d_model, int(d_model * expand))
self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias)
self.act = nn.GELU()
def forward(self, x):
return self.fc2(self.act(self.fc1(x)))
class PositionalEncoding(nn.Module):
def __init__(
self,
cutoffs: Tuple[float] = (256,),
n_pos: Tuple[int] = (32,),
cutoffs_start=None,
):
super().__init__()
if cutoffs_start is None:
cutoffs_start = (1,) * len(cutoffs)
assert len(cutoffs) == len(n_pos)
self.freqs = nn.ParameterList([
nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2))
for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start)
])
def forward(self, coords: torch.Tensor):
_B, _N, D = coords.shape
assert D == len(self.freqs)
embed = torch.cat(
tuple(
torch.cat(
(
torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq),
torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq),
),
axis=-1,
)
/ math.sqrt(len(freq))
for x, freq in zip(coords.moveaxis(-1, 0), self.freqs)
),
axis=-1,
)
return embed
def _bin_init_exp(cutoff: float, n: int):
return torch.exp(torch.linspace(0, math.log(cutoff + 1), n))
def _bin_init_linear(cutoff: float, n: int):
return torch.linspace(-cutoff, cutoff, n)
class RelativePositionalBias(nn.Module):
def __init__(
self,
n_head: int,
cutoff_spatial: float,
cutoff_temporal: float,
n_spatial: int = 32,
n_temporal: int = 16,
):
super().__init__()
self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial)
self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1)
self.register_buffer("spatial_bins", self._spatial_bins)
self.register_buffer("temporal_bins", self._temporal_bins)
self.n_spatial = n_spatial
self.n_head = n_head
self.bias = nn.Parameter(
-0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head)
)
def forward(self, coords: torch.Tensor):
_B, _N, _D = coords.shape
t = coords[..., 0]
yx = coords[..., 1:]
temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2)
spatial_dist = torch.cdist(yx, yx)
spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins)
torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1)
temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins)
torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1)
idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial
bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head))
bias = bias.transpose(-1, 1)
return bias
class RelativePositionalAttention(nn.Module):
def __init__(
self,
coord_dim: int,
embed_dim: int,
n_head: int,
cutoff_spatial: float = 256,
cutoff_temporal: float = 16,
n_spatial: int = 32,
n_temporal: int = 16,
dropout: float = 0.0,
mode: Literal["bias", "rope", "none"] = "bias",
attn_dist_mode: str = "v0",
):
super().__init__()
if not embed_dim % (2 * n_head) == 0:
raise ValueError(
f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}"
)
self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True)
self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True)
self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True)
self.proj = nn.Linear(embed_dim, embed_dim)
self.dropout = dropout
self.n_head = n_head
self.embed_dim = embed_dim
self.cutoff_spatial = cutoff_spatial
self.attn_dist_mode = attn_dist_mode
if mode == "bias" or mode is True:
self.pos_bias = RelativePositionalBias(
n_head=n_head,
cutoff_spatial=cutoff_spatial,
cutoff_temporal=cutoff_temporal,
n_spatial=n_spatial,
n_temporal=n_temporal,
)
elif mode == "rope":
n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head))
self.rot_pos_enc = RotaryPositionalEncoding(
cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim),
n_pos=(embed_dim // n_head - coord_dim * n_split,)
+ (n_split,) * coord_dim,
)
elif mode == "none":
pass
elif mode is None or mode is False:
logger.warning(
"attn_positional_bias is not set (None or False), no positional bias."
)
else:
raise ValueError(f"Unknown mode {mode}")
self._mode = mode
def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
coords: torch.Tensor,
padding_mask: torch.Tensor = None,
):
B, N, D = query.size()
q = self.q_pro(query)
k = self.k_pro(key)
v = self.v_pro(value)
k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2)
attn_mask = torch.zeros(
(B, self.n_head, N, N), device=query.device, dtype=q.dtype
)
attn_ignore_val = -1e3
yx = coords[..., 1:]
spatial_dist = torch.cdist(yx, yx)
spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1)
attn_mask.masked_fill_(spatial_mask, attn_ignore_val)
if coords is not None:
if self._mode == "bias":
attn_mask = attn_mask + self.pos_bias(coords)
elif self._mode == "rope":
q, k = self.rot_pos_enc(q, k, coords)
if self.attn_dist_mode == "v0":
dist = torch.cdist(coords, coords, p=2)
attn_mask += torch.exp(-0.1 * dist.unsqueeze(1))
elif self.attn_dist_mode == "v1":
attn_mask += torch.exp(
-5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial
)
else:
raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}")
if padding_mask is not None:
ignore_mask = torch.logical_or(
padding_mask.unsqueeze(1), padding_mask.unsqueeze(2)
).unsqueeze(1)
attn_mask.masked_fill_(ignore_mask, attn_ignore_val)
y = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0
)
y = y.transpose(1, 2).contiguous().view(B, N, D)
y = self.proj(y)
return y
class EncoderLayer(nn.Module):
def __init__(
self,
coord_dim: int = 2,
d_model=256,
num_heads=4,
dropout=0.1,
cutoff_spatial: int = 256,
window: int = 16,
positional_bias: Literal["bias", "rope", "none"] = "bias",
positional_bias_n_spatial: int = 32,
attn_dist_mode: str = "v0",
):
super().__init__()
self.positional_bias = positional_bias
self.attn = RelativePositionalAttention(
coord_dim,
d_model,
num_heads,
cutoff_spatial=cutoff_spatial,
n_spatial=positional_bias_n_spatial,
cutoff_temporal=window,
n_temporal=window,
dropout=dropout,
mode=positional_bias,
attn_dist_mode=attn_dist_mode,
)
self.mlp = FeedForward(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
coords: torch.Tensor,
padding_mask: torch.Tensor = None,
):
x = self.norm1(x)
# setting coords to None disables positional bias
a = self.attn(
x,
x,
x,
coords=coords if self.positional_bias else None,
padding_mask=padding_mask,
)
x = x + a
x = x + self.mlp(self.norm2(x))
return x
class DecoderLayer(nn.Module):
def __init__(
self,
coord_dim: int = 2,
d_model=256,
num_heads=4,
dropout=0.1,
window: int = 16,
cutoff_spatial: int = 256,
positional_bias: Literal["bias", "rope", "none"] = "bias",
positional_bias_n_spatial: int = 32,
attn_dist_mode: str = "v0",
):
super().__init__()
self.positional_bias = positional_bias
self.attn = RelativePositionalAttention(
coord_dim,
d_model,
num_heads,
cutoff_spatial=cutoff_spatial,
n_spatial=positional_bias_n_spatial,
cutoff_temporal=window,
n_temporal=window,
dropout=dropout,
mode=positional_bias,
attn_dist_mode=attn_dist_mode,
)
self.mlp = FeedForward(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.norm3 = nn.LayerNorm(d_model)
def forward(
self,
x: torch.Tensor,
y: torch.Tensor,
coords: torch.Tensor,
padding_mask: torch.Tensor = None,
):
x = self.norm1(x)
y = self.norm2(y)
# cross attention
# setting coords to None disables positional bias
a = self.attn(
x,
y,
y,
coords=coords if self.positional_bias else None,
padding_mask=padding_mask,
)
x = x + a
x = x + self.mlp(self.norm3(x))
return x
class TrackingTransformer(torch.nn.Module):
def __init__(
self,
coord_dim: int = 3,
feat_dim: int = 0,
d_model: int = 128,
nhead: int = 4,
num_encoder_layers: int = 4,
num_decoder_layers: int = 4,
dropout: float = 0.1,
pos_embed_per_dim: int = 32,
feat_embed_per_dim: int = 1,
window: int = 6,
spatial_pos_cutoff: int = 256,
attn_positional_bias: Literal["bias", "rope", "none"] = "rope",
attn_positional_bias_n_spatial: int = 16,
causal_norm: Literal[
"none", "linear", "softmax", "quiet_softmax"
] = "quiet_softmax",
attn_dist_mode: str = "v0",
):
super().__init__()
self.config = dict(
coord_dim=coord_dim,
feat_dim=feat_dim,
pos_embed_per_dim=pos_embed_per_dim,
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_encoder_layers,
num_decoder_layers=num_decoder_layers,
window=window,
dropout=dropout,
attn_positional_bias=attn_positional_bias,
attn_positional_bias_n_spatial=attn_positional_bias_n_spatial,
spatial_pos_cutoff=spatial_pos_cutoff,
feat_embed_per_dim=feat_embed_per_dim,
causal_norm=causal_norm,
attn_dist_mode=attn_dist_mode,
)
# TODO remove, alredy present in self.config
# self.window = window
# self.feat_dim = feat_dim
# self.coord_dim = coord_dim
self.proj = nn.Linear(
(1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model
)
self.norm = nn.LayerNorm(d_model)
self.encoder = nn.ModuleList([
EncoderLayer(
coord_dim,
d_model,
nhead,
dropout,
window=window,
cutoff_spatial=spatial_pos_cutoff,
positional_bias=attn_positional_bias,
positional_bias_n_spatial=attn_positional_bias_n_spatial,
attn_dist_mode=attn_dist_mode,
)
for _ in range(num_encoder_layers)
])
self.decoder = nn.ModuleList([
DecoderLayer(
coord_dim,
d_model,
nhead,
dropout,
window=window,
cutoff_spatial=spatial_pos_cutoff,
positional_bias=attn_positional_bias,
positional_bias_n_spatial=attn_positional_bias_n_spatial,
attn_dist_mode=attn_dist_mode,
)
for _ in range(num_decoder_layers)
])
self.head_x = FeedForward(d_model)
self.head_y = FeedForward(d_model)
if feat_embed_per_dim > 1:
self.feat_embed = PositionalEncoding(
cutoffs=(1000,) * feat_dim,
n_pos=(feat_embed_per_dim,) * feat_dim,
cutoffs_start=(0.01,) * feat_dim,
)
else:
self.feat_embed = nn.Identity()
self.pos_embed = PositionalEncoding(
cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim,
n_pos=(pos_embed_per_dim,) * (1 + coord_dim),
)
# self.pos_embed = NoPositionalEncoding(d=pos_embed_per_dim * (1 + coord_dim))
# @profile
def forward(self, coords, features=None, padding_mask=None, attn_feat=None):
assert coords.ndim == 3 and coords.shape[-1] in (3, 4)
_B, _N, _D = coords.shape
# disable padded coords (such that it doesnt affect minimum)
if padding_mask is not None:
coords = coords.clone()
coords[padding_mask] = coords.max()
# remove temporal offset
min_time = coords[:, :, :1].min(dim=1, keepdims=True).values
coords = coords - min_time
pos = self.pos_embed(coords)
if features is None or features.numel() == 0:
features = pos
else:
features = self.feat_embed(features)
features = torch.cat((pos, features), axis=-1)
features = self.proj(features)
if attn_feat is not None:
# add attention embedding
features = features + attn_feat
features = self.norm(features)
x = features
# encoder
for enc in self.encoder:
x = enc(x, coords=coords, padding_mask=padding_mask)
y = features
# decoder w cross attention
for dec in self.decoder:
y = dec(y, x, coords=coords, padding_mask=padding_mask)
# y = dec(y, y, coords=coords, padding_mask=padding_mask)
x = self.head_x(x)
y = self.head_y(y)
# outer product is the association matrix (logits)
A = torch.einsum("bnd,bmd->bnm", x, y)
return A
def normalize_output(
self,
A: torch.FloatTensor,
timepoints: torch.LongTensor,
coords: torch.FloatTensor,
) -> torch.FloatTensor:
"""Apply (parental) softmax, or elementwise sigmoid.
Args:
A: Tensor of shape B, N, N
timepoints: Tensor of shape B, N
coords: Tensor of shape B, N, (time + n_spatial)
"""
assert A.ndim == 3
assert timepoints.ndim == 2
assert coords.ndim == 3
assert coords.shape[2] == 1 + self.config["coord_dim"]
# spatial distances
dist = torch.cdist(coords[:, :, 1:], coords[:, :, 1:])
invalid = dist > self.config["spatial_pos_cutoff"]
if self.config["causal_norm"] == "none":
# Spatially distant entries are set to zero
A = torch.sigmoid(A)
A[invalid] = 0
else:
return torch.stack([
blockwise_causal_norm(
_A, _t, mode=self.config["causal_norm"], mask_invalid=_m
)
for _A, _t, _m in zip(A, timepoints, invalid)
])
return A
def save(self, folder):
folder = Path(folder)
folder.mkdir(parents=True, exist_ok=True)
yaml.safe_dump(self.config, open(folder / "config.yaml", "w"))
torch.save(self.state_dict(), folder / "model.pt")
@classmethod
def from_folder(
cls, folder, map_location=None, checkpoint_path: str = "model.pt"
):
folder = Path(folder)
config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader)
model = cls(**config)
fpath = folder / checkpoint_path
logger.info(f"Loading model state from {fpath}")
state = torch.load(fpath, map_location=map_location, weights_only=True)
# if state is a checkpoint, we have to extract state_dict
if "state_dict" in state:
state = state["state_dict"]
state = OrderedDict(
(k[6:], v) for k, v in state.items() if k.startswith("model.")
)
model.load_state_dict(state)
return model
@classmethod
def from_cfg(
cls, cfg_path
):
cfg_path = Path(cfg_path)
config = yaml.load(open(cfg_path), Loader=yaml.FullLoader)
model = cls(**config)
return model