Spaces:
Sleeping
Sleeping
File size: 3,272 Bytes
b57c46e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | # Adapted from LaM-SLidE
# https://github.com/ml-jku/LaM-SLidE/blob/main/src/models/components/encoder.py
from abc import ABC
from functools import partial
import torch
import torch.nn as nn
from einops import repeat
from .torch_modules import CrossAttentionBlock, SelfAttentionBlock
class EncoderBase(ABC, nn.Module):
def __init__(
self,
dim_input: int,
dim_latent: int,
num_latents: int,
entity_embedding: nn.Module,
dropout_latent: float = 0.0,
act: nn.Module = partial(nn.GELU, approximate="tanh"),
):
super().__init__()
self.entity_embedding = entity_embedding
self.dropout_latent = nn.Dropout2d(dropout_latent)
self.dim_input = dim_input
self.dim_context = dim_input + self.entity_embedding.embedding_dim
self.latents = nn.Parameter(torch.randn(num_latents, dim_latent))
self.mlp = nn.Sequential(
nn.Linear(self.dim_context, dim_latent),
act(),
nn.Linear(dim_latent, self.dim_context),
)
def prepare_inputs(self, x, entities):
entity_embeddings = self.entity_embedding(entities)
x = torch.cat([x, entity_embeddings], dim=-1)
x = self.mlp(x)
latents = repeat(self.latents, "N D -> B N D", B=x.shape[0])
latents = self.dropout_latent(latents)
return x, latents
class Encoder(EncoderBase):
def __init__(
self,
dim_input: int,
dim_latent: int,
dim_head_cross: int,
dim_head_latent: int,
num_latents: int,
num_head_cross: int,
num_head_latent: int,
num_block_cross: int,
num_block_attn: int,
qk_norm: bool,
entity_embedding: nn.Module,
dropout_latent: float = 0.0,
act: nn.Module = partial(nn.GELU, approximate="tanh"),
):
super().__init__(
dim_input=dim_input,
dim_latent=dim_latent,
num_latents=num_latents,
entity_embedding=entity_embedding,
act=act,
dropout_latent=dropout_latent,
)
self.cross_attn_blocks = nn.ModuleList([])
for _ in range(num_block_cross):
self.cross_attn_blocks.append(
CrossAttentionBlock(
dim=dim_latent,
heads=num_head_cross,
dim_head=dim_head_cross,
act=act,
context_dim=self.dim_context,
qk_norm=qk_norm,
)
)
self.blocks_attn = nn.ModuleList([])
for _ in range(num_block_attn):
self.blocks_attn.append(
SelfAttentionBlock(
dim=dim_latent,
heads=num_head_latent,
dim_head=dim_head_latent,
act=act,
qk_norm=qk_norm,
)
)
def forward(self, x, entities, mask=None):
x, latents = self.prepare_inputs(x, entities)
for cross_attn in self.cross_attn_blocks:
latents = cross_attn(latents, context=x, mask=mask)
for self_attn in self.blocks_attn:
latents = self_attn(latents)
return latents |