| | import gradio as gr |
| | import torch |
| | from transformers import AutoTokenizer, AutoConfig |
| | from huggingface_hub import hf_hub_url |
| | import os |
| |
|
| | |
| | from model import IndoBERTClassifier |
| |
|
| | |
| | MODEL_ID = "TrioF/KlikBERT" |
| |
|
| | |
| | config = AutoConfig.from_pretrained(MODEL_ID) |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| |
|
| | |
| | model = IndoBERTClassifier(config) |
| | model_path = hf_hub_url(repo_id=MODEL_ID, filename="pytorch_model.bin") |
| | model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location="cpu")) |
| | model.eval() |
| |
|
| | |
| | |
| | id2label_clickbait = config.custom_id2label['clickbait'] |
| | id2label_kategori = config.custom_id2label['kategori'] |
| |
|
| |
|
| | |
| | def predict(judul, isi): |
| | inputs = tokenizer( |
| | judul, |
| | isi, |
| | truncation=True, |
| | padding=True, |
| | max_length=512, |
| | return_tensors="pt" |
| | ) |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | |
| | clickbait_logits = outputs["clickbait_logits"] |
| | kategori_logits = outputs["kategori_logits"] |
| | |
| | pred_clickbait_id = torch.argmax(clickbait_logits, dim=1).item() |
| | pred_kategori_id = torch.argmax(kategori_logits, dim=1).item() |
| | |
| | pred_clickbait_label = id2label_clickbait[str(pred_clickbait_id)] |
| | pred_kategori_label = id2label_kategori[str(pred_kategori_id)] |
| | |
| | |
| | |
| | return pred_clickbait_label, pred_kategori_label |
| |
|
| |
|
| | |
| | inputs = [ |
| | gr.Textbox(lines=2, label="Judul Berita", placeholder="Masukkan judul berita di sini..."), |
| | gr.Textbox(lines=10, label="Isi Berita", placeholder="Masukkan isi berita di sini...") |
| | ] |
| |
|
| | |
| | |
| | outputs = [ |
| | gr.Text(label="Prediksi Clickbait"), |
| | gr.Text(label="Prediksi Kategori Berita") |
| | ] |
| |
|
| | title = "Model Multi-Task KlikBERT" |
| | description = "Model ini memprediksi apakah judul clickbait dan apa kategori beritanya. Model ini dimuat dari repositori TrioF/KlikBERT." |
| | |
| | iface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=title, description=description) |
| | iface.launch() |