MukeshKapoor25's picture
changs
6f2ff70
import torch
import torch.nn as nn
from .bert import BERT
class BERTLog(nn.Module):
"""
BERT Log Model
"""
def __init__(self, bert: BERT, vocab_size):
"""
:param bert: BERT model which should be trained
:param vocab_size: total vocab size for masked_lm
"""
super().__init__()
self.bert = bert
self.mask_lm = MaskedLogModel(self.bert.hidden, vocab_size)
self.time_lm = TimeLogModel(self.bert.hidden)
# self.fnn_cls = LinearCLS(self.bert.hidden)
# self.cls_lm = LogClassifier(self.bert.hidden)
def forward(self, x, time_info):
x = self.bert(x, time_info=time_info) # [batch, seq_len, hidden]
cls_output = x[:, 0] # [CLS] token vector from BERT
return {
"logkey_output": self.mask_lm(x), # [batch, seq_len, vocab_size]
"time_output": self.time_lm(x), # optional
"cls_output": cls_output, # [batch, hidden]
"cls_fnn_output": None, # unused for now
"token_embeddings": x[0] # [seq_len, hidden] for first batch element
}
class MaskedLogModel(nn.Module):
"""
Predicting original token from masked input sequence
"""
def __init__(self, hidden, vocab_size):
super().__init__()
self.linear = nn.Linear(hidden, vocab_size)
self.softmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
return self.softmax(self.linear(x))
class TimeLogModel(nn.Module):
def __init__(self, hidden, time_size=1):
super().__init__()
self.linear = nn.Linear(hidden, time_size)
def forward(self, x):
return self.linear(x)
class LogClassifier(nn.Module):
def __init__(self, hidden):
super().__init__()
self.linear = nn.Linear(hidden, hidden)
def forward(self, cls):
return self.linear(cls)
class LinearCLS(nn.Module):
def __init__(self, hidden):
super().__init__()
self.linear = nn.Linear(hidden, hidden)
def forward(self, x):
return self.linear(x)