#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_VISIBLE_DEVICES"] = "2" from datasets import load_dataset from unsloth import FastLanguageModel from trl import SFTConfig, SFTTrainer from unsloth.chat_templates import get_chat_template, train_on_responses_only MODEL_NAME = "unsloth/Qwen3-8B" DATA_PATH = "verified_combined_0-80.json" TEST_DATA_PATH = "verified_combined_0-80_test.json" MAX_SEQ_LENGTH = 4096 FP16_SAVE_DIR = "/home/mshahidul/readctrl_model/full_model/classifier_model" TEST_SPLIT_RATIO = 0.1 SPLIT_SEED = 3407 SYSTEM_PROMPT = ( "You are an expert medical editor and Health Literacy specialist. " "Classify the health literacy level of the provided text." ) USER_PROMPT = """Classify the health literacy level of the rewritten text. Labels: - low_health_literacy: very simple, living-room language, minimal jargon. - intermediate_health_literacy: standard public-friendly language, limited jargon. - proficient_health_literacy: technical, clinical, or academic language. Input: Full Source Text: <<>> Rewritten Text: <<>> Output: Return only one label string from the list above.""" def build_messages(fulltext: str, diff_label_texts: str, label: str): user_content = USER_PROMPT.replace("<<>>", fulltext).replace( "<<>>", diff_label_texts ) return [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}, {"role": "assistant", "content": label}, ] def main(): model, tokenizer = FastLanguageModel.from_pretrained( model_name=MODEL_NAME, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=False, load_in_8bit=False, full_finetuning=False, ) model = FastLanguageModel.get_peft_model( model, r=32, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=32, lora_dropout=0, bias="none", use_gradient_checkpointing="unsloth", random_state=3407, use_rslora=False, loftq_config=None, ) tokenizer = get_chat_template(tokenizer, chat_template="qwen3-instruct") dataset = load_dataset("json", data_files=DATA_PATH, split="train") split = dataset.train_test_split(test_size=TEST_SPLIT_RATIO, seed=SPLIT_SEED) train_dataset = split["train"] test_dataset = split["test"] test_dataset.to_json(TEST_DATA_PATH) def formatting_prompts_func(examples): texts = [] for fulltext, diff_label_texts, label in zip( examples["fulltext"], examples["diff_label_texts"], examples["label"], ): messages = build_messages(fulltext, diff_label_texts, label) text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) texts.append(text) return {"text": texts} train_dataset = train_dataset.map(formatting_prompts_func, batched=True) trainer = SFTTrainer( model=model, processing_class=tokenizer, train_dataset=train_dataset, eval_dataset=None, args=SFTConfig( dataset_text_field="text", per_device_train_batch_size=64, gradient_accumulation_steps=16, warmup_steps=5, # max_steps=60, num_train_epochs=1, learning_rate=2e-4, logging_steps=1, optim="adamw_8bit", weight_decay=0.001, lr_scheduler_type="linear", seed=3407, report_to="none", ), ) trainer = train_on_responses_only( trainer, instruction_part="<|im_start|>user\n", response_part="<|im_start|>assistant\n", ) trainer.train() os.makedirs(FP16_SAVE_DIR, exist_ok=True) model.save_pretrained_merged( FP16_SAVE_DIR, tokenizer, save_method="merged_16bit", ) if __name__ == "__main__": main()