ssdataanalysis commited on
Commit
1f5ea7f
·
verified ·
1 Parent(s): ac64bab

Add training script with QLoRA

Browse files
Files changed (1) hide show
  1. train.py +81 -168
train.py CHANGED
@@ -1,25 +1,19 @@
1
- #!/usr/bin/env python3
2
- """
3
- Hebrew-first SFT training for Gemma 4 E2B/E4B.
4
- DictaLM-style recipe: 50/50 Hebrew/English mix, LoRA r=64, 3 epochs.
5
- """
6
-
7
  import os
8
- import sys
 
9
  import random
10
  import json
11
  from datasets import load_dataset, concatenate_datasets
12
- from transformers import AutoModelForImageTextToText, AutoTokenizer
13
  from peft import LoraConfig
14
  from trl import SFTConfig, SFTTrainer
15
  import trackio
16
  import torch
17
  from transformers import TrainerCallback
18
 
19
- # Trackio init
20
  trackio.init(
21
- project=os.environ.get("TRACKIO_PROJECT", "hebrew-gemma4"),
22
- space_id=os.environ.get("TRACKIO_SPACE_ID", "ssdataanalysis/mlintern-heb4"),
23
  )
24
 
25
  class TrackioAlertCallback(TrainerCallback):
@@ -28,43 +22,23 @@ class TrackioAlertCallback(TrainerCallback):
28
  loss = logs["loss"]
29
  step = state.global_step
30
  if loss > 5.0 and step > 50:
31
- trackio.alert(
32
- title="High Loss Warning",
33
- text=f"loss={loss:.3f} at step {step} — lr may be too high, consider reducing",
34
- level="WARN"
35
- )
36
  elif step % 100 == 0:
37
- trackio.alert(
38
- title="Training Progress",
39
- text=f"loss={loss:.3f} at step {step}",
40
- level="INFO"
41
- )
42
-
43
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
44
  if metrics and "eval_loss" in metrics:
45
- trackio.alert(
46
- title="Eval Complete",
47
- text=f"eval_loss={metrics['eval_loss']:.3f} at step {state.global_step}",
48
- level="INFO"
49
- )
50
 
51
  def convert_hebrew_qa_to_messages(example):
52
- """Convert yuvalav/hebrew-qa to messages format."""
53
  instruction = example.get("instruction", "")
54
  input_text = example.get("input", "")
55
  output = example.get("output", "")
56
  user_content = instruction
57
  if input_text and str(input_text).strip():
58
  user_content += "\n" + str(input_text)
59
- return {
60
- "messages": [
61
- {"role": "user", "content": user_content},
62
- {"role": "assistant", "content": output},
63
- ]
64
- }
65
 
66
  def convert_hebrew_chatml_to_messages(example):
67
- """Convert itayl/hebrewQA-chatml to messages format."""
68
  conversations = example.get("conversations", [])
69
  messages = []
70
  for turn in conversations:
@@ -79,164 +53,103 @@ def convert_hebrew_chatml_to_messages(example):
79
  return {"messages": messages}
80
 
81
  def prepare_dataset(hebrew_ratio=0.5, max_total=120000, seed=42):
82
- """Prepare mixed Hebrew-English instruction dataset."""
83
  random.seed(seed)
84
-
85
- hebrew_samples_target = int(max_total * hebrew_ratio)
86
- english_samples_target = max_total - hebrew_samples_target
87
-
88
  datasets_list = []
89
-
90
- # Hebrew datasets
91
  print("Loading Hebrew datasets...")
92
-
93
- # 1. yuvalav/hebrew-qa (~30K)
94
  ds_he1 = load_dataset("yuvalav/hebrew-qa", split="train")
95
  ds_he1 = ds_he1.map(convert_hebrew_qa_to_messages, remove_columns=ds_he1.column_names)
96
  datasets_list.append(("hebrew-qa", ds_he1))
97
  print(f" hebrew-qa: {len(ds_he1)}")
98
-
99
- # 2. itayl/hebrewQA-chatml (~30K)
100
  ds_he2 = load_dataset("itayl/hebrewQA-chatml", split="train")
101
  ds_he2 = ds_he2.map(convert_hebrew_chatml_to_messages, remove_columns=ds_he2.column_names)
102
  datasets_list.append(("hebrewQA-chatml", ds_he2))
