bohraanuj23 commited on
Commit
3966af7
·
1 Parent(s): 31d5be1

Added updated model classes.

Browse files
Files changed (1) hide show
  1. model.py +34 -16
model.py CHANGED
@@ -11,15 +11,19 @@ class ResidualBlock(nn.Module):
11
  self.dropout = nn.Dropout(dropout)
12
  self.fc2 = nn.Linear(out_features, out_features)
13
 
 
 
 
 
 
 
14
  def forward(self, x):
15
- residual = x
16
  out = self.fc1(x)
17
  out = self.relu(out)
18
  out = self.dropout(out)
19
  out = self.fc2(out)
20
- if residual.shape == out.shape:
21
- out += residual
22
- return out
23
 
24
 
25
  class DualEncoderModel(pl.LightningModule):
@@ -36,26 +40,31 @@ class DualEncoderModel(pl.LightningModule):
36
  super().__init__()
37
  self.save_hyperparameters()
38
 
 
39
  self.lab_cont_encoder = (
40
  nn.Sequential(ResidualBlock(lab_cont_dim, 64), ResidualBlock(64, 64))
41
  if lab_cont_dim > 0
42
  else None
43
  )
44
 
 
45
  self.lab_cat_embeddings = nn.ModuleList(
46
  [nn.Embedding(dim + 1, embedding_dim) for dim in lab_cat_dims]
47
  )
48
 
 
49
  self.conv_cont_encoder = (
50
  nn.Sequential(ResidualBlock(conv_cont_dim, 64), ResidualBlock(64, 64))
51
  if conv_cont_dim > 0
52
  else None
53
  )
54
 
 
55
  self.conv_cat_embeddings = nn.ModuleList(
56
  [nn.Embedding(dim + 1, embedding_dim) for dim in conv_cat_dims]
57
  )
58
 
 
59
  total_dim = 0
60
  if self.lab_cont_encoder:
61
  total_dim += 64
@@ -75,23 +84,32 @@ class DualEncoderModel(pl.LightningModule):
75
 
76
  def forward(self, lab_cont, lab_cat, conv_cont, conv_cat):
77
  embeddings = []
 
 
78
  if self.lab_cont_encoder and lab_cont.nelement() > 0:
79
  embeddings.append(self.lab_cont_encoder(lab_cont))
80
- embeddings.extend(
81
- [
82
- emb(torch.clamp(lab_cat[:, i], min=0))
83
- for i, emb in enumerate(self.lab_cat_embeddings)
84
- ]
85
- )
86
 
 
 
 
 
 
 
 
 
 
 
87
  if self.conv_cont_encoder and conv_cont.nelement() > 0:
88
  embeddings.append(self.conv_cont_encoder(conv_cont))
89
- embeddings.extend(
90
- [
91
- emb(torch.clamp(conv_cat[:, i], min=0))
92
- for i, emb in enumerate(self.conv_cat_embeddings)
93
- ]
94
- )
 
 
 
95
 
96
  fused = torch.cat(embeddings, dim=1)
97
  return self.classifier(fused)
 
11
  self.dropout = nn.Dropout(dropout)
12
  self.fc2 = nn.Linear(out_features, out_features)
13
 
14
+ self.projection = (
15
+ nn.Linear(in_features, out_features)
16
+ if in_features != out_features
17
+ else nn.Identity()
18
+ )
19
+
20
  def forward(self, x):
21
+ residual = self.projection(x)
22
  out = self.fc1(x)
23
  out = self.relu(out)
24
  out = self.dropout(out)
25
  out = self.fc2(out)
26
+ return out + residual
 
 
27
 
28
 
29
  class DualEncoderModel(pl.LightningModule):
 
40
  super().__init__()
41
  self.save_hyperparameters()
42
 
43
+ # Lab continuous
44
  self.lab_cont_encoder = (
45
  nn.Sequential(ResidualBlock(lab_cont_dim, 64), ResidualBlock(64, 64))
46
  if lab_cont_dim > 0
47
  else None
48
  )
49
 
50
+ # Lab categorical
51
  self.lab_cat_embeddings = nn.ModuleList(
52
  [nn.Embedding(dim + 1, embedding_dim) for dim in lab_cat_dims]
53
  )
54
 
55
+ # Conversation continuous
56
  self.conv_cont_encoder = (
57
  nn.Sequential(ResidualBlock(conv_cont_dim, 64), ResidualBlock(64, 64))
58
  if conv_cont_dim > 0
59
  else None
60
  )
61
 
62
+ # Conversation categorical
63
  self.conv_cat_embeddings = nn.ModuleList(
64
  [nn.Embedding(dim + 1, embedding_dim) for dim in conv_cat_dims]
65
  )
66
 
67
+ # Calculate total input dimension to classifier
68
  total_dim = 0
69
  if self.lab_cont_encoder:
70
  total_dim += 64
 
84
 
85
  def forward(self, lab_cont, lab_cat, conv_cont, conv_cat):
86
  embeddings = []
87
+
88
+ # Lab continuous
89
  if self.lab_cont_encoder and lab_cont.nelement() > 0:
90
  embeddings.append(self.lab_cont_encoder(lab_cont))
 
 
 
 
 
 
91
 
92
+ # Lab categorical
93
+ if self.lab_cat_embeddings and lab_cat.nelement() > 0:
94
+ embeddings.extend(
95
+ [
96
+ emb(torch.clamp(lab_cat[:, i], min=0))
97
+ for i, emb in enumerate(self.lab_cat_embeddings)
98
+ ]
99
+ )
100
+
101
+ # Conv continuous
102
  if self.conv_cont_encoder and conv_cont.nelement() > 0:
103
  embeddings.append(self.conv_cont_encoder(conv_cont))
104
+
105
+ # Conv categorical
106
+ if self.conv_cat_embeddings and conv_cat.nelement() > 0:
107
+ embeddings.extend(
108
+ [
109
+ emb(torch.clamp(conv_cat[:, i], min=0))
110
+ for i, emb in enumerate(self.conv_cat_embeddings)
111
+ ]
112
+ )
113
 
114
  fused = torch.cat(embeddings, dim=1)
115
  return self.classifier(fused)