SeverityAnalysis / model.py
bohraanuj23's picture
Added updated model classes.
3966af7
import torch
import torch.nn as nn
import pytorch_lightning as pl
class ResidualBlock(nn.Module):
def __init__(self, in_features, out_features, dropout=0.2):
super().__init__()
self.fc1 = nn.Linear(in_features, out_features)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
self.fc2 = nn.Linear(out_features, out_features)
self.projection = (
nn.Linear(in_features, out_features)
if in_features != out_features
else nn.Identity()
)
def forward(self, x):
residual = self.projection(x)
out = self.fc1(x)
out = self.relu(out)
out = self.dropout(out)
out = self.fc2(out)
return out + residual
class DualEncoderModel(pl.LightningModule):
def __init__(
self,
lab_cont_dim,
lab_cat_dims,
conv_cont_dim,
conv_cat_dims,
embedding_dim,
num_classes,
lr=1e-3,
):
super().__init__()
self.save_hyperparameters()
# Lab continuous
self.lab_cont_encoder = (
nn.Sequential(ResidualBlock(lab_cont_dim, 64), ResidualBlock(64, 64))
if lab_cont_dim > 0
else None
)
# Lab categorical
self.lab_cat_embeddings = nn.ModuleList(
[nn.Embedding(dim + 1, embedding_dim) for dim in lab_cat_dims]
)
# Conversation continuous
self.conv_cont_encoder = (
nn.Sequential(ResidualBlock(conv_cont_dim, 64), ResidualBlock(64, 64))
if conv_cont_dim > 0
else None
)
# Conversation categorical
self.conv_cat_embeddings = nn.ModuleList(
[nn.Embedding(dim + 1, embedding_dim) for dim in conv_cat_dims]
)
# Calculate total input dimension to classifier
total_dim = 0
if self.lab_cont_encoder:
total_dim += 64
if lab_cat_dims:
total_dim += embedding_dim * len(lab_cat_dims)
if self.conv_cont_encoder:
total_dim += 64
if conv_cat_dims:
total_dim += embedding_dim * len(conv_cat_dims)
self.classifier = nn.Sequential(
nn.Linear(total_dim, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, num_classes),
)
def forward(self, lab_cont, lab_cat, conv_cont, conv_cat):
embeddings = []
# Lab continuous
if self.lab_cont_encoder and lab_cont.nelement() > 0:
embeddings.append(self.lab_cont_encoder(lab_cont))
# Lab categorical
if self.lab_cat_embeddings and lab_cat.nelement() > 0:
embeddings.extend(
[
emb(torch.clamp(lab_cat[:, i], min=0))
for i, emb in enumerate(self.lab_cat_embeddings)
]
)
# Conv continuous
if self.conv_cont_encoder and conv_cont.nelement() > 0:
embeddings.append(self.conv_cont_encoder(conv_cont))
# Conv categorical
if self.conv_cat_embeddings and conv_cat.nelement() > 0:
embeddings.extend(
[
emb(torch.clamp(conv_cat[:, i], min=0))
for i, emb in enumerate(self.conv_cat_embeddings)
]
)
fused = torch.cat(embeddings, dim=1)
return self.classifier(fused)