| import gradio as gr |
| import torch |
| from bert_gru_classifier import BERTBiGRUClassifier |
| from bert_lstm_classifier import BERTBiLSTMClassifier |
| from transformers import AutoTokenizer |
|
|
| CLASS_MAP = {0: "Negative", 1: "Neutral", 2: "Positive" } |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained("LazarusNLP/NusaBERT-large") |
|
|
| |
| bigru_model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-concate-BiGRU-NusaX-ace") |
| bigru_model.eval() |
|
|
| bigru_translate_model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-concate-BiGRU-NusaTranslate-senti") |
| bigru_translate_model.eval() |
|
|
| bilstm_model = BERTBiLSTMClassifier.from_pretrained("Amal17/NusaBERT-concate-BiLSTM-NusaX-ace") |
| bilstm_model.eval() |
|
|
| bilstm_translate_model = BERTBiLSTMClassifier.from_pretrained("Amal17/NusaBERT-concate-BiLSTM-NusaTranslate-senti") |
| bilstm_translate_model.eval() |
|
|
|
|
|
|
| |
| def predict_with_model(model, text): |
| inputs = tokenizer( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=128, |
| return_tensors="pt" |
| ) |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| logits = outputs["logits"] |
| probs = torch.softmax(logits, dim=1) |
| pred = torch.argmax(probs, dim=1).item() |
| confidence = probs[0][pred].item() |
| return pred, confidence |
|
|
| |
| def compare_models(text): |
| pred_a, conf_a = predict_with_model(bigru_model, text) |
| pred_b, conf_b = predict_with_model(bilstm_model, text) |
|
|
| pred_c, conf_c = predict_with_model(bigru_translate_model, text) |
| pred_d, conf_d = predict_with_model(bilstm_translate_model, text) |
|
|
| return ( |
| f"Class: {pred_a} ({CLASS_MAP[pred_a]}) with confidence: {conf_a:.4f}", |
| f"Class: {pred_b} ({CLASS_MAP[pred_b]}) with confidence: {conf_b:.4f}", |
| f"Class: {pred_c} ({CLASS_MAP[pred_c]}) with confidence: {conf_c:.4f}", |
| f"Class: {pred_d} ({CLASS_MAP[pred_d]}) with confidence: {conf_d:.4f}", |
| ) |
|
|
| |
| interface = gr.Interface( |
| fn=compare_models, |
| inputs=gr.Textbox(label="Input Text"), |
| outputs=[ |
| gr.Textbox(label="NusaBERT-BiGRU-ace"), |
| gr.Textbox(label="NusaBERT-BiLSTM-ace"), |
| gr.Textbox(label="NusaBERT-BiGRU-translate"), |
| gr.Textbox(label="NusaBERT-BiLSRM-translate"), |
| ], |
| title="Hybrid Model NusaBERT + RNN" |
| ) |
|
|
| interface.launch() |