File size: 4,231 Bytes
1db7196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from datasets import load_dataset
from unsloth import FastLanguageModel
from trl import SFTConfig, SFTTrainer

from unsloth.chat_templates import get_chat_template, train_on_responses_only

MODEL_NAME = "unsloth/Qwen3-8B"
DATA_PATH = "verified_combined_0-80.json"
TEST_DATA_PATH = "verified_combined_0-80_test.json"
MAX_SEQ_LENGTH = 4096
FP16_SAVE_DIR = "/home/mshahidul/readctrl_model/full_model/classifier_model"
TEST_SPLIT_RATIO = 0.1
SPLIT_SEED = 3407

SYSTEM_PROMPT = (
    "You are an expert medical editor and Health Literacy specialist. "
    "Classify the health literacy level of the provided text."
)

USER_PROMPT = """Classify the health literacy level of the rewritten text.

Labels:
- low_health_literacy: very simple, living-room language, minimal jargon.
- intermediate_health_literacy: standard public-friendly language, limited jargon.
- proficient_health_literacy: technical, clinical, or academic language.

Input:
Full Source Text:
<<<FULLTEXT>>>

Rewritten Text:
<<<DIFF_LABEL_TEXTS>>>

Output: Return only one label string from the list above."""


def build_messages(fulltext: str, diff_label_texts: str, label: str):
    user_content = USER_PROMPT.replace("<<<FULLTEXT>>>", fulltext).replace(
        "<<<DIFF_LABEL_TEXTS>>>", diff_label_texts
    )
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": label},
    ]


def main():
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=MODEL_NAME,
        max_seq_length=MAX_SEQ_LENGTH,
        load_in_4bit=False,
        load_in_8bit=False,
        full_finetuning=False,
    )

    model = FastLanguageModel.get_peft_model(
        model,
        r=32,
        target_modules=[
            "q_proj",
            "k_proj",
            "v_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj",
        ],
        lora_alpha=32,
        lora_dropout=0,
        bias="none",
        use_gradient_checkpointing="unsloth",
        random_state=3407,
        use_rslora=False,
        loftq_config=None,
    )

    tokenizer = get_chat_template(tokenizer, chat_template="qwen3-instruct")
    dataset = load_dataset("json", data_files=DATA_PATH, split="train")
    split = dataset.train_test_split(test_size=TEST_SPLIT_RATIO, seed=SPLIT_SEED)
    train_dataset = split["train"]
    test_dataset = split["test"]
    test_dataset.to_json(TEST_DATA_PATH)

    def formatting_prompts_func(examples):
        texts = []
        for fulltext, diff_label_texts, label in zip(
            examples["fulltext"],
            examples["diff_label_texts"],
            examples["label"],
        ):
            messages = build_messages(fulltext, diff_label_texts, label)
            text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            texts.append(text)
        return {"text": texts}

    train_dataset = train_dataset.map(formatting_prompts_func, batched=True)

    trainer = SFTTrainer(
        model=model,
        processing_class=tokenizer,
        train_dataset=train_dataset,
        eval_dataset=None,
        args=SFTConfig(
            dataset_text_field="text",
            per_device_train_batch_size=64,
            gradient_accumulation_steps=16,
            warmup_steps=5,
            # max_steps=60,
            num_train_epochs=1,
            learning_rate=2e-4,
            logging_steps=1,
            optim="adamw_8bit",
            weight_decay=0.001,
            lr_scheduler_type="linear",
            seed=3407,
            report_to="none",
        ),
    )

    trainer = train_on_responses_only(
        trainer,
        instruction_part="<|im_start|>user\n",
        response_part="<|im_start|>assistant\n",
    )

    trainer.train()

    os.makedirs(FP16_SAVE_DIR, exist_ok=True)
    model.save_pretrained_merged(
        FP16_SAVE_DIR,
        tokenizer,
        save_method="merged_16bit",
    )


if __name__ == "__main__":
    main()