Spaces:
Runtime error
Runtime error
File size: 2,115 Bytes
6f2ff70 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 | import torch
import torch.nn as nn
from .bert import BERT
class BERTLog(nn.Module):
"""
BERT Log Model
"""
def __init__(self, bert: BERT, vocab_size):
"""
:param bert: BERT model which should be trained
:param vocab_size: total vocab size for masked_lm
"""
super().__init__()
self.bert = bert
self.mask_lm = MaskedLogModel(self.bert.hidden, vocab_size)
self.time_lm = TimeLogModel(self.bert.hidden)
# self.fnn_cls = LinearCLS(self.bert.hidden)
# self.cls_lm = LogClassifier(self.bert.hidden)
def forward(self, x, time_info):
x = self.bert(x, time_info=time_info) # [batch, seq_len, hidden]
cls_output = x[:, 0] # [CLS] token vector from BERT
return {
"logkey_output": self.mask_lm(x), # [batch, seq_len, vocab_size]
"time_output": self.time_lm(x), # optional
"cls_output": cls_output, # [batch, hidden]
"cls_fnn_output": None, # unused for now
"token_embeddings": x[0] # [seq_len, hidden] for first batch element
}
class MaskedLogModel(nn.Module):
"""
Predicting original token from masked input sequence
"""
def __init__(self, hidden, vocab_size):
super().__init__()
self.linear = nn.Linear(hidden, vocab_size)
self.softmax = nn.LogSoftmax(dim=-1)
def forward(self, x):
return self.softmax(self.linear(x))
class TimeLogModel(nn.Module):
def __init__(self, hidden, time_size=1):
super().__init__()
self.linear = nn.Linear(hidden, time_size)
def forward(self, x):
return self.linear(x)
class LogClassifier(nn.Module):
def __init__(self, hidden):
super().__init__()
self.linear = nn.Linear(hidden, hidden)
def forward(self, cls):
return self.linear(cls)
class LinearCLS(nn.Module):
def __init__(self, hidden):
super().__init__()
self.linear = nn.Linear(hidden, hidden)
def forward(self, x):
return self.linear(x)
|