| | |
| | |
| |
|
| | import json |
| | import os |
| | import numpy as np |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| |
|
| | from datasets import load_dataset |
| | from transformers import AutoTokenizer, AutoModelForSequenceClassification |
| |
|
| | MODEL_DIR = "/home/mshahidul/readctrl_model/full_model/distilbert_classifier" |
| | TEST_DATA_PATH = "verified_combined_0-80_test.json" |
| | MAX_LENGTH = 512 |
| | ACCURACY_OUTPUT_PATH = "accuracy_results_distilbert.json" |
| |
|
| | LABELS = [ |
| | "low_health_literacy", |
| | "intermediate_health_literacy", |
| | "proficient_health_literacy", |
| | ] |
| | LABEL2ID = {label: idx for idx, label in enumerate(LABELS)} |
| | ID2LABEL = {idx: label for label, idx in LABEL2ID.items()} |
| |
|
| |
|
| | def build_input_text(fulltext: str, diff_label_texts: str) -> str: |
| | return ( |
| | "Classify the health literacy level of the rewritten text.\n\n" |
| | "Labels:\n" |
| | "- low_health_literacy: very simple, living-room language, minimal jargon.\n" |
| | "- intermediate_health_literacy: standard public-friendly language, limited jargon.\n" |
| | "- proficient_health_literacy: technical, clinical, or academic language.\n\n" |
| | f"Full Source Text:\n{fulltext}\n\n" |
| | f"Rewritten Text:\n{diff_label_texts}\n" |
| | ) |
| |
|
| |
|
| | def main(): |
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) |
| | model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) |
| | model.eval() |
| | model.to("cuda") |
| |
|
| | dataset = load_dataset("json", data_files=TEST_DATA_PATH, split="train") |
| |
|
| | correct = 0 |
| | total = 0 |
| |
|
| | for example in dataset: |
| | text = build_input_text(example["fulltext"], example["diff_label_texts"]) |
| | inputs = tokenizer( |
| | text, |
| | max_length=MAX_LENGTH, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | inputs = {k: v.to("cuda") for k, v in inputs.items()} |
| | with np.errstate(all="ignore"): |
| | outputs = model(**inputs) |
| | pred_id = int(outputs.logits.argmax(dim=-1).item()) |
| | pred_label = ID2LABEL.get(pred_id, "") |
| | print(f"Predicted: {pred_label}, Expected: {example['label']}") |
| | if pred_label == example["label"]: |
| | correct += 1 |
| | total += 1 |
| |
|
| | accuracy = (correct / total) if total else 0.0 |
| | results = { |
| | "accuracy": round(accuracy, 6), |
| | "correct": correct, |
| | "total": total, |
| | "model_dir": MODEL_DIR, |
| | "test_data_path": TEST_DATA_PATH, |
| | } |
| | with open(ACCURACY_OUTPUT_PATH, "w", encoding="utf-8") as handle: |
| | json.dump(results, handle, ensure_ascii=True) |
| | handle.write("\n") |
| | print(f"Accuracy: {accuracy:.4f} ({correct}/{total})") |
| | print(f"Saved accuracy info to {ACCURACY_OUTPUT_PATH}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|