#!/usr/bin/env python3 # -*- coding: utf-8 -*- import json import os import numpy as np os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" from datasets import load_dataset from transformers import AutoTokenizer, AutoModelForSequenceClassification MODEL_DIR = "/home/mshahidul/readctrl_model/full_model/distilbert_classifier" TEST_DATA_PATH = "verified_combined_0-80_test.json" MAX_LENGTH = 512 ACCURACY_OUTPUT_PATH = "accuracy_results_distilbert.json" LABELS = [ "low_health_literacy", "intermediate_health_literacy", "proficient_health_literacy", ] LABEL2ID = {label: idx for idx, label in enumerate(LABELS)} ID2LABEL = {idx: label for label, idx in LABEL2ID.items()} def build_input_text(fulltext: str, diff_label_texts: str) -> str: return ( "Classify the health literacy level of the rewritten text.\n\n" "Labels:\n" "- low_health_literacy: very simple, living-room language, minimal jargon.\n" "- intermediate_health_literacy: standard public-friendly language, limited jargon.\n" "- proficient_health_literacy: technical, clinical, or academic language.\n\n" f"Full Source Text:\n{fulltext}\n\n" f"Rewritten Text:\n{diff_label_texts}\n" ) def main(): tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True) model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR) model.eval() model.to("cuda") dataset = load_dataset("json", data_files=TEST_DATA_PATH, split="train") correct = 0 total = 0 for example in dataset: text = build_input_text(example["fulltext"], example["diff_label_texts"]) inputs = tokenizer( text, max_length=MAX_LENGTH, truncation=True, return_tensors="pt", ) inputs = {k: v.to("cuda") for k, v in inputs.items()} with np.errstate(all="ignore"): outputs = model(**inputs) pred_id = int(outputs.logits.argmax(dim=-1).item()) pred_label = ID2LABEL.get(pred_id, "") print(f"Predicted: {pred_label}, Expected: {example['label']}") if pred_label == example["label"]: correct += 1 total += 1 accuracy = (correct / total) if total else 0.0 results = { "accuracy": round(accuracy, 6), "correct": correct, "total": total, "model_dir": MODEL_DIR, "test_data_path": TEST_DATA_PATH, } with open(ACCURACY_OUTPUT_PATH, "w", encoding="utf-8") as handle: json.dump(results, handle, ensure_ascii=True) handle.write("\n") print(f"Accuracy: {accuracy:.4f} ({correct}/{total})") print(f"Saved accuracy info to {ACCURACY_OUTPUT_PATH}") if __name__ == "__main__": main()