| |
| |
|
|
| import json |
| import os |
| import numpy as np |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
|
|
| from datasets import load_dataset |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
| MODEL_DIR = "/home/mshahidul/readctrl_model/full_model/distilbert_classifier" |
| TEST_DATA_PATH = "verified_combined_0-80_test.json" |
| MAX_LENGTH = 512 |
| ACCURACY_OUTPUT_PATH = "accuracy_results_distilbert.json" |
|
|
| LABELS = [ |
| "low_health_literacy", |
| "intermediate_health_literacy", |
| "proficient_health_literacy", |
| ] |
| LABEL2ID = {label: idx for idx, label in enumerate(LABELS)} |
| ID2LABEL = {idx: label for label, idx in LABEL2ID.items()} |
|
|
|
|
| def build_input_text(fulltext: str, diff_label_texts: str) -> str: |
| return ( |
| "Classify the health literacy level of the rewritten text.\n\n" |
| "Labels:\n" |
| "- low_health_literacy: very simple, living-room language, minimal jargon.\n" |
| "- intermediate_health_literacy: standard public-friendly language, limited jargon.\n" |
| "- proficient_health_literacy: technical, clinical, or academic language.\n\n" |
| f"Full Source Text:\n{fulltext}\n\n" |
| f"Rewritten Text:\n{diff_label_texts}\n" |
| ) |
|
|
|
|
| def main(): |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) |
| model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) |
| model.eval() |
| model.to("cuda") |
|
|
| dataset = load_dataset("json", data_files=TEST_DATA_PATH, split="train") |
|
|
| correct = 0 |
| total = 0 |
|
|
| for example in dataset: |
| text = build_input_text(example["fulltext"], example["diff_label_texts"]) |
| inputs = tokenizer( |
| text, |
| max_length=MAX_LENGTH, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| inputs = {k: v.to("cuda") for k, v in inputs.items()} |
| with np.errstate(all="ignore"): |
| outputs = model(**inputs) |
| pred_id = int(outputs.logits.argmax(dim=-1).item()) |
| pred_label = ID2LABEL.get(pred_id, "") |
| print(f"Predicted: {pred_label}, Expected: {example['label']}") |
| if pred_label == example["label"]: |
| correct += 1 |
| total += 1 |
|
|
| accuracy = (correct / total) if total else 0.0 |
| results = { |
| "accuracy": round(accuracy, 6), |
| "correct": correct, |
| "total": total, |
| "model_dir": MODEL_DIR, |
| "test_data_path": TEST_DATA_PATH, |
| } |
| with open(ACCURACY_OUTPUT_PATH, "w", encoding="utf-8") as handle: |
| json.dump(results, handle, ensure_ascii=True) |
| handle.write("\n") |
| print(f"Accuracy: {accuracy:.4f} ({correct}/{total})") |
| print(f"Saved accuracy info to {ACCURACY_OUTPUT_PATH}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|