moos124's picture
Upload train.py
fb159f8 verified
import torch
from datasets import load_dataset
from trl import SFTConfig, SFTTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from peft import LoraConfig
import trackio
# Configuration
MODEL_ID = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
DATASET_ID = "iamtarun/code_instructions_120k_alpaca"
OUTPUT_DIR = "./qwen-coder-multilingual-sft"
HUB_MODEL_ID = "moos124/qwen-coder-multilingual-sft"
def preprocess_function(example):
# Convert Alpaca format to ChatML format
user_content = example["instruction"]
if example.get("input"):
user_content += f"\n\nInput: {example['input']}"
return {
"messages": [
{"role": "user", "content": user_content},
{"role": "assistant", "content": example["output"]}
]
}
def main():
# 1. Load Dataset
dataset = load_dataset(DATASET_ID, split="train")
dataset = dataset.map(preprocess_function, remove_columns=dataset.column_names)
# 2. Tokenizer & Model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token
# 3. PEFT Config (LoRA)
peft_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
# 4. SFTConfig
sft_config = SFTConfig(
output_dir=OUTPUT_DIR,
max_seq_length=2048,
dataset_text_field="messages",
packing=False,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
num_train_epochs=1,
weight_decay=0.01,
lr_scheduler_type="cosine",
warmup_steps=100,
logging_steps=10,
logging_first_step=True,
disable_tqdm=True,
logging_strategy="steps",
bf16=True,
gradient_checkpointing=True,
push_to_hub=True,
hub_model_id=HUB_MODEL_ID,
save_strategy="steps",
save_steps=500,
report_to="trackio",
)
# 5. Trainer
trainer = SFTTrainer(
model=MODEL_ID,
train_dataset=dataset,
args=sft_config,
peft_config=peft_config,
processing_class=tokenizer,
)
# 6. Train
trainer.train()
# 7. Save & Push
trainer.save_model(OUTPUT_DIR)
trainer.push_to_hub()
if __name__ == "__main__":
main()