File size: 2,660 Bytes
0f8411f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | import torch
import torch.nn as nn
from transformers import BertModel, AutoModel
class bert_labeler(nn.Module):
def __init__(self, p=0.1, clinical=False, freeze_embeddings=False, pretrain_path=None):
""" Init the labeler module
@param p (float): p to use for dropout in the linear heads, 0.1 by default is consistant with
transformers.BertForSequenceClassification
@param clinical (boolean): True if Bio_Clinical BERT desired, False otherwise. Ignored if
pretrain_path is not None
@param freeze_embeddings (boolean): true to freeze bert embeddings during training
@param pretrain_path (string): path to load checkpoint from
"""
super(bert_labeler, self).__init__()
if pretrain_path is not None:
self.bert = BertModel.from_pretrained(pretrain_path)
elif clinical:
self.bert = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
else:
self.bert = BertModel.from_pretrained('bert-base-uncased')
if freeze_embeddings:
for param in self.bert.embeddings.parameters():
param.requires_grad = False
self.dropout = nn.Dropout(p)
#size of the output of transformer's last layer
hidden_size = self.bert.pooler.dense.in_features
#classes: present, absent, unknown, blank for 12 conditions + support devices
self.linear_heads = nn.ModuleList([nn.Linear(hidden_size, 4, bias=True) for _ in range(13)])
#classes: yes, no for the 'no finding' observation
self.linear_heads.append(nn.Linear(hidden_size, 2, bias=True))
def forward(self, source_padded, attention_mask):
""" Forward pass of the labeler
@param source_padded (torch.LongTensor): Tensor of word indices with padding, shape (batch_size, max_len)
@param attention_mask (torch.Tensor): Mask to avoid attention on padding tokens, shape (batch_size, max_len)
@returns out (List[torch.Tensor])): A list of size 14 containing tensors. The first 13 have shape
(batch_size, 4) and the last has shape (batch_size, 2)
"""
#shape (batch_size, max_len, hidden_size)
final_hidden = self.bert(source_padded, attention_mask=attention_mask)[0]
#shape (batch_size, hidden_size)
cls_hidden = final_hidden[:, 0, :].squeeze(dim=1)
cls_hidden = self.dropout(cls_hidden)
out = []
for i in range(14):
out.append(self.linear_heads[i](cls_hidden))
return out
|