scibert-multimodal-novelty / modeling_multimodal_scibert.py
JasonYan777's picture
Upload folder using huggingface_hub
130ecf1 verified
import torch
import torch.nn as nn
from transformers import AutoModel
class MultiModalSciBERT(nn.Module):
def __init__(self, scibert_model="allenai/scibert_scivocab_uncased",
use_classification_emb=True, use_proximity_emb=True, use_similarity_features=True):
super().__init__()
self.USE_CLASSIFICATION_EMB = use_classification_emb
self.USE_PROXIMITY_EMB = use_proximity_emb
self.USE_SIMILARITY_FEATURES = use_similarity_features
self.scibert = AutoModel.from_pretrained(scibert_model)
hidden_size = self.scibert.config.hidden_size
total_dim = hidden_size
if self.USE_CLASSIFICATION_EMB:
total_dim += 768
if self.USE_PROXIMITY_EMB:
total_dim += 768
if self.USE_SIMILARITY_FEATURES:
total_dim += 2
self.classifier = nn.Sequential(
nn.Dropout(0.1),
nn.Linear(total_dim, 512),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(512, 128),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(128, 2),
)
def forward(self, input_ids, attention_mask,
classification_emb=None, proximity_emb=None, similarity_features=None):
outputs = self.scibert(input_ids=input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :]
feats = [cls_output]
if self.USE_CLASSIFICATION_EMB and classification_emb is not None:
feats.append(classification_emb)
if self.USE_PROXIMITY_EMB and proximity_emb is not None:
feats.append(proximity_emb)
if self.USE_SIMILARITY_FEATURES and similarity_features is not None:
feats.append(similarity_features)
combined = torch.cat(feats, dim=1)
return self.classifier(combined)