# NER.py from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score import torch from transformers import AutoTokenizer, AutoModelForTokenClassification import gradio as gr import nltk from nltk.tokenize import word_tokenize # Download necessary NLTK data nltk.download('punkt') nltk.download('punkt_tab') # Load the two models model_id = "Swaraj66/Banglabert-finetuned-ner" model_id2 = "Swaraj66/Finetuned_RemBERT" banglabert_tokenizer = AutoTokenizer.from_pretrained(model_id) banglabert_model = AutoModelForTokenClassification.from_pretrained(model_id) rembert_tokenizer = AutoTokenizer.from_pretrained(model_id2) rembert_model = AutoModelForTokenClassification.from_pretrained(model_id2) # Helper functions def get_word_logits(model, tokenizer, tokens): encodings = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True) word_ids = encodings.word_ids() with torch.no_grad(): logits = model(**encodings).logits selected_logits = [] seen = set() for idx, word_idx in enumerate(word_ids): if word_idx is None: continue if word_idx not in seen: selected_logits.append(logits[0, idx]) seen.add(word_idx) return torch.stack(selected_logits) def ensemble_predict(tokens): rembert_logits = get_word_logits(rembert_model, rembert_tokenizer, tokens) banglabert_logits = get_word_logits(banglabert_model, banglabert_tokenizer, tokens) min_len = min(rembert_logits.shape[0], banglabert_logits.shape[0]) rembert_logits = rembert_logits[:min_len] banglabert_logits = banglabert_logits[:min_len] ensemble_logits = rembert_logits + banglabert_logits preds = torch.argmax(ensemble_logits, dim=-1) return preds.tolist() # Label mapping id2label = { 0: "O", 1: "B-PER", 2: "I-PER", 3: "B-ORG", 4: "I-ORG", 5: "B-LOC", 6: "I-LOC", 7: "B-MISC", 8: "I-MISC", "0": "O", "1": "B-PER", "2": "I-PER", "3": "B-ORG", "4": "I-ORG", "5": "B-LOC", "6": "I-LOC", "7": "B-MISC", "8": "I-MISC" } # Main NER function def ner_function(user_input): words = word_tokenize(user_input) preds = ensemble_predict(words) pred_labels_list = [id2label[str(label)] for label in preds] labeled_words = list(zip(words, pred_labels_list)) entities = [] current_entity = "" current_label = None for word, label in labeled_words: if label.startswith("B-"): if current_entity and current_label: entities.append((current_entity.strip(), current_label)) current_entity = word current_label = label[2:] elif label.startswith("I-") and current_label == label[2:]: current_entity += " " + word else: if current_entity and current_label: entities.append((current_entity.strip(), current_label)) current_entity = "" current_label = None if current_entity and current_label: entities.append((current_entity.strip(), current_label)) return entities # Gradio UI def build_ui(): with gr.Blocks() as demo: gr.Markdown("# Named Entity Recognition App Using Ensemble Model (RemBERT + BanglaBERT)\nEnter a sentence to detect named entities.") with gr.Row(): input_text = gr.Textbox(label="Enter a sentence", placeholder="Type your text here...") with gr.Row(): submit_btn = gr.Button("Analyze Entities") with gr.Row(): output_json = gr.JSON(label="Named Entities") submit_btn.click(fn=ner_function, inputs=input_text, outputs=output_json) return demo # Launch the app app = build_ui() if __name__ == "__main__": app.launch()