| | |
| | |
| |
|
| | import os |
| | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "2" |
| | from datasets import load_dataset |
| | from unsloth import FastLanguageModel |
| | from trl import SFTConfig, SFTTrainer |
| |
|
| | from unsloth.chat_templates import get_chat_template, train_on_responses_only |
| |
|
| | MODEL_NAME = "unsloth/Qwen3-8B" |
| | DATA_PATH = "verified_combined_0-80.json" |
| | TEST_DATA_PATH = "verified_combined_0-80_test.json" |
| | MAX_SEQ_LENGTH = 4096 |
| | FP16_SAVE_DIR = "/home/mshahidul/readctrl_model/full_model/classifier_model" |
| | TEST_SPLIT_RATIO = 0.1 |
| | SPLIT_SEED = 3407 |
| |
|
| | SYSTEM_PROMPT = ( |
| | "You are an expert medical editor and Health Literacy specialist. " |
| | "Classify the health literacy level of the provided text." |
| | ) |
| |
|
| | USER_PROMPT = """Classify the health literacy level of the rewritten text. |
| | |
| | Labels: |
| | - low_health_literacy: very simple, living-room language, minimal jargon. |
| | - intermediate_health_literacy: standard public-friendly language, limited jargon. |
| | - proficient_health_literacy: technical, clinical, or academic language. |
| | |
| | Input: |
| | Full Source Text: |
| | <<<FULLTEXT>>> |
| | |
| | Rewritten Text: |
| | <<<DIFF_LABEL_TEXTS>>> |
| | |
| | Output: Return only one label string from the list above.""" |
| |
|
| |
|
| | def build_messages(fulltext: str, diff_label_texts: str, label: str): |
| | user_content = USER_PROMPT.replace("<<<FULLTEXT>>>", fulltext).replace( |
| | "<<<DIFF_LABEL_TEXTS>>>", diff_label_texts |
| | ) |
| | return [ |
| | {"role": "system", "content": SYSTEM_PROMPT}, |
| | {"role": "user", "content": user_content}, |
| | {"role": "assistant", "content": label}, |
| | ] |
| |
|
| |
|
| | def main(): |
| | model, tokenizer = FastLanguageModel.from_pretrained( |
| | model_name=MODEL_NAME, |
| | max_seq_length=MAX_SEQ_LENGTH, |
| | load_in_4bit=False, |
| | load_in_8bit=False, |
| | full_finetuning=False, |
| | ) |
| |
|
| | model = FastLanguageModel.get_peft_model( |
| | model, |
| | r=32, |
| | target_modules=[ |
| | "q_proj", |
| | "k_proj", |
| | "v_proj", |
| | "o_proj", |
| | "gate_proj", |
| | "up_proj", |
| | "down_proj", |
| | ], |
| | lora_alpha=32, |
| | lora_dropout=0, |
| | bias="none", |
| | use_gradient_checkpointing="unsloth", |
| | random_state=3407, |
| | use_rslora=False, |
| | loftq_config=None, |
| | ) |
| |
|
| | tokenizer = get_chat_template(tokenizer, chat_template="qwen3-instruct") |
| | dataset = load_dataset("json", data_files=DATA_PATH, split="train") |
| | split = dataset.train_test_split(test_size=TEST_SPLIT_RATIO, seed=SPLIT_SEED) |
| | train_dataset = split["train"] |
| | test_dataset = split["test"] |
| | test_dataset.to_json(TEST_DATA_PATH) |
| |
|
| | def formatting_prompts_func(examples): |
| | texts = [] |
| | for fulltext, diff_label_texts, label in zip( |
| | examples["fulltext"], |
| | examples["diff_label_texts"], |
| | examples["label"], |
| | ): |
| | messages = build_messages(fulltext, diff_label_texts, label) |
| | text = tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=False |
| | ) |
| | texts.append(text) |
| | return {"text": texts} |
| |
|
| | train_dataset = train_dataset.map(formatting_prompts_func, batched=True) |
| |
|
| | trainer = SFTTrainer( |
| | model=model, |
| | processing_class=tokenizer, |
| | train_dataset=train_dataset, |
| | eval_dataset=None, |
| | args=SFTConfig( |
| | dataset_text_field="text", |
| | per_device_train_batch_size=64, |
| | gradient_accumulation_steps=16, |
| | warmup_steps=5, |
| | |
| | num_train_epochs=1, |
| | learning_rate=2e-4, |
| | logging_steps=1, |
| | optim="adamw_8bit", |
| | weight_decay=0.001, |
| | lr_scheduler_type="linear", |
| | seed=3407, |
| | report_to="none", |
| | ), |
| | ) |
| |
|
| | trainer = train_on_responses_only( |
| | trainer, |
| | instruction_part="<|im_start|>user\n", |
| | response_part="<|im_start|>assistant\n", |
| | ) |
| |
|
| | trainer.train() |
| |
|
| | os.makedirs(FP16_SAVE_DIR, exist_ok=True) |
| | model.save_pretrained_merged( |
| | FP16_SAVE_DIR, |
| | tokenizer, |
| | save_method="merged_16bit", |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |