File size: 586 Bytes
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from torch import nn
from transformers import BertModel
import logging

logger = logging.getLogger(__name__)


class BaseModel(nn.Module):
    def __init__(self,
                 bert_model="aubmindlab/bert-base-arabertv2",
                 num_labels=2,
                 dropout=0.1,
                 num_types=0):
        super().__init__()

        self.bert_model = bert_model
        self.num_labels = num_labels
        self.num_types = num_types
        self.dropout = dropout

        self.bert = BertModel.from_pretrained(bert_model)
        self.dropout = nn.Dropout(dropout)