Spaces:
Runtime error
Runtime error
File size: 3,445 Bytes
f27e986 14f41f0 f27e986 14f41f0 f27e986 14f41f0 f27e986 14f41f0 f27e986 3966af7 f27e986 3966af7 14f41f0 3966af7 f27e986 14f41f0 f27e986 14f41f0 f27e986 14f41f0 f27e986 3966af7 14f41f0 f27e986 14f41f0 3966af7 14f41f0 f27e986 3966af7 14f41f0 f27e986 14f41f0 3966af7 14f41f0 f27e986 3966af7 14f41f0 f27e986 14f41f0 f27e986 14f41f0 3966af7 14f41f0 f27e986 3966af7 14f41f0 3966af7 f27e986 14f41f0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 | import torch
import torch.nn as nn
import pytorch_lightning as pl
class ResidualBlock(nn.Module):
def __init__(self, in_features, out_features, dropout=0.2):
super().__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(out_features, out_features)
self.projection = (
nn.Linear(in_features, out_features)
if in_features != out_features
else nn.Identity()
)
def forward(self, x):
residual = self.projection(x)
out = self.fc1(x)
out = self.relu(out)
out = self.dropout(out)
out = self.fc2(out)
return out + residual
class DualEncoderModel(pl.LightningModule):
def __init__(
self,
lab_cont_dim,
lab_cat_dims,
conv_cont_dim,
conv_cat_dims,
embedding_dim,
num_classes,
lr=1e-3,
):
super().__init__()
self.save_hyperparameters()
# Lab continuous
self.lab_cont_encoder = (
nn.Sequential(ResidualBlock(lab_cont_dim, 64), ResidualBlock(64, 64))
if lab_cont_dim > 0
else None
)
# Lab categorical
self.lab_cat_embeddings = nn.ModuleList(
[nn.Embedding(dim + 1, embedding_dim) for dim in lab_cat_dims]
)
# Conversation continuous
self.conv_cont_encoder = (
nn.Sequential(ResidualBlock(conv_cont_dim, 64), ResidualBlock(64, 64))
if conv_cont_dim > 0
else None
)
# Conversation categorical
self.conv_cat_embeddings = nn.ModuleList(
[nn.Embedding(dim + 1, embedding_dim) for dim in conv_cat_dims]
)
# Calculate total input dimension to classifier
total_dim = 0
if self.lab_cont_encoder:
total_dim += 64
if lab_cat_dims:
total_dim += embedding_dim * len(lab_cat_dims)
if self.conv_cont_encoder:
total_dim += 64
if conv_cat_dims:
total_dim += embedding_dim * len(conv_cat_dims)
self.classifier = nn.Sequential(
nn.Linear(total_dim, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, num_classes),
)
def forward(self, lab_cont, lab_cat, conv_cont, conv_cat):
embeddings = []
# Lab continuous
if self.lab_cont_encoder and lab_cont.nelement() > 0:
embeddings.append(self.lab_cont_encoder(lab_cont))
# Lab categorical
if self.lab_cat_embeddings and lab_cat.nelement() > 0:
embeddings.extend(
[
emb(torch.clamp(lab_cat[:, i], min=0))
for i, emb in enumerate(self.lab_cat_embeddings)
]
)
# Conv continuous
if self.conv_cont_encoder and conv_cont.nelement() > 0:
embeddings.append(self.conv_cont_encoder(conv_cont))
# Conv categorical
if self.conv_cat_embeddings and conv_cat.nelement() > 0:
embeddings.extend(
[
emb(torch.clamp(conv_cat[:, i], min=0))
for i, emb in enumerate(self.conv_cat_embeddings)
]
)
fused = torch.cat(embeddings, dim=1)
return self.classifier(fused)
|