Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import pytorch_lightning as pl | |
| class ResidualBlock(nn.Module): | |
| def __init__(self, in_features, out_features, dropout=0.2): | |
| super().__init__() | |
| self.fc1 = nn.Linear(in_features, out_features) | |
| self.relu = nn.ReLU() | |
| self.dropout = nn.Dropout(dropout) | |
| self.fc2 = nn.Linear(out_features, out_features) | |
| self.projection = ( | |
| nn.Linear(in_features, out_features) | |
| if in_features != out_features | |
| else nn.Identity() | |
| ) | |
| def forward(self, x): | |
| residual = self.projection(x) | |
| out = self.fc1(x) | |
| out = self.relu(out) | |
| out = self.dropout(out) | |
| out = self.fc2(out) | |
| return out + residual | |
| class DualEncoderModel(pl.LightningModule): | |
| def __init__( | |
| self, | |
| lab_cont_dim, | |
| lab_cat_dims, | |
| conv_cont_dim, | |
| conv_cat_dims, | |
| embedding_dim, | |
| num_classes, | |
| lr=1e-3, | |
| ): | |
| super().__init__() | |
| self.save_hyperparameters() | |
| # Lab continuous | |
| self.lab_cont_encoder = ( | |
| nn.Sequential(ResidualBlock(lab_cont_dim, 64), ResidualBlock(64, 64)) | |
| if lab_cont_dim > 0 | |
| else None | |
| ) | |
| # Lab categorical | |
| self.lab_cat_embeddings = nn.ModuleList( | |
| [nn.Embedding(dim + 1, embedding_dim) for dim in lab_cat_dims] | |
| ) | |
| # Conversation continuous | |
| self.conv_cont_encoder = ( | |
| nn.Sequential(ResidualBlock(conv_cont_dim, 64), ResidualBlock(64, 64)) | |
| if conv_cont_dim > 0 | |
| else None | |
| ) | |
| # Conversation categorical | |
| self.conv_cat_embeddings = nn.ModuleList( | |
| [nn.Embedding(dim + 1, embedding_dim) for dim in conv_cat_dims] | |
| ) | |
| # Calculate total input dimension to classifier | |
| total_dim = 0 | |
| if self.lab_cont_encoder: | |
| total_dim += 64 | |
| if lab_cat_dims: | |
| total_dim += embedding_dim * len(lab_cat_dims) | |
| if self.conv_cont_encoder: | |
| total_dim += 64 | |
| if conv_cat_dims: | |
| total_dim += embedding_dim * len(conv_cat_dims) | |
| self.classifier = nn.Sequential( | |
| nn.Linear(total_dim, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.3), | |
| nn.Linear(128, num_classes), | |
| ) | |
| def forward(self, lab_cont, lab_cat, conv_cont, conv_cat): | |
| embeddings = [] | |
| # Lab continuous | |
| if self.lab_cont_encoder and lab_cont.nelement() > 0: | |
| embeddings.append(self.lab_cont_encoder(lab_cont)) | |
| # Lab categorical | |
| if self.lab_cat_embeddings and lab_cat.nelement() > 0: | |
| embeddings.extend( | |
| [ | |
| emb(torch.clamp(lab_cat[:, i], min=0)) | |
| for i, emb in enumerate(self.lab_cat_embeddings) | |
| ] | |
| ) | |
| # Conv continuous | |
| if self.conv_cont_encoder and conv_cont.nelement() > 0: | |
| embeddings.append(self.conv_cont_encoder(conv_cont)) | |
| # Conv categorical | |
| if self.conv_cat_embeddings and conv_cat.nelement() > 0: | |
| embeddings.extend( | |
| [ | |
| emb(torch.clamp(conv_cat[:, i], min=0)) | |
| for i, emb in enumerate(self.conv_cat_embeddings) | |
| ] | |
| ) | |
| fused = torch.cat(embeddings, dim=1) | |
| return self.classifier(fused) | |