| import csv |
| from pathlib import Path |
|
|
| import gradio as gr |
| import torch |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| from pyvi import ViTokenizer |
|
|
| |
| BASE_DIR = Path(__file__).resolve().parent |
| MODEL_DIR = BASE_DIR / "model" |
| LABELS_CSV = BASE_DIR / "symptom2disease_vi.csv" |
|
|
| if not MODEL_DIR.exists(): |
| raise FileNotFoundError(f"Model directory not found: {MODEL_DIR}") |
|
|
| |
| tokenizer_files = ["bpe.codes", "merges.txt", "tokenizer.json"] |
| has_tokenizer_files = any((MODEL_DIR / name).exists() for name in tokenizer_files) |
|
|
| if has_tokenizer_files: |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, local_files_only=True) |
| else: |
| tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base") |
|
|
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR, local_files_only=True) |
|
|
| def load_label_list(csv_path: Path): |
| if not csv_path.exists(): |
| return None |
| with csv_path.open("r", encoding="utf-8", newline="") as handle: |
| reader = csv.DictReader(handle) |
| labels = [row.get("label", "").strip() for row in reader if row.get("label")] |
| if not labels: |
| return None |
| |
| return sorted(set(labels)) |
|
|
| label_list = load_label_list(LABELS_CSV) |
| if label_list and len(label_list) == model.config.num_labels: |
| model.config.id2label = {idx: name for idx, name in enumerate(label_list)} |
| model.config.label2id = {name: idx for idx, name in enumerate(label_list)} |
|
|
| |
| model.eval() |
|
|
| def predict_disease(text): |
| |
| text_segmented = ViTokenizer.tokenize(text) |
| |
| |
| inputs = tokenizer( |
| text_segmented, |
| return_tensors="pt", |
| truncation=True, |
| max_length=256, |
| padding='max_length' |
| ) |
| |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| |
| logits = outputs.logits |
| prediction = torch.argmax(logits, dim=-1).item() |
| |
| |
| id2label = model.config.id2label or {} |
| label = id2label.get(prediction) or id2label.get(str(prediction), f"LABEL_{prediction}") |
| |
| |
| probs = torch.nn.functional.softmax(logits, dim=-1) |
| confidence = torch.max(probs).item() |
| |
| return label, confidence |
|
|
| def predict_for_ui(text): |
| if not text or not text.strip(): |
| return "", 0.0 |
| label, confidence = predict_disease(text.strip()) |
| return label, round(confidence, 4) |
|
|
| demo = gr.Interface( |
| fn=predict_for_ui, |
| inputs=gr.Textbox(lines=4, label="Symptoms (Vietnamese)"), |
| outputs=[ |
| gr.Textbox(label="Predicted disease"), |
| gr.Number(label="Confidence"), |
| ], |
| title="Symptom to Disease Predictor", |
| description="Enter symptoms in Vietnamese to get a predicted disease label.", |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |