CS403 / app.py
dinox16
Add Gradio app and model
84b0435
import csv
from pathlib import Path
import gradio as gr
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from pyvi import ViTokenizer # PhoBERT cần phân tách từ tiếng Việt
# 1. Đường dẫn tuyệt đối đến thư mục chứa các file model
BASE_DIR = Path(__file__).resolve().parent
MODEL_DIR = BASE_DIR / "model"
LABELS_CSV = BASE_DIR / "symptom2disease_vi.csv"
if not MODEL_DIR.exists():
raise FileNotFoundError(f"Model directory not found: {MODEL_DIR}")
# Prefer local tokenizer files if present; otherwise fallback to base PhoBERT tokenizer.
tokenizer_files = ["bpe.codes", "merges.txt", "tokenizer.json"]
has_tokenizer_files = any((MODEL_DIR / name).exists() for name in tokenizer_files)
if has_tokenizer_files:
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True)
else:
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True)
def load_label_list(csv_path: Path):
if not csv_path.exists():
return None
with csv_path.open("r", encoding="utf-8", newline="") as handle:
reader = csv.DictReader(handle)
labels = [row.get("label", "").strip() for row in reader if row.get("label")]
if not labels:
return None
# Mimic sklearn LabelEncoder: sorted unique labels.
return sorted(set(labels))
label_list = load_label_list(LABELS_CSV)
if label_list and len(label_list) == model.config.num_labels:
model.config.id2label = {idx: name for idx, name in enumerate(label_list)}
model.config.label2id = {name: idx for idx, name in enumerate(label_list)}
# Đưa model vào chế độ đánh giá (evaluation mode)
model.eval()
def predict_disease(text):
# Bước A: Tiền xử lý văn bản (Phân tách từ cho tiếng Việt)
text_segmented = ViTokenizer.tokenize(text)
# Bước B: Mã hóa văn bản đầu vào
inputs = tokenizer(
text_segmented,
return_tensors="pt",
truncation=True,
max_length=256,
padding='max_length'
)
# Bước C: Thực hiện dự đoán
with torch.no_grad():
outputs = model(**inputs)
# Bước D: Lấy kết quả có xác suất cao nhất (Logits -> Softmax -> Argmax)
logits = outputs.logits
prediction = torch.argmax(logits, dim=-1).item()
# Bước E: Ánh xạ ID sang nhãn tên bệnh (nếu bạn đã lưu id2label trong config)
id2label = model.config.id2label or {}
label = id2label.get(prediction) or id2label.get(str(prediction), f"LABEL_{prediction}")
# Tính xác suất (confidence score)
probs = torch.nn.functional.softmax(logits, dim=-1)
confidence = torch.max(probs).item()
return label, confidence
def predict_for_ui(text):
if not text or not text.strip():
return "", 0.0
label, confidence = predict_disease(text.strip())
return label, round(confidence, 4)
demo = gr.Interface(
fn=predict_for_ui,
inputs=gr.Textbox(lines=4, label="Symptoms (Vietnamese)"),
outputs=[
gr.Textbox(label="Predicted disease"),
gr.Number(label="Confidence"),
],
title="Symptom to Disease Predictor",
description="Enter symptoms in Vietnamese to get a predicted disease label.",
)
if __name__ == "__main__":
demo.launch()