training-scripts / train_summary.py
epinfomax's picture
Upload train_summary.py with huggingface_hub
ff05f7b verified
# /// script
# dependencies = [
# "trl",
# "peft",
# "trackio",
# "transformers",
# "datasets",
# "bitsandbytes",
# "accelerate"
# ]
# ///
import os
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch
import trackio
# Configuration
model_id = "Qwen/Qwen2.5-7B-Instruct"
dataset_id = "daekeun-ml/naver-news-summarization-ko"
output_dir = "Qwen2.5-7B-Summarize-Ko"
hub_model_id = f"epinfomax/{output_dir}"
print(f"Starting training for {model_id} on {dataset_id}")
# 1. Load and Format Dataset
dataset = load_dataset(dataset_id, split="train")
def format_to_messages(example):
# Map 'document' -> input, 'summary' -> output
return {
"messages": [
{"role": "user", "content": f"Summarize the following document:\n\n{example['document']}"},
{"role": "assistant", "content": example['summary']}
]
}
print("Formatting dataset...")
dataset = dataset.map(format_to_messages, remove_columns=dataset.column_names)
# Create a small eval split
dataset = dataset.train_test_split(test_size=0.05, seed=42)
print(f"Train size: {len(dataset['train'])}, Eval size: {len(dataset['test'])}")
# 2. Model & Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
# Quantization Config (4-bit)
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
# 3. LoRA Config
peft_config = LoraConfig(
r=16,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)
# 4. Trainer
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4, # Adjust based on A10G memory
gradient_accumulation_steps=4,
learning_rate=2e-4,
logging_steps=25,
eval_strategy="steps",
eval_steps=100,
save_strategy="steps",
save_steps=100,
push_to_hub=True,
hub_model_id=hub_model_id,
report_to="trackio",
project="BizFlow-Summarizer",
run_name="Qwen-7B-SFT-Run1",
fp16=True,
max_seq_length=1024, # Truncate to save memory/time
dataset_text_field="messages", # Use the messages column
packing=False # Qwen might be sensitive to packing with chat template? Better safe.
)
trainer = SFTTrainer(
model=model_id,
train_dataset=dataset["train"],
eval_dataset=dataset["test"],
peft_config=peft_config,
args=training_args,
processing_class=tokenizer,
)
print("Starting training...")
trainer.train()
print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")