Spaces:
Sleeping
Sleeping
File size: 2,982 Bytes
9e9178e df02cd1 9e9178e df02cd1 9e9178e df02cd1 9e9178e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
from transformers import AutoModelForTokenClassification, AutoTokenizer
from config import NER_MODEL
import torch
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_auth_token=True)
model = AutoModelForTokenClassification.from_pretrained(NER_MODEL, use_auth_token=True).to(device)
id_to_label = {
0: 'O',
1: 'B-COURT',
2: 'B-DATE',
3: 'B-DECISION',
4: 'B-LAW',
5: 'B-MONEY',
6: 'B-OFFICIAL GAZZETE',
7: 'B-PERSON',
8: 'B-REFERENCE',
9: 'I-COURT',
10: 'I-LAW',
11: 'I-MONEY',
12: 'I-OFFICIAL GAZZETE',
13: 'I-PERSON',
14: 'I-REFERENCE'
}
def perform_ner(text):
try:
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predictions = torch.argmax(logits, dim=2).squeeze().tolist()
except RuntimeError as e:
if "CUDA out of memory" in str(e):
print("Switching to CPU due to memory constraints.")
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model.cpu()(**inputs) # Run model on CPU
logits = outputs.logits
predictions = torch.argmax(logits, dim=2).squeeze().tolist()
else:
raise e
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
labels = [id_to_label[pred] for pred in predictions]
results = [
(token, label)
for token, label in zip(tokens, labels)
if token not in tokenizer.all_special_tokens
]
return results
text = ""
def merge_entities(token_label_pairs):
merged_words, merged_labels = [], []
current_word, current_label = "", None
for token, label in token_label_pairs:
if token.startswith("##"):
current_word += token[2:]
else:
if current_word:
merged_words.append(current_word)
merged_labels.append(current_label)
current_word, current_label = token, label
if current_word:
merged_words.append(current_word)
merged_labels.append(current_label)
final_words, final_labels = [], []
for i, (word, label) in enumerate(zip(merged_words, merged_labels)):
if final_labels and (
label == final_labels[-1] or
(label.startswith("I-") and final_labels[-1].endswith(label[2:])) or
(label.startswith("B-") and final_labels[-1].endswith(label[2:]))
):
final_words[-1] += " " + word
else:
final_words.append(word)
final_labels.append(label)
return final_words, final_labels
results = perform_ner(text)
words,labels = merge_entities(results)
for i,b in zip(words,labels):
print(i + " ### " + b)
|