103
  print(f" hebrewQA-chatml: {len(ds_he2)}")
104
-
105
- total_hebrew = len(ds_he1) + len(ds_he2)
106
- print(f"Total Hebrew: {total_hebrew}")
107
-
108
- # English datasets
109
  print("Loading English datasets...")
110
-
111
- # 3. OpenHermes 2.5 H4 (~950K, take subset)
112
  ds_en1 = load_dataset("HuggingFaceTB/OpenHermes-2.5-H4", split="train_sft")
113
  ds_en1 = ds_en1.remove_columns([c for c in ds_en1.column_names if c != "messages"])
114
- # Filter to only user/assistant/system roles
115
  def filter_messages(example):
116
  msgs = example.get("messages", [])
117
  return all(m.get("role") in ["user", "assistant", "system"] for m in msgs)
118
  ds_en1 = ds_en1.filter(filter_messages)
119
- # Sample
120
- if len(ds_en1) > english_samples_target:
121
- ds_en1 = ds_en1.shuffle(seed=seed).select(range(english_samples_target))
122
  datasets_list.append(("OpenHermes", ds_en1))
123
  print(f" OpenHermes: {len(ds_en1)}")
124
-
125
- # Combine and shuffle
126
  all_datasets = [d for _, d in datasets_list]
127
  combined = concatenate_datasets(all_datasets)
128
  combined = combined.shuffle(seed=seed)
129
-
130
- # Verify format
131
- sample = combined[0]
132
- print(f"\nSample messages: {json.dumps(sample['messages'][:2], ensure_ascii=False)}")
133
-
134
- total = len(combined)
135
- hebrew_count = len(ds_he1) + len(ds_he2)
136
- print(f"\nFinal dataset: {total} samples ({hebrew_count} Hebrew, {len(ds_en1)} English)")
137
- print(f"Hebrew ratio: {hebrew_count/total:.2%}")
138
-
139
  return combined
140
 
141
- def train(model_id, output_dir, hebrew_ratio=0.5, max_total=120000):
142
- print(f"=== Training {model_id} -> {output_dir} ===")
143
-
144
- # Dataset
145
- train_dataset = prepare_dataset(hebrew_ratio=hebrew_ratio, max_total=max_total)
146
-
147
- # Create a small eval set (first 1000 samples)
148
- eval_dataset = train_dataset.select(range(min(1000, len(train_dataset))))
149
- train_dataset = train_dataset.select(range(min(1000, len(train_dataset)), len(train_dataset)))
150
-
151
- # Tokenizer
152
- print("Loading tokenizer...")
153
- tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
154
-
155
- # Model
156
- print("Loading model...")
157
- model = AutoModelForImageTextToText.from_pretrained(
158
- model_id,
159
- attn_implementation="sdpa",
160
- dtype="bfloat16",
161
- device_map="auto",
162
- )
163
-
164
- # LoRA config - Dicta style: high rank, all linear layers
165
- peft_config = LoraConfig(
166
- r=64,
167
- lora_alpha=16,
168
- lora_dropout=0.05,
169
- bias="none",
170
- task_type="CAUSAL_LM",
171
- target_modules="all-linear",
172
- exclude_modules=["vision_tower", "multi_modal_projector"],
173
- )
174
-
175
- # Training args
176
- training_args = SFTConfig(
177
- output_dir=output_dir,
178
- num_train_epochs=3,
179
- per_device_train_batch_size=1,
180
- gradient_accumulation_steps=8,
181
- learning_rate=2e-4,
182
- lr_scheduler_type="cosine",
183
- warmup_ratio=0.03,
184
- weight_decay=0.01,
185
- max_length=4096,
186
- packing=True,
187
- bf16=True,
188
- use_liger_kernel=True,
189
- logging_strategy="steps",
190
- logging_steps=10,
191
- logging_first_step=True,
192
- eval_strategy="steps",
193
- eval_steps=100,
194
- save_strategy="epoch",
195
- save_total_limit=2,
196
- push_to_hub=True,
197
- hub_model_id=output_dir,
198
- report_to="trackio",
199
- run_name=output_dir,
200
- remove_unused_columns=False,
201
- disable_tqdm=True,
202
- dataset_num_proc=8,
203
- gradient_checkpointing=True,
204
- )
205
-
206
- # Trainer
207
- trainer = SFTTrainer(
208
- model=model,
209
- args=training_args,
210
- train_dataset=train_dataset,
211
- eval_dataset=eval_dataset,
212
- peft_config=peft_config,
213
- processing_class=tokenizer,
214
- callbacks=[TrackioAlertCallback()],
215
- )
216
-
217
- # Train
218
- print("Starting training...")
219
- trainer.train()
220
-
221
- # Save
222
- trainer.save_model(output_dir)
223
- trainer.push_to_hub()
224
-
225
- trackio.alert(
226
- title="Training Complete",
227
- text=f"Model {output_dir} training completed successfully",
228
- level="INFO"
229
- )
230
 
