from transformers import AutoModelForTokenClassification, AutoTokenizer from config import NER_MODEL import torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(NER_MODEL, use_auth_token=True) model = AutoModelForTokenClassification.from_pretrained(NER_MODEL, use_auth_token=True).to(device) id_to_label = { 0: 'O', 1: 'B-COURT', 2: 'B-DATE', 3: 'B-DECISION', 4: 'B-LAW', 5: 'B-MONEY', 6: 'B-OFFICIAL GAZZETE', 7: 'B-PERSON', 8: 'B-REFERENCE', 9: 'I-COURT', 10: 'I-LAW', 11: 'I-MONEY', 12: 'I-OFFICIAL GAZZETE', 13: 'I-PERSON', 14: 'I-REFERENCE' } def perform_ner(text): try: inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predictions = torch.argmax(logits, dim=2).squeeze().tolist() except RuntimeError as e: if "CUDA out of memory" in str(e): print("Switching to CPU due to memory constraints.") inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model.cpu()(**inputs) # Run model on CPU logits = outputs.logits predictions = torch.argmax(logits, dim=2).squeeze().tolist() else: raise e tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze()) labels = [id_to_label[pred] for pred in predictions] results = [ (token, label) for token, label in zip(tokens, labels) if token not in tokenizer.all_special_tokens ] return results text = "" def merge_entities(token_label_pairs): merged_words, merged_labels = [], [] current_word, current_label = "", None for token, label in token_label_pairs: if token.startswith("##"): current_word += token[2:] else: if current_word: merged_words.append(current_word) merged_labels.append(current_label) current_word, current_label = token, label if current_word: merged_words.append(current_word) merged_labels.append(current_label) final_words, final_labels = [], [] for i, (word, label) in enumerate(zip(merged_words, merged_labels)): if final_labels and ( label == final_labels[-1] or (label.startswith("I-") and final_labels[-1].endswith(label[2:])) or (label.startswith("B-") and final_labels[-1].endswith(label[2:])) ): final_words[-1] += " " + word else: final_words.append(word) final_labels.append(label) return final_words, final_labels results = perform_ner(text) words,labels = merge_entities(results) for i,b in zip(words,labels): print(i + " ### " + b)