Commit
·
be995d4
1
Parent(s):
21792c3
Create mian.py
Browse files
mian.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
!pip install -q transformers
|
| 2 |
+
|
| 3 |
+
from transformers import RemBertForTokenClassification, RemBertTokenizerFast
|
| 4 |
+
from transformers import XLMRobertaForTokenClassification, XLMRobertaTokenizerFast
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
main_path = "Misha24-10/MultiCoNER-2-recognition-model"
|
| 10 |
+
|
| 11 |
+
model_1 = XLMRobertaForTokenClassification.from_pretrained(main_path,
|
| 12 |
+
subfolder = "xlm_roberta_large_mountain")
|
| 13 |
+
tokenizer_1 = XLMRobertaTokenizerFast.from_pretrained(main_path,
|
| 14 |
+
subfolder = "xlm_roberta_large_mountain")
|
| 15 |
+
|
| 16 |
+
model_2 = RemBertForTokenClassification.from_pretrained(main_path,
|
| 17 |
+
subfolder = "google-rembert-ft_for_multi_ner_v3")
|
| 18 |
+
tokenizer_2 = RemBertTokenizerFast.from_pretrained(main_path,
|
| 19 |
+
subfolder = "google-rembert-ft_for_multi_ner_v3")
|
| 20 |
+
|
| 21 |
+
model_3 = RemBertForTokenClassification.from_pretrained(main_path,
|
| 22 |
+
subfolder = "google-rembert-ft_for_multi_ner_sky")
|
| 23 |
+
tokenizer_3 = RemBertTokenizerFast.from_pretrained(main_path,
|
| 24 |
+
subfolder = "google-rembert-ft_for_multi_ner_sky")
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
def compute_last_leyer_probs(model, tokenizer, sentence):
|
| 29 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 30 |
+
number_of_tokens = tokenizer.encode_plus(sentence, return_tensors='pt',)['input_ids'].shape[-1]
|
| 31 |
+
list_of_words = sentence.split()
|
| 32 |
+
|
| 33 |
+
inputs = tokenizer(list_of_words, is_split_into_words=True, padding='max_length', max_length = min(number_of_tokens,512), truncation=True, return_tensors="pt")
|
| 34 |
+
input_ids = inputs['input_ids'].to(device)
|
| 35 |
+
attention_mask = inputs['attention_mask'].to(device)
|
| 36 |
+
label_ids = torch.Tensor(align_word_ids(inputs.word_ids()))
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
| 39 |
+
logits = outputs.logits
|
| 40 |
+
return (logits[:, (label_ids == 1), :])
|
| 41 |
+
|
| 42 |
+
weights = {'model_1': 1, 'model_2': 1, 'model_3': 1}
|
| 43 |
+
|
| 44 |
+
def align_word_ids(word_ids, return_word_ids=False):
|
| 45 |
+
previous_word_idx = None
|
| 46 |
+
label_ids = []
|
| 47 |
+
index_list = []
|
| 48 |
+
for idx, word_idx in enumerate(word_ids):
|
| 49 |
+
|
| 50 |
+
if word_idx is None:
|
| 51 |
+
label_ids.append(-100)
|
| 52 |
+
|
| 53 |
+
elif word_idx != previous_word_idx:
|
| 54 |
+
try:
|
| 55 |
+
label_ids.append(1)
|
| 56 |
+
index_list.append(idx)
|
| 57 |
+
except:
|
| 58 |
+
label_ids.append(-100)
|
| 59 |
+
else:
|
| 60 |
+
try:
|
| 61 |
+
label_ids.append(1 if label_all_tokens else -100)
|
| 62 |
+
except:
|
| 63 |
+
label_ids.append(-100)
|
| 64 |
+
previous_word_idx = word_idx
|
| 65 |
+
|
| 66 |
+
if return_word_ids:
|
| 67 |
+
return label_ids, index_list
|
| 68 |
+
else:
|
| 69 |
+
return label_ids
|
| 70 |
+
|
| 71 |
+
def weighted_voting(sentence):
|
| 72 |
+
predictions = []
|
| 73 |
+
for idx, (model, tokenizer) in enumerate([(model_1, tokenizer_1), (model_2, tokenizer_2), (model_3, tokenizer_3)]):
|
| 74 |
+
logits = compute_last_leyer_probs(model, tokenizer, sentence)
|
| 75 |
+
predictions.append(logits * weights[f'model_{idx+1}'])
|
| 76 |
+
final_logits = sum(predictions)
|
| 77 |
+
final_predictions = torch.argmax(final_logits, dim=2)
|
| 78 |
+
labels = [model_1.config.id2label[i] for i in final_predictions.tolist()[0]]
|
| 79 |
+
return labels
|
| 80 |
+
|
| 81 |
+
sent_ex = "Elon Musk 's brother sits on the boards of tesla".lower()
|
| 82 |
+
|
| 83 |
+
weighted_voting(sent_ex)
|