|
|
from einops import rearrange |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from timm.models.layers import trunc_normal_ |
|
|
import math |
|
|
|
|
|
|
|
|
class FeatureResizer(nn.Module): |
|
|
def __init__(self, input_feat_size, output_feat_size, dropout, do_ln=True): |
|
|
super().__init__() |
|
|
self.do_ln = do_ln |
|
|
self.fc = nn.Linear(input_feat_size, output_feat_size, bias=True) |
|
|
self.layer_norm = nn.LayerNorm(output_feat_size, eps=1e-12) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, encoder_features): |
|
|
x = self.fc(encoder_features) |
|
|
if self.do_ln: |
|
|
x = self.layer_norm(x) |
|
|
output = self.dropout(x) |
|
|
return output |
|
|
|
|
|
|
|
|
def l2norm(X, dim=-1, eps=1e-12): |
|
|
""" |
|
|
L2-normalize columns of X |
|
|
""" |
|
|
norm = torch.pow(X, 2).sum(dim=dim, keepdim=True).sqrt() + eps |
|
|
X = torch.div(X, norm) |
|
|
return X |
|
|
|
|
|
class Mlp(nn.Module): |
|
|
|
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Linear(in_features, hidden_features) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Linear(hidden_features, out_features) |
|
|
self.drop = nn.Dropout(drop) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
def conv_layer(in_dim, out_dim, kernel_size=1, padding=0, stride=1): |
|
|
return nn.Sequential( |
|
|
nn.Conv2d(in_dim, out_dim, kernel_size, stride, padding, bias=False), |
|
|
nn.BatchNorm2d(out_dim), nn.ReLU(True)) |
|
|
|
|
|
def hard_softmax(logits, dim): |
|
|
y_soft = logits.softmax(dim) |
|
|
|
|
|
index = y_soft.max(dim, keepdim=True)[1] |
|
|
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) |
|
|
ret = y_hard - y_soft.detach() + y_soft |
|
|
return ret |
|
|
|
|
|
def gumbel_softmax(logits: torch.Tensor, tau: float = 1, dim: int = -2) -> torch.Tensor: |
|
|
gumbel_dist = torch.distributions.gumbel.Gumbel( |
|
|
torch.tensor(0., device=logits.device, dtype=logits.dtype), |
|
|
torch.tensor(1., device=logits.device, dtype=logits.dtype)) |
|
|
gumbels = gumbel_dist.sample(logits.shape) |
|
|
|
|
|
gumbels = (logits + gumbels) / tau |
|
|
y_soft = gumbels.softmax(dim) |
|
|
|
|
|
index = y_soft.max(dim, keepdim=True)[1] |
|
|
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0) |
|
|
ret = y_hard - y_soft.detach() + y_soft |
|
|
|
|
|
return ret |
|
|
|
|
|
class Fusion(nn.Module): |
|
|
def __init__(self, in_dim_1, in_dim_2, out_dim, bias=False) -> None: |
|
|
super().__init__() |
|
|
|
|
|
self.fusion = nn.Sequential( |
|
|
nn.Conv2d(in_dim_1+in_dim_2, out_dim, 3, padding=1, bias=bias), |
|
|
nn.BatchNorm2d(out_dim), |
|
|
nn.ReLU(), |
|
|
nn.Conv2d(out_dim, out_dim, 3, padding=1, bias=bias), |
|
|
nn.BatchNorm2d(out_dim), |
|
|
nn.ReLU(), |
|
|
) |
|
|
|
|
|
def forward(self, in_1, in_2): |
|
|
if in_1.shape[-1] < in_2.shape[-1]: |
|
|
in_1 = F.interpolate(in_1, size=in_2.shape[-2:], mode='bilinear', align_corners=True) |
|
|
elif in_1.shape[-1] > in_2.shape[-1]: |
|
|
in_2 = F.interpolate(in_2, size=in_1.shape[-2:], mode='bilinear', align_corners=True) |
|
|
|
|
|
x = torch.cat((in_1, in_2), dim=1) |
|
|
x = self.fusion(x) |
|
|
return x |
|
|
|
|
|
class DProjector(nn.Module): |
|
|
def __init__(self, text_dim=512, in_dim=512, kernel_size=1): |
|
|
super().__init__() |
|
|
self.in_dim = in_dim |
|
|
self.kernel_size = kernel_size |
|
|
|
|
|
|
|
|
self.vis = nn.Sequential( |
|
|
nn.Upsample(scale_factor=2, mode='bilinear'), |
|
|
conv_layer(in_dim, in_dim, 3, padding=1), |
|
|
nn.Upsample(scale_factor=2, mode='bilinear'), |
|
|
conv_layer(in_dim, in_dim, 3, padding=1), |
|
|
nn.Conv2d(in_dim, in_dim, 1)) |
|
|
|
|
|
|
|
|
out_dim = 1 * in_dim * kernel_size * kernel_size + 1 |
|
|
self.txt = nn.Linear(text_dim, out_dim) |
|
|
|
|
|
def forward(self, x, text): |
|
|
''' |
|
|
x: b, 512, 104, 104 |
|
|
text: b, 512 |
|
|
''' |
|
|
x = self.vis(x) |
|
|
|
|
|
B, C, H, W = x.size() |
|
|
|
|
|
x = x.reshape(1, B * C, H, W) |
|
|
|
|
|
text = self.txt(text) |
|
|
|
|
|
weight, bias = text[:, :-1], text[:, -1] |
|
|
weight = weight.reshape(B, C, self.kernel_size, self.kernel_size) |
|
|
|
|
|
out = F.conv2d(x, |
|
|
weight, |
|
|
padding=1, |
|
|
groups=B, |
|
|
bias=bias) |
|
|
|
|
|
|
|
|
out = out.transpose(0,1) |
|
|
return out |
|
|
|
|
|
|
|
|
class CrossAttn(nn.Module): |
|
|
def __init__(self, |
|
|
q_dim, |
|
|
kv_dim, |
|
|
hidden_dim, |
|
|
num_heads, |
|
|
out_dim=None, |
|
|
qkv_bias=False, |
|
|
qk_scale=None, |
|
|
attn_drop=0., |
|
|
proj_drop=0., |
|
|
qkv_fuse=False): |
|
|
super().__init__() |
|
|
if out_dim is None: |
|
|
out_dim = q_dim |
|
|
self.num_heads = num_heads |
|
|
head_dim = hidden_dim // num_heads |
|
|
self.scale = qk_scale or head_dim**-0.5 |
|
|
self.qkv_fuse = qkv_fuse |
|
|
|
|
|
self.q_proj = nn.Linear(q_dim, hidden_dim, bias=qkv_bias) |
|
|
self.k_proj = nn.Linear(kv_dim, hidden_dim, bias=qkv_bias) |
|
|
self.v_proj = nn.Linear(kv_dim, hidden_dim, bias=qkv_bias) |
|
|
self.attn_drop = nn.Dropout(attn_drop) |
|
|
self.proj = nn.Linear(hidden_dim, out_dim) |
|
|
self.proj_drop = nn.Dropout(proj_drop) |
|
|
|
|
|
def forward(self, query, key, value=None, mask=None): |
|
|
B, N, C = query.shape |
|
|
if value is None: |
|
|
value = key |
|
|
S = key.size(1) |
|
|
|
|
|
q = rearrange(self.q_proj(query), 'b n (h c)-> b h n c', h=self.num_heads, b=B, n=N, c=C // self.num_heads) |
|
|
|
|
|
k = rearrange(self.k_proj(key), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads) |
|
|
|
|
|
v = rearrange(self.v_proj(value), 'b n (h c)-> b h n c', h=self.num_heads, b=B, c=C // self.num_heads) |
|
|
|
|
|
|
|
|
if mask is not None: |
|
|
mask = mask[:,None,:,None].expand(-1, self.num_heads, -1, -1) |
|
|
k = k * mask |
|
|
v = v * mask |
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
attn = attn + (1e4*mask.transpose(-2,-1)-1e4) |
|
|
else: |
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale |
|
|
attn = attn.softmax(dim=-1) |
|
|
attn = self.attn_drop(attn) |
|
|
|
|
|
assert attn.shape == (B, self.num_heads, N, S) |
|
|
|
|
|
out = rearrange(attn @ v, 'b h n c -> b n (h c)', h=self.num_heads, b=B, n=N, c=C // self.num_heads) |
|
|
out = self.proj(out) |
|
|
out = self.proj_drop(out) |
|
|
return out |
|
|
|
|
|
class OriLoadToken(nn.Module): |
|
|
def __init__(self, token_dim, bias, drop) -> None: |
|
|
super().__init__() |
|
|
self.cross_attn = CrossAttn( |
|
|
q_dim=token_dim, |
|
|
kv_dim=768, |
|
|
hidden_dim=token_dim, |
|
|
num_heads=1, |
|
|
out_dim=token_dim, |
|
|
qkv_bias=bias, |
|
|
attn_drop=drop, |
|
|
proj_drop=drop, |
|
|
) |
|
|
self.normq = nn.LayerNorm(token_dim) |
|
|
self.normk = nn.LayerNorm(768) |
|
|
|
|
|
self.normq = nn.LayerNorm(token_dim) |
|
|
self.normk = nn.LayerNorm(768) |
|
|
|
|
|
def forward(self, tokens, text, pad_mask): |
|
|
tokens = tokens + self.cross_attn(query=self.normq(tokens), key=self.normk(text.permute(0,2,1)), mask=pad_mask[...,0]) |
|
|
return tokens |
|
|
|
|
|
|
|
|
class LoadToken(nn.Module): |
|
|
def __init__(self, token_dim, bias, drop) -> None: |
|
|
super().__init__() |
|
|
self.cross_attn = CrossAttn( |
|
|
q_dim=token_dim, |
|
|
kv_dim=768, |
|
|
hidden_dim=token_dim, |
|
|
num_heads=1, |
|
|
out_dim=token_dim, |
|
|
qkv_bias=bias, |
|
|
attn_drop=drop, |
|
|
proj_drop=drop, |
|
|
) |
|
|
self.normq = nn.LayerNorm(token_dim) |
|
|
self.normk = nn.LayerNorm(768) |
|
|
self.norm = nn.LayerNorm(token_dim) |
|
|
self.mlp = Mlp(token_dim, token_dim*2, token_dim) |
|
|
|
|
|
def forward(self, tokens, text, pad_mask): |
|
|
ltoken, ttoken = torch.split(tokens, [tokens.shape[1]-1,1], dim=1) |
|
|
ttoken = ttoken + self.cross_attn(query=self.normq(ttoken), key=self.normk(text.permute(0,2,1)), mask=pad_mask[...,0]) |
|
|
tokens = torch.cat((ltoken, ttoken), dim=1) |
|
|
return tokens |
|
|
|
|
|
class LoadLayer(nn.Module): |
|
|
def __init__(self, token_dim, drop, bias=False, pe_shape=None) -> None: |
|
|
super().__init__() |
|
|
if pe_shape >30: |
|
|
self.loadtoken = LoadToken( |
|
|
token_dim=token_dim, |
|
|
bias=bias, |
|
|
drop=drop |
|
|
) |
|
|
self.norm = nn.LayerNorm(token_dim) |
|
|
self.mlp = Mlp(token_dim, token_dim*2, token_dim) |
|
|
self.positional_embedding = nn.Parameter(torch.randn(pe_shape**2, token_dim) / token_dim ** 0.5) |
|
|
self.pe_shape = pe_shape |
|
|
|
|
|
def forward(self, tokens, text, pad_mask): |
|
|
if self.pe_shape > 30: |
|
|
tokens = self.loadtoken(tokens, text, pad_mask) |
|
|
tokens = self.mlp(self.norm(tokens)) |
|
|
return tokens, self.positional_embedding |
|
|
|
|
|
|
|
|
class MetricFuser(nn.Module): |
|
|
def __init__(self, token_dim, vis_dim, hidden_dim, drop=0., bias=True) -> None: |
|
|
super().__init__() |
|
|
self.norm_v = nn.LayerNorm(vis_dim) |
|
|
self.norm_t = nn.LayerNorm(token_dim) |
|
|
|
|
|
self.q_proj = nn.Linear(token_dim, hidden_dim, bias=bias) |
|
|
self.k_proj = nn.Linear(vis_dim, hidden_dim, bias=bias) |
|
|
self.v_proj = nn.Linear(vis_dim, hidden_dim, bias=bias) |
|
|
|
|
|
self.proj = nn.Linear(hidden_dim, token_dim) |
|
|
self.norm = nn.LayerNorm(token_dim) |
|
|
self.mlp = Mlp(token_dim, token_dim*2, token_dim, drop=drop) |
|
|
self.tau = nn.Parameter(torch.ones(1), requires_grad=True) |
|
|
|
|
|
def with_pe(self, vis, pe): |
|
|
return vis + pe |
|
|
|
|
|
def forward(self, tokens, vis, pad_mask=None, pe=None): |
|
|
b, c, h, w = vis.shape |
|
|
vis = rearrange(vis, 'b c h w -> b (h w) c') |
|
|
b_tok, c_tok, n_tok = tokens.shape |
|
|
tokens = rearrange(tokens, 'b c n -> b n c') |
|
|
|
|
|
if pe is not None: |
|
|
vis = self.with_pe(vis, pe) |
|
|
|
|
|
vis_norm = self.norm_v(vis) |
|
|
tokens_norm = self.norm_t(tokens) |
|
|
|
|
|
|
|
|
q = self.q_proj(tokens_norm) |
|
|
q = q * pad_mask |
|
|
k = self.k_proj(vis_norm) |
|
|
v = self.v_proj(vis_norm) |
|
|
|
|
|
q = l2norm(q, dim=-1) |
|
|
k = l2norm(k, dim=-1) |
|
|
|
|
|
raw_attn = (q @ k.transpose(-2, -1)) |
|
|
|
|
|
tau = torch.clamp(self.tau, max=0).exp() |
|
|
|
|
|
attn = gumbel_softmax(raw_attn / math.sqrt(q.shape[-1]), dim=-2, tau=tau) |
|
|
hit_map = attn |
|
|
|
|
|
attn = attn / (attn.sum(dim=-1, keepdim=True) + 1) |
|
|
new_tokens = attn @ v |
|
|
new_tokens = self.proj(new_tokens) |
|
|
new_tokens = self.mlp(self.norm(new_tokens+tokens)) |
|
|
|
|
|
return torch.mean(new_tokens, dim=1).unsqueeze(-1).unsqueeze(-1) |
|
|
|
|
|
|
|
|
class PositionEmbeddingSine1D(nn.Module): |
|
|
def __init__(self, num_pos_feats=256, temperature=10000, normalize=False, scale=None): |
|
|
super().__init__() |
|
|
self.num_pos_feats = num_pos_feats |
|
|
self.temperature = temperature |
|
|
self.normalize = normalize |
|
|
if scale is not None and normalize is False: |
|
|
raise ValueError("normalize should be True if scale is passed") |
|
|
if scale is None: |
|
|
scale = 2 * math.pi |
|
|
self.scale = scale |
|
|
|
|
|
def forward(self, text, mask): |
|
|
B, C, L = text.shape |
|
|
not_mask = ~mask |
|
|
x_embed = not_mask.cumsum(1, dtype=torch.float32) |
|
|
|
|
|
if self.normalize: |
|
|
eps = 1e-6 |
|
|
x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale |
|
|
|
|
|
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=text.device) |
|
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) |
|
|
|
|
|
pos_x = x_embed[:, :, None] / dim_t |
|
|
pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2) |
|
|
pos = pos_x.permute(0, 2, 1) |
|
|
return pos |
|
|
|
|
|
|
|
|
class VisionLanguageFusionModule(nn.Module): |
|
|
def __init__(self, d_model=768, nhead=8, dropout=0.1): |
|
|
super().__init__() |
|
|
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout) |
|
|
|
|
|
self.text_pos_encoder = PositionEmbeddingSine1D(d_model, normalize=True) |
|
|
self.vis_proj = nn.Linear(1024, d_model) |
|
|
|
|
|
def with_pos_embed(self, tensor, pos: torch.Tensor = None): |
|
|
return tensor if pos is None else tensor + pos |
|
|
|
|
|
def forward(self, text, visual, text_key_padding_mask, text_pe=False): |
|
|
B, C, H, W = visual.shape |
|
|
visual = self.vis_proj(visual.permute(0, 2, 3, 1)) |
|
|
visual = rearrange(visual, 'b h w c -> (h w) b c') |
|
|
text_key_padding_mask = text_key_padding_mask.squeeze(-1).bool() |
|
|
|
|
|
if text_pe : |
|
|
text_pos = self.text_pos_encoder(text, text_key_padding_mask) |
|
|
text_pos = rearrange(text_pos, 'b c l -> l b c') |
|
|
else : |
|
|
text_pos = None |
|
|
|
|
|
text = rearrange(text, 'b c l -> l b c') |
|
|
query = self.with_pos_embed(visual, None) / math.sqrt(visual.shape[-1]) |
|
|
key = self.with_pos_embed(text, text_pos) / math.sqrt(text.shape[-1]) |
|
|
|
|
|
fused_visual = self.multihead_attn( |
|
|
query=query, |
|
|
key=key, |
|
|
value=text, |
|
|
key_padding_mask=text_key_padding_mask |
|
|
)[0] |
|
|
visual = visual + fused_visual |
|
|
|
|
|
visual = rearrange(visual, '(h w) b c -> b c h w', h=H, w=W) |
|
|
return visual |
|
|
|
|
|
|
|
|
class CGAttention(nn.Module): |
|
|
def __init__(self, token_dim, vis_dim, hidden_dim, drop=0., bias=True) -> None: |
|
|
super().__init__() |
|
|
self.norm_v = nn.LayerNorm(vis_dim) |
|
|
self.norm_t = nn.LayerNorm(token_dim) |
|
|
self.q_proj = nn.Linear(token_dim, hidden_dim, bias=bias) |
|
|
self.k_proj = nn.Linear(vis_dim, hidden_dim, bias=bias) |
|
|
self.v_proj = nn.Linear(vis_dim, hidden_dim, bias=bias) |
|
|
self.proj = nn.Linear(hidden_dim, token_dim) |
|
|
self.proj_drop = nn.Dropout(drop) |
|
|
self.norm = nn.LayerNorm(token_dim) |
|
|
self.mlp = Mlp(token_dim, token_dim*2, token_dim, drop=drop) |
|
|
self.tau = nn.Parameter(torch.ones(1), requires_grad=True) |
|
|
|
|
|
def with_pe(self, vis, pe): |
|
|
return vis + pe |
|
|
|
|
|
def forward(self, tokens, vis, pe=None): |
|
|
b, c, h , w = vis.shape |
|
|
vis = rearrange(vis, 'b c h w -> b (h w) c') |
|
|
if pe is not None: |
|
|
vis = self.with_pe(vis, pe) |
|
|
vis = self.norm_v(vis) |
|
|
q = self.q_proj(self.norm_t(tokens)) |
|
|
k = self.k_proj(vis) |
|
|
v = self.v_proj(vis) |
|
|
|
|
|
q = l2norm(q, dim=-1) |
|
|
k = l2norm(k, dim=-1) |
|
|
raw_attn = (q @ k.transpose(-2, -1)) |
|
|
tau = torch.clamp(self.tau, max=0).exp() |
|
|
attn = gumbel_softmax(raw_attn, dim=-2, tau=tau) |
|
|
hit_map = attn |
|
|
attn = attn / (attn.sum(dim=-1, keepdim=True) + 1) |
|
|
new_tokens = attn @ v |
|
|
new_tokens = self.proj_drop(self.proj(new_tokens)) |
|
|
new_tokens = self.mlp(self.norm(new_tokens+tokens)) |
|
|
return new_tokens, hit_map.reshape(b, -1, h, w) |
|
|
|
|
|
class Decoder_fuse(nn.Module): |
|
|
def __init__(self, args) -> None: |
|
|
super().__init__() |
|
|
''' |
|
|
c1 :128, 120, 120 |
|
|
c2 :256, 60, 60 |
|
|
c3 :512, 30, 30 |
|
|
c4 :1024, 15 ,15 |
|
|
''' |
|
|
token_dim = args.token_dim |
|
|
self.tokens = nn.Embedding(args.num_token, token_dim) |
|
|
trunc_normal_(self.tokens.weight, std=0.02) |
|
|
|
|
|
dims = [1024, 512, 256, 128] |
|
|
pe_shapes = [30, 60, 120] |
|
|
|
|
|
self.layers = [] |
|
|
for pe_shape in pe_shapes: |
|
|
self.layers.append(LoadLayer(token_dim, drop=.1, bias=False, pe_shape=pe_shape)) |
|
|
self.cgattention1 = CGAttention(token_dim=token_dim, |
|
|
vis_dim=token_dim, |
|
|
hidden_dim=token_dim, |
|
|
drop=.1, |
|
|
bias=True) |
|
|
self.cgattention2 = CGAttention(token_dim=token_dim, |
|
|
vis_dim=token_dim, |
|
|
hidden_dim=token_dim, |
|
|
drop=.1, |
|
|
bias=True) |
|
|
self.layers = nn.ModuleList(self.layers) |
|
|
self.fuses = [] |
|
|
for dim in [dims[0], dims[2], dims[3]]: |
|
|
self.fuses.append(Fusion(dim, token_dim, token_dim, bias=True)) |
|
|
self.fuses = nn.ModuleList(self.fuses) |
|
|
self.proj = DProjector(text_dim=token_dim, in_dim=token_dim) |
|
|
|
|
|
|
|
|
self.fuse_mode = args.fuse_mode |
|
|
if args.fuse_mode == 'simple_attn': |
|
|
self.metric_tensor_generator = MetricFuser(768, dims[0], token_dim) |
|
|
elif 'lang_tf_attn' in args.fuse_mode : |
|
|
self.metric_tensor_generator = VisionLanguageFusionModule(d_model=768) |
|
|
else : |
|
|
self.metric_tensor_generator = None |
|
|
|
|
|
def forward(self, vis, text, pad_mask): |
|
|
x_c4, x_c3, x_c2, x_c1 = vis |
|
|
tokens = self.tokens.weight[None,...].expand(x_c1.shape[0], -1, -1) |
|
|
if self.fuse_mode == 'simple_attn' : |
|
|
metric_tensor = self.metric_tensor_generator(text, x_c4, pad_mask) |
|
|
elif self.fuse_mode == 'lang_tf_attn_wope' : |
|
|
metric_tensor = self.metric_tensor_generator(text, x_c4, pad_mask, False) |
|
|
|
|
|
elif self.fuse_mode == 'lang_tf_attn_wpe' : |
|
|
metric_tensor = self.metric_tensor_generator(text, x_c4, pad_mask, True) |
|
|
|
|
|
maps = [] |
|
|
v = x_c4 |
|
|
for idx, (load, layer, fuse, v_) in enumerate(zip(self.layers,[self.cgattention1,self.cgattention2,self.cgattention2], self.fuses, [x_c3, x_c2, x_c1])): |
|
|
v = fuse(v, v_) |
|
|
tokens, pe = load(tokens, text, pad_mask) |
|
|
tokens, hitmap = layer(tokens, v, pe=pe) |
|
|
maps.append(hitmap) |
|
|
out = self.proj(v, tokens[:,-1]) |
|
|
return out, maps, metric_tensor |
|
|
|