Swaraj66 commited on
Commit
0243138
·
verified ·
1 Parent(s): 9c5a7de

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NER.py
2
+
3
+ from seqeval.metrics import classification_report, accuracy_score, f1_score, precision_score, recall_score
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
6
+ import gradio as gr
7
+ import nltk
8
+ from nltk.tokenize import word_tokenize
9
+
10
+ # Download necessary NLTK data
11
+ nltk.download('punkt')
12
+ nltk.download('punkt_tab')
13
+
14
+ # Load the two models
15
+ model_id = "Swaraj66/Banglabert-finetuned-ner"
16
+ model_id2 = "Swaraj66/Finetuned_RemBERT"
17
+
18
+ banglabert_tokenizer = AutoTokenizer.from_pretrained(model_id)
19
+ banglabert_model = AutoModelForTokenClassification.from_pretrained(model_id)
20
+
21
+ rembert_tokenizer = AutoTokenizer.from_pretrained(model_id2)
22
+ rembert_model = AutoModelForTokenClassification.from_pretrained(model_id2)
23
+
24
+ # Helper functions
25
+ def get_word_logits(model, tokenizer, tokens):
26
+ encodings = tokenizer(tokens, is_split_into_words=True, return_tensors="pt", padding=True, truncation=True)
27
+ word_ids = encodings.word_ids()
28
+
29
+ with torch.no_grad():
30
+ logits = model(**encodings).logits
31
+
32
+ selected_logits = []
33
+ seen = set()
34
+ for idx, word_idx in enumerate(word_ids):
35
+ if word_idx is None:
36
+ continue
37
+ if word_idx not in seen:
38
+ selected_logits.append(logits[0, idx])
39
+ seen.add(word_idx)
40
+
41
+ return torch.stack(selected_logits)
42
+
43
+ def ensemble_predict(tokens):
44
+ rembert_logits = get_word_logits(rembert_model, rembert_tokenizer, tokens)
45
+ banglabert_logits = get_word_logits(banglabert_model, banglabert_tokenizer, tokens)
46
+
47
+ min_len = min(rembert_logits.shape[0], banglabert_logits.shape[0])
48
+ rembert_logits = rembert_logits[:min_len]
49
+ banglabert_logits = banglabert_logits[:min_len]
50
+
51
+ ensemble_logits = rembert_logits + banglabert_logits
52
+
53
+ preds = torch.argmax(ensemble_logits, dim=-1)
54
+ return preds.tolist()
55
+
56
+ # Label mapping
57
+ id2label = {
58
+ 0: "O", 1: "B-PER", 2: "I-PER", 3: "B-ORG", 4: "I-ORG", 5: "B-LOC", 6: "I-LOC", 7: "B-MISC", 8: "I-MISC",
59
+ "0": "O", "1": "B-PER", "2": "I-PER", "3": "B-ORG", "4": "I-ORG", "5": "B-LOC", "6": "I-LOC", "7": "B-MISC", "8": "I-MISC"
60
+ }
61
+
62
+ # Main NER function
63
+ def ner_function(user_input):
64
+ words = word_tokenize(user_input)
65
+ preds = ensemble_predict(words)
66
+ pred_labels_list = [id2label[str(label)] for label in preds]
67
+
68
+ labeled_words = list(zip(words, pred_labels_list))
69
+
70
+ entities = []
71
+ current_entity = ""
72
+ current_label = None
73
+
74
+ for word, label in labeled_words:
75
+ if label.startswith("B-"):
76
+ if current_entity and current_label:
77
+ entities.append((current_entity.strip(), current_label))
78
+ current_entity = word
79
+ current_label = label[2:]
80
+ elif label.startswith("I-") and current_label == label[2:]:
81
+ current_entity += " " + word
82
+ else:
83
+ if current_entity and current_label:
84
+ entities.append((current_entity.strip(), current_label))
85
+ current_entity = ""
86
+ current_label = None
87
+
88
+ if current_entity and current_label:
89
+ entities.append((current_entity.strip(), current_label))
90
+
91
+ return entities
92
+
93
+ # Gradio UI
94
+ def build_ui():
95
+ with gr.Blocks() as demo:
96
+ gr.Markdown("# Named Entity Recognition App Using Ensemble Model (RemBERT + BanglaBERT)\nEnter a sentence to detect named entities.")
97
+ with gr.Row():
98
+ input_text = gr.Textbox(label="Enter a sentence", placeholder="Type your text here...")
99
+ with gr.Row():
100
+ submit_btn = gr.Button("Analyze Entities")
101
+ with gr.Row():
102
+ output_json = gr.JSON(label="Named Entities")
103
+
104
+ submit_btn.click(fn=ner_function, inputs=input_text, outputs=output_json)
105
+
106
+ return demo
107
+
108
+ # Launch the app
109
+ app = build_ui()
110
+
111
+ if __name__ == "__main__":
112
+ app.launch()