|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as functional |
|
|
|
|
|
import random |
|
|
|
|
|
|
|
|
class TransNetV2(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
F=16, L=3, S=2, D=1024, |
|
|
use_many_hot_targets=True, |
|
|
use_frame_similarity=True, |
|
|
use_color_histograms=True, |
|
|
use_mean_pooling=False, |
|
|
dropout_rate=0.5, |
|
|
use_convex_comb_reg=False, |
|
|
use_resnet_features=False, |
|
|
use_resnet_like_top=False, |
|
|
frame_similarity_on_last_layer=False): |
|
|
super(TransNetV2, self).__init__() |
|
|
|
|
|
if use_resnet_features or use_resnet_like_top or use_convex_comb_reg or frame_similarity_on_last_layer: |
|
|
raise NotImplemented("Some options not implemented in Pytorch version of Transnet!") |
|
|
|
|
|
self.SDDCNN = nn.ModuleList( |
|
|
[StackedDDCNNV2(in_filters=3, n_blocks=S, filters=F, stochastic_depth_drop_prob=0.)] + |
|
|
[StackedDDCNNV2(in_filters=(F * 2 ** (i - 1)) * 4, n_blocks=S, filters=F * 2 ** i) for i in range(1, L)] |
|
|
) |
|
|
|
|
|
self.frame_sim_layer = FrameSimilarity( |
|
|
sum([(F * 2 ** i) * 4 for i in range(L)]), lookup_window=101, output_dim=128, similarity_dim=128, use_bias=True |
|
|
) if use_frame_similarity else None |
|
|
self.color_hist_layer = ColorHistograms( |
|
|
lookup_window=101, output_dim=128 |
|
|
) if use_color_histograms else None |
|
|
|
|
|
self.dropout = nn.Dropout(dropout_rate) if dropout_rate is not None else None |
|
|
|
|
|
output_dim = ((F * 2 ** (L - 1)) * 4) * 3 * 6 |
|
|
if use_frame_similarity: output_dim += 128 |
|
|
if use_color_histograms: output_dim += 128 |
|
|
|
|
|
self.fc1 = nn.Linear(output_dim, D) |
|
|
self.cls_layer1 = nn.Linear(D, 1) |
|
|
self.cls_layer2 = nn.Linear(D, 1) if use_many_hot_targets else None |
|
|
|
|
|
self.use_mean_pooling = use_mean_pooling |
|
|
self.eval() |
|
|
|
|
|
def forward(self, inputs): |
|
|
assert isinstance(inputs, torch.Tensor) and list(inputs.shape[2:]) == [27, 48, 3] and inputs.dtype == torch.uint8, \ |
|
|
"incorrect input type and/or shape" |
|
|
|
|
|
x = inputs.permute([0, 4, 1, 2, 3]).float() |
|
|
x = x.div_(255.) |
|
|
|
|
|
block_features = [] |
|
|
for block in self.SDDCNN: |
|
|
x = block(x) |
|
|
block_features.append(x) |
|
|
|
|
|
if self.use_mean_pooling: |
|
|
x = torch.mean(x, dim=[3, 4]) |
|
|
x = x.permute(0, 2, 1) |
|
|
else: |
|
|
x = x.permute(0, 2, 3, 4, 1) |
|
|
x = x.reshape(x.shape[0], x.shape[1], -1) |
|
|
|
|
|
if self.frame_sim_layer is not None: |
|
|
x = torch.cat([self.frame_sim_layer(block_features), x], 2) |
|
|
|
|
|
if self.color_hist_layer is not None: |
|
|
x = torch.cat([self.color_hist_layer(inputs), x], 2) |
|
|
|
|
|
x = self.fc1(x) |
|
|
x = functional.relu(x) |
|
|
|
|
|
if self.dropout is not None: |
|
|
x = self.dropout(x) |
|
|
|
|
|
one_hot = self.cls_layer1(x) |
|
|
|
|
|
if self.cls_layer2 is not None: |
|
|
return one_hot, {"many_hot": self.cls_layer2(x)} |
|
|
|
|
|
return one_hot |
|
|
|
|
|
|
|
|
class StackedDDCNNV2(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_filters, |
|
|
n_blocks, |
|
|
filters, |
|
|
shortcut=True, |
|
|
use_octave_conv=False, |
|
|
pool_type="avg", |
|
|
stochastic_depth_drop_prob=0.0): |
|
|
super(StackedDDCNNV2, self).__init__() |
|
|
|
|
|
if use_octave_conv: |
|
|
raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!") |
|
|
|
|
|
assert pool_type == "max" or pool_type == "avg" |
|
|
if use_octave_conv and pool_type == "max": |
|
|
print("WARN: Octave convolution was designed with average pooling, not max pooling.") |
|
|
|
|
|
self.shortcut = shortcut |
|
|
self.DDCNN = nn.ModuleList([ |
|
|
DilatedDCNNV2(in_filters if i == 1 else filters * 4, filters, octave_conv=use_octave_conv, |
|
|
activation=functional.relu if i != n_blocks else None) for i in range(1, n_blocks + 1) |
|
|
]) |
|
|
self.pool = nn.MaxPool3d(kernel_size=(1, 2, 2)) if pool_type == "max" else nn.AvgPool3d(kernel_size=(1, 2, 2)) |
|
|
self.stochastic_depth_drop_prob = stochastic_depth_drop_prob |
|
|
|
|
|
def forward(self, inputs): |
|
|
x = inputs |
|
|
shortcut = None |
|
|
|
|
|
for block in self.DDCNN: |
|
|
x = block(x) |
|
|
if shortcut is None: |
|
|
shortcut = x |
|
|
|
|
|
x = functional.relu(x) |
|
|
|
|
|
if self.shortcut is not None: |
|
|
if self.stochastic_depth_drop_prob != 0.: |
|
|
if self.training: |
|
|
if random.random() < self.stochastic_depth_drop_prob: |
|
|
x = shortcut |
|
|
else: |
|
|
x = x + shortcut |
|
|
else: |
|
|
x = (1 - self.stochastic_depth_drop_prob) * x + shortcut |
|
|
else: |
|
|
x += shortcut |
|
|
|
|
|
x = self.pool(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DilatedDCNNV2(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_filters, |
|
|
filters, |
|
|
batch_norm=True, |
|
|
activation=None, |
|
|
octave_conv=False): |
|
|
super(DilatedDCNNV2, self).__init__() |
|
|
|
|
|
if octave_conv: |
|
|
raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!") |
|
|
|
|
|
assert not (octave_conv and batch_norm) |
|
|
|
|
|
self.Conv3D_1 = Conv3DConfigurable(in_filters, filters, 1, use_bias=not batch_norm) |
|
|
self.Conv3D_2 = Conv3DConfigurable(in_filters, filters, 2, use_bias=not batch_norm) |
|
|
self.Conv3D_4 = Conv3DConfigurable(in_filters, filters, 4, use_bias=not batch_norm) |
|
|
self.Conv3D_8 = Conv3DConfigurable(in_filters, filters, 8, use_bias=not batch_norm) |
|
|
|
|
|
self.bn = nn.BatchNorm3d(filters * 4, eps=1e-3) if batch_norm else None |
|
|
self.activation = activation |
|
|
|
|
|
def forward(self, inputs): |
|
|
conv1 = self.Conv3D_1(inputs) |
|
|
conv2 = self.Conv3D_2(inputs) |
|
|
conv3 = self.Conv3D_4(inputs) |
|
|
conv4 = self.Conv3D_8(inputs) |
|
|
|
|
|
x = torch.cat([conv1, conv2, conv3, conv4], dim=1) |
|
|
|
|
|
if self.bn is not None: |
|
|
x = self.bn(x) |
|
|
|
|
|
if self.activation is not None: |
|
|
x = self.activation(x) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class Conv3DConfigurable(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_filters, |
|
|
filters, |
|
|
dilation_rate, |
|
|
separable=True, |
|
|
octave=False, |
|
|
use_bias=True, |
|
|
kernel_initializer=None): |
|
|
super(Conv3DConfigurable, self).__init__() |
|
|
|
|
|
if octave: |
|
|
raise NotImplemented("Octave convolution not implemented in Pytorch version of Transnet!") |
|
|
if kernel_initializer is not None: |
|
|
raise NotImplemented("Kernel initializers are not implemented in Pytorch version of Transnet!") |
|
|
|
|
|
assert not (separable and octave) |
|
|
|
|
|
if separable: |
|
|
|
|
|
conv1 = nn.Conv3d(in_filters, 2 * filters, kernel_size=(1, 3, 3), |
|
|
dilation=(1, 1, 1), padding=(0, 1, 1), bias=False) |
|
|
conv2 = nn.Conv3d(2 * filters, filters, kernel_size=(3, 1, 1), |
|
|
dilation=(dilation_rate, 1, 1), padding=(dilation_rate, 0, 0), bias=use_bias) |
|
|
self.layers = nn.ModuleList([conv1, conv2]) |
|
|
else: |
|
|
conv = nn.Conv3d(in_filters, filters, kernel_size=3, |
|
|
dilation=(dilation_rate, 1, 1), padding=(dilation_rate, 1, 1), bias=use_bias) |
|
|
self.layers = nn.ModuleList([conv]) |
|
|
|
|
|
def forward(self, inputs): |
|
|
x = inputs |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class FrameSimilarity(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
in_filters, |
|
|
similarity_dim=128, |
|
|
lookup_window=101, |
|
|
output_dim=128, |
|
|
stop_gradient=False, |
|
|
use_bias=False): |
|
|
super(FrameSimilarity, self).__init__() |
|
|
|
|
|
if stop_gradient: |
|
|
raise NotImplemented("Stop gradient not implemented in Pytorch version of Transnet!") |
|
|
|
|
|
self.projection = nn.Linear(in_filters, similarity_dim, bias=use_bias) |
|
|
self.fc = nn.Linear(lookup_window, output_dim) |
|
|
|
|
|
self.lookup_window = lookup_window |
|
|
assert lookup_window % 2 == 1, "`lookup_window` must be odd integer" |
|
|
|
|
|
def forward(self, inputs): |
|
|
x = torch.cat([torch.mean(x, dim=[3, 4]) for x in inputs], dim=1) |
|
|
x = torch.transpose(x, 1, 2) |
|
|
|
|
|
x = self.projection(x) |
|
|
x = functional.normalize(x, p=2, dim=2) |
|
|
|
|
|
batch_size, time_window = x.shape[0], x.shape[1] |
|
|
similarities = torch.bmm(x, x.transpose(1, 2)) |
|
|
similarities_padded = functional.pad(similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2]) |
|
|
|
|
|
batch_indices = torch.arange(0, batch_size, device=x.device).view([batch_size, 1, 1]).repeat( |
|
|
[1, time_window, self.lookup_window]) |
|
|
time_indices = torch.arange(0, time_window, device=x.device).view([1, time_window, 1]).repeat( |
|
|
[batch_size, 1, self.lookup_window]) |
|
|
lookup_indices = torch.arange(0, self.lookup_window, device=x.device).view([1, 1, self.lookup_window]).repeat( |
|
|
[batch_size, time_window, 1]) + time_indices |
|
|
|
|
|
similarities = similarities_padded[batch_indices, time_indices, lookup_indices] |
|
|
return functional.relu(self.fc(similarities)) |
|
|
|
|
|
|
|
|
class ColorHistograms(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
|
lookup_window=101, |
|
|
output_dim=None): |
|
|
super(ColorHistograms, self).__init__() |
|
|
|
|
|
self.fc = nn.Linear(lookup_window, output_dim) if output_dim is not None else None |
|
|
self.lookup_window = lookup_window |
|
|
assert lookup_window % 2 == 1, "`lookup_window` must be odd integer" |
|
|
|
|
|
@staticmethod |
|
|
def compute_color_histograms(frames): |
|
|
frames = frames.int() |
|
|
|
|
|
def get_bin(frames): |
|
|
|
|
|
R, G, B = frames[:, :, 0], frames[:, :, 1], frames[:, :, 2] |
|
|
R, G, B = R >> 5, G >> 5, B >> 5 |
|
|
return (R << 6) + (G << 3) + B |
|
|
|
|
|
batch_size, time_window, height, width, no_channels = frames.shape |
|
|
assert no_channels == 3 |
|
|
frames_flatten = frames.view(batch_size * time_window, height * width, 3) |
|
|
|
|
|
binned_values = get_bin(frames_flatten) |
|
|
frame_bin_prefix = (torch.arange(0, batch_size * time_window, device=frames.device) << 9).view(-1, 1) |
|
|
binned_values = (binned_values + frame_bin_prefix).view(-1) |
|
|
|
|
|
histograms = torch.zeros(batch_size * time_window * 512, dtype=torch.int32, device=frames.device) |
|
|
histograms.scatter_add_(0, binned_values, torch.ones(len(binned_values), dtype=torch.int32, device=frames.device)) |
|
|
|
|
|
histograms = histograms.view(batch_size, time_window, 512).float() |
|
|
histograms_normalized = functional.normalize(histograms, p=2, dim=2) |
|
|
return histograms_normalized |
|
|
|
|
|
def forward(self, inputs): |
|
|
x = self.compute_color_histograms(inputs) |
|
|
|
|
|
batch_size, time_window = x.shape[0], x.shape[1] |
|
|
similarities = torch.bmm(x, x.transpose(1, 2)) |
|
|
similarities_padded = functional.pad(similarities, [(self.lookup_window - 1) // 2, (self.lookup_window - 1) // 2]) |
|
|
|
|
|
batch_indices = torch.arange(0, batch_size, device=x.device).view([batch_size, 1, 1]).repeat( |
|
|
[1, time_window, self.lookup_window]) |
|
|
time_indices = torch.arange(0, time_window, device=x.device).view([1, time_window, 1]).repeat( |
|
|
[batch_size, 1, self.lookup_window]) |
|
|
lookup_indices = torch.arange(0, self.lookup_window, device=x.device).view([1, 1, self.lookup_window]).repeat( |
|
|
[batch_size, time_window, 1]) + time_indices |
|
|
|
|
|
similarities = similarities_padded[batch_indices, time_indices, lookup_indices] |
|
|
|
|
|
if self.fc is not None: |
|
|
return functional.relu(self.fc(similarities)) |
|
|
return similarities |
|
|
|