import os import torch import torch.nn as nn import random import gradio as gr import nltk from nltk.tokenize import word_tokenize from transformers import AutoTokenizer, AutoModelForTokenClassification from huggingface_hub import hf_hub_download # Set seed for reproducibility random.seed(42) torch.manual_seed(42) # CRF Layer implementation class CRFLayer(nn.Module): def __init__(self, num_tags): super(CRFLayer, self).__init__() self.num_tags = num_tags self.transitions = nn.Parameter(torch.randn(num_tags, num_tags)) self.start_transitions = nn.Parameter(torch.randn(num_tags)) self.end_transitions = nn.Parameter(torch.randn(num_tags)) def forward(self, emissions): return self.viterbi_decode(emissions) def compute_log_likelihood(self, emissions, tags): # emissions: (seq_len, num_tags) seq_len = emissions.shape[0] # Score for the given tag sequence score = self.start_transitions[tags[0]] + emissions[0, tags[0]] for i in range(1, seq_len): score += self.transitions[tags[i - 1], tags[i]] + emissions[i, tags[i]] score += self.end_transitions[tags[-1]] # Compute partition function using log-sum-exp alphas = self.start_transitions + emissions[0] for i in range(1, seq_len): emission = emissions[i].unsqueeze(0) # (1, num_tags) alpha_exp = alphas.unsqueeze(1) + self.transitions # (num_tags, num_tags) alphas = torch.logsumexp(alpha_exp, dim=0) + emission.squeeze() Z = torch.logsumexp(alphas + self.end_transitions, dim=0) return score - Z def viterbi_decode(self, emissions): seq_len = emissions.shape[0] backpointers = [] viterbi_vars = self.start_transitions + emissions[0] for i in range(1, seq_len): broadcast_score = viterbi_vars.unsqueeze(1) + self.transitions best_score, best_tag = torch.max(broadcast_score, dim=0) viterbi_vars = best_score + emissions[i] backpointers.append(best_tag) best_score = viterbi_vars + self.end_transitions best_tag = torch.argmax(best_score).item() # Backtrace best_path = [best_tag] for bptrs in reversed(backpointers): best_tag = bptrs[best_tag].item() best_path.insert(0, best_tag) return best_path # --- Checkpoints --- banglabert_checkpoint = "Swaraj66/BNER_Finetuned_BanglaBERT" rembert_checkpoint = "Swaraj66/BNER_Finetuned_RemBERT" crf_assets_checkpoint = "Swaraj66/BNER_CRF_Layer" # --- Load BanglaBERT --- banglabert_tokenizer = AutoTokenizer.from_pretrained( banglabert_checkpoint, use_fast=True ) banglabert_model = AutoModelForTokenClassification.from_pretrained( banglabert_checkpoint ) # --- Load RemBERT --- rembert_tokenizer = AutoTokenizer.from_pretrained( rembert_checkpoint ) rembert_model = AutoModelForTokenClassification.from_pretrained( rembert_checkpoint ) # --- Download CRF model weights from private repo --- model_path = hf_hub_download( repo_id="Swaraj66/BNER_CRF_Layer", filename="crf_model.pt" # <- must match the filename in repo ) # --- Load CRF model with weights --- CRFmodel = CRFLayer(num_tags=9) CRFmodel.load_state_dict(torch.load(model_path, map_location="cpu")) CRFmodel.eval() print("✅ CRF model loaded from Hugging Face private repo") def get_word_logits(model, tokenizer, tokens): encodings = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True) word_ids = encodings.word_ids() with torch.no_grad(): logits = model(**encodings).logits selected_logits = [] seen = set() for idx, word_idx in enumerate(word_ids): if word_idx is None: continue if word_idx not in seen: selected_logits.append(logits[0, idx]) seen.add(word_idx) return torch.stack(selected_logits) # (num_words, num_labels) def ensemble_predict(tokens,rembert_model,rembert_tokenizer,Current_banglabert_model,Current_banglabert_tokenizer,CRFmodel): rembert_logits = get_word_logits(rembert_model, rembert_tokenizer, tokens) banglabert_logits = get_word_logits(Current_banglabert_model, Current_banglabert_tokenizer, tokens) min_len = min(rembert_logits.shape[0], banglabert_logits.shape[0]) rembert_logits = rembert_logits[:min_len] banglabert_logits = banglabert_logits[:min_len] ensemble_logits = rembert_logits + banglabert_logits test_logits = [ensemble_logits] # Test on a new emission (logits) sequence with torch.no_grad(): for logits in test_logits: # test_logits = list of tensors en_crf_predicted_sequence = CRFmodel(logits) preds = torch.argmax(ensemble_logits, dim=-1) just_ensembled=preds.tolist() return en_crf_predicted_sequence model_checkpoint_Base="csebuetnlp/banglabert" banglabert_tokenizer_base = AutoTokenizer.from_pretrained( model_checkpoint_Base, use_fast=True ) id2label = { 0: "O", 1: "B-PER", 2: "I-PER", 3: "B-ORG", 4: "I-ORG", 5: "B-LOC", 6: "I-LOC", 7: "B-MISC", 8: "I-MISC", "0": "O", "1": "B-PER", "2": "I-PER", "3": "B-ORG", "4": "I-ORG", "5": "B-LOC", "6": "I-LOC", "7": "B-MISC", "8": "I-MISC" } # Make sure to download punkt if you haven't already nltk.download('punkt') nltk.download('punkt_tab') def ner_function(user_input): words = word_tokenize(user_input) print("words -> ",words) preds = ensemble_predict(words,rembert_model,rembert_tokenizer,banglabert_model,banglabert_tokenizer_base,CRFmodel) pred_labels_list = [id2label[str(label)] for label in preds] # Convert to str for safety print("Labels----->",pred_labels_list) labeled_words = list(zip(words, pred_labels_list)) entities = [] current_entity = "" current_label = None for word, label in labeled_words: if label.startswith("B-"): if current_entity and current_label: entities.append((current_entity.strip(), current_label)) current_entity = word current_label = label[2:] elif label.startswith("I-") and current_label == label[2:]: current_entity += " " + word else: if current_entity and current_label: entities.append((current_entity.strip(), current_label)) current_entity = "" current_label = None if current_entity and current_label: entities.append((current_entity.strip(), current_label)) return entities # Gradio app def build_ui(): with gr.Blocks() as demo: gr.Markdown("# Named Entity Recognition App Using Transformer Ensembles with CRF (RemBERT and Banglabert)\nEnter a sentence to detect named entities.") with gr.Row(): input_text = gr.Textbox(label="Enter a sentence", placeholder="Type your text here...") with gr.Row(): submit_btn = gr.Button("Analyze Entities") with gr.Row(): output_json = gr.JSON(label="Named Entities") submit_btn.click(fn=ner_function, inputs=input_text, outputs=output_json) return demo # Create the app app = build_ui() # For local running (comment this out when deploying if you want) if __name__ == "__main__": app.launch()