File size: 2,792 Bytes
9c6961c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import os
import numpy as np
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
MODEL_DIR = "/home/mshahidul/readctrl_model/full_model/distilbert_classifier"
TEST_DATA_PATH = "verified_combined_0-80_test.json"
MAX_LENGTH = 512
ACCURACY_OUTPUT_PATH = "accuracy_results_distilbert.json"
LABELS = [
"low_health_literacy",
"intermediate_health_literacy",
"proficient_health_literacy",
]
LABEL2ID = {label: idx for idx, label in enumerate(LABELS)}
ID2LABEL = {idx: label for label, idx in LABEL2ID.items()}
def build_input_text(fulltext: str, diff_label_texts: str) -> str:
return (
"Classify the health literacy level of the rewritten text.\n\n"
"Labels:\n"
"- low_health_literacy: very simple, living-room language, minimal jargon.\n"
"- intermediate_health_literacy: standard public-friendly language, limited jargon.\n"
"- proficient_health_literacy: technical, clinical, or academic language.\n\n"
f"Full Source Text:\n{fulltext}\n\n"
f"Rewritten Text:\n{diff_label_texts}\n"
)
def main():
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR)
model.eval()
model.to("cuda")
dataset = load_dataset("json", data_files=TEST_DATA_PATH, split="train")
correct = 0
total = 0
for example in dataset:
text = build_input_text(example["fulltext"], example["diff_label_texts"])
inputs = tokenizer(
text,
max_length=MAX_LENGTH,
truncation=True,
return_tensors="pt",
)
inputs = {k: v.to("cuda") for k, v in inputs.items()}
with np.errstate(all="ignore"):
outputs = model(**inputs)
pred_id = int(outputs.logits.argmax(dim=-1).item())
pred_label = ID2LABEL.get(pred_id, "")
print(f"Predicted: {pred_label}, Expected: {example['label']}")
if pred_label == example["label"]:
correct += 1
total += 1
accuracy = (correct / total) if total else 0.0
results = {
"accuracy": round(accuracy, 6),
"correct": correct,
"total": total,
"model_dir": MODEL_DIR,
"test_data_path": TEST_DATA_PATH,
}
with open(ACCURACY_OUTPUT_PATH, "w", encoding="utf-8") as handle:
json.dump(results, handle, ensure_ascii=True)
handle.write("\n")
print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
print(f"Saved accuracy info to {ACCURACY_OUTPUT_PATH}")
if __name__ == "__main__":
main()
|