File size: 4,231 Bytes
1db7196 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | #!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from datasets import load_dataset
from unsloth import FastLanguageModel
from trl import SFTConfig, SFTTrainer
from unsloth.chat_templates import get_chat_template, train_on_responses_only
MODEL_NAME = "unsloth/Qwen3-8B"
DATA_PATH = "verified_combined_0-80.json"
TEST_DATA_PATH = "verified_combined_0-80_test.json"
MAX_SEQ_LENGTH = 4096
FP16_SAVE_DIR = "/home/mshahidul/readctrl_model/full_model/classifier_model"
TEST_SPLIT_RATIO = 0.1
SPLIT_SEED = 3407
SYSTEM_PROMPT = (
"You are an expert medical editor and Health Literacy specialist. "
"Classify the health literacy level of the provided text."
)
USER_PROMPT = """Classify the health literacy level of the rewritten text.
Labels:
- low_health_literacy: very simple, living-room language, minimal jargon.
- intermediate_health_literacy: standard public-friendly language, limited jargon.
- proficient_health_literacy: technical, clinical, or academic language.
Input:
Full Source Text:
<<<FULLTEXT>>>
Rewritten Text:
<<<DIFF_LABEL_TEXTS>>>
Output: Return only one label string from the list above."""
def build_messages(fulltext: str, diff_label_texts: str, label: str):
user_content = USER_PROMPT.replace("<<<FULLTEXT>>>", fulltext).replace(
"<<<DIFF_LABEL_TEXTS>>>", diff_label_texts
)
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
{"role": "assistant", "content": label},
]
def main():
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
model = FastLanguageModel.get_peft_model(
model,
r=32,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=32,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)
tokenizer = get_chat_template(tokenizer, chat_template="qwen3-instruct")
dataset = load_dataset("json", data_files=DATA_PATH, split="train")
split = dataset.train_test_split(test_size=TEST_SPLIT_RATIO, seed=SPLIT_SEED)
train_dataset = split["train"]
test_dataset = split["test"]
test_dataset.to_json(TEST_DATA_PATH)
def formatting_prompts_func(examples):
texts = []
for fulltext, diff_label_texts, label in zip(
examples["fulltext"],
examples["diff_label_texts"],
examples["label"],
):
messages = build_messages(fulltext, diff_label_texts, label)
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
return {"text": texts}
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=None,
args=SFTConfig(
dataset_text_field="text",
per_device_train_batch_size=64,
gradient_accumulation_steps=16,
warmup_steps=5,
# max_steps=60,
num_train_epochs=1,
learning_rate=2e-4,
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.001,
lr_scheduler_type="linear",
seed=3407,
report_to="none",
),
)
trainer = train_on_responses_only(
trainer,
instruction_part="<|im_start|>user\n",
response_part="<|im_start|>assistant\n",
)
trainer.train()
os.makedirs(FP16_SAVE_DIR, exist_ok=True)
model.save_pretrained_merged(
FP16_SAVE_DIR,
tokenizer,
save_method="merged_16bit",
)
if __name__ == "__main__":
main() |