Amal17's picture
4 models
4ac1935
import gradio as gr
import torch
from bert_gru_classifier import BERTBiGRUClassifier
from bert_lstm_classifier import BERTBiLSTMClassifier
from transformers import AutoTokenizer
CLASS_MAP = {0: "Negative", 1: "Neutral", 2: "Positive" }
# Load tokenizer (pakai tokenizer yang sama untuk semua model)
tokenizer = AutoTokenizer.from_pretrained("LazarusNLP/NusaBERT-large")
# Load models
bigru_model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-concate-BiGRU-NusaX-ace")
bigru_model.eval()
bigru_translate_model = BERTBiGRUClassifier.from_pretrained("Amal17/NusaBERT-concate-BiGRU-NusaTranslate-senti")
bigru_translate_model.eval()
bilstm_model = BERTBiLSTMClassifier.from_pretrained("Amal17/NusaBERT-concate-BiLSTM-NusaX-ace")
bilstm_model.eval()
bilstm_translate_model = BERTBiLSTMClassifier.from_pretrained("Amal17/NusaBERT-concate-BiLSTM-NusaTranslate-senti")
bilstm_translate_model.eval()
# Inference helper
def predict_with_model(model, text):
inputs = tokenizer(
text,
padding="max_length",
truncation=True,
max_length=128,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
logits = outputs["logits"]
probs = torch.softmax(logits, dim=1)
pred = torch.argmax(probs, dim=1).item()
confidence = probs[0][pred].item()
return pred, confidence
# Gradio interface function
def compare_models(text):
pred_a, conf_a = predict_with_model(bigru_model, text)
pred_b, conf_b = predict_with_model(bilstm_model, text)
pred_c, conf_c = predict_with_model(bigru_translate_model, text)
pred_d, conf_d = predict_with_model(bilstm_translate_model, text)
return (
f"Class: {pred_a} ({CLASS_MAP[pred_a]}) with confidence: {conf_a:.4f}",
f"Class: {pred_b} ({CLASS_MAP[pred_b]}) with confidence: {conf_b:.4f}",
f"Class: {pred_c} ({CLASS_MAP[pred_c]}) with confidence: {conf_c:.4f}",
f"Class: {pred_d} ({CLASS_MAP[pred_d]}) with confidence: {conf_d:.4f}",
)
# Build Gradio UI
interface = gr.Interface(
fn=compare_models,
inputs=gr.Textbox(label="Input Text"),
outputs=[
gr.Textbox(label="NusaBERT-BiGRU-ace"),
gr.Textbox(label="NusaBERT-BiLSTM-ace"),
gr.Textbox(label="NusaBERT-BiGRU-translate"),
gr.Textbox(label="NusaBERT-BiLSRM-translate"),
],
title="Hybrid Model NusaBERT + RNN"
)
interface.launch()