File size: 3,445 Bytes
f27e986
 
14f41f0
f27e986
 
 
14f41f0
f27e986
14f41f0
f27e986
14f41f0
 
f27e986
3966af7
 
 
 
 
 
f27e986
3966af7
14f41f0
 
 
 
3966af7
f27e986
 
14f41f0
f27e986
 
 
 
 
 
 
 
14f41f0
f27e986
 
14f41f0
f27e986
3966af7
14f41f0
 
 
 
f27e986
14f41f0
3966af7
14f41f0
 
f27e986
 
3966af7
14f41f0
 
 
 
f27e986
14f41f0
3966af7
14f41f0
 
f27e986
 
3966af7
14f41f0
 
 
 
 
 
 
 
 
 
f27e986
14f41f0
 
 
 
f27e986
 
 
14f41f0
3966af7
 
14f41f0
 
f27e986
3966af7
 
 
 
 
 
 
 
 
 
14f41f0
 
3966af7
 
 
 
 
 
 
 
 
f27e986
14f41f0
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import torch
import torch.nn as nn
import pytorch_lightning as pl


class ResidualBlock(nn.Module):
    def __init__(self, in_features, out_features, dropout=0.2):
        super().__init__()
        self.fc1 = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        self.fc2 = nn.Linear(out_features, out_features)

        self.projection = (
            nn.Linear(in_features, out_features)
            if in_features != out_features
            else nn.Identity()
        )

    def forward(self, x):
        residual = self.projection(x)
        out = self.fc1(x)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out + residual


class DualEncoderModel(pl.LightningModule):
    def __init__(
        self,
        lab_cont_dim,
        lab_cat_dims,
        conv_cont_dim,
        conv_cat_dims,
        embedding_dim,
        num_classes,
        lr=1e-3,
    ):
        super().__init__()
        self.save_hyperparameters()

        # Lab continuous
        self.lab_cont_encoder = (
            nn.Sequential(ResidualBlock(lab_cont_dim, 64), ResidualBlock(64, 64))
            if lab_cont_dim > 0
            else None
        )

        # Lab categorical
        self.lab_cat_embeddings = nn.ModuleList(
            [nn.Embedding(dim + 1, embedding_dim) for dim in lab_cat_dims]
        )

        # Conversation continuous
        self.conv_cont_encoder = (
            nn.Sequential(ResidualBlock(conv_cont_dim, 64), ResidualBlock(64, 64))
            if conv_cont_dim > 0
            else None
        )

        # Conversation categorical
        self.conv_cat_embeddings = nn.ModuleList(
            [nn.Embedding(dim + 1, embedding_dim) for dim in conv_cat_dims]
        )

        # Calculate total input dimension to classifier
        total_dim = 0
        if self.lab_cont_encoder:
            total_dim += 64
        if lab_cat_dims:
            total_dim += embedding_dim * len(lab_cat_dims)
        if self.conv_cont_encoder:
            total_dim += 64
        if conv_cat_dims:
            total_dim += embedding_dim * len(conv_cat_dims)

        self.classifier = nn.Sequential(
            nn.Linear(total_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, num_classes),
        )

    def forward(self, lab_cont, lab_cat, conv_cont, conv_cat):
        embeddings = []

        # Lab continuous
        if self.lab_cont_encoder and lab_cont.nelement() > 0:
            embeddings.append(self.lab_cont_encoder(lab_cont))

        # Lab categorical
        if self.lab_cat_embeddings and lab_cat.nelement() > 0:
            embeddings.extend(
                [
                    emb(torch.clamp(lab_cat[:, i], min=0))
                    for i, emb in enumerate(self.lab_cat_embeddings)
                ]
            )

        # Conv continuous
        if self.conv_cont_encoder and conv_cont.nelement() > 0:
            embeddings.append(self.conv_cont_encoder(conv_cont))

        # Conv categorical
        if self.conv_cat_embeddings and conv_cat.nelement() > 0:
            embeddings.extend(
                [
                    emb(torch.clamp(conv_cat[:, i], min=0))
                    for i, emb in enumerate(self.conv_cat_embeddings)
                ]
            )

        fused = torch.cat(embeddings, dim=1)
        return self.classifier(fused)