Spaces:
Sleeping
Sleeping
| import torch | |
| import numpy as np | |
| import pickle | |
| from transformers import AutoTokenizer, AutoModel | |
| from normalizer import normalize | |
| import gradio as gr | |
| # --- Device --- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # --- Load tokenizers and models --- | |
| bert_tokenizer = AutoTokenizer.from_pretrained("csebuetnlp/banglabert") | |
| titu_tokenizer = AutoTokenizer.from_pretrained("hishab/titulm-llama-3.2-1b-v1.0") | |
| bert_model = AutoModel.from_pretrained("csebuetnlp/banglabert").to(device).eval() | |
| titu_model = AutoModel.from_pretrained("hishab/titulm-llama-3.2-1b-v1.0").to(device).eval() | |
| # --- Load trained LightGBM model and preprocessing info --- | |
| with open("multiclass_lightgbm_bert_titu.pkl", "rb") as f: | |
| classifier = pickle.load(f) | |
| with open("preprocessing_info.pkl", "rb") as f: | |
| info = pickle.load(f) | |
| class_names = info['class_names'] | |
| bert_max_len = 45 | |
| titu_max_len = 148 | |
| # --- Helper functions --- | |
| def preprocess(text): | |
| return normalize(text) | |
| def get_embedding(text, tokenizer, model, max_len): | |
| enc = tokenizer([text], return_tensors="pt", padding=True, truncation=True, max_length=max_len).to(device) | |
| with torch.inference_mode(): | |
| out = model(**enc) | |
| last_hidden = out.last_hidden_state | |
| attn = enc.get("attention_mask", None) | |
| if attn is not None: | |
| attn = attn.unsqueeze(-1) | |
| emb = (last_hidden * attn).sum(dim=1) / attn.sum(dim=1).clamp(min=1e-6) | |
| else: | |
| emb = last_hidden.mean(dim=1) | |
| return emb.detach().cpu().numpy() | |
| def predict(text): | |
| text = preprocess(text) | |
| bert_emb = get_embedding(text, bert_tokenizer, bert_model, bert_max_len) | |
| titu_emb = get_embedding(text, titu_tokenizer, titu_model, titu_max_len) | |
| features = np.concatenate([bert_emb, titu_emb], axis=1) | |
| pred_idx = classifier.predict(features)[0] | |
| return class_names[pred_idx] | |
| # --- Gradio interface --- | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(label="Enter Bangla text"), | |
| outputs=gr.Textbox(label="Predicted Trait"), | |
| title="Bangla Personality Trait Predictor", | |
| description="Enter Bangla text and get the predicted personality trait." | |
| ) | |
| iface.launch() | |