import torch import torch.nn as nn import pytorch_lightning as pl class ResidualBlock(nn.Module): def __init__(self, in_features, out_features, dropout=0.2): super().__init__() self.fc1 = nn.Linear(in_features, out_features) self.relu = nn.ReLU() self.dropout = nn.Dropout(dropout) self.fc2 = nn.Linear(out_features, out_features) self.projection = ( nn.Linear(in_features, out_features) if in_features != out_features else nn.Identity() ) def forward(self, x): residual = self.projection(x) out = self.fc1(x) out = self.relu(out) out = self.dropout(out) out = self.fc2(out) return out + residual class DualEncoderModel(pl.LightningModule): def __init__( self, lab_cont_dim, lab_cat_dims, conv_cont_dim, conv_cat_dims, embedding_dim, num_classes, lr=1e-3, ): super().__init__() self.save_hyperparameters() # Lab continuous self.lab_cont_encoder = ( nn.Sequential(ResidualBlock(lab_cont_dim, 64), ResidualBlock(64, 64)) if lab_cont_dim > 0 else None ) # Lab categorical self.lab_cat_embeddings = nn.ModuleList( [nn.Embedding(dim + 1, embedding_dim) for dim in lab_cat_dims] ) # Conversation continuous self.conv_cont_encoder = ( nn.Sequential(ResidualBlock(conv_cont_dim, 64), ResidualBlock(64, 64)) if conv_cont_dim > 0 else None ) # Conversation categorical self.conv_cat_embeddings = nn.ModuleList( [nn.Embedding(dim + 1, embedding_dim) for dim in conv_cat_dims] ) # Calculate total input dimension to classifier total_dim = 0 if self.lab_cont_encoder: total_dim += 64 if lab_cat_dims: total_dim += embedding_dim * len(lab_cat_dims) if self.conv_cont_encoder: total_dim += 64 if conv_cat_dims: total_dim += embedding_dim * len(conv_cat_dims) self.classifier = nn.Sequential( nn.Linear(total_dim, 128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128, num_classes), ) def forward(self, lab_cont, lab_cat, conv_cont, conv_cat): embeddings = [] # Lab continuous if self.lab_cont_encoder and lab_cont.nelement() > 0: embeddings.append(self.lab_cont_encoder(lab_cont)) # Lab categorical if self.lab_cat_embeddings and lab_cat.nelement() > 0: embeddings.extend( [ emb(torch.clamp(lab_cat[:, i], min=0)) for i, emb in enumerate(self.lab_cat_embeddings) ] ) # Conv continuous if self.conv_cont_encoder and conv_cont.nelement() > 0: embeddings.append(self.conv_cont_encoder(conv_cont)) # Conv categorical if self.conv_cat_embeddings and conv_cat.nelement() > 0: embeddings.extend( [ emb(torch.clamp(conv_cat[:, i], min=0)) for i, emb in enumerate(self.conv_cat_embeddings) ] ) fused = torch.cat(embeddings, dim=1) return self.classifier(fused)