|
|
import os
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import random
|
|
|
import gradio as gr
|
|
|
import nltk
|
|
|
from nltk.tokenize import word_tokenize
|
|
|
from transformers import AutoTokenizer, AutoModelForTokenClassification
|
|
|
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
random.seed(42)
|
|
|
torch.manual_seed(42)
|
|
|
|
|
|
|
|
|
class CRFLayer(nn.Module):
|
|
|
def __init__(self, num_tags):
|
|
|
super(CRFLayer, self).__init__()
|
|
|
self.num_tags = num_tags
|
|
|
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
|
|
|
self.start_transitions = nn.Parameter(torch.randn(num_tags))
|
|
|
self.end_transitions = nn.Parameter(torch.randn(num_tags))
|
|
|
|
|
|
def forward(self, emissions):
|
|
|
return self.viterbi_decode(emissions)
|
|
|
|
|
|
def compute_log_likelihood(self, emissions, tags):
|
|
|
|
|
|
seq_len = emissions.shape[0]
|
|
|
|
|
|
|
|
|
score = self.start_transitions[tags[0]] + emissions[0, tags[0]]
|
|
|
for i in range(1, seq_len):
|
|
|
score += self.transitions[tags[i - 1], tags[i]] + emissions[i, tags[i]]
|
|
|
score += self.end_transitions[tags[-1]]
|
|
|
|
|
|
|
|
|
alphas = self.start_transitions + emissions[0]
|
|
|
for i in range(1, seq_len):
|
|
|
emission = emissions[i].unsqueeze(0)
|
|
|
alpha_exp = alphas.unsqueeze(1) + self.transitions
|
|
|
alphas = torch.logsumexp(alpha_exp, dim=0) + emission.squeeze()
|
|
|
Z = torch.logsumexp(alphas + self.end_transitions, dim=0)
|
|
|
return score - Z
|
|
|
|
|
|
def viterbi_decode(self, emissions):
|
|
|
seq_len = emissions.shape[0]
|
|
|
backpointers = []
|
|
|
|
|
|
viterbi_vars = self.start_transitions + emissions[0]
|
|
|
for i in range(1, seq_len):
|
|
|
broadcast_score = viterbi_vars.unsqueeze(1) + self.transitions
|
|
|
best_score, best_tag = torch.max(broadcast_score, dim=0)
|
|
|
viterbi_vars = best_score + emissions[i]
|
|
|
backpointers.append(best_tag)
|
|
|
|
|
|
best_score = viterbi_vars + self.end_transitions
|
|
|
best_tag = torch.argmax(best_score).item()
|
|
|
|
|
|
|
|
|
best_path = [best_tag]
|
|
|
for bptrs in reversed(backpointers):
|
|
|
best_tag = bptrs[best_tag].item()
|
|
|
best_path.insert(0, best_tag)
|
|
|
return best_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
banglabert_checkpoint = "Swaraj66/BNER_Finetuned_BanglaBERT"
|
|
|
rembert_checkpoint = "Swaraj66/BNER_Finetuned_RemBERT"
|
|
|
crf_assets_checkpoint = "Swaraj66/BNER_CRF_Layer"
|
|
|
|
|
|
|
|
|
banglabert_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
banglabert_checkpoint, use_fast=True
|
|
|
)
|
|
|
banglabert_model = AutoModelForTokenClassification.from_pretrained(
|
|
|
banglabert_checkpoint
|
|
|
)
|
|
|
|
|
|
|
|
|
rembert_tokenizer = AutoTokenizer.from_pretrained(
|
|
|
rembert_checkpoint
|
|
|
)
|
|
|
rembert_model = AutoModelForTokenClassification.from_pretrained(
|
|
|
rembert_checkpoint
|
|
|
)
|
|
|
|
|
|
|
|
|
model_path = hf_hub_download(
|
|
|
repo_id="Swaraj66/BNER_CRF_Layer",
|
|
|
filename="crf_model.pt"
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
CRFmodel = CRFLayer(num_tags=9)
|
|
|
CRFmodel.load_state_dict(torch.load(model_path, map_location="cpu"))
|
|
|
CRFmodel.eval()
|
|
|
|
|
|
print("✅ CRF model loaded from Hugging Face private repo")
|
|
|
|
|
|
def get_word_logits(model, tokenizer, tokens):
|
|
|
encodings = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
|
|
|
word_ids = encodings.word_ids()
|
|
|
|
|
|
with torch.no_grad():
|
|
|
logits = model(**encodings).logits
|
|
|
|
|
|
selected_logits = []
|
|
|
seen = set()
|
|
|
for idx, word_idx in enumerate(word_ids):
|
|
|
if word_idx is None:
|
|
|
continue
|
|
|
if word_idx not in seen:
|
|
|
selected_logits.append(logits[0, idx])
|
|
|
seen.add(word_idx)
|
|
|
|
|
|
return torch.stack(selected_logits)
|
|
|
|
|
|
def ensemble_predict(tokens,rembert_model,rembert_tokenizer,Current_banglabert_model,Current_banglabert_tokenizer,CRFmodel):
|
|
|
|
|
|
rembert_logits = get_word_logits(rembert_model, rembert_tokenizer, tokens)
|
|
|
banglabert_logits = get_word_logits(Current_banglabert_model, Current_banglabert_tokenizer, tokens)
|
|
|
|
|
|
min_len = min(rembert_logits.shape[0], banglabert_logits.shape[0])
|
|
|
rembert_logits = rembert_logits[:min_len]
|
|
|
banglabert_logits = banglabert_logits[:min_len]
|
|
|
|
|
|
ensemble_logits = rembert_logits + banglabert_logits
|
|
|
test_logits = [ensemble_logits]
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for logits in test_logits:
|
|
|
en_crf_predicted_sequence = CRFmodel(logits)
|
|
|
|
|
|
|
|
|
|
|
|
preds = torch.argmax(ensemble_logits, dim=-1)
|
|
|
just_ensembled=preds.tolist()
|
|
|
|
|
|
|
|
|
return en_crf_predicted_sequence
|
|
|
|
|
|
model_checkpoint_Base="csebuetnlp/banglabert"
|
|
|
banglabert_tokenizer_base = AutoTokenizer.from_pretrained(
|
|
|
model_checkpoint_Base, use_fast=True
|
|
|
)
|
|
|
|
|
|
id2label = {
|
|
|
0: "O",
|
|
|
1: "B-PER",
|
|
|
2: "I-PER",
|
|
|
3: "B-ORG",
|
|
|
4: "I-ORG",
|
|
|
5: "B-LOC",
|
|
|
6: "I-LOC",
|
|
|
7: "B-MISC",
|
|
|
8: "I-MISC",
|
|
|
"0": "O",
|
|
|
"1": "B-PER",
|
|
|
"2": "I-PER",
|
|
|
"3": "B-ORG",
|
|
|
"4": "I-ORG",
|
|
|
"5": "B-LOC",
|
|
|
"6": "I-LOC",
|
|
|
"7": "B-MISC",
|
|
|
"8": "I-MISC"
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
nltk.download('punkt')
|
|
|
nltk.download('punkt_tab')
|
|
|
|
|
|
|
|
|
def ner_function(user_input):
|
|
|
words = word_tokenize(user_input)
|
|
|
print("words -> ",words)
|
|
|
preds = ensemble_predict(words,rembert_model,rembert_tokenizer,banglabert_model,banglabert_tokenizer_base,CRFmodel)
|
|
|
pred_labels_list = [id2label[str(label)] for label in preds]
|
|
|
|
|
|
print("Labels----->",pred_labels_list)
|
|
|
|
|
|
labeled_words = list(zip(words, pred_labels_list))
|
|
|
|
|
|
entities = []
|
|
|
current_entity = ""
|
|
|
current_label = None
|
|
|
|
|
|
for word, label in labeled_words:
|
|
|
if label.startswith("B-"):
|
|
|
if current_entity and current_label:
|
|
|
entities.append((current_entity.strip(), current_label))
|
|
|
current_entity = word
|
|
|
current_label = label[2:]
|
|
|
elif label.startswith("I-") and current_label == label[2:]:
|
|
|
current_entity += " " + word
|
|
|
else:
|
|
|
if current_entity and current_label:
|
|
|
entities.append((current_entity.strip(), current_label))
|
|
|
current_entity = ""
|
|
|
current_label = None
|
|
|
|
|
|
if current_entity and current_label:
|
|
|
entities.append((current_entity.strip(), current_label))
|
|
|
|
|
|
return entities
|
|
|
|
|
|
|
|
|
def build_ui():
|
|
|
with gr.Blocks() as demo:
|
|
|
gr.Markdown("# Named Entity Recognition App Using Transformer Ensembles with CRF (RemBERT and Banglabert)\nEnter a sentence to detect named entities.")
|
|
|
with gr.Row():
|
|
|
input_text = gr.Textbox(label="Enter a sentence", placeholder="Type your text here...")
|
|
|
with gr.Row():
|
|
|
submit_btn = gr.Button("Analyze Entities")
|
|
|
with gr.Row():
|
|
|
output_json = gr.JSON(label="Named Entities")
|
|
|
|
|
|
submit_btn.click(fn=ner_function, inputs=input_text, outputs=output_json)
|
|
|
|
|
|
return demo
|
|
|
|
|
|
|
|
|
app = build_ui()
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
app.launch()
|
|
|
|