File size: 5,069 Bytes
91ac57d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import streamlit as st
import torch
import numpy as np
from transformers import BertTokenizer, BertForTokenClassification
import json
import os # <-- Pastikan 'os' di-import
import pandas as pd
# --- KONFIGURASI ---
MODEL_DIR = "./fine_tuned_bert_ner"
# --- FUNGSI UNTUK MEMUAT MODEL ---
@st.cache_resource
def load_model_and_tokenizer(model_dir_relative): # <-- ganti nama argumen
"""
Memuat model, tokenizer, dan daftar tag dari direktori yang disimpan.
"""
try:
# --- PERBAIKAN UNTUK STREAMLIT CLOUD ---
# Ubah path relatif (misal: "./fine_tuned_bert_ner")
# menjadi path absolut (misal: "/mount/src/.../fine_tuned_bert_ner")
# Ini mencegah transformers salah mengira path lokal sebagai ID repo Hugging Face
model_dir_absolute = os.path.abspath(model_dir_relative)
# Muat model & tokenizer dari path absolut
model = BertForTokenClassification.from_pretrained(model_dir_absolute)
tokenizer = BertTokenizer.from_pretrained(model_dir_absolute)
# --- AKHIR PERBAIKAN ---
if not hasattr(model.config, 'id2label'):
st.error("Error: 'id2label' tidak ditemukan di dalam config.json model.")
return None, None, None, None
tag_values = [model.config.id2label[i] for i in range(len(model.config.id2label))]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
return model, tokenizer, tag_values, device
except Exception as e:
st.error(f"Error saat memuat model: {e}")
st.error(f"Pastikan folder '{model_dir_relative}' ada di direktori yang sama dengan app.py")
return None, None, None, None
# --- FUNGSI UNTUK PREDIKSI ---
def predict(text, model, tokenizer, tag_values, device):
"""
Melakukan prediksi NER pada teks input.
"""
tokenized_sentence = tokenizer.encode(text, truncation=True, max_length=512)
input_ids = torch.tensor([tokenized_sentence]).to(device)
with torch.no_grad():
output = model(input_ids)
label_indices = np.argmax(output[0].to('cpu').numpy(), axis=2)
tokens = tokenizer.convert_ids_to_tokens(input_ids.to('cpu').numpy()[0])
new_tokens, new_labels = [], []
for token, label_idx in zip(tokens, label_indices[0]):
if token in ['[CLS]', '[SEP]']:
continue
if token.startswith("##"):
if new_tokens:
new_tokens[-1] = new_tokens[-1] + token[2:]
else:
new_labels.append(tag_values[label_idx])
new_tokens.append(token)
return list(zip(new_tokens, new_labels))
# --- FUNGSI UTAMA APLIKASI ---
def main():
st.set_page_config(
page_title="Aplikasi NER Medis",
page_icon="🧪",
layout="wide"
)
st.title("🧪 Aplikasi Named Entity Recognition (NER) dengan BERT")
st.markdown("Aplikasi ini menggunakan model BERT yang di-fine-tune untuk mengenali entitas dari teks medis.")
with st.spinner("Memuat model..."):
# Panggil fungsi dengan MODEL_DIR global
model, tokenizer, tag_values, device = load_model_and_tokenizer(MODEL_DIR)
if model and tokenizer and tag_values and device:
st.success("Model berhasil dimuat!")
st.header("Analisis Teks Anda")
default_text = (
"Pasteurellosis in japanese quail (Coturnix coturnix japonica) caused by Pasteurella multocida multocida A:4. \n\n"
"Evaluation of transdermal penetration enhancers using a novel skin alternative. \n\n"
"A novel alternative to animal skin models was developed in order to aid in the screening of transdermal penetration enhancer."
)
user_input = st.text_area("Masukkan teks untuk dianalisis di sini:", default_text, height=150)
if st.button("🚀 Analisis Teks", type="primary"):
if user_input:
with st.spinner("Menganalisis teks..."):
results = predict(user_input, model, tokenizer, tag_values, device)
st.subheader("Hasil Analisis (Tabel Data)")
df = pd.DataFrame(results, columns=["Token", "Tag"])
st.dataframe(df, use_container_width=True)
with st.expander("Lihat Entitas yang Ditemukan Saja"):
entities_only = df[df["Tag"] != 'O']
if not entities_only.empty:
st.dataframe(entities_only, use_container_width=True)
else:
st.info("Tidak ada entitas yang ditemukan.")
else:
st.warning("Silakan masukkan teks terlebih dahulu.")
# Menjalankan aplikasi
if __name__ == "__main__":
main()
|