lightita commited on
Commit
a53b482
Β·
verified Β·
1 Parent(s): fdfb5e5

Update train_seallm_khm_sum.py

Browse files
Files changed (1) hide show
  1. train_seallm_khm_sum.py +53 -25
train_seallm_khm_sum.py CHANGED
@@ -4,20 +4,22 @@ from datasets import load_dataset
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
7
- BitsAndBytesConfig,
8
  TrainingArguments,
 
 
 
9
  )
10
- from trl import SFTTrainer
11
- from peft import LoraConfig
12
 
13
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
14
  DATASET_NAME = "bltlab/lr-sum"
15
  DATASET_CONFIG = "khm"
16
 
 
17
  def load_khm_dataset():
18
  raw = load_dataset(DATASET_NAME, DATASET_CONFIG)
19
 
20
- # Try to find train/validation; if not, split test
21
  if "train" in raw:
22
  train = raw["train"]
23
  if "validation" in raw:
@@ -28,7 +30,7 @@ def load_khm_dataset():
28
  split = train.train_test_split(test_size=0.05, seed=42)
29
  train, eval_ds = split["train"], split["test"]
30
  else:
31
- # Some LR-Sum subsets only have 'test'; we split that.
32
  split = raw["test"].train_test_split(test_size=0.1, seed=42)
33
  train, eval_ds = split["train"], split["test"]
34
 
@@ -36,7 +38,7 @@ def load_khm_dataset():
36
  article = example["text"]
37
  summary = example["summary"]
38
 
39
- # Simple Khmer instruction β†’ Khmer summary
40
  text = (
41
  "αžŸαžΌαž˜αžŸαž„αŸ’αžαŸαž”αž’αžαŸ’αžαž”αž‘αžαžΆαž„αž€αŸ’αžšαŸ„αž˜αž‡αžΆαž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžšαŸ–\n\n"
42
  f"{article}\n\n"
@@ -62,7 +64,7 @@ def load_khm_dataset():
62
 
63
 
64
  def load_model_and_tokenizer():
65
- # QLoRA 4-bit quantization config
66
  bnb_config = BitsAndBytesConfig(
67
  load_in_4bit=True,
68
  bnb_4bit_use_double_quant=True,
@@ -74,7 +76,6 @@ def load_model_and_tokenizer():
74
  MODEL_NAME,
75
  trust_remote_code=True,
76
  )
77
-
78
  if tokenizer.pad_token is None:
79
  tokenizer.pad_token = tokenizer.eos_token
80
 
@@ -85,15 +86,16 @@ def load_model_and_tokenizer():
85
  trust_remote_code=True,
86
  )
87
 
88
- # Enable gradient checkpointing for memory
89
  model.gradient_checkpointing_enable()
90
 
91
  return model, tokenizer
92
 
 
93
  def main():
94
  train_ds, eval_ds = load_khm_dataset()
95
  model, tokenizer = load_model_and_tokenizer()
