readctrl / code /text_classifier /qwen3_(4b)_instruct.py
shahidul034's picture
Add files using upload-large-folder tool
1db7196 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from datasets import load_dataset
from unsloth import FastLanguageModel
from trl import SFTConfig, SFTTrainer
from unsloth.chat_templates import get_chat_template, train_on_responses_only
MODEL_NAME = "unsloth/Qwen3-8B"
DATA_PATH = "verified_combined_0-80.json"
TEST_DATA_PATH = "verified_combined_0-80_test.json"
MAX_SEQ_LENGTH = 4096
FP16_SAVE_DIR = "/home/mshahidul/readctrl_model/full_model/classifier_model"
TEST_SPLIT_RATIO = 0.1
SPLIT_SEED = 3407
SYSTEM_PROMPT = (
"You are an expert medical editor and Health Literacy specialist. "
"Classify the health literacy level of the provided text."
)
USER_PROMPT = """Classify the health literacy level of the rewritten text.
Labels:
- low_health_literacy: very simple, living-room language, minimal jargon.
- intermediate_health_literacy: standard public-friendly language, limited jargon.
- proficient_health_literacy: technical, clinical, or academic language.
Input:
Full Source Text:
<<<FULLTEXT>>>
Rewritten Text:
<<<DIFF_LABEL_TEXTS>>>
Output: Return only one label string from the list above."""
def build_messages(fulltext: str, diff_label_texts: str, label: str):
user_content = USER_PROMPT.replace("<<<FULLTEXT>>>", fulltext).replace(
"<<<DIFF_LABEL_TEXTS>>>", diff_label_texts
)
return [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_content},
{"role": "assistant", "content": label},
]
def main():
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=MODEL_NAME,
max_seq_length=MAX_SEQ_LENGTH,
load_in_4bit=False,
load_in_8bit=False,
full_finetuning=False,
)
model = FastLanguageModel.get_peft_model(
model,
r=32,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=32,
lora_dropout=0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=3407,
use_rslora=False,
loftq_config=None,
)
tokenizer = get_chat_template(tokenizer, chat_template="qwen3-instruct")
dataset = load_dataset("json", data_files=DATA_PATH, split="train")
split = dataset.train_test_split(test_size=TEST_SPLIT_RATIO, seed=SPLIT_SEED)
train_dataset = split["train"]
test_dataset = split["test"]
test_dataset.to_json(TEST_DATA_PATH)
def formatting_prompts_func(examples):
texts = []
for fulltext, diff_label_texts, label in zip(
examples["fulltext"],
examples["diff_label_texts"],
examples["label"],
):
messages = build_messages(fulltext, diff_label_texts, label)
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
return {"text": texts}
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
trainer = SFTTrainer(
model=model,
processing_class=tokenizer,
train_dataset=train_dataset,
eval_dataset=None,
args=SFTConfig(
dataset_text_field="text",
per_device_train_batch_size=64,
gradient_accumulation_steps=16,
warmup_steps=5,
# max_steps=60,
num_train_epochs=1,
learning_rate=2e-4,
logging_steps=1,
optim="adamw_8bit",
weight_decay=0.001,
lr_scheduler_type="linear",
seed=3407,
report_to="none",
),
)
trainer = train_on_responses_only(
trainer,
instruction_part="<|im_start|>user\n",
response_part="<|im_start|>assistant\n",
)
trainer.train()
os.makedirs(FP16_SAVE_DIR, exist_ok=True)
model.save_pretrained_merged(
FP16_SAVE_DIR,
tokenizer,
save_method="merged_16bit",
)
if __name__ == "__main__":
main()