231
- print(f"Done! Model saved to {output_dir}")
 
 
 
 
232
 
233
- if __name__ == "__main__":
234
- import argparse
235
- parser = argparse.ArgumentParser()
236
- parser.add_argument("--model_id", type=str, required=True)
237
- parser.add_argument("--output_dir", type=str, required=True)
238
- parser.add_argument("--hebrew_ratio", type=float, default=0.5)
239
- parser.add_argument("--max_total", type=int, default=120000)
240
- args = parser.parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
- train(args.model_id, args.output_dir, args.hebrew_ratio, args.max_total)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
3
+
4
  import random
5
  import json
6
  from datasets import load_dataset, concatenate_datasets
7
+ from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
8
  from peft import LoraConfig
9
  from trl import SFTConfig, SFTTrainer
10
  import trackio
11
  import torch
12
  from transformers import TrainerCallback
13
 
 
14
  trackio.init(
15
+ project="hebrew-gemma4",
16
+ space_id="ssdataanalysis/mlintern-heb4",
17
  )
18
 
19
  class TrackioAlertCallback(TrainerCallback):
 
22
  loss = logs["loss"]
23
  step = state.global_step
24
  if loss > 5.0 and step > 50:
25
+ trackio.alert(title="High Loss Warning", text=f"loss={loss:.3f} at step {step} lr too high", level="WARN")
 
 
 
 
26
  elif step % 100 == 0:
27
+ trackio.alert(title="Training Progress", text=f"loss={loss:.3f} at step {step}", level="INFO")
 
 
 
 
 
28
  def on_evaluate(self, args, state, control, metrics=None, **kwargs):
29
  if metrics and "eval_loss" in metrics:
30
+ trackio.alert(title="Eval Complete", text=f"eval_loss={metrics['eval_loss']:.3f} at step {state.global_step}", level="INFO")
 
 
 
 
31
 
32
  def convert_hebrew_qa_to_messages(example):
 
33
  instruction = example.get("instruction", "")
34
  input_text = example.get("input", "")
35
  output = example.get("output", "")
36
  user_content = instruction
37
  if input_text and str(input_text).strip():
38
  user_content += "\n" + str(input_text)
39
+ return {"messages": [{"role": "user", "content": user_content}, {"role": "assistant", "content": output}]}
 
 
 
 
 
40
 
41
  def convert_hebrew_chatml_to_messages(example):
 
42
  conversations = example.get("conversations", [])
43
  messages = []
44
  for turn in conversations:
 
53
  return {"messages": messages}
54
 
55
  def prepare_dataset(hebrew_ratio=0.5, max_total=120000, seed=42):
 
56
  random.seed(seed)
 
 
 
 
57
  datasets_list = []
 
 
58
  print("Loading Hebrew datasets...")
 
 
59
  ds_he1 = load_dataset("yuvalav/hebrew-qa", split="train")
60
  ds_he1 = ds_he1.map(convert_hebrew_qa_to_messages, remove_columns=ds_he1.column_names)
61
  datasets_list.append(("hebrew-qa", ds_he1))
62
  print(f" hebrew-qa: {len(ds_he1)}")
 
 
63
  ds_he2 = load_dataset("itayl/hebrewQA-chatml", split="train")
64
  ds_he2 = ds_he2.map(convert_hebrew_chatml_to_messages, remove_columns=ds_he2.column_names)
65
  datasets_list.append(("hebrewQA-chatml", ds_he2))
66
  print(f" hebrewQA-chatml: {len(ds_he2)}")
 
 
 
 
 
67
  print("Loading English datasets...")
 
 
