# Adapted from LaM-SLidE # https://github.com/ml-jku/LaM-SLidE/blob/main/src/models/components/encoder.py from abc import ABC from functools import partial import torch import torch.nn as nn from einops import repeat from .torch_modules import CrossAttentionBlock, SelfAttentionBlock class EncoderBase(ABC, nn.Module): def __init__( self, dim_input: int, dim_latent: int, num_latents: int, entity_embedding: nn.Module, dropout_latent: float = 0.0, act: nn.Module = partial(nn.GELU, approximate="tanh"), ): super().__init__() self.entity_embedding = entity_embedding self.dropout_latent = nn.Dropout2d(dropout_latent) self.dim_input = dim_input self.dim_context = dim_input + self.entity_embedding.embedding_dim self.latents = nn.Parameter(torch.randn(num_latents, dim_latent)) self.mlp = nn.Sequential( nn.Linear(self.dim_context, dim_latent), act(), nn.Linear(dim_latent, self.dim_context), ) def prepare_inputs(self, x, entities): entity_embeddings = self.entity_embedding(entities) x = torch.cat([x, entity_embeddings], dim=-1) x = self.mlp(x) latents = repeat(self.latents, "N D -> B N D", B=x.shape[0]) latents = self.dropout_latent(latents) return x, latents class Encoder(EncoderBase): def __init__( self, dim_input: int, dim_latent: int, dim_head_cross: int, dim_head_latent: int, num_latents: int, num_head_cross: int, num_head_latent: int, num_block_cross: int, num_block_attn: int, qk_norm: bool, entity_embedding: nn.Module, dropout_latent: float = 0.0, act: nn.Module = partial(nn.GELU, approximate="tanh"), ): super().__init__( dim_input=dim_input, dim_latent=dim_latent, num_latents=num_latents, entity_embedding=entity_embedding, act=act, dropout_latent=dropout_latent, ) self.cross_attn_blocks = nn.ModuleList([]) for _ in range(num_block_cross): self.cross_attn_blocks.append( CrossAttentionBlock( dim=dim_latent, heads=num_head_cross, dim_head=dim_head_cross, act=act, context_dim=self.dim_context, qk_norm=qk_norm, ) ) self.blocks_attn = nn.ModuleList([]) for _ in range(num_block_attn): self.blocks_attn.append( SelfAttentionBlock( dim=dim_latent, heads=num_head_latent, dim_head=dim_head_latent, act=act, qk_norm=qk_norm, ) ) def forward(self, x, entities, mask=None): x, latents = self.prepare_inputs(x, entities) for cross_attn in self.cross_attn_blocks: latents = cross_attn(latents, context=x, mask=mask) for self_attn in self.blocks_attn: latents = self_attn(latents) return latents