samiha-akter commited on
Commit
06823b1
·
verified ·
1 Parent(s): 008f708

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pickle
4
+ from transformers import AutoTokenizer, AutoModel
5
+ from normalizer import normalize
6
+ import gradio as gr
7
+
8
+ # --- Device ---
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # --- Load tokenizers and models ---
12
+ bert_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglabert")
13
+ titu_tokenizer = AutoTokenizer.from_pretrained("hishab/titulm-llama-3.2-1b-v1.0")
14
+ bert_model = AutoModel.from_pretrained("csebuetnlp/banglabert").to(device).eval()
15
+ titu_model = AutoModel.from_pretrained("hishab/titulm-llama-3.2-1b-v1.0").to(device).eval()
16
+
17
+ # --- Load trained LightGBM model and preprocessing info ---
18
+ with open("multiclass_lightgbm_bert_titu.pkl", "rb") as f:
19
+ classifier = pickle.load(f)
20
+
21
+ with open("preprocessing_info.pkl", "rb") as f:
22
+ info = pickle.load(f)
23
+ class_names = info['class_names']
24
+ bert_max_len = 45
25
+ titu_max_len = 148
26
+
27
+ # --- Helper functions ---
28
+ def preprocess(text):
29
+ return normalize(text)
30
+
31
+ def get_embedding(text, tokenizer, model, max_len):
32
+ enc = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(device)
33
+ with torch.inference_mode():
34
+ out = model(**enc)
35
+ last_hidden = out.last_hidden_state
36
+ attn = enc.get("attention_mask", None)
37
+ if attn is not None:
38
+ attn = attn.unsqueeze(-1)
39
+ emb = (last_hidden * attn).sum(dim=1) / attn.sum(dim=1).clamp(min=1e-6)
40
+ else:
41
+ emb = last_hidden.mean(dim=1)
42
+ return emb.detach().cpu().numpy()
43
+
44
+ def predict(text):
45
+ text = preprocess(text)
46
+ bert_emb = get_embedding(text, bert_tokenizer, bert_model, bert_max_len)
47
+ titu_emb = get_embedding(text, titu_tokenizer, titu_model, titu_max_len)
48
+ features = np.concatenate([bert_emb, titu_emb], axis=1)
49
+ pred_idx = classifier.predict(features)[0]
50
+ return class_names[pred_idx]
51
+
52
+ # --- Gradio interface ---
53
+ iface = gr.Interface(
54
+ fn=predict,
55
+ inputs=gr.Textbox(label="Enter Bangla text"),
56
+ outputs=gr.Textbox(label="Predicted Trait"),
57
+ title="Bangla Personality Trait Predictor",
58
+ description="Enter Bangla text and get the predicted personality trait."
59
+ )
60
+
61
+ iface.launch()