Spaces:
Running
Running
| from .base import BaseModel | |
| from .bert_modules.bert import BERT | |
| import torch.nn as nn | |
| class BERTModel(BaseModel): | |
| def __init__(self, args): | |
| super().__init__(args) | |
| self.bert = BERT(args) | |
| self.out = nn.Linear(self.bert.hidden, args.num_items + 1) | |
| def code(cls): | |
| return 'bert' | |
| def forward(self, x): | |
| x = self.bert(x) | |
| return self.out(x) | |