| import torch.nn as nn |
|
|
| |
| class DisorderPredictor(nn.Module): |
| def __init__(self, input_dim, hidden_dim, num_heads, num_layers, dropout): |
| super(DisorderPredictor, self).__init__() |
| self.embedding_dim = input_dim |
| self.self_attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads, dropout=dropout) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=hidden_dim, |
| nhead=num_heads, |
| dropout=dropout, |
| batch_first=True |
| ) |
| self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| self.classifier = nn.Linear(input_dim, 1) |
| |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, embeddings): |
| attn_out, _ = self.self_attention(embeddings, embeddings, embeddings) |
| transformer_out = self.transformer_encoder(attn_out) |
| logits = self.classifier(transformer_out) |
| probs = self.sigmoid(logits.squeeze(-1)) |
| return probs |
| |