96
 
 
97
  lora_config = LoraConfig(
98
  r=64,
99
  lora_alpha=16,
@@ -101,8 +103,40 @@ def main():
101
  bias="none",
102
  task_type="CAUSAL_LM",
103
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
- # Use standard TrainingArguments instead of SFTConfig
106
  training_args = TrainingArguments(
107
  output_dir="seallm-khm-sum-lora",
108
  num_train_epochs=2,
@@ -115,33 +149,27 @@ def main():
115
  save_total_limit=2,
116
  lr_scheduler_type="cosine",
117
  warmup_ratio=0.03,
118
- # old transformers may not support bf16, so let's be safe:
119
- fp16=True, # use fp16 instead of bf16
120
- report_to="none", # if this errors next, we’ll drop it
121
  )
122
 
123
-
124
- trainer = SFTTrainer(
125
  model=model,
126
- tokenizer=tokenizer,
127
- train_dataset=train_ds,
128
- eval_dataset=eval_ds,
129
- peft_config=lora_config,
130
  args=training_args,
131
- dataset_text_field="text",
132
- max_seq_length=1024, # set here instead of in config
133
- # packing=False, # keep off for compatibility
134
  )
135
 
136
  trainer.train()
137
 
138
- # Save LoRA adapter and tokenizer
139
- trainer.model.save_pretrained("seallm-khm-sum-lora")
140
  tokenizer.save_pretrained("seallm-khm-sum-lora")
141
 
142
  repo_id = os.environ.get("OUTPUT_REPO_ID", "")
143
  if repo_id:
144
- trainer.model.push_to_hub(repo_id)
145
  tokenizer.push_to_hub(repo_id)
146
 
147
 
 
4
  from transformers import (
5
  AutoTokenizer,
6
  AutoModelForCausalLM,
 
7
  TrainingArguments,
8
+ Trainer,
9
+ DataCollatorForLanguageModeling,
10
+ BitsAndBytesConfig,
11
  )
12
+ from peft import LoraConfig, get_peft_model
 
13
 
14
  MODEL_NAME = "SeaLLMs/SeaLLMs-v3-1.5B"
15
  DATASET_NAME = "bltlab/lr-sum"
16
  DATASET_CONFIG = "khm"
17
 
18
+
19
  def load_khm_dataset():
20
  raw = load_dataset(DATASET_NAME, DATASET_CONFIG)
21
 
22
+ # Try standard splits first
23
  if "train" in raw:
24
  train = raw["train"]
25
  if "validation" in raw:
 
30
  split = train.train_test_split(test_size=0.05, seed=42)
31
  train, eval_ds = split["train"], split["test"]
32
  else:
33
+ # Some subsets only have 'test'; split that
34
  split = raw["test"].train_test_split(test_size=0.1, seed=42)
35
  train, eval_ds = split["train"], split["test"]
36
 
 
38
  article = example["text"]
39
  summary = example["summary"]
40
 
41
+ # Simple Khmer instruction-style format
42
  text = (
43
  "αžŸαžΌαž˜αžŸαž„αŸ’αžαŸαž”αž’αžαŸ’αžαž”αž‘αžαžΆαž„αž€αŸ’αžšαŸ„αž˜αž‡αžΆαž—αžΆαžŸαžΆαžαŸ’αž˜αŸ‚αžšαŸ–\n\n"
44
  f"{article}\n\n"
 
64
 
65
 
66
  def load_model_and_tokenizer():
67
+ # QLoRA 4-bit config
68
  bnb_config = BitsAndBytesConfig(
69
  load_in_4bit=True,
70
  bnb_4bit_use_double_quant=True,
 
76
  MODEL_NAME,
77
  trust_remote_code=True,
78
  )
 
79
  if tokenizer.pad_token is None:
80
  tokenizer.pad_token = tokenizer.eos_token
81
 
 
86
  trust_remote_code=True,
87
  )
88
 
 
89
  model.gradient_checkpointing_enable()
90
 
91
  return model, tokenizer
92
 
93
+
94
  def main():
95
  train_ds, eval_ds = load_khm_dataset()
96
  model, tokenizer = load_model_and_tokenizer()
97
 
98
+ # Apply LoRA to the model
99
  lora_config = LoraConfig(
100
  r=64,
101
  lora_alpha=16,
 
103
  bias="none",
104
  task_type="CAUSAL_LM",
105
  )
106
+ model = get_peft_model(model, lora_config)
107
+
108
+ # Tokenize datasets
109
+ max_length = 1024
110
+
111
+ def tokenize_function(batch):
112
+ out = tokenizer(
113
+ batch["text"],
114
+ max_length=max_length,
115
+ truncation=True,
116
+ padding="max_length",
117
+ )
118
+ # Causal LM: labels = input_ids
119
+ out["labels"] = out["input_ids"].copy()
120
+ return out
121
+
122
+ train_tokenized = train_ds.map(
123
+ tokenize_function,
124
+ batched=True,
125
+ remove_columns=["text"],
126
+ desc="Tokenizing train set",
127
+ )
128
+ eval_tokenized = eval_ds.map(
129
+ tokenize_function,
130
+ batched=True,
131
+ remove_columns=["text"],
132
+ desc="Tokenizing eval set",
133
+ )
134
+
135
+ data_collator = DataCollatorForLanguageModeling(
136
+ tokenizer=tokenizer,
137
+ mlm=False,
138
+ )
139
 
 
140
  training_args = TrainingArguments(
141
  output_dir="seallm-khm-sum-lora",
142
  num_train_epochs=2,
 
149
  save_total_limit=2,
150
  lr_scheduler_type="cosine",
151
  warmup_ratio=0.03,
152
+ fp16=True, # safer for old transformers
153
+ report_to="none", # remove if this crashes
 
154
  )
155
 
156
+ trainer = Trainer(
 
157
  model=model,
 
 
 
 
158
  args=training_args,
159
+ train_dataset=train_tokenized,
160
+ eval_dataset=eval_tokenized,
161
+ data_collator=data_collator,
162
  )
163
 
164
  trainer.train()
165
 
166
+ # Save LoRA adapter + tokenizer
167
+ model.save_pretrained("seallm-khm-sum-lora")
168
  tokenizer.save_pretrained("seallm-khm-sum-lora")
169
 
170
  repo_id = os.environ.get("OUTPUT_REPO_ID", "")
171
  if repo_id:
172
+ model.push_to_hub(repo_id)
173
  tokenizer.push_to_hub(repo_id)
174
 
175