File size: 2,820 Bytes
ff05f7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# /// script
# dependencies = [
#   "trl",
#   "peft",
#   "trackio",
#   "transformers",
#   "datasets",
#   "bitsandbytes",
#   "accelerate"
# ]
# ///

import os
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from transformers import AutoTokenizer, BitsAndBytesConfig
import torch
import trackio

# Configuration
model_id = "Qwen/Qwen2.5-7B-Instruct"
dataset_id = "daekeun-ml/naver-news-summarization-ko"
output_dir = "Qwen2.5-7B-Summarize-Ko"
hub_model_id = f"epinfomax/{output_dir}"

print(f"Starting training for {model_id} on {dataset_id}")

# 1. Load and Format Dataset
dataset = load_dataset(dataset_id, split="train")

def format_to_messages(example):
    # Map 'document' -> input, 'summary' -> output
    return {
        "messages": [
            {"role": "user", "content": f"Summarize the following document:\n\n{example['document']}"},
            {"role": "assistant", "content": example['summary']}
        ]
    }

print("Formatting dataset...")
dataset = dataset.map(format_to_messages, remove_columns=dataset.column_names)
# Create a small eval split
dataset = dataset.train_test_split(test_size=0.05, seed=42)

print(f"Train size: {len(dataset['train'])}, Eval size: {len(dataset['test'])}")

# 2. Model & Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

# Quantization Config (4-bit)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

# 3. LoRA Config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
)

# 4. Trainer
training_args = SFTConfig(
    output_dir=output_dir,
    num_train_epochs=3,
    per_device_train_batch_size=4, # Adjust based on A10G memory
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    logging_steps=25,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    push_to_hub=True,
    hub_model_id=hub_model_id,
    report_to="trackio",
    project="BizFlow-Summarizer",
    run_name="Qwen-7B-SFT-Run1",
    fp16=True,
    max_seq_length=1024, # Truncate to save memory/time
    dataset_text_field="messages", # Use the messages column
    packing=False # Qwen might be sensitive to packing with chat template? Better safe.
)

trainer = SFTTrainer(
    model=model_id,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    args=training_args,
    processing_class=tokenizer,
)

print("Starting training...")
trainer.train()

print("Pushing to hub...")
trainer.push_to_hub()
print("Done!")