File size: 3,469 Bytes
9c6961c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import json
import re
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from datasets import load_dataset
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template

MODEL_DIR = "/home/mshahidul/readctrl_model/full_model/classifier_model"
TEST_DATA_PATH = "verified_combined_0-80_test.json"
MAX_SEQ_LENGTH = 4096
ACCURACY_OUTPUT_PATH = "accuracy_results.json"

SYSTEM_PROMPT = (
    "You are an expert medical editor and Health Literacy specialist. "
    "Classify the health literacy level of the provided text."
)

USER_PROMPT = """Classify the health literacy level of the rewritten text.

Labels:
- low_health_literacy: very simple, living-room language, minimal jargon.
- intermediate_health_literacy: standard public-friendly language, limited jargon.
- proficient_health_literacy: technical, clinical, or academic language.

Input:
Full Source Text:
<<<FULLTEXT>>>

Rewritten Text:
<<<DIFF_LABEL_TEXTS>>>

Output: Return only one label string from the list above."""

LABELS = {
    "low_health_literacy",
    "intermediate_health_literacy",
    "proficient_health_literacy",
}


def build_user_prompt(fulltext: str, diff_label_texts: str) -> str:
    return USER_PROMPT.replace("<<<FULLTEXT>>>", fulltext).replace(
        "<<<DIFF_LABEL_TEXTS>>>", diff_label_texts
    )


def extract_label(text: str) -> str:
    match = re.search(
        r"(low_health_literacy|intermediate_health_literacy|proficient_health_literacy)",
        text,
    )
    return match.group(1) if match else ""


def main():
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_DIR,
        max_seq_length=MAX_SEQ_LENGTH,
        load_in_4bit=False,
        load_in_8bit=False,
    )
    tokenizer = get_chat_template(tokenizer, chat_template="qwen3-instruct")

    dataset = load_dataset("json", data_files=TEST_DATA_PATH, split="train")

    correct = 0
    total = 0

    for example in dataset:
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {
                "role": "user",
                "content": build_user_prompt(
                    example["fulltext"], example["diff_label_texts"]
                ),
            },
        ]
        text = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True,
        )
        outputs = model.generate(
            **tokenizer(text, return_tensors="pt").to("cuda"),
            max_new_tokens=20,
            temperature=0.0,
            top_p=1.0,
        )
        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
        pred = extract_label(decoded)
        print(f"Predicted: {pred}, Expected: {example['label']}")
        if pred == example["label"]:
            correct += 1
        total += 1

    accuracy = (correct / total) if total else 0.0
    results = {
        "accuracy": round(accuracy, 6),
        "correct": correct,
        "total": total,
        "model_dir": MODEL_DIR,
        "test_data_path": TEST_DATA_PATH,
    }
    with open(ACCURACY_OUTPUT_PATH, "w", encoding="utf-8") as handle:
        json.dump(results, handle, ensure_ascii=True)
        handle.write("\n")
    print(f"Accuracy: {accuracy:.4f} ({correct}/{total})")
    print(f"Saved accuracy info to {ACCURACY_OUTPUT_PATH}")


if __name__ == "__main__":
    main()