lightita commited on
Commit
8f20942
Β·
verified Β·
1 Parent(s): 4d5cc5c

Create train_seallm_khm_sum.py

Browse files
Files changed (1) hide show
  1. train_seallm_khm_sum.py +150 -0
train_seallm_khm_sum.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from datasets import load_dataset
4
+ from transformers import (
5
+ AutoTokenizer,
6
+ AutoModelForCausalLM,
7
+ BitsAndBytesConfig,
8
+ )
9
+ from trl import SFTTrainer, SFTConfig
10
+ from peft import LoraConfig
11
+
12
+ MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
13
+ DATASET_NAME = "bltlab/lr-sum"
14
+ DATASET_CONFIG = "khm"
15
+
16
+ def load_khm_dataset():
17
+ raw = load_dataset(DATASET_NAME, DATASET_CONFIG)
18
+
19
+ # Try to find train/validation; if not, split test
20
+ if "train" in raw:
21
+ train = raw["train"]
22
+ if "validation" in raw:
23
+ eval_ds = raw["validation"]
24
+ elif "test" in raw:
25
+ eval_ds = raw["test"]
26
+ else:
27
+ split = train.train_test_split(test_size=0.05, seed=42)
28
+ train, eval_ds = split["train"], split["test"]
29
+ else:
30
+ # Some LR-Sum subsets only have 'test'; we split that.
31
+ split = raw["test"].train_test_split(test_size=0.1, seed=42)
32
+ train, eval_ds = split["train"], split["test"]
33
+
34
+ def format_example(example):
35
+ article = example["text"]
36
+ summary = example["summary"]
37
+
38
+ # Simple Khmer instruction β†’ Khmer summary
39
+ text = (
40
+ "αžŸαžΌαž˜αžŸαž„αŸ’αžαŸαž”αž’αžαŸ’αžαž”αž‘αžαžΆαž„αž€αŸ’αžšαŸ„αž˜αž‡αžΆαž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžšαŸ–\n\n"
41
+ f"{article}\n\n"
42
+ "αžŸαŸαž…αž€αŸ’αžαžΈαžŸαž„αŸ’αžαŸαž”αŸ– "
43
+ f"{summary}"
44
+ )
45
+ return {"text": text}
46
+
47
+ cols_to_remove = list(train.features)
48
+
49
+ train = train.map(
50
+ format_example,
51
+ remove_columns=cols_to_remove,
52
+ desc="Formatting train set",
53
+ )
54
+ eval_ds = eval_ds.map(
55
+ format_example,
56
+ remove_columns=cols_to_remove,
57
+ desc="Formatting eval set",
58
+ )
59
+
60
+ return train, eval_ds
61
+
62
+
63
+ def load_model_and_tokenizer():
64
+ # QLoRA 4-bit quantization config
65
+ bnb_config = BitsAndBytesConfig(
66
+ load_in_4bit=True,
67
+ bnb_4bit_use_double_quant=True,
68
+ bnb_4bit_quant_type="nf4",
69
+ bnb_4bit_compute_dtype=torch.bfloat16,
70
+ )
71
+
72
+ tokenizer = AutoTokenizer.from_pretrained(
73
+ MODEL_NAME,
74
+ trust_remote_code=True,
75
+ )
76
+
77
+ if tokenizer.pad_token is None:
78
+ tokenizer.pad_token = tokenizer.eos_token
79
+
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ MODEL_NAME,
82
+ quantization_config=bnb_config,
83
+ device_map="auto",
84
+ trust_remote_code=True,
85
+ )
86
+
87
+ # Enable gradient checkpointing for memory
88
+ model.gradient_checkpointing_enable()
89
+
90
+ return model, tokenizer
91
+
92
+
93
+ def main():
94
+ train_ds, eval_ds = load_khm_dataset()
95
+ model, tokenizer = load_model_and_tokenizer()
96
+
97
+ lora_config = LoraConfig(
98
+ r=64,
99
+ lora_alpha=16,
100
+ lora_dropout=0.05,
101
+ bias="none",
102
+ task_type="CAUSAL_LM",
103
+ )
104
+
105
+ sft_config = SFTConfig(
106
+ output_dir="seallm-khm-sum-lora",
107
+ num_train_epochs=2,
108
+ per_device_train_batch_size=2,
109
+ per_device_eval_batch_size=2,
110
+ gradient_accumulation_steps=8,
111
+ learning_rate=2e-4,
112
+ logging_steps=10,
113
+ eval_strategy="steps",
114
+ eval_steps=200,
115
+ save_steps=200,
116
+ save_total_limit=2,
117
+ max_seq_length=1024,
118
+ packing=True,
119
+ lr_scheduler_type="cosine",
120
+ warmup_ratio=0.03,
121
+ bf16=True,
122
+ gradient_checkpointing=True,
123
+ report_to="none", # or "wandb" etc.
124
+ )
125
+
126
+ trainer = SFTTrainer(
127
+ model=model,
128
+ tokenizer=tokenizer,
129
+ train_dataset=train_ds,
130
+ eval_dataset=eval_ds,
131
+ peft_config=lora_config,
132
+ args=sft_config,
133
+ dataset_text_field="text",
134
+ )
135
+
136
+ trainer.train()
137
+
138
+ # Save LoRA adapter and tokenizer
139
+ trainer.model.save_pretrained("seallm-khm-sum-lora")
140
+ tokenizer.save_pretrained("seallm-khm-sum-lora")
141
+
142
+ # Optionally push directly to the Hub (needs HF_TOKEN env)
143
+ repo_id = os.environ.get("OUTPUT_REPO_ID", "")
144
+ if repo_id:
145
+ trainer.model.push_to_hub(repo_id)
146
+ tokenizer.push_to_hub(repo_id)
147
+
148
+
149
+ if __name__ == "__main__":
150
+ main()