forensics-grpo / code /libs /modeling /backbones.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
8.43 kB
import torch
from torch import nn
from torch.nn import functional as F
from .models import register_backbone
from .blocks import (get_sinusoid_encoding, TransformerBlock, MaskedConv1D,
ConvBlock, LayerNorm)
@register_backbone("convTransformer")
class ConvTransformerBackbone(nn.Module):
"""
A backbone that combines convolutions with transformers
"""
def __init__(
self,
n_in, # input feature dimension
n_embd, # embedding dimension (after convolution)
n_head, # number of head for self-attention in transformers
n_embd_ks, # conv kernel size of the embedding network
max_len, # max sequence length
arch = (2, 2, 5), # (#convs, #stem transformers, #branch transformers)
mha_win_size = [-1]*6, # size of local window for mha
scale_factor = 2, # dowsampling rate for the branch,
with_ln = False, # if to attach layernorm after conv
attn_pdrop = 0.0, # dropout rate for the attention map
proj_pdrop = 0.0, # dropout rate for the projection / MLP
path_pdrop = 0.0, # droput rate for drop path
use_abs_pe = False, # use absolute position embedding
use_rel_pe = False, # use relative position embedding
):
super().__init__()
assert len(arch) == 3
assert len(mha_win_size) == (1 + arch[2])
self.arch = arch
self.mha_win_size = mha_win_size
self.max_len = max_len
self.relu = nn.ReLU(inplace=True)
self.scale_factor = scale_factor
self.use_abs_pe = use_abs_pe
self.use_rel_pe = use_rel_pe
# position embedding (1, C, T), rescaled by 1/sqrt(n_embd)
if self.use_abs_pe:
pos_embd = get_sinusoid_encoding(self.max_len, n_embd) / (n_embd**0.5)
self.register_buffer("pos_embd", pos_embd, persistent=False)
# embedding network using convs
self.embd = nn.ModuleList()
self.embd_norm = nn.ModuleList()
for idx in range(arch[0]):
if idx == 0:
in_channels = n_in
else:
in_channels = n_embd
self.embd.append(MaskedConv1D(
in_channels, n_embd, n_embd_ks,
stride=1, padding=n_embd_ks//2, bias=(not with_ln)
)
)
if with_ln:
self.embd_norm.append(
LayerNorm(n_embd)
)
else:
self.embd_norm.append(nn.Identity())
# stem network using (vanilla) transformer
self.stem = nn.ModuleList()
for idx in range(arch[1]):
self.stem.append(TransformerBlock(
n_embd, n_head,
n_ds_strides=(1, 1),
attn_pdrop=attn_pdrop,
proj_pdrop=proj_pdrop,
path_pdrop=path_pdrop,
mha_win_size=self.mha_win_size[0],
use_rel_pe=self.use_rel_pe
)
)
# main branch using transformer with pooling
self.branch = nn.ModuleList()
for idx in range(arch[2]):
self.branch.append(TransformerBlock(
n_embd, n_head,
n_ds_strides=(self.scale_factor, self.scale_factor),
attn_pdrop=attn_pdrop,
proj_pdrop=proj_pdrop,
path_pdrop=path_pdrop,
mha_win_size=self.mha_win_size[1+idx],
use_rel_pe=self.use_rel_pe
)
)
# init weights
self.apply(self.__init_weights__)
def __init_weights__(self, module):
# set nn.Linear/nn.Conv1d bias term to 0
if isinstance(module, (nn.Linear, nn.Conv1d)):
if module.bias is not None:
torch.nn.init.constant_(module.bias, 0.)
def forward(self, x, mask):
# x: batch size, feature channel, sequence length,
# mask: batch size, 1, sequence length (bool)
B, C, T = x.size()
# embedding network
for idx in range(len(self.embd)):
x, mask = self.embd[idx](x, mask)
x = self.relu(self.embd_norm[idx](x))
# training: using fixed length position embeddings
if self.use_abs_pe and self.training:
assert T <= self.max_len, "Reached max length."
pe = self.pos_embd
# add pe to x
x = x + pe[:, :, :T] * mask.to(x.dtype)
# inference: re-interpolate position embeddings for over-length sequences
if self.use_abs_pe and (not self.training):
if T >= self.max_len:
pe = F.interpolate(
self.pos_embd, T, mode='linear', align_corners=False)
else:
pe = self.pos_embd
# add pe to x
x = x + pe[:, :, :T] * mask.to(x.dtype)
# stem transformer
for idx in range(len(self.stem)):
x, mask = self.stem[idx](x, mask)
# prep for outputs
out_feats = tuple()
out_masks = tuple()
# 1x resolution
out_feats += (x, )
out_masks += (mask, )
# main branch with downsampling
for idx in range(len(self.branch)):
x, mask = self.branch[idx](x, mask)
out_feats += (x, )
out_masks += (mask, )
return out_feats, out_masks
@register_backbone("conv")
class ConvBackbone(nn.Module):
"""
A backbone that with only conv
"""
def __init__(
self,
n_in, # input feature dimension
n_embd, # embedding dimension (after convolution)
n_embd_ks, # conv kernel size of the embedding network
arch = (2, 2, 5), # (#convs, #stem convs, #branch convs)
scale_factor = 2, # dowsampling rate for the branch
with_ln=False, # if to use layernorm
):
super().__init__()
assert len(arch) == 3
self.arch = arch
self.relu = nn.ReLU(inplace=True)
self.scale_factor = scale_factor
# embedding network using convs
self.embd = nn.ModuleList()
self.embd_norm = nn.ModuleList()
for idx in range(arch[0]):
if idx == 0:
in_channels = n_in
else:
in_channels = n_embd
self.embd.append(MaskedConv1D(
in_channels, n_embd, n_embd_ks,
stride=1, padding=n_embd_ks//2, bias=(not with_ln)
)
)
if with_ln:
self.embd_norm.append(
LayerNorm(n_embd)
)
else:
self.embd_norm.append(nn.Identity())
# stem network using (vanilla) transformer
self.stem = nn.ModuleList()
for idx in range(arch[1]):
self.stem.append(ConvBlock(n_embd, 3, 1))
# main branch using transformer with pooling
self.branch = nn.ModuleList()
for idx in range(arch[2]):
self.branch.append(ConvBlock(n_embd, 3, self.scale_factor))
# init weights
self.apply(self.__init_weights__)
def __init_weights__(self, module):
# set nn.Linear bias term to 0
if isinstance(module, (nn.Linear, nn.Conv1d)):
if module.bias is not None:
torch.nn.init.constant_(module.bias, 0.)
def forward(self, x, mask):
# x: batch size, feature channel, sequence length,
# mask: batch size, 1, sequence length (bool)
B, C, T = x.size()
# embedding network
for idx in range(len(self.embd)):
x, mask = self.embd[idx](x, mask)
x = self.relu(self.embd_norm[idx](x))
# stem conv
for idx in range(len(self.stem)):
x, mask = self.stem[idx](x, mask)
# prep for outputs
out_feats = tuple()
out_masks = tuple()
# 1x resolution
out_feats += (x, )
out_masks += (mask, )
# main branch with downsampling
for idx in range(len(self.branch)):
x, mask = self.branch[idx](x, mask)
out_feats += (x, )
out_masks += (mask, )
return out_feats, out_masks