| | |
| | |
| |
|
| | import json |
| | import re |
| | import os |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| | from datasets import load_dataset |
| | from unsloth import FastLanguageModel |
| | from unsloth.chat_templates import get_chat_template |
| |
|
| | MODEL_DIR = "/home/mshahidul/readctrl_model/full_model/classifier_model" |
| | TEST_DATA_PATH = "verified_combined_0-80_test.json" |
| | MAX_SEQ_LENGTH = 4096 |
| | ACCURACY_OUTPUT_PATH = "accuracy_results.json" |
| |
|
| | SYSTEM_PROMPT = ( |
| | "You are an expert medical editor and Health Literacy specialist. " |
| | "Classify the health literacy level of the provided text." |
| | ) |
| |
|
| | USER_PROMPT = """Classify the health literacy level of the rewritten text. |
| | |
| | Labels: |
| | - low_health_literacy: very simple, living-room language, minimal jargon. |
| | - intermediate_health_literacy: standard public-friendly language, limited jargon. |
| | - proficient_health_literacy: technical, clinical, or academic language. |
| | |
| | Input: |
| | Full Source Text: |
| | <<<FULLTEXT>>> |
| | |
| | Rewritten Text: |
| | <<<DIFF_LABEL_TEXTS>>> |
| | |
| | Output: Return only one label string from the list above.""" |
| |
|
| | LABELS = { |
| | "low_health_literacy", |
| | "intermediate_health_literacy", |
| | "proficient_health_literacy", |
| | } |
| |
|
| |
|
| | def build_user_prompt(fulltext: str, diff_label_texts: str) -> str: |
| | return USER_PROMPT.replace("<<<FULLTEXT>>>", fulltext).replace( |
| | "<<<DIFF_LABEL_TEXTS>>>", diff_label_texts |
| | ) |
| |
|
| |
|
| | def extract_label(text: str) -> str: |
| | match = re.search( |
| | r"(low_health_literacy|intermediate_health_literacy|proficient_health_literacy)", |
| | text, |
| | ) |
| | return match.group(1) if match else "" |
| |
|
| |
|
| | def main(): |
| | model, tokenizer = FastLanguageModel.from_pretrained( |
| | model_name=MODEL_DIR, |
| | max_seq_length=MAX_SEQ_LENGTH, |
| | load_in_4bit=False, |
| | load_in_8bit=False, |
| | ) |
| | tokenizer = get_chat_template(tokenizer, chat_template="qwen3-instruct") |
| |
|
| | dataset = load_dataset("json", data_files=TEST_DATA_PATH, split="train") |
| |
|
| | correct = 0 |
| | total = 0 |
| |
|
| | for example in dataset: |
| | messages = [ |
| | {"role": "system", "content": SYSTEM_PROMPT}, |
| | { |
| | "role": "user", |
| | "content": build_user_prompt( |
| | example["fulltext"], example["diff_label_texts"] |
| | ), |
| | }, |
| | ] |
| | text = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | ) |
| | outputs = model.generate( |
| | **tokenizer(text, return_tensors="pt").to("cuda"), |
| | max_new_tokens=20, |
| | temperature=0.0, |
| | top_p=1.0, |
| | ) |
| | decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | pred = extract_label(decoded) |
| | print(f"Predicted: {pred}, Expected: {example['label']}") |
| | if pred == example["label"]: |
| | correct += 1 |
| | total += 1 |
| |
|
| | accuracy = (correct / total) if total else 0.0 |
| | results = { |
| | "accuracy": round(accuracy, 6), |
| | "correct": correct, |
| | "total": total, |
| | "model_dir": MODEL_DIR, |
| | "test_data_path": TEST_DATA_PATH, |
| | } |
| | with open(ACCURACY_OUTPUT_PATH, "w", encoding="utf-8") as handle: |
| | json.dump(results, handle, ensure_ascii=True) |
| | handle.write("\n") |
| | print(f"Accuracy: {accuracy:.4f} ({correct}/{total})") |
| | print(f"Saved accuracy info to {ACCURACY_OUTPUT_PATH}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|