rdsarjito commited on
Commit
3ab25d3
Β·
1 Parent(s): feaf663
Files changed (1) hide show
  1. app.py +54 -77
app.py CHANGED
@@ -1,107 +1,84 @@
1
  import streamlit as st
2
  import torch
3
- import re
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
- import requests
6
- from bs4 import BeautifulSoup
7
- import os
 
 
 
 
 
8
 
9
- # === Konfigurasi Umum ===
10
- MODEL_PATH = 'model/alergen_model.pt' # Pastikan path dan nama file model benar
11
- LABELS = ['susu', 'kacang', 'telur', 'makanan_laut', 'gandum']
12
- MAX_LEN = 128
13
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
- # === Load Model & Tokenizer ===
16
  @st.cache_resource
17
  def load_model():
18
- tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p2")
19
- model = AutoModelForSequenceClassification.from_pretrained(
20
- "indobenchmark/indobert-base-p2",
21
- num_labels=len(LABELS),
22
- problem_type="multi_label_classification"
23
- )
24
-
25
- # Load state dict dan target columns
26
- state = torch.load(MODEL_PATH, map_location=DEVICE)
27
- model.load_state_dict(state['model_state_dict'])
28
- target_columns = state['target_columns'] # Simpan target_columns
29
-
30
- model.to(DEVICE)
31
  model.eval()
32
-
 
 
33
  return tokenizer, model, target_columns
34
 
35
- # === Cleaning Teks ===
36
  def clean_text(text):
37
  text = text.replace('--', ' ')
38
  text = re.sub(r"http\S+", "", text)
39
- text = re.sub('\n', ' ', text)
40
  text = re.sub("[^a-zA-Z0-9\s]", " ", text)
41
  text = re.sub(" {2,}", " ", text)
42
- return text.lower().strip()
43
-
44
- # === Scrape dari Cookpad ===
45
- def scrape_ingredients(url):
46
- try:
47
- headers = {'User-Agent': 'Mozilla/5.0'}
48
- r = requests.get(url, headers=headers)
49
- soup = BeautifulSoup(r.content, 'html.parser')
50
- ingredients_div = soup.find('div', id='ingredients')
51
- if ingredients_div:
52
- return ingredients_div.get_text(separator=' ')
53
- except:
54
- return None
55
 
56
- # === Prediksi Alergen ===
57
- def predict_alergen(text, tokenizer, model, target_columns, threshold):
58
- text = clean_text(text)
59
  encoding = tokenizer.encode_plus(
60
- text,
61
  add_special_tokens=True,
62
- max_length=MAX_LEN,
63
  truncation=True,
64
- padding='max_length',
65
- return_tensors='pt'
66
  )
67
- input_ids = encoding['input_ids'].to(DEVICE)
68
- attention_mask = encoding['attention_mask'].to(DEVICE)
 
69
 
70
  with torch.no_grad():
71
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
72
- probs = torch.sigmoid(outputs.logits).cpu().numpy()[0]
 
73
 
74
- return {label: float(prob) for label, prob in zip(target_columns, probs)}
 
75
 
76
- # === UI Streamlit ===
77
- st.set_page_config(page_title="Deteksi Alergen IndoBERT", page_icon="🍲")
78
- st.title("🍲 Deteksi Alergen dari Resep Cookpad (IndoBERT)")
79
 
80
  tokenizer, model, target_columns = load_model()
81
 
82
- input_mode = st.radio("Pilih input:", ["Teks Manual", "URL Cookpad"])
83
- if input_mode == "Teks Manual":
84
- user_input = st.text_area("πŸ“ Masukkan bahan makanan:")
85
- else:
86
- url = st.text_input("πŸ”— Masukkan URL Cookpad:")
87
- user_input = ""
88
- if url:
89
- scraped = scrape_ingredients(url)
90
- if scraped:
91
- user_input = scraped
92
- st.success("βœ… Berhasil mengambil bahan dari URL")
93
- st.text_area("πŸ“‹ Bahan dari URL:", value=user_input, height=200)
94
- else:
95
- st.error("❌ Gagal mengambil data dari URL.")
96
-
97
- threshold = st.slider("🎚 Threshold (default 0.5):", 0.0, 1.0, 0.5)
98
 
