score-ae / src /model /decoder.py
hroth's picture
Upload 90 files
b57c46e verified
raw
history blame
3.25 kB
# Adapted from LaM-SLidE
# https://github.com/ml-jku/LaM-SLidE/blob/main/src/models/components/decoder.py
from functools import partial
from typing import Dict
import torch
import torch.nn as nn
from einops import repeat
from einops.layers.torch import Rearrange
from .torch_modules import CrossAttentionBlock, SelfAttentionBlock
class Decoder(nn.Module):
def __init__(
self,
outputs: Dict[str, int],
dim_query: int,
dim_latent: int,
entity_embedding: nn.Module,
dim_head_cross: int = 64,
dim_head_latent: int = 64,
num_head_cross: int = 1,
num_head_latent: int = 4,
num_block_cross: int = 2,
num_block_attn: int = 4,
dropout_query: float = 0.1,
dropout_latent: float = 0.0,
qk_norm: bool = False,
act: nn.Module = partial(nn.GELU, approximate="tanh"),
):
super().__init__()
self.entity_embedding = entity_embedding
dim_entity_embedding = self.entity_embedding.embedding_dim
self.query_mlp = nn.Sequential(
nn.Dropout(dropout_query),
nn.Linear(dim_entity_embedding, dim_query),
)
self.dropout_latent = nn.Dropout(dropout_latent)
self.self_attn_blocks = nn.ModuleList(
[
SelfAttentionBlock(
dim_latent,
heads=num_head_latent,
dim_head=dim_head_latent,
act=act,
qk_norm=qk_norm,
)
for _ in range(num_block_attn)
]
)
self.cross_attn_blocks = nn.ModuleList(
[
CrossAttentionBlock(
dim=dim_latent,
heads=num_head_cross,
dim_head=dim_head_cross,
act=act,
context_dim=dim_query,
qk_norm=qk_norm,
)
for _ in range(num_block_cross)
]
)
self.output_block = CrossAttentionBlock(
dim=dim_query,
heads=num_head_cross,
dim_head=dim_head_cross,
act=act,
context_dim=dim_latent,
qk_norm=qk_norm,
)
self.output_layers = nn.ModuleDict()
for name, out_dim in outputs.items():
self.output_layers[name] = nn.Sequential(
nn.Linear(dim_query, dim_query),
act(),
nn.Linear(dim_query, out_dim),
)
def queries(self, entities):
entity_embeddings = self.entity_embedding(entities)
queries = self.query_mlp(entity_embeddings)
return queries
def forward(self, latent, entities):
queries = self.queries(entities)
latent = self.dropout_latent(latent)
for block in self.self_attn_blocks:
latent = block(latent)
for block in self.cross_attn_blocks:
latent = block(latent, context=queries)
out_block = self.output_block(queries, context=latent)
outputs = {}
for name in self.output_layers.keys():
outputs[name] = self.output_layers[name](out_block)
return outputs