Spaces:
Sleeping
Sleeping
| # NER.py | |
| from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| import gradio as gr | |
| import nltk | |
| from nltk.tokenize import word_tokenize | |
| # Download necessary NLTK data | |
| nltk.download('punkt') | |
| nltk.download('punkt_tab') | |
| # Load the two models | |
| model_id = "Swaraj66/Banglabert-finetuned-ner" | |
| model_id2 = "Swaraj66/Finetuned_RemBERT" | |
| banglabert_tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| banglabert_model = AutoModelForTokenClassification.from_pretrained(model_id) | |
| rembert_tokenizer = AutoTokenizer.from_pretrained(model_id2) | |
| rembert_model = AutoModelForTokenClassification.from_pretrained(model_id2) | |
| # Helper functions | |
| def get_word_logits(model, tokenizer, tokens): | |
| encodings = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True) | |
| word_ids = encodings.word_ids() | |
| with torch.no_grad(): | |
| logits = model(**encodings).logits | |
| selected_logits = [] | |
| seen = set() | |
| for idx, word_idx in enumerate(word_ids): | |
| if word_idx is None: | |
| continue | |
| if word_idx not in seen: | |
| selected_logits.append(logits[0, idx]) | |
| seen.add(word_idx) | |
| return torch.stack(selected_logits) | |
| def ensemble_predict(tokens): | |
| rembert_logits = get_word_logits(rembert_model, rembert_tokenizer, tokens) | |
| banglabert_logits = get_word_logits(banglabert_model, banglabert_tokenizer, tokens) | |
| min_len = min(rembert_logits.shape[0], banglabert_logits.shape[0]) | |
| rembert_logits = rembert_logits[:min_len] | |
| banglabert_logits = banglabert_logits[:min_len] | |
| ensemble_logits = rembert_logits + banglabert_logits | |
| preds = torch.argmax(ensemble_logits, dim=-1) | |
| return preds.tolist() | |
| # Label mapping | |
| id2label = { | |
| 0: "O", 1: "B-PER", 2: "I-PER", 3: "B-ORG", 4: "I-ORG", 5: "B-LOC", 6: "I-LOC", 7: "B-MISC", 8: "I-MISC", | |
| "0": "O", "1": "B-PER", "2": "I-PER", "3": "B-ORG", "4": "I-ORG", "5": "B-LOC", "6": "I-LOC", "7": "B-MISC", "8": "I-MISC" | |
| } | |
| # Main NER function | |
| def ner_function(user_input): | |
| words = word_tokenize(user_input) | |
| preds = ensemble_predict(words) | |
| pred_labels_list = [id2label[str(label)] for label in preds] | |
| labeled_words = list(zip(words, pred_labels_list)) | |
| entities = [] | |
| current_entity = "" | |
| current_label = None | |
| for word, label in labeled_words: | |
| if label.startswith("B-"): | |
| if current_entity and current_label: | |
| entities.append((current_entity.strip(), current_label)) | |
| current_entity = word | |
| current_label = label[2:] | |
| elif label.startswith("I-") and current_label == label[2:]: | |
| current_entity += " " + word | |
| else: | |
| if current_entity and current_label: | |
| entities.append((current_entity.strip(), current_label)) | |
| current_entity = "" | |
| current_label = None | |
| if current_entity and current_label: | |
| entities.append((current_entity.strip(), current_label)) | |
| return entities | |
| # Gradio UI | |
| def build_ui(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Named Entity Recognition App Using Ensemble Model (RemBERT + BanglaBERT)\nEnter a sentence to detect named entities.") | |
| with gr.Row(): | |
| input_text = gr.Textbox(label="Enter a sentence", placeholder="Type your text here...") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Analyze Entities") | |
| with gr.Row(): | |
| output_json = gr.JSON(label="Named Entities") | |
| submit_btn.click(fn=ner_function, inputs=input_text, outputs=output_json) | |
| return demo | |
| # Launch the app | |
| app = build_ui() | |
| if __name__ == "__main__": | |
| app.launch() | |