99
- if st.button("πŸš€ Prediksi"):
100
- if user_input.strip():
101
- result = predict_alergen(user_input, tokenizer, model, target_columns, threshold)
102
- st.subheader("πŸ“Š Hasil Prediksi Alergen:")
103
- for label, prob in result.items():
104
- status = "βœ… Ada" if prob >= threshold else "❌ Tidak Ada"
105
- st.write(f"- **{label}**: {status} ({prob:.2f})")
106
  else:
107
- st.warning("⚠️ Masukkan teks bahan atau URL terlebih dahulu.")
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import torch
3
+ import torch.nn as nn
4
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import re
6
+
7
+ # ----- Define model class -----
8
+ class MultilabelBertClassifier(nn.Module):
9
+ def __init__(self, model_name, num_labels):
10
+ super(MultilabelBertClassifier, self).__init__()
11
+ self.bert = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
12
+ self.bert.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
13
 
14
+ def forward(self, input_ids, attention_mask):
15
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
16
+ return outputs.logits
 
 
17
 
18
+ # ----- Load model and tokenizer -----
19
  @st.cache_resource
20
  def load_model():
21
+ model_path = "model/alergen_model.pt"
22
+ checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
23
+ target_columns = checkpoint["target_columns"]
24
+
25
+ model = MultilabelBertClassifier("indobenchmark/indobert-base-p1", num_labels=len(target_columns))
26
+ model.load_state_dict(checkpoint["model_state_dict"])
 
 
 
 
 
 
 
27
  model.eval()
28
+
29
+ tokenizer = AutoTokenizer.from_pretrained("indobenchmark/indobert-base-p1")
30
+
31
  return tokenizer, model, target_columns
32
 
33
+ # ----- Preprocessing function -----
34
  def clean_text(text):
35
  text = text.replace('--', ' ')
36
  text = re.sub(r"http\S+", "", text)
37
+ text = re.sub("\n", " ", text)
38
  text = re.sub("[^a-zA-Z0-9\s]", " ", text)
39
  text = re.sub(" {2,}", " ", text)
40
+ return text.strip().lower()
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ # ----- Prediction function -----
43
+ def predict_alergens(text, tokenizer, model, target_columns, max_length=128):
44
+ cleaned = clean_text(text)
45
  encoding = tokenizer.encode_plus(
46
+ cleaned,
47
  add_special_tokens=True,
48
+ max_length=max_length,
49
  truncation=True,
50
+ return_tensors='pt',
51
+ padding='max_length'
52
  )
53
+
54
+ input_ids = encoding["input_ids"]
55
+ attention_mask = encoding["attention_mask"]
56
 
57
  with torch.no_grad():
58
  outputs = model(input_ids=input_ids, attention_mask=attention_mask)
59
+ probs = torch.sigmoid(outputs)
60
+ preds = (probs > 0.5).float().squeeze(0).tolist()
61
 
62
+ results = {target: bool(preds[i]) for i, target in enumerate(target_columns)}
63
+ return results
64
 
65
+ # ----- Streamlit App UI -----
66
+ st.title("Deteksi Alergen dari Resep")
 
67
 
68
  tokenizer, model, target_columns = load_model()
69
 
70
+ with st.form("alergen_form"):
71
+ input_text = st.text_area("Masukkan daftar bahan (ingredients):", height=200)
72
+ submitted = st.form_submit_button("Deteksi Alergen")
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ if submitted:
75
+ if input_text.strip() == "":
76
+ st.warning("Mohon masukkan teks bahan terlebih dahulu.")
 
 
 
 
77
  else:
78
+ results = predict_alergens(input_text, tokenizer, model, target_columns)
79
+ st.subheader("Hasil Deteksi Alergen:")
80
+ for alergi, status in results.items():
81
+ if status:
82
+ st.error(f"- {alergi.capitalize()}")
83
+ else:
84
+ st.success(f"- {alergi.capitalize()}: Aman")