File size: 3,246 Bytes
b57c46e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Adapted from LaM-SLidE
# https://github.com/ml-jku/LaM-SLidE/blob/main/src/models/components/decoder.py

from functools import partial
from typing import Dict

import torch
import torch.nn as nn
from einops import repeat
from einops.layers.torch import Rearrange

from .torch_modules import CrossAttentionBlock, SelfAttentionBlock


class Decoder(nn.Module):

    def __init__(
        self,
        outputs: Dict[str, int],
        dim_query: int,
        dim_latent: int,
        entity_embedding: nn.Module,
        dim_head_cross: int = 64,
        dim_head_latent: int = 64,
        num_head_cross: int = 1,
        num_head_latent: int = 4,
        num_block_cross: int = 2,
        num_block_attn: int = 4,
        dropout_query: float = 0.1,
        dropout_latent: float = 0.0,
        qk_norm: bool = False,
        act: nn.Module = partial(nn.GELU, approximate="tanh"),
    ):
        super().__init__()

        self.entity_embedding = entity_embedding
        dim_entity_embedding = self.entity_embedding.embedding_dim
        self.query_mlp = nn.Sequential(
            nn.Dropout(dropout_query),
            nn.Linear(dim_entity_embedding, dim_query),
        )
        self.dropout_latent = nn.Dropout(dropout_latent)
        self.self_attn_blocks = nn.ModuleList(
            [
                SelfAttentionBlock(
                    dim_latent,
                    heads=num_head_latent,
                    dim_head=dim_head_latent,
                    act=act,
                    qk_norm=qk_norm,
                )
                for _ in range(num_block_attn)
            ]
        )
        self.cross_attn_blocks = nn.ModuleList(
            [
                CrossAttentionBlock(
                    dim=dim_latent,
                    heads=num_head_cross,
                    dim_head=dim_head_cross,
                    act=act,
                    context_dim=dim_query,
                    qk_norm=qk_norm,
                )
                for _ in range(num_block_cross)
            ]
        )

        self.output_block = CrossAttentionBlock(
            dim=dim_query,
            heads=num_head_cross,
            dim_head=dim_head_cross,
            act=act,
            context_dim=dim_latent,
            qk_norm=qk_norm,
        )

        self.output_layers = nn.ModuleDict()
        for name, out_dim in outputs.items():
            self.output_layers[name] = nn.Sequential(
                nn.Linear(dim_query, dim_query),
                act(),
                nn.Linear(dim_query, out_dim),
            )

    def queries(self, entities):
        entity_embeddings = self.entity_embedding(entities)
        queries = self.query_mlp(entity_embeddings)
        return queries

    def forward(self, latent, entities):
        queries = self.queries(entities)

        latent = self.dropout_latent(latent)
        for block in self.self_attn_blocks:
            latent = block(latent)
        for block in self.cross_attn_blocks:
            latent = block(latent, context=queries)

        out_block = self.output_block(queries, context=latent)

        outputs = {}
        for name in self.output_layers.keys():
            outputs[name] = self.output_layers[name](out_block)
        return outputs