epinfomax commited on
Commit
ff05f7b
·
verified ·
1 Parent(s): 7dc6fb1

Upload train_summary.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_summary.py +106 -0
train_summary.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = [
3
+ # "trl",
4
+ # "peft",
5
+ # "trackio",
6
+ # "transformers",
7
+ # "datasets",
8
+ # "bitsandbytes",
9
+ # "accelerate"
10
+ # ]
11
+ # ///
12
+
13
+ import os
14
+ from datasets import load_dataset
15
+ from peft import LoraConfig
16
+ from trl import SFTTrainer, SFTConfig
17
+ from transformers import AutoTokenizer, BitsAndBytesConfig
18
+ import torch
19
+ import trackio
20
+
21
+ # Configuration
22
+ model_id = "Qwen/Qwen2.5-7B-Instruct"
23
+ dataset_id = "daekeun-ml/naver-news-summarization-ko"
24
+ output_dir = "Qwen2.5-7B-Summarize-Ko"
25
+ hub_model_id = f"epinfomax/{output_dir}"
26
+
27
+ print(f"Starting training for {model_id} on {dataset_id}")
28
+
29
+ # 1. Load and Format Dataset
30
+ dataset = load_dataset(dataset_id, split="train")
31
+
32
+ def format_to_messages(example):
33
+ # Map 'document' -> input, 'summary' -> output
34
+ return {
35
+ "messages": [
36
+ {"role": "user", "content": f"Summarize the following document:\n\n{example['document']}"},
37
+ {"role": "assistant", "content": example['summary']}
38
+ ]
39
+ }
40
+
41
+ print("Formatting dataset...")
42
+ dataset = dataset.map(format_to_messages, remove_columns=dataset.column_names)
43
+ # Create a small eval split
44
+ dataset = dataset.train_test_split(test_size=0.05, seed=42)
45
+
46
+ print(f"Train size: {len(dataset['train'])}, Eval size: {len(dataset['test'])}")
47
+
48
+ # 2. Model & Tokenizer
49
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+
52
+ # Quantization Config (4-bit)
53
+ bnb_config = BitsAndBytesConfig(
54
+ load_in_4bit=True,
55
+ bnb_4bit_quant_type="nf4",
56
+ bnb_4bit_compute_dtype=torch.float16,
57
+ )
58
+
59
+ # 3. LoRA Config
60
+ peft_config = LoraConfig(
61
+ r=16,
62
+ lora_alpha=32,
63
+ lora_dropout=0.05,
64
+ bias="none",
65
+ task_type="CAUSAL_LM",
66
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
67
+ )
68
+
69
+ # 4. Trainer
70
+ training_args = SFTConfig(
71
+ output_dir=output_dir,
72
+ num_train_epochs=3,
73
+ per_device_train_batch_size=4, # Adjust based on A10G memory
74
+ gradient_accumulation_steps=4,
75
+ learning_rate=2e-4,
76
+ logging_steps=25,
77
+ eval_strategy="steps",
78
+ eval_steps=100,
79
+ save_strategy="steps",
80
+ save_steps=100,
81
+ push_to_hub=True,
82
+ hub_model_id=hub_model_id,
83
+ report_to="trackio",
84
+ project="BizFlow-Summarizer",
85
+ run_name="Qwen-7B-SFT-Run1",
86
+ fp16=True,
87
+ max_seq_length=1024, # Truncate to save memory/time
88
+ dataset_text_field="messages", # Use the messages column
89
+ packing=False # Qwen might be sensitive to packing with chat template? Better safe.
90
+ )
91
+
92
+ trainer = SFTTrainer(
93
+ model=model_id,
94
+ train_dataset=dataset["train"],
95
+ eval_dataset=dataset["test"],
96
+ peft_config=peft_config,
97
+ args=training_args,
98
+ processing_class=tokenizer,
99
+ )
100
+
101
+ print("Starting training...")
102
+ trainer.train()
103
+
104
+ print("Pushing to hub...")
105
+ trainer.push_to_hub()
106
+ print("Done!")