| | import torch |
| |
|
| | from transformers import BertForSequenceClassification, BertTokenizer, DataCollatorForTokenClassification |
| | import numpy as np |
| | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
|
| | class BERT(): |
| | def __init__(self): |
| | self.num_classes = 13 |
| | self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") |
| |
|
| |
|
| |
|
| | |
| | self.model = BertForSequenceClassification.from_pretrained("bert-base-uncased").to(self.device) |
| | self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
| |
|
| | |
| | self.model.classifier = torch.nn.Linear(in_features = 768, out_features= self.num_classes) |
| | self.data_collator = DataCollatorForTokenClassification(self.tokenizer) |
| | |
| | def getModel(self): |
| | return self.model |
| | |
| | |
| | def get_tokenizer(self): |
| | return self.tokenizer |
| | |
| | def tokenize(self, txt): |
| | return self.tokenizer(txt, return_tensors='pt') |
| | |
| |
|
| | |
| |
|
| |
|