# /// script # dependencies = [ # "trl", # "peft", # "trackio", # "transformers", # "datasets", # "bitsandbytes", # "accelerate" # ] # /// import os from datasets import load_dataset from peft import LoraConfig from trl import SFTTrainer, SFTConfig from transformers import AutoTokenizer, BitsAndBytesConfig import torch import trackio # Configuration model_id = "Qwen/Qwen2.5-7B-Instruct" dataset_id = "daekeun-ml/naver-news-summarization-ko" output_dir = "Qwen2.5-7B-Summarize-Ko" hub_model_id = f"epinfomax/{output_dir}" print(f"Starting training for {model_id} on {dataset_id}") # 1. Load and Format Dataset dataset = load_dataset(dataset_id, split="train") def format_to_messages(example): # Map 'document' -> input, 'summary' -> output return { "messages": [ {"role": "user", "content": f"Summarize the following document:\n\n{example['document']}"}, {"role": "assistant", "content": example['summary']} ] } print("Formatting dataset...") dataset = dataset.map(format_to_messages, remove_columns=dataset.column_names) # Create a small eval split dataset = dataset.train_test_split(test_size=0.05, seed=42) print(f"Train size: {len(dataset['train'])}, Eval size: {len(dataset['test'])}") # 2. Model & Tokenizer tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token # Quantization Config (4-bit) bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.float16, ) # 3. LoRA Config peft_config = LoraConfig( r=16, lora_alpha=32, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] ) # 4. Trainer training_args = SFTConfig( output_dir=output_dir, num_train_epochs=3, per_device_train_batch_size=4, # Adjust based on A10G memory gradient_accumulation_steps=4, learning_rate=2e-4, logging_steps=25, eval_strategy="steps", eval_steps=100, save_strategy="steps", save_steps=100, push_to_hub=True, hub_model_id=hub_model_id, report_to="trackio", project="BizFlow-Summarizer", run_name="Qwen-7B-SFT-Run1", fp16=True, max_seq_length=1024, # Truncate to save memory/time dataset_text_field="messages", # Use the messages column packing=False # Qwen might be sensitive to packing with chat template? Better safe. ) trainer = SFTTrainer( model=model_id, train_dataset=dataset["train"], eval_dataset=dataset["test"], peft_config=peft_config, args=training_args, processing_class=tokenizer, ) print("Starting training...") trainer.train() print("Pushing to hub...") trainer.push_to_hub() print("Done!")