raster2seq / models /deformable_points.py
anas
Initial deployment of Raster2Seq floor plan vectorization API
fadb92b
import einops
import torch
import torch.nn.functional as F
from torch import nn
class LayerNormProxy(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = einops.rearrange(x, "b c h w -> b h w c")
x = self.norm(x)
return einops.rearrange(x, "b h w c -> b c h w")
class MSDeformablePoints(nn.Module):
def __init__(
self,
embed_dim,
n_levels,
n_heads,
offset_range_factor=-1,
):
super().__init__()
self.n_head_channels = embed_dim // n_heads
self.scale = self.n_head_channels**-0.5
self.n_heads = n_heads
self.nc = self.n_head_channels * n_heads
self.offset_range_factor = offset_range_factor
self.kernel_sizes = [(n_levels - 1 - i) * 2 + 1 for i in range(n_levels)] # [7, 5, 3, 1]
self.strides = [2 ** (n_levels - i) for i in range(n_levels)] # [16, 8, 4, 2]
self.conv_offset = nn.ModuleList(
[
nn.Sequential(
nn.Conv2d(
self.n_head_channels,
self.n_head_channels,
self.kernel_sizes[i],
self.strides[i],
self.kernel_sizes[i] // 2,
groups=self.n_heads,
),
LayerNormProxy(self.n_head_channels),
nn.GELU(),
nn.Conv2d(self.n_head_channels, 2, 1, 1, 0, bias=False),
)
for i in range(n_levels)
]
)
self.proj_q = nn.ModuleList(
[nn.Conv2d(self.nc, self.nc, kernel_size=1, stride=1, padding=0) for _ in range(n_levels)]
)
@torch.no_grad()
def _get_ref_points(self, H_key, W_key, B, dtype, device):
ref_y, ref_x = torch.meshgrid(
torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
indexing="ij",
)
ref = torch.stack((ref_y, ref_x), -1)
ref[..., 1].div_(W_key).mul_(2.0).sub_(1.0)
ref[..., 0].div_(H_key).mul_(2.0).sub_(1.0)
ref = ref[None, ...].expand(B * self.n_heads, -1, -1, -1) # [B*g, H, W, 2]
return ref
def forward(self, x, spatial_shapes, level_start_index):
B = x.size(0)
dtype, device = x.dtype, x.device
x_list = x.split([H_ * W_ for H_, W_ in spatial_shapes], dim=1)
out = []
for i in range(len(x_list)):
cur_x = x_list[i]
q = self.proj_q[i](
einops.rearrange(cur_x, "b (h w) c -> b c h w", h=spatial_shapes[i][0], w=spatial_shapes[i][1])
)
q_off = einops.rearrange(q, "b (g c) h w -> (b g) c h w", g=self.n_heads, c=self.n_head_channels)
offset = self.conv_offset[i](q_off).contiguous() # [B*g, 2, Hg Wg]
Hk, Wk = offset.size(2), offset.size(3)
if self.offset_range_factor >= 0:
offset_range = torch.tensor([1.0 / Hk, 1.0 / Wk], device=device).reshape(1, 2, 1, 1)
offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)
offset = einops.rearrange(offset, "b two h w -> b h w two")
reference = self._get_ref_points(Hk, Wk, B, dtype, device)
if self.offset_range_factor >= 0:
pos = offset + reference
else:
pos = (offset + reference).clamp(-1.0, +1.0)
H, W = spatial_shapes[i]
x_sampled = F.grid_sample(
input=cur_x.reshape(B * self.n_heads, self.n_head_channels, H, W), # [B*g, Cg, H, W]
grid=pos[..., (1, 0)], # y, x -> x, y: [B*g, Hg, Wg, 2]
mode="bilinear",
align_corners=True,
) # [B*g, Cg, Hg, Wg]
x_sampled = einops.rearrange(x_sampled, "(B g) C H W -> B (H W) (g C)", B=B)
out.append(x_sampled)
return torch.cat(out, dim=1)