| | import torch |
| | from torch import nn |
| | from .ldm.modules.attention import CrossAttention |
| | from inspect import isfunction |
| | import comfy.ops |
| | ops = comfy.ops.manual_cast |
| |
|
| | def exists(val): |
| | return val is not None |
| |
|
| |
|
| | def uniq(arr): |
| | return{el: True for el in arr}.keys() |
| |
|
| |
|
| | def default(val, d): |
| | if exists(val): |
| | return val |
| | return d() if isfunction(d) else d |
| |
|
| |
|
| | |
| | class GEGLU(nn.Module): |
| | def __init__(self, dim_in, dim_out): |
| | super().__init__() |
| | self.proj = ops.Linear(dim_in, dim_out * 2) |
| |
|
| | def forward(self, x): |
| | x, gate = self.proj(x).chunk(2, dim=-1) |
| | return x * torch.nn.functional.gelu(gate) |
| |
|
| |
|
| | class FeedForward(nn.Module): |
| | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): |
| | super().__init__() |
| | inner_dim = int(dim * mult) |
| | dim_out = default(dim_out, dim) |
| | project_in = nn.Sequential( |
| | ops.Linear(dim, inner_dim), |
| | nn.GELU() |
| | ) if not glu else GEGLU(dim, inner_dim) |
| |
|
| | self.net = nn.Sequential( |
| | project_in, |
| | nn.Dropout(dropout), |
| | ops.Linear(inner_dim, dim_out) |
| | ) |
| |
|
| | def forward(self, x): |
| | return self.net(x) |
| |
|
| |
|
| | class GatedCrossAttentionDense(nn.Module): |
| | def __init__(self, query_dim, context_dim, n_heads, d_head): |
| | super().__init__() |
| |
|
| | self.attn = CrossAttention( |
| | query_dim=query_dim, |
| | context_dim=context_dim, |
| | heads=n_heads, |
| | dim_head=d_head, |
| | operations=ops) |
| | self.ff = FeedForward(query_dim, glu=True) |
| |
|
| | self.norm1 = ops.LayerNorm(query_dim) |
| | self.norm2 = ops.LayerNorm(query_dim) |
| |
|
| | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) |
| | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) |
| |
|
| | |
| | |
| | |
| | self.scale = 1 |
| |
|
| | def forward(self, x, objs): |
| |
|
| | x = x + self.scale * \ |
| | torch.tanh(self.alpha_attn) * self.attn(self.norm1(x), objs, objs) |
| | x = x + self.scale * \ |
| | torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) |
| |
|
| | return x |
| |
|
| |
|
| | class GatedSelfAttentionDense(nn.Module): |
| | def __init__(self, query_dim, context_dim, n_heads, d_head): |
| | super().__init__() |
| |
|
| | |
| | |
| | self.linear = ops.Linear(context_dim, query_dim) |
| |
|
| | self.attn = CrossAttention( |
| | query_dim=query_dim, |
| | context_dim=query_dim, |
| | heads=n_heads, |
| | dim_head=d_head, |
| | operations=ops) |
| | self.ff = FeedForward(query_dim, glu=True) |
| |
|
| | self.norm1 = ops.LayerNorm(query_dim) |
| | self.norm2 = ops.LayerNorm(query_dim) |
| |
|
| | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) |
| | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) |
| |
|
| | |
| | |
| | |
| | self.scale = 1 |
| |
|
| | def forward(self, x, objs): |
| |
|
| | N_visual = x.shape[1] |
| | objs = self.linear(objs) |
| |
|
| | x = x + self.scale * torch.tanh(self.alpha_attn) * self.attn( |
| | self.norm1(torch.cat([x, objs], dim=1)))[:, 0:N_visual, :] |
| | x = x + self.scale * \ |
| | torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) |
| |
|
| | return x |
| |
|
| |
|
| | class GatedSelfAttentionDense2(nn.Module): |
| | def __init__(self, query_dim, context_dim, n_heads, d_head): |
| | super().__init__() |
| |
|
| | |
| | |
| | self.linear = ops.Linear(context_dim, query_dim) |
| |
|
| | self.attn = CrossAttention( |
| | query_dim=query_dim, context_dim=query_dim, dim_head=d_head, operations=ops) |
| | self.ff = FeedForward(query_dim, glu=True) |
| |
|
| | self.norm1 = ops.LayerNorm(query_dim) |
| | self.norm2 = ops.LayerNorm(query_dim) |
| |
|
| | self.register_parameter('alpha_attn', nn.Parameter(torch.tensor(0.))) |
| | self.register_parameter('alpha_dense', nn.Parameter(torch.tensor(0.))) |
| |
|
| | |
| | |
| | |
| | self.scale = 1 |
| |
|
| | def forward(self, x, objs): |
| |
|
| | B, N_visual, _ = x.shape |
| | B, N_ground, _ = objs.shape |
| |
|
| | objs = self.linear(objs) |
| |
|
| | |
| | size_v = math.sqrt(N_visual) |
| | size_g = math.sqrt(N_ground) |
| | assert int(size_v) == size_v, "Visual tokens must be square rootable" |
| | assert int(size_g) == size_g, "Grounding tokens must be square rootable" |
| | size_v = int(size_v) |
| | size_g = int(size_g) |
| |
|
| | |
| | out = self.attn(self.norm1(torch.cat([x, objs], dim=1)))[ |
| | :, N_visual:, :] |
| | out = out.permute(0, 2, 1).reshape(B, -1, size_g, size_g) |
| | out = torch.nn.functional.interpolate( |
| | out, (size_v, size_v), mode='bicubic') |
| | residual = out.reshape(B, -1, N_visual).permute(0, 2, 1) |
| |
|
| | |
| | x = x + self.scale * torch.tanh(self.alpha_attn) * residual |
| | x = x + self.scale * \ |
| | torch.tanh(self.alpha_dense) * self.ff(self.norm2(x)) |
| |
|
| | return x |
| |
|
| |
|
| | class FourierEmbedder(): |
| | def __init__(self, num_freqs=64, temperature=100): |
| |
|
| | self.num_freqs = num_freqs |
| | self.temperature = temperature |
| | self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) |
| |
|
| | @torch.no_grad() |
| | def __call__(self, x, cat_dim=-1): |
| | "x: arbitrary shape of tensor. dim: cat dim" |
| | out = [] |
| | for freq in self.freq_bands: |
| | out.append(torch.sin(freq * x)) |
| | out.append(torch.cos(freq * x)) |
| | return torch.cat(out, cat_dim) |
| |
|
| |
|
| | class PositionNet(nn.Module): |
| | def __init__(self, in_dim, out_dim, fourier_freqs=8): |
| | super().__init__() |
| | self.in_dim = in_dim |
| | self.out_dim = out_dim |
| |
|
| | self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs) |
| | self.position_dim = fourier_freqs * 2 * 4 |
| |
|
| | self.linears = nn.Sequential( |
| | ops.Linear(self.in_dim + self.position_dim, 512), |
| | nn.SiLU(), |
| | ops.Linear(512, 512), |
| | nn.SiLU(), |
| | ops.Linear(512, out_dim), |
| | ) |
| |
|
| | self.null_positive_feature = torch.nn.Parameter( |
| | torch.zeros([self.in_dim])) |
| | self.null_position_feature = torch.nn.Parameter( |
| | torch.zeros([self.position_dim])) |
| |
|
| | def forward(self, boxes, masks, positive_embeddings): |
| | B, N, _ = boxes.shape |
| | masks = masks.unsqueeze(-1) |
| | positive_embeddings = positive_embeddings |
| |
|
| | |
| | xyxy_embedding = self.fourier_embedder(boxes) |
| |
|
| | |
| | positive_null = self.null_positive_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) |
| | xyxy_null = self.null_position_feature.to(device=boxes.device, dtype=boxes.dtype).view(1, 1, -1) |
| |
|
| | |
| | positive_embeddings = positive_embeddings * \ |
| | masks + (1 - masks) * positive_null |
| | xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null |
| |
|
| | objs = self.linears( |
| | torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) |
| | assert objs.shape == torch.Size([B, N, self.out_dim]) |
| | return objs |
| |
|
| |
|
| | class Gligen(nn.Module): |
| | def __init__(self, modules, position_net, key_dim): |
| | super().__init__() |
| | self.module_list = nn.ModuleList(modules) |
| | self.position_net = position_net |
| | self.key_dim = key_dim |
| | self.max_objs = 30 |
| | self.current_device = torch.device("cpu") |
| |
|
| | def _set_position(self, boxes, masks, positive_embeddings): |
| | objs = self.position_net(boxes, masks, positive_embeddings) |
| | def func(x, extra_options): |
| | key = extra_options["transformer_index"] |
| | module = self.module_list[key] |
| | return module(x, objs.to(device=x.device, dtype=x.dtype)) |
| | return func |
| |
|
| | def set_position(self, latent_image_shape, position_params, device): |
| | batch, c, h, w = latent_image_shape |
| | masks = torch.zeros([self.max_objs], device="cpu") |
| | boxes = [] |
| | positive_embeddings = [] |
| | for p in position_params: |
| | x1 = (p[4]) / w |
| | y1 = (p[3]) / h |
| | x2 = (p[4] + p[2]) / w |
| | y2 = (p[3] + p[1]) / h |
| | masks[len(boxes)] = 1.0 |
| | boxes += [torch.tensor((x1, y1, x2, y2)).unsqueeze(0)] |
| | positive_embeddings += [p[0]] |
| | append_boxes = [] |
| | append_conds = [] |
| | if len(boxes) < self.max_objs: |
| | append_boxes = [torch.zeros( |
| | [self.max_objs - len(boxes), 4], device="cpu")] |
| | append_conds = [torch.zeros( |
| | [self.max_objs - len(boxes), self.key_dim], device="cpu")] |
| |
|
| | box_out = torch.cat( |
| | boxes + append_boxes).unsqueeze(0).repeat(batch, 1, 1) |
| | masks = masks.unsqueeze(0).repeat(batch, 1) |
| | conds = torch.cat(positive_embeddings + |
| | append_conds).unsqueeze(0).repeat(batch, 1, 1) |
| | return self._set_position( |
| | box_out.to(device), |
| | masks.to(device), |
| | conds.to(device)) |
| |
|
| | def set_empty(self, latent_image_shape, device): |
| | batch, c, h, w = latent_image_shape |
| | masks = torch.zeros([self.max_objs], device="cpu").repeat(batch, 1) |
| | box_out = torch.zeros([self.max_objs, 4], |
| | device="cpu").repeat(batch, 1, 1) |
| | conds = torch.zeros([self.max_objs, self.key_dim], |
| | device="cpu").repeat(batch, 1, 1) |
| | return self._set_position( |
| | box_out.to(device), |
| | masks.to(device), |
| | conds.to(device)) |
| |
|
| |
|
| | def load_gligen(sd): |
| | sd_k = sd.keys() |
| | output_list = [] |
| | key_dim = 768 |
| | for a in ["input_blocks", "middle_block", "output_blocks"]: |
| | for b in range(20): |
| | k_temp = filter(lambda k: "{}.{}.".format(a, b) |
| | in k and ".fuser." in k, sd_k) |
| | k_temp = map(lambda k: (k, k.split(".fuser.")[-1]), k_temp) |
| |
|
| | n_sd = {} |
| | for k in k_temp: |
| | n_sd[k[1]] = sd[k[0]] |
| | if len(n_sd) > 0: |
| | query_dim = n_sd["linear.weight"].shape[0] |
| | key_dim = n_sd["linear.weight"].shape[1] |
| |
|
| | if key_dim == 768: |
| | n_heads = 8 |
| | d_head = query_dim // n_heads |
| | else: |
| | d_head = 64 |
| | n_heads = query_dim // d_head |
| |
|
| | gated = GatedSelfAttentionDense( |
| | query_dim, key_dim, n_heads, d_head) |
| | gated.load_state_dict(n_sd, strict=False) |
| | output_list.append(gated) |
| |
|
| | if "position_net.null_positive_feature" in sd_k: |
| | in_dim = sd["position_net.null_positive_feature"].shape[0] |
| | out_dim = sd["position_net.linears.4.weight"].shape[0] |
| |
|
| | class WeightsLoader(torch.nn.Module): |
| | pass |
| | w = WeightsLoader() |
| | w.position_net = PositionNet(in_dim, out_dim) |
| | w.load_state_dict(sd, strict=False) |
| |
|
| | gligen = Gligen(output_list, w.position_net, key_dim) |
| | return gligen |
| |
|