File size: 1,863 Bytes
130ecf1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
import torch.nn as nn
from transformers import AutoModel
class MultiModalSciBERT(nn.Module):
def __init__(self, scibert_model="allenai/scibert_scivocab_uncased",
use_classification_emb=True, use_proximity_emb=True, use_similarity_features=True):
super().__init__()
self.USE_CLASSIFICATION_EMB = use_classification_emb
self.USE_PROXIMITY_EMB = use_proximity_emb
self.USE_SIMILARITY_FEATURES = use_similarity_features
self.scibert = AutoModel.from_pretrained(scibert_model)
hidden_size = self.scibert.config.hidden_size
total_dim = hidden_size
if self.USE_CLASSIFICATION_EMB:
total_dim += 768
if self.USE_PROXIMITY_EMB:
total_dim += 768
if self.USE_SIMILARITY_FEATURES:
total_dim += 2
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(total_dim, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 2),
)
def forward(self, input_ids, attention_mask,
classification_emb=None, proximity_emb=None, similarity_features=None):
outputs = self.scibert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
feats = [cls_output]
if self.USE_CLASSIFICATION_EMB and classification_emb is not None:
feats.append(classification_emb)
if self.USE_PROXIMITY_EMB and proximity_emb is not None:
feats.append(proximity_emb)
if self.USE_SIMILARITY_FEATURES and similarity_features is not None:
feats.append(similarity_features)
combined = torch.cat(feats, dim=1)
return self.classifier(combined)
|