KlikBERT / app.py
TrioF's picture
Update app.py
5bb7e34 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoConfig
from huggingface_hub import hf_hub_url
import os
# Impor kelas kustom Anda secara eksplisit
from model import IndoBERTClassifier
# --- Konfigurasi dan Pemuatan Model ---
MODEL_ID = "Hydra-RKMI/KlikBERT"
# Muat tokenizer dan config dari Hub
config = AutoConfig.from_pretrained(MODEL_ID)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
# Inisialisasi kelas kustom dan muat bobot dari Hub
model = IndoBERTClassifier(config)
model_path = hf_hub_url(repo_id=MODEL_ID, filename="pytorch_model.bin")
model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location="cpu"))
model.eval()
# --- Pemetaan Label ---
# Pastikan config.json Anda sudah menggunakan 'custom_id2label'
id2label_clickbait = config.custom_id2label['clickbait']
id2label_kategori = config.custom_id2label['kategori']
# --- Fungsi Prediksi ---
def predict(judul, isi):
inputs = tokenizer(
judul,
isi,
truncation=True,
padding=True,
max_length=512,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
clickbait_logits = outputs["clickbait_logits"]
kategori_logits = outputs["kategori_logits"]
pred_clickbait_id = torch.argmax(clickbait_logits, dim=1).item()
pred_kategori_id = torch.argmax(kategori_logits, dim=1).item()
pred_clickbait_label = id2label_clickbait[str(pred_clickbait_id)]
pred_kategori_label = id2label_kategori[str(pred_kategori_id)]
# --- PERUBAHAN DI SINI ---
# Kembalikan dua nilai terpisah, bukan dictionary
return pred_clickbait_label, pred_kategori_label
# --- Antarmuka Gradio ---
inputs = [
gr.Textbox(lines=2, label="Judul Berita", placeholder="Masukkan judul berita di sini..."),
gr.Textbox(lines=10, label="Isi Berita", placeholder="Masukkan isi berita di sini...")
]
# --- PERUBAHAN DI SINI ---
# Gunakan dua komponen output terpisah
outputs = [
gr.Text(label="Prediksi Clickbait"),
gr.Text(label="Prediksi Kategori Berita")
]
title = "Model Multi-Task KlikBERT"
description = "Model ini memprediksi apakah judul clickbait dan apa kategori beritanya. Model ini dimuat dari repositori TrioF/KlikBERT."
iface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=title, description=description)
iface.launch()