# Adapted from LaM-SLidE # https://github.com/ml-jku/LaM-SLidE/blob/main/src/models/components/decoder.py from functools import partial from typing import Dict import torch import torch.nn as nn from einops import repeat from einops.layers.torch import Rearrange from .torch_modules import CrossAttentionBlock, SelfAttentionBlock class Decoder(nn.Module): def __init__( self, outputs: Dict[str, int], dim_query: int, dim_latent: int, entity_embedding: nn.Module, dim_head_cross: int = 64, dim_head_latent: int = 64, num_head_cross: int = 1, num_head_latent: int = 4, num_block_cross: int = 2, num_block_attn: int = 4, dropout_query: float = 0.1, dropout_latent: float = 0.0, qk_norm: bool = False, act: nn.Module = partial(nn.GELU, approximate="tanh"), ): super().__init__() self.entity_embedding = entity_embedding dim_entity_embedding = self.entity_embedding.embedding_dim self.query_mlp = nn.Sequential( nn.Dropout(dropout_query), nn.Linear(dim_entity_embedding, dim_query), ) self.dropout_latent = nn.Dropout(dropout_latent) self.self_attn_blocks = nn.ModuleList( [ SelfAttentionBlock( dim_latent, heads=num_head_latent, dim_head=dim_head_latent, act=act, qk_norm=qk_norm, ) for _ in range(num_block_attn) ] ) self.cross_attn_blocks = nn.ModuleList( [ CrossAttentionBlock( dim=dim_latent, heads=num_head_cross, dim_head=dim_head_cross, act=act, context_dim=dim_query, qk_norm=qk_norm, ) for _ in range(num_block_cross) ] ) self.output_block = CrossAttentionBlock( dim=dim_query, heads=num_head_cross, dim_head=dim_head_cross, act=act, context_dim=dim_latent, qk_norm=qk_norm, ) self.output_layers = nn.ModuleDict() for name, out_dim in outputs.items(): self.output_layers[name] = nn.Sequential( nn.Linear(dim_query, dim_query), act(), nn.Linear(dim_query, out_dim), ) def queries(self, entities): entity_embeddings = self.entity_embedding(entities) queries = self.query_mlp(entity_embeddings) return queries def forward(self, latent, entities): queries = self.queries(entities) latent = self.dropout_latent(latent) for block in self.self_attn_blocks: latent = block(latent) for block in self.cross_attn_blocks: latent = block(latent, context=queries) out_block = self.output_block(queries, context=latent) outputs = {} for name in self.output_layers.keys(): outputs[name] = self.output_layers[name](out_block) return outputs