Spaces:
Sleeping
Sleeping
| # Adapted from LaM-SLidE | |
| # https://github.com/ml-jku/LaM-SLidE/blob/main/src/models/components/decoder.py | |
| from functools import partial | |
| from typing import Dict | |
| import torch | |
| import torch.nn as nn | |
| from einops import repeat | |
| from einops.layers.torch import Rearrange | |
| from .torch_modules import CrossAttentionBlock, SelfAttentionBlock | |
| class Decoder(nn.Module): | |
| def __init__( | |
| self, | |
| outputs: Dict[str, int], | |
| dim_query: int, | |
| dim_latent: int, | |
| entity_embedding: nn.Module, | |
| dim_head_cross: int = 64, | |
| dim_head_latent: int = 64, | |
| num_head_cross: int = 1, | |
| num_head_latent: int = 4, | |
| num_block_cross: int = 2, | |
| num_block_attn: int = 4, | |
| dropout_query: float = 0.1, | |
| dropout_latent: float = 0.0, | |
| qk_norm: bool = False, | |
| act: nn.Module = partial(nn.GELU, approximate="tanh"), | |
| ): | |
| super().__init__() | |
| self.entity_embedding = entity_embedding | |
| dim_entity_embedding = self.entity_embedding.embedding_dim | |
| self.query_mlp = nn.Sequential( | |
| nn.Dropout(dropout_query), | |
| nn.Linear(dim_entity_embedding, dim_query), | |
| ) | |
| self.dropout_latent = nn.Dropout(dropout_latent) | |
| self.self_attn_blocks = nn.ModuleList( | |
| [ | |
| SelfAttentionBlock( | |
| dim_latent, | |
| heads=num_head_latent, | |
| dim_head=dim_head_latent, | |
| act=act, | |
| qk_norm=qk_norm, | |
| ) | |
| for _ in range(num_block_attn) | |
| ] | |
| ) | |
| self.cross_attn_blocks = nn.ModuleList( | |
| [ | |
| CrossAttentionBlock( | |
| dim=dim_latent, | |
| heads=num_head_cross, | |
| dim_head=dim_head_cross, | |
| act=act, | |
| context_dim=dim_query, | |
| qk_norm=qk_norm, | |
| ) | |
| for _ in range(num_block_cross) | |
| ] | |
| ) | |
| self.output_block = CrossAttentionBlock( | |
| dim=dim_query, | |
| heads=num_head_cross, | |
| dim_head=dim_head_cross, | |
| act=act, | |
| context_dim=dim_latent, | |
| qk_norm=qk_norm, | |
| ) | |
| self.output_layers = nn.ModuleDict() | |
| for name, out_dim in outputs.items(): | |
| self.output_layers[name] = nn.Sequential( | |
| nn.Linear(dim_query, dim_query), | |
| act(), | |
| nn.Linear(dim_query, out_dim), | |
| ) | |
| def queries(self, entities): | |
| entity_embeddings = self.entity_embedding(entities) | |
| queries = self.query_mlp(entity_embeddings) | |
| return queries | |
| def forward(self, latent, entities): | |
| queries = self.queries(entities) | |
| latent = self.dropout_latent(latent) | |
| for block in self.self_attn_blocks: | |
| latent = block(latent) | |
| for block in self.cross_attn_blocks: | |
| latent = block(latent, context=queries) | |
| out_block = self.output_block(queries, context=latent) | |
| outputs = {} | |
| for name in self.output_layers.keys(): | |
| outputs[name] = self.output_layers[name](out_block) | |
| return outputs |