| |
| |
|
|
| import os |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| from datasets import load_dataset |
| from unsloth import FastLanguageModel |
| from trl import SFTConfig, SFTTrainer |
|
|
| from unsloth.chat_templates import get_chat_template, train_on_responses_only |
|
|
| MODEL_NAME = "unsloth/Qwen3-8B" |
| DATA_PATH = "verified_combined_0-80.json" |
| TEST_DATA_PATH = "verified_combined_0-80_test.json" |
| MAX_SEQ_LENGTH = 4096 |
| FP16_SAVE_DIR = "/home/mshahidul/readctrl_model/full_model/classifier_model" |
| TEST_SPLIT_RATIO = 0.1 |
| SPLIT_SEED = 3407 |
|
|
| SYSTEM_PROMPT = ( |
| "You are an expert medical editor and Health Literacy specialist. " |
| "Classify the health literacy level of the provided text." |
| ) |
|
|
| USER_PROMPT = """Classify the health literacy level of the rewritten text. |
| |
| Labels: |
| - low_health_literacy: very simple, living-room language, minimal jargon. |
| - intermediate_health_literacy: standard public-friendly language, limited jargon. |
| - proficient_health_literacy: technical, clinical, or academic language. |
| |
| Input: |
| Full Source Text: |
| <<<FULLTEXT>>> |
| |
| Rewritten Text: |
| <<<DIFF_LABEL_TEXTS>>> |
| |
| Output: Return only one label string from the list above.""" |
|
|
|
|
| def build_messages(fulltext: str, diff_label_texts: str, label: str): |
| user_content = USER_PROMPT.replace("<<<FULLTEXT>>>", fulltext).replace( |
| "<<<DIFF_LABEL_TEXTS>>>", diff_label_texts |
| ) |
| return [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": user_content}, |
| {"role": "assistant", "content": label}, |
| ] |
|
|
|
|
| def main(): |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=MODEL_NAME, |
| max_seq_length=MAX_SEQ_LENGTH, |
| load_in_4bit=False, |
| load_in_8bit=False, |
| full_finetuning=False, |
| ) |
|
|
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=32, |
| target_modules=[ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "o_proj", |
| "gate_proj", |
| "up_proj", |
| "down_proj", |
| ], |
| lora_alpha=32, |
| lora_dropout=0, |
| bias="none", |
| use_gradient_checkpointing="unsloth", |
| random_state=3407, |
| use_rslora=False, |
| loftq_config=None, |
| ) |
|
|
| tokenizer = get_chat_template(tokenizer, chat_template="qwen3-instruct") |
| dataset = load_dataset("json", data_files=DATA_PATH, split="train") |
| split = dataset.train_test_split(test_size=TEST_SPLIT_RATIO, seed=SPLIT_SEED) |
| train_dataset = split["train"] |
| test_dataset = split["test"] |
| test_dataset.to_json(TEST_DATA_PATH) |
|
|
| def formatting_prompts_func(examples): |
| texts = [] |
| for fulltext, diff_label_texts, label in zip( |
| examples["fulltext"], |
| examples["diff_label_texts"], |
| examples["label"], |
| ): |
| messages = build_messages(fulltext, diff_label_texts, label) |
| text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=False |
| ) |
| texts.append(text) |
| return {"text": texts} |
|
|
| train_dataset = train_dataset.map(formatting_prompts_func, batched=True) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| processing_class=tokenizer, |
| train_dataset=train_dataset, |
| eval_dataset=None, |
| args=SFTConfig( |
| dataset_text_field="text", |
| per_device_train_batch_size=64, |
| gradient_accumulation_steps=16, |
| warmup_steps=5, |
| |
| num_train_epochs=1, |
| learning_rate=2e-4, |
| logging_steps=1, |
| optim="adamw_8bit", |
| weight_decay=0.001, |
| lr_scheduler_type="linear", |
| seed=3407, |
| report_to="none", |
| ), |
| ) |
|
|
| trainer = train_on_responses_only( |
| trainer, |
| instruction_part="<|im_start|>user\n", |
| response_part="<|im_start|>assistant\n", |
| ) |
|
|
| trainer.train() |
|
|
| os.makedirs(FP16_SAVE_DIR, exist_ok=True) |
| model.save_pretrained_merged( |
| FP16_SAVE_DIR, |
| tokenizer, |
| save_method="merged_16bit", |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |