| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | import einops |
| | import math |
| |
|
| | class FuzzyEmbedding(nn.Module): |
| | def __init__(self, grid_size, scale, width, apply_fuzzy=False): |
| | super(FuzzyEmbedding, self).__init__() |
| | assert grid_size == 1024, "grid_size must be 1024 for now" |
| | |
| | self.grid_size = grid_size |
| | self.scale = scale |
| | self.width = width |
| | self.apply_fuzzy = apply_fuzzy |
| | |
| | |
| | self.positional_embedding = nn.Parameter( |
| | scale * torch.randn(grid_size, width)) |
| | |
| | self.class_positional_embedding = nn.Parameter( |
| | scale * torch.randn(1, width)) |
| |
|
| | @torch.cuda.amp.autocast(enabled=False) |
| | def forward(self, grid_height, grid_width, train=True, dtype=torch.float32): |
| | meshx, meshy = torch.meshgrid( |
| | torch.tensor(list(range(grid_height)), device=self.positional_embedding.device), |
| | torch.tensor(list(range(grid_width)), device=self.positional_embedding.device) |
| | ) |
| | meshx = meshx.to(dtype) |
| | meshy = meshy.to(dtype) |
| |
|
| | |
| | meshx = 2 * (meshx / (grid_height - 1)) - 1 |
| | meshy = 2 * (meshy / (grid_width - 1)) - 1 |
| | |
| | if self.apply_fuzzy: |
| | |
| | if train: |
| | noise_x = torch.rand_like(meshx) * 0.0008 - 0.0004 |
| | noise_y = torch.rand_like(meshy) * 0.0008 - 0.0004 |
| | else: |
| | noise_x = torch.zeros_like(meshx) |
| | noise_y = torch.zeros_like(meshy) |
| |
|
| | |
| | meshx = meshx + noise_x |
| | meshy = meshy + noise_y |
| | |
| | grid = torch.stack((meshy, meshx), 2).to(self.positional_embedding.device) |
| | grid = grid.unsqueeze(0) |
| | |
| | positional_embedding = einops.rearrange(self.positional_embedding, "(h w) d -> d h w", h=int(math.sqrt(self.grid_size)), w=int(math.sqrt(self.grid_size))) |
| | positional_embedding = positional_embedding.to(dtype) |
| | positional_embedding = positional_embedding.unsqueeze(0) |
| |
|
| | fuzzy_embedding = F.grid_sample(positional_embedding, grid, align_corners=False) |
| | fuzzy_embedding = fuzzy_embedding.to(dtype) |
| | fuzzy_embedding = einops.rearrange(fuzzy_embedding, "b d h w -> b (h w) d").squeeze(0) |
| |
|
| | final_embedding = torch.cat([self.class_positional_embedding, fuzzy_embedding], dim=0) |
| | return final_embedding |
| |
|
| |
|
| | if __name__ == "__main__": |
| | fuzzy_embedding = FuzzyEmbedding(256, 1.0, 1024) |
| | grid_height = 16 |
| | grid_width = 32 |
| | fuzzy_embedding = fuzzy_embedding(grid_height, grid_width, dtype=torch.bfloat16) |
| | print(fuzzy_embedding.shape) |