68
  ds_en1 = load_dataset("HuggingFaceTB/OpenHermes-2.5-H4", split="train_sft")
69
  ds_en1 = ds_en1.remove_columns([c for c in ds_en1.column_names if c != "messages"])
 
70
  def filter_messages(example):
71
  msgs = example.get("messages", [])
72
  return all(m.get("role") in ["user", "assistant", "system"] for m in msgs)
73
  ds_en1 = ds_en1.filter(filter_messages)
74
+ english_target = max_total - (len(ds_he1) + len(ds_he2))
75
+ if len(ds_en1) > english_target:
76
+ ds_en1 = ds_en1.shuffle(seed=seed).select(range(english_target))
77
  datasets_list.append(("OpenHermes", ds_en1))
78
  print(f" OpenHermes: {len(ds_en1)}")
 
 
79
  all_datasets = [d for _, d in datasets_list]
80
  combined = concatenate_datasets(all_datasets)
81
  combined = combined.shuffle(seed=seed)
82
+ print(f"Final dataset: {len(combined)} samples")
 
 
 
 
 
 
 
 
 
83
  return combined
84
 
85
+ model_id = os.environ.get("MODEL_ID", "google/gemma-4-E2B-it")
86
+ output_dir = os.environ.get("OUTPUT_DIR", "ssdataanalysis/gemma-4-E2B-hebrew-first")
87
+ print(f"=== Training {model_id} -> {output_dir} ===")
88
+
89
+ train_dataset = prepare_dataset(hebrew_ratio=0.5, max_total=120000)
90
+ eval_dataset = train_dataset.select(range(min(1000, len(train_dataset))))
91
+ train_dataset = train_dataset.select(range(min(1000, len(train_dataset)), len(train_dataset)))
92
+
93
+ print("Loading tokenizer...")
94
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
95
+ print("Loading model with 4-bit quantization...")
96
+ bnb_config = BitsAndBytesConfig(
97
+ load_in_4bit=True,
98
+ bnb_4bit_quant_type="nf4",
99
+ bnb_4bit_compute_dtype=torch.bfloat16,
100
+ bnb_4bit_use_double_quant=True,
101
+ )
102
+ model = AutoModelForImageTextToText.from_pretrained(
103
+ model_id,
104
+ attn_implementation="sdpa",
105
+ quantization_config=bnb_config,
106
+ device_map="auto",
107
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
+ peft_config = LoraConfig(
110
+ r=64, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
111
+ target_modules="all-linear",
112
+ exclude_modules=["vision_tower", "multi_modal_projector"],
113
+ )
114
 
115
+ training_args = SFTConfig(
116
+ output_dir=output_dir,
117
+ num_train_epochs=3,
118
+ per_device_train_batch_size=1,
119
+ gradient_accumulation_steps=8,
120
+ learning_rate=2e-4,
121
+ lr_scheduler_type="cosine",
122
+ warmup_ratio=0.03,
123
+ weight_decay=0.01,
124
+ max_length=4096,
125
+ packing=True,
126
+ bf16=True,
127
+ use_liger_kernel=True,
128
+ logging_strategy="steps",
129
+ logging_steps=10,
130
+ logging_first_step=True,
131
+ eval_strategy="steps",
132
+ eval_steps=100,
133
+ save_strategy="epoch",
134
+ save_total_limit=2,
135
+ push_to_hub=True,
136
+ hub_model_id=output_dir,
137
+ report_to="trackio",
138
+ run_name=output_dir,
139
+ remove_unused_columns=False,
140
+ disable_tqdm=True,
141
+ dataset_num_proc=8,
142
+ gradient_checkpointing=True,
143
+ )
144
 
145
+ trainer = SFTTrainer(
146
+ model=model, args=training_args, train_dataset=train_dataset,
147
+ eval_dataset=eval_dataset, peft_config=peft_config,
148
+ processing_class=tokenizer, callbacks=[TrackioAlertCallback()],
149
+ )
150
+ print("Starting training...")
151
+ trainer.train()
152
+ trainer.save_model(output_dir)
153
+ trainer.push_to_hub()
154
+ trackio.alert(title="Training Complete", text=f"Model {output_dir} training completed successfully", level="INFO")
155
+ print(f"Done! Model saved to {output_dir}")