samiha-akter's picture
Create app.py
06823b1 verified
import torch
import numpy as np
import pickle
from transformers import AutoTokenizer, AutoModel
from normalizer import normalize
import gradio as gr
# --- Device ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# --- Load tokenizers and models ---
bert_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglabert")
titu_tokenizer = AutoTokenizer.from_pretrained("hishab/titulm-llama-3.2-1b-v1.0")
bert_model = AutoModel.from_pretrained("csebuetnlp/banglabert").to(device).eval()
titu_model = AutoModel.from_pretrained("hishab/titulm-llama-3.2-1b-v1.0").to(device).eval()
# --- Load trained LightGBM model and preprocessing info ---
with open("multiclass_lightgbm_bert_titu.pkl", "rb") as f:
classifier = pickle.load(f)
with open("preprocessing_info.pkl", "rb") as f:
info = pickle.load(f)
class_names = info['class_names']
bert_max_len = 45
titu_max_len = 148
# --- Helper functions ---
def preprocess(text):
return normalize(text)
def get_embedding(text, tokenizer, model, max_len):
enc = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(device)
with torch.inference_mode():
out = model(**enc)
last_hidden = out.last_hidden_state
attn = enc.get("attention_mask", None)
if attn is not None:
attn = attn.unsqueeze(-1)
emb = (last_hidden * attn).sum(dim=1) / attn.sum(dim=1).clamp(min=1e-6)
else:
emb = last_hidden.mean(dim=1)
return emb.detach().cpu().numpy()
def predict(text):
text = preprocess(text)
bert_emb = get_embedding(text, bert_tokenizer, bert_model, bert_max_len)
titu_emb = get_embedding(text, titu_tokenizer, titu_model, titu_max_len)
features = np.concatenate([bert_emb, titu_emb], axis=1)
pred_idx = classifier.predict(features)[0]
return class_names[pred_idx]
# --- Gradio interface ---
iface = gr.Interface(
fn=predict,
inputs=gr.Textbox(label="Enter Bangla text"),
outputs=gr.Textbox(label="Predicted Trait"),
title="Bangla Personality Trait Predictor",
description="Enter Bangla text and get the predicted personality trait."
)
iface.launch()