Swaraj66 commited on
Commit
b7ba275
·
verified ·
1 Parent(s): c03fb0d

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import random
5
+ import gradio as gr
6
+ import nltk
7
+ from nltk.tokenize import word_tokenize
8
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
9
+ from huggingface_hub import hf_hub_download
10
+
11
+
12
+
13
+ # Set seed for reproducibility
14
+ random.seed(42)
15
+ torch.manual_seed(42)
16
+
17
+ # CRF Layer implementation
18
+ class CRFLayer(nn.Module):
19
+ def __init__(self, num_tags):
20
+ super(CRFLayer, self).__init__()
21
+ self.num_tags = num_tags
22
+ self.transitions = nn.Parameter(torch.randn(num_tags, num_tags))
23
+ self.start_transitions = nn.Parameter(torch.randn(num_tags))
24
+ self.end_transitions = nn.Parameter(torch.randn(num_tags))
25
+
26
+ def forward(self, emissions):
27
+ return self.viterbi_decode(emissions)
28
+
29
+ def compute_log_likelihood(self, emissions, tags):
30
+ # emissions: (seq_len, num_tags)
31
+ seq_len = emissions.shape[0]
32
+
33
+ # Score for the given tag sequence
34
+ score = self.start_transitions[tags[0]] + emissions[0, tags[0]]
35
+ for i in range(1, seq_len):
36
+ score += self.transitions[tags[i - 1], tags[i]] + emissions[i, tags[i]]
37
+ score += self.end_transitions[tags[-1]]
38
+
39
+ # Compute partition function using log-sum-exp
40
+ alphas = self.start_transitions + emissions[0]
41
+ for i in range(1, seq_len):
42
+ emission = emissions[i].unsqueeze(0) # (1, num_tags)
43
+ alpha_exp = alphas.unsqueeze(1) + self.transitions # (num_tags, num_tags)
44
+ alphas = torch.logsumexp(alpha_exp, dim=0) + emission.squeeze()
45
+ Z = torch.logsumexp(alphas + self.end_transitions, dim=0)
46
+ return score - Z
47
+
48
+ def viterbi_decode(self, emissions):
49
+ seq_len = emissions.shape[0]
50
+ backpointers = []
51
+
52
+ viterbi_vars = self.start_transitions + emissions[0]
53
+ for i in range(1, seq_len):
54
+ broadcast_score = viterbi_vars.unsqueeze(1) + self.transitions
55
+ best_score, best_tag = torch.max(broadcast_score, dim=0)
56
+ viterbi_vars = best_score + emissions[i]
57
+ backpointers.append(best_tag)
58
+
59
+ best_score = viterbi_vars + self.end_transitions
60
+ best_tag = torch.argmax(best_score).item()
61
+
62
+ # Backtrace
63
+ best_path = [best_tag]
64
+ for bptrs in reversed(backpointers):
65
+ best_tag = bptrs[best_tag].item()
66
+ best_path.insert(0, best_tag)
67
+ return best_path
68
+
69
+
70
+
71
+
72
+
73
+ # --- Checkpoints ---
74
+ banglabert_checkpoint = "Swaraj66/BNER_Finetuned_BanglaBERT"
75
+ rembert_checkpoint = "Swaraj66/BNER_Finetuned_RemBERT"
76
+ crf_assets_checkpoint = "Swaraj66/BNER_CRF_Layer"
77
+
78
+ # --- Load BanglaBERT ---
79
+ banglabert_tokenizer = AutoTokenizer.from_pretrained(
80
+ banglabert_checkpoint, use_fast=True
81
+ )
82
+ banglabert_model = AutoModelForTokenClassification.from_pretrained(
83
+ banglabert_checkpoint
84
+ )
85
+
86
+ # --- Load RemBERT ---
87
+ rembert_tokenizer = AutoTokenizer.from_pretrained(
88
+ rembert_checkpoint
89
+ )
90
+ rembert_model = AutoModelForTokenClassification.from_pretrained(
91
+ rembert_checkpoint
92
+ )
93
+
94
+ # --- Download CRF model weights from private repo ---
95
+ model_path = hf_hub_download(
96
+ repo_id="Swaraj66/BNER_CRF_Layer",
97
+ filename="crf_model.pt" # <- must match the filename in repo
98
+
99
+ )
100
+
101
+ # --- Load CRF model with weights ---
102
+ CRFmodel = CRFLayer(num_tags=9)
103
+ CRFmodel.load_state_dict(torch.load(model_path, map_location="cpu"))
104
+ CRFmodel.eval()
105
+
106
+ print("✅ CRF model loaded from Hugging Face private repo")
107
+
108
+ def get_word_logits(model, tokenizer, tokens):
109
+ encodings = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
110
+ word_ids = encodings.word_ids()
111
+
112
+ with torch.no_grad():
113
+ logits = model(**encodings).logits
114
+
115
+ selected_logits = []
116
+ seen = set()
117
+ for idx, word_idx in enumerate(word_ids):
118
+ if word_idx is None:
119
+ continue
120
+ if word_idx not in seen:
121
+ selected_logits.append(logits[0, idx])
122
+ seen.add(word_idx)
123
+
124
+ return torch.stack(selected_logits) # (num_words, num_labels)
125
+
126
+ def ensemble_predict(tokens,rembert_model,rembert_tokenizer,Current_banglabert_model,Current_banglabert_tokenizer,CRFmodel):
127
+
128
+ rembert_logits = get_word_logits(rembert_model, rembert_tokenizer, tokens)
129
+ banglabert_logits = get_word_logits(Current_banglabert_model, Current_banglabert_tokenizer, tokens)
130
+
131
+ min_len = min(rembert_logits.shape[0], banglabert_logits.shape[0])
132
+ rembert_logits = rembert_logits[:min_len]
133
+ banglabert_logits = banglabert_logits[:min_len]
134
+
135
+ ensemble_logits = rembert_logits + banglabert_logits
136
+ test_logits = [ensemble_logits]
137
+
138
+ # Test on a new emission (logits) sequence
139
+ with torch.no_grad():
140
+ for logits in test_logits: # test_logits = list of tensors
141
+ en_crf_predicted_sequence = CRFmodel(logits)
142
+
143
+
144
+
145
+ preds = torch.argmax(ensemble_logits, dim=-1)
146
+ just_ensembled=preds.tolist()
147
+
148
+
149
+ return en_crf_predicted_sequence
150
+
151
+ model_checkpoint_Base="csebuetnlp/banglabert"
152
+ banglabert_tokenizer_base = AutoTokenizer.from_pretrained(
153
+ model_checkpoint_Base, use_fast=True
154
+ )
155
+
156
+ id2label = {
157
+ 0: "O",
158
+ 1: "B-PER",
159
+ 2: "I-PER",
160
+ 3: "B-ORG",
161
+ 4: "I-ORG",
162
+ 5: "B-LOC",
163
+ 6: "I-LOC",
164
+ 7: "B-MISC",
165
+ 8: "I-MISC",
166
+ "0": "O",
167
+ "1": "B-PER",
168
+ "2": "I-PER",
169
+ "3": "B-ORG",
170
+ "4": "I-ORG",
171
+ "5": "B-LOC",
172
+ "6": "I-LOC",
173
+ "7": "B-MISC",
174
+ "8": "I-MISC"
175
+ }
176
+
177
+
178
+ # Make sure to download punkt if you haven't already
179
+ nltk.download('punkt')
180
+ nltk.download('punkt_tab')
181
+
182
+
183
+ def ner_function(user_input):
184
+ words = word_tokenize(user_input)
185
+ print("words -> ",words)
186
+ preds = ensemble_predict(words,rembert_model,rembert_tokenizer,banglabert_model,banglabert_tokenizer_base,CRFmodel)
187
+ pred_labels_list = [id2label[str(label)] for label in preds] # Convert to str for safety
188
+
189
+ print("Labels----->",pred_labels_list)
190
+
191
+ labeled_words = list(zip(words, pred_labels_list))
192
+
193
+ entities = []
194
+ current_entity = ""
195
+ current_label = None
196
+
197
+ for word, label in labeled_words:
198
+ if label.startswith("B-"):
199
+ if current_entity and current_label:
200
+ entities.append((current_entity.strip(), current_label))
201
+ current_entity = word
202
+ current_label = label[2:]
203
+ elif label.startswith("I-") and current_label == label[2:]:
204
+ current_entity += " " + word
205
+ else:
206
+ if current_entity and current_label:
207
+ entities.append((current_entity.strip(), current_label))
208
+ current_entity = ""
209
+ current_label = None
210
+
211
+ if current_entity and current_label:
212
+ entities.append((current_entity.strip(), current_label))
213
+
214
+ return entities
215
+
216
+ # Gradio app
217
+ def build_ui():
218
+ with gr.Blocks() as demo:
219
+ gr.Markdown("# Named Entity Recognition App Using Transformer Ensembles with CRF (RemBERT and Banglabert)\nEnter a sentence to detect named entities.")
220
+ with gr.Row():
221
+ input_text = gr.Textbox(label="Enter a sentence", placeholder="Type your text here...")
222
+ with gr.Row():
223
+ submit_btn = gr.Button("Analyze Entities")
224
+ with gr.Row():
225
+ output_json = gr.JSON(label="Named Entities")
226
+
227
+ submit_btn.click(fn=ner_function, inputs=input_text, outputs=output_json)
228
+
229
+ return demo
230
+
231
+ # Create the app
232
+ app = build_ui()
233
+
234
+ # For local running (comment this out when deploying if you want)
235
+ if __name__ == "__main__":
236
+ app.launch()