File size: 3,272 Bytes
b57c46e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Adapted from LaM-SLidE
# https://github.com/ml-jku/LaM-SLidE/blob/main/src/models/components/encoder.py

from abc import ABC
from functools import partial

import torch
import torch.nn as nn
from einops import repeat

from .torch_modules import CrossAttentionBlock, SelfAttentionBlock


class EncoderBase(ABC, nn.Module):
    def __init__(
        self,
        dim_input: int,
        dim_latent: int,
        num_latents: int,
        entity_embedding: nn.Module,
        dropout_latent: float = 0.0,
        act: nn.Module = partial(nn.GELU, approximate="tanh"),
    ):
        super().__init__()
        self.entity_embedding = entity_embedding
        self.dropout_latent = nn.Dropout2d(dropout_latent)
        self.dim_input = dim_input
        self.dim_context = dim_input + self.entity_embedding.embedding_dim

        self.latents = nn.Parameter(torch.randn(num_latents, dim_latent))

        self.mlp = nn.Sequential(
            nn.Linear(self.dim_context, dim_latent),
            act(),
            nn.Linear(dim_latent, self.dim_context),
        )

    def prepare_inputs(self, x, entities):
        entity_embeddings = self.entity_embedding(entities)
        x = torch.cat([x, entity_embeddings], dim=-1)
        x = self.mlp(x)
        latents = repeat(self.latents, "N D -> B N D", B=x.shape[0])
        latents = self.dropout_latent(latents)
        return x, latents


class Encoder(EncoderBase):

    def __init__(
        self,
        dim_input: int,
        dim_latent: int,
        dim_head_cross: int,
        dim_head_latent: int,
        num_latents: int,
        num_head_cross: int,
        num_head_latent: int,
        num_block_cross: int,
        num_block_attn: int,
        qk_norm: bool,
        entity_embedding: nn.Module,
        dropout_latent: float = 0.0,
        act: nn.Module = partial(nn.GELU, approximate="tanh"),
    ):
        super().__init__(
            dim_input=dim_input,
            dim_latent=dim_latent,
            num_latents=num_latents,
            entity_embedding=entity_embedding,
            act=act,
            dropout_latent=dropout_latent,
        )
        self.cross_attn_blocks = nn.ModuleList([])

        for _ in range(num_block_cross):
            self.cross_attn_blocks.append(
                CrossAttentionBlock(
                    dim=dim_latent,
                    heads=num_head_cross,
                    dim_head=dim_head_cross,
                    act=act,
                    context_dim=self.dim_context,
                    qk_norm=qk_norm,
                )
            )

        self.blocks_attn = nn.ModuleList([])
        for _ in range(num_block_attn):
            self.blocks_attn.append(
                SelfAttentionBlock(
                    dim=dim_latent,
                    heads=num_head_latent,
                    dim_head=dim_head_latent,
                    act=act,
                    qk_norm=qk_norm,
                )
            )

    def forward(self, x, entities, mask=None):
        x, latents = self.prepare_inputs(x, entities)

        for cross_attn in self.cross_attn_blocks:
            latents = cross_attn(latents, context=x, mask=mask)
        for self_attn in self.blocks_attn:
            latents = self_attn(latents)
        return latents