Swaraj66's picture
Upload app.py with huggingface_hub
b7ba275 verified
import os
import torch
import torch.nn as nn
import random
import gradio as gr
import nltk
from nltk.tokenize import word_tokenize
from transformers import AutoTokenizer, AutoModelForTokenClassification
from huggingface_hub import hf_hub_download
# Set seed for reproducibility
random.seed(42)
torch.manual_seed(42)
# CRF Layer implementation
class CRFLayer(nn.Module):
def __init__(self, num_tags):
super(CRFLayer, self).__init__()
self.num_tags = num_tags
self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
self.start_transitions = nn.Parameter(torch.randn(num_tags))
self.end_transitions = nn.Parameter(torch.randn(num_tags))
def forward(self, emissions):
return self.viterbi_decode(emissions)
def compute_log_likelihood(self, emissions, tags):
# emissions: (seq_len, num_tags)
seq_len = emissions.shape[0]
# Score for the given tag sequence
score = self.start_transitions[tags[0]] + emissions[0, tags[0]]
for i in range(1, seq_len):
score += self.transitions[tags[i - 1], tags[i]] + emissions[i, tags[i]]
score += self.end_transitions[tags[-1]]
# Compute partition function using log-sum-exp
alphas = self.start_transitions + emissions[0]
for i in range(1, seq_len):
emission = emissions[i].unsqueeze(0) # (1, num_tags)
alpha_exp = alphas.unsqueeze(1) + self.transitions # (num_tags, num_tags)
alphas = torch.logsumexp(alpha_exp, dim=0) + emission.squeeze()
Z = torch.logsumexp(alphas + self.end_transitions, dim=0)
return score - Z
def viterbi_decode(self, emissions):
seq_len = emissions.shape[0]
backpointers = []
viterbi_vars = self.start_transitions + emissions[0]
for i in range(1, seq_len):
broadcast_score = viterbi_vars.unsqueeze(1) + self.transitions
best_score, best_tag = torch.max(broadcast_score, dim=0)
viterbi_vars = best_score + emissions[i]
backpointers.append(best_tag)
best_score = viterbi_vars + self.end_transitions
best_tag = torch.argmax(best_score).item()
# Backtrace
best_path = [best_tag]
for bptrs in reversed(backpointers):
best_tag = bptrs[best_tag].item()
best_path.insert(0, best_tag)
return best_path
# --- Checkpoints ---
banglabert_checkpoint = "Swaraj66/BNER_Finetuned_BanglaBERT"
rembert_checkpoint = "Swaraj66/BNER_Finetuned_RemBERT"
crf_assets_checkpoint = "Swaraj66/BNER_CRF_Layer"
# --- Load BanglaBERT ---
banglabert_tokenizer = AutoTokenizer.from_pretrained(
banglabert_checkpoint, use_fast=True
)
banglabert_model = AutoModelForTokenClassification.from_pretrained(
banglabert_checkpoint
)
# --- Load RemBERT ---
rembert_tokenizer = AutoTokenizer.from_pretrained(
rembert_checkpoint
)
rembert_model = AutoModelForTokenClassification.from_pretrained(
rembert_checkpoint
)
# --- Download CRF model weights from private repo ---
model_path = hf_hub_download(
repo_id="Swaraj66/BNER_CRF_Layer",
filename="crf_model.pt" # <- must match the filename in repo
)
# --- Load CRF model with weights ---
CRFmodel = CRFLayer(num_tags=9)
CRFmodel.load_state_dict(torch.load(model_path, map_location="cpu"))
CRFmodel.eval()
print("✅ CRF model loaded from Hugging Face private repo")
def get_word_logits(model, tokenizer, tokens):
encodings = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
word_ids = encodings.word_ids()
with torch.no_grad():
logits = model(**encodings).logits
selected_logits = []
seen = set()
for idx, word_idx in enumerate(word_ids):
if word_idx is None:
continue
if word_idx not in seen:
selected_logits.append(logits[0, idx])
seen.add(word_idx)
return torch.stack(selected_logits) # (num_words, num_labels)
def ensemble_predict(tokens,rembert_model,rembert_tokenizer,Current_banglabert_model,Current_banglabert_tokenizer,CRFmodel):
rembert_logits = get_word_logits(rembert_model, rembert_tokenizer, tokens)
banglabert_logits = get_word_logits(Current_banglabert_model, Current_banglabert_tokenizer, tokens)
min_len = min(rembert_logits.shape[0], banglabert_logits.shape[0])
rembert_logits = rembert_logits[:min_len]
banglabert_logits = banglabert_logits[:min_len]
ensemble_logits = rembert_logits + banglabert_logits
test_logits = [ensemble_logits]
# Test on a new emission (logits) sequence
with torch.no_grad():
for logits in test_logits: # test_logits = list of tensors
en_crf_predicted_sequence = CRFmodel(logits)
preds = torch.argmax(ensemble_logits, dim=-1)
just_ensembled=preds.tolist()
return en_crf_predicted_sequence
model_checkpoint_Base="csebuetnlp/banglabert"
banglabert_tokenizer_base = AutoTokenizer.from_pretrained(
model_checkpoint_Base, use_fast=True
)
id2label = {
0: "O",
1: "B-PER",
2: "I-PER",
3: "B-ORG",
4: "I-ORG",
5: "B-LOC",
6: "I-LOC",
7: "B-MISC",
8: "I-MISC",
"0": "O",
"1": "B-PER",
"2": "I-PER",
"3": "B-ORG",
"4": "I-ORG",
"5": "B-LOC",
"6": "I-LOC",
"7": "B-MISC",
"8": "I-MISC"
}
# Make sure to download punkt if you haven't already
nltk.download('punkt')
nltk.download('punkt_tab')
def ner_function(user_input):
words = word_tokenize(user_input)
print("words -> ",words)
preds = ensemble_predict(words,rembert_model,rembert_tokenizer,banglabert_model,banglabert_tokenizer_base,CRFmodel)
pred_labels_list = [id2label[str(label)] for label in preds] # Convert to str for safety
print("Labels----->",pred_labels_list)
labeled_words = list(zip(words, pred_labels_list))
entities = []
current_entity = ""
current_label = None
for word, label in labeled_words:
if label.startswith("B-"):
if current_entity and current_label:
entities.append((current_entity.strip(), current_label))
current_entity = word
current_label = label[2:]
elif label.startswith("I-") and current_label == label[2:]:
current_entity += " " + word
else:
if current_entity and current_label:
entities.append((current_entity.strip(), current_label))
current_entity = ""
current_label = None
if current_entity and current_label:
entities.append((current_entity.strip(), current_label))
return entities
# Gradio app
def build_ui():
with gr.Blocks() as demo:
gr.Markdown("# Named Entity Recognition App Using Transformer Ensembles with CRF (RemBERT and Banglabert)\nEnter a sentence to detect named entities.")
with gr.Row():
input_text = gr.Textbox(label="Enter a sentence", placeholder="Type your text here...")
with gr.Row():
submit_btn = gr.Button("Analyze Entities")
with gr.Row():
output_json = gr.JSON(label="Named Entities")
submit_btn.click(fn=ner_function, inputs=input_text, outputs=output_json)
return demo
# Create the app
app = build_ui()
# For local running (comment this out when deploying if you want)
if __name__ == "__main__":
app.launch()