import csv from pathlib import Path import gradio as gr import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from pyvi import ViTokenizer # PhoBERT cần phân tách từ tiếng Việt # 1. Đường dẫn tuyệt đối đến thư mục chứa các file model BASE_DIR = Path(__file__).resolve().parent MODEL_DIR = BASE_DIR / "model" LABELS_CSV = BASE_DIR / "symptom2disease_vi.csv" if not MODEL_DIR.exists(): raise FileNotFoundError(f"Model directory not found: {MODEL_DIR}") # Prefer local tokenizer files if present; otherwise fallback to base PhoBERT tokenizer. tokenizer_files = ["bpe.codes", "merges.txt", "tokenizer.json"] has_tokenizer_files = any((MODEL_DIR / name).exists() for name in tokenizer_files) if has_tokenizer_files: tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True) else: tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base") model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True) def load_label_list(csv_path: Path): if not csv_path.exists(): return None with csv_path.open("r", encoding="utf-8", newline="") as handle: reader = csv.DictReader(handle) labels = [row.get("label", "").strip() for row in reader if row.get("label")] if not labels: return None # Mimic sklearn LabelEncoder: sorted unique labels. return sorted(set(labels)) label_list = load_label_list(LABELS_CSV) if label_list and len(label_list) == model.config.num_labels: model.config.id2label = {idx: name for idx, name in enumerate(label_list)} model.config.label2id = {name: idx for idx, name in enumerate(label_list)} # Đưa model vào chế độ đánh giá (evaluation mode) model.eval() def predict_disease(text): # Bước A: Tiền xử lý văn bản (Phân tách từ cho tiếng Việt) text_segmented = ViTokenizer.tokenize(text) # Bước B: Mã hóa văn bản đầu vào inputs = tokenizer( text_segmented, return_tensors="pt", truncation=True, max_length=256, padding='max_length' ) # Bước C: Thực hiện dự đoán with torch.no_grad(): outputs = model(**inputs) # Bước D: Lấy kết quả có xác suất cao nhất (Logits -> Softmax -> Argmax) logits = outputs.logits prediction = torch.argmax(logits, dim=-1).item() # Bước E: Ánh xạ ID sang nhãn tên bệnh (nếu bạn đã lưu id2label trong config) id2label = model.config.id2label or {} label = id2label.get(prediction) or id2label.get(str(prediction), f"LABEL_{prediction}") # Tính xác suất (confidence score) probs = torch.nn.functional.softmax(logits, dim=-1) confidence = torch.max(probs).item() return label, confidence def predict_for_ui(text): if not text or not text.strip(): return "", 0.0 label, confidence = predict_disease(text.strip()) return label, round(confidence, 4) demo = gr.Interface( fn=predict_for_ui, inputs=gr.Textbox(lines=4, label="Symptoms (Vietnamese)"), outputs=[ gr.Textbox(label="Predicted disease"), gr.Number(label="Confidence"), ], title="Symptom to Disease Predictor", description="Enter symptoms in Vietnamese to get a predicted disease label.", ) if __name__ == "__main__": demo.launch()