TrioF commited on
Commit
5bb7e34
·
verified ·
1 Parent(s): 75d06a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +73 -73
app.py CHANGED
@@ -1,74 +1,74 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoConfig
4
- from huggingface_hub import hf_hub_url
5
- import os
6
-
7
- # Impor kelas kustom Anda secara eksplisit
8
- from model import IndoBERTClassifier
9
-
10
- # --- Konfigurasi dan Pemuatan Model ---
11
- MODEL_ID = "TrioF/KlikBERT"
12
-
13
- # Muat tokenizer dan config dari Hub
14
- config = AutoConfig.from_pretrained(MODEL_ID)
15
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
-
17
- # Inisialisasi kelas kustom dan muat bobot dari Hub
18
- model = IndoBERTClassifier(config)
19
- model_path = hf_hub_url(repo_id=MODEL_ID, filename="pytorch_model.bin")
20
- model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location="cpu"))
21
- model.eval()
22
-
23
- # --- Pemetaan Label ---
24
- # Pastikan config.json Anda sudah menggunakan 'custom_id2label'
25
- id2label_clickbait = config.custom_id2label['clickbait']
26
- id2label_kategori = config.custom_id2label['kategori']
27
-
28
-
29
- # --- Fungsi Prediksi ---
30
- def predict(judul, isi):
31
- inputs = tokenizer(
32
- judul,
33
- isi,
34
- truncation=True,
35
- padding=True,
36
- max_length=512,
37
- return_tensors="pt"
38
- )
39
-
40
- with torch.no_grad():
41
- outputs = model(**inputs)
42
-
43
- clickbait_logits = outputs["clickbait_logits"]
44
- kategori_logits = outputs["kategori_logits"]
45
-
46
- pred_clickbait_id = torch.argmax(clickbait_logits, dim=1).item()
47
- pred_kategori_id = torch.argmax(kategori_logits, dim=1).item()
48
-
49
- pred_clickbait_label = id2label_clickbait[str(pred_clickbait_id)]
50
- pred_kategori_label = id2label_kategori[str(pred_kategori_id)]
51
-
52
- # --- PERUBAHAN DI SINI ---
53
- # Kembalikan dua nilai terpisah, bukan dictionary
54
- return pred_clickbait_label, pred_kategori_label
55
-
56
-
57
- # --- Antarmuka Gradio ---
58
- inputs = [
59
- gr.Textbox(lines=2, label="Judul Berita", placeholder="Masukkan judul berita di sini..."),
60
- gr.Textbox(lines=10, label="Isi Berita", placeholder="Masukkan isi berita di sini...")
61
- ]
62
-
63
- # --- PERUBAHAN DI SINI ---
64
- # Gunakan dua komponen output terpisah
65
- outputs = [
66
- gr.Text(label="Prediksi Clickbait"),
67
- gr.Text(label="Prediksi Kategori Berita")
68
- ]
69
-
70
- title = "Model Multi-Task KlikBERT"
71
- description = "Model ini memprediksi apakah judul clickbait dan apa kategori beritanya. Model ini dimuat dari repositori TrioF/KlikBERT."
72
-
73
- iface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=title, description=description)
74
  iface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoConfig
4
+ from huggingface_hub import hf_hub_url
5
+ import os
6
+
7
+ # Impor kelas kustom Anda secara eksplisit
8
+ from model import IndoBERTClassifier
9
+
10
+ # --- Konfigurasi dan Pemuatan Model ---
11
+ MODEL_ID = "Hydra-RKMI/KlikBERT"
12
+
13
+ # Muat tokenizer dan config dari Hub
14
+ config = AutoConfig.from_pretrained(MODEL_ID)
15
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
+
17
+ # Inisialisasi kelas kustom dan muat bobot dari Hub
18
+ model = IndoBERTClassifier(config)
19
+ model_path = hf_hub_url(repo_id=MODEL_ID, filename="pytorch_model.bin")
20
+ model.load_state_dict(torch.hub.load_state_dict_from_url(model_path, map_location="cpu"))
21
+ model.eval()
22
+
23
+ # --- Pemetaan Label ---
24
+ # Pastikan config.json Anda sudah menggunakan 'custom_id2label'
25
+ id2label_clickbait = config.custom_id2label['clickbait']
26
+ id2label_kategori = config.custom_id2label['kategori']
27
+
28
+
29
+ # --- Fungsi Prediksi ---
30
+ def predict(judul, isi):
31
+ inputs = tokenizer(
32
+ judul,
33
+ isi,
34
+ truncation=True,
35
+ padding=True,
36
+ max_length=512,
37
+ return_tensors="pt"
38
+ )
39
+
40
+ with torch.no_grad():
41
+ outputs = model(**inputs)
42
+
43
+ clickbait_logits = outputs["clickbait_logits"]
44
+ kategori_logits = outputs["kategori_logits"]
45
+
46
+ pred_clickbait_id = torch.argmax(clickbait_logits, dim=1).item()
47
+ pred_kategori_id = torch.argmax(kategori_logits, dim=1).item()
48
+
49
+ pred_clickbait_label = id2label_clickbait[str(pred_clickbait_id)]
50
+ pred_kategori_label = id2label_kategori[str(pred_kategori_id)]
51
+
52
+ # --- PERUBAHAN DI SINI ---
53
+ # Kembalikan dua nilai terpisah, bukan dictionary
54
+ return pred_clickbait_label, pred_kategori_label
55
+
56
+
57
+ # --- Antarmuka Gradio ---
58
+ inputs = [
59
+ gr.Textbox(lines=2, label="Judul Berita", placeholder="Masukkan judul berita di sini..."),
60
+ gr.Textbox(lines=10, label="Isi Berita", placeholder="Masukkan isi berita di sini...")
61
+ ]
62
+
63
+ # --- PERUBAHAN DI SINI ---
64
+ # Gunakan dua komponen output terpisah
65
+ outputs = [
66
+ gr.Text(label="Prediksi Clickbait"),
67
+ gr.Text(label="Prediksi Kategori Berita")
68
+ ]
69
+
70
+ title = "Model Multi-Task KlikBERT"
71
+ description = "Model ini memprediksi apakah judul clickbait dan apa kategori beritanya. Model ini dimuat dari repositori TrioF/KlikBERT."
72
+
73
+ iface = gr.Interface(fn=predict, inputs=inputs, outputs=outputs, title=title, description=description)
74
  iface.launch()