ssdataanalysis commited on
Commit
779c4ca
·
verified ·
1 Parent(s): 3ea3a00

Disable eval to fix OOM on A10G

Browse files
Files changed (1) hide show
  1. train.py +153 -0
train.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
3
+
4
+ import random
5
+ import json
6
+ from datasets import load_dataset, concatenate_datasets
7
+ from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
8
+ from peft import LoraConfig
9
+ from trl import SFTConfig, SFTTrainer
10
+ import trackio
11
+ import torch
12
+ from transformers import TrainerCallback
13
+
14
+ trackio.init(
15
+ project="hebrew-gemma4",
16
+ space_id="ssdataanalysis/mlintern-heb4",
17
+ )
18
+
19
+ class TrackioAlertCallback(TrainerCallback):
20
+ def on_log(self, args, state, control, logs=None, **kwargs):
21
+ if logs and "loss" in logs:
22
+ loss = logs["loss"]
23
+ step = state.global_step
24
+ if loss > 5.0 and step > 50:
25
+ trackio.alert(title="High Loss Warning", text=f"loss={loss:.3f} at step {step} lr too high", level="WARN")
26
+ elif step % 100 == 0:
27
+ trackio.alert(title="Training Progress", text=f"loss={loss:.3f} at step {step}", level="INFO")
28
+ def on_evaluate(self, args, state, control, metrics=None, **kwargs):
29
+ if metrics and "eval_loss" in metrics:
30
+ trackio.alert(title="Eval Complete", text=f"eval_loss={metrics['eval_loss']:.3f} at step {state.global_step}", level="INFO")
31
+
32
+ def convert_hebrew_qa_to_messages(example):
33
+ instruction = example.get("instruction", "")
34
+ input_text = example.get("input", "")
35
+ output = example.get("output", "")
36
+ user_content = instruction
37
+ if input_text and str(input_text).strip():
38
+ user_content += "\n" + str(input_text)
39
+ return {"messages": [{"role": "user", "content": user_content}, {"role": "assistant", "content": output}]}
40
+
41
+ def convert_hebrew_chatml_to_messages(example):
42
+ conversations = example.get("conversations", [])
43
+ messages = []
44
+ for turn in conversations:
45
+ role = turn.get("from", "")
46
+ content = turn.get("value", "")
47
+ if role == "human":
48
+ messages.append({"role": "user", "content": content})
49
+ elif role == "gpt":
50
+ messages.append({"role": "assistant", "content": content})
51
+ else:
52
+ messages.append({"role": role, "content": content})
53
+ return {"messages": messages}
54
+
55
+ def prepare_dataset(hebrew_ratio=0.5, max_total=120000, seed=42):
56
+ random.seed(seed)
57
+ datasets_list = []
58
+ print("Loading Hebrew datasets...")
59
+ ds_he1 = load_dataset("yuvalav/hebrew-qa", split="train")
60
+ ds_he1 = ds_he1.map(convert_hebrew_qa_to_messages, remove_columns=ds_he1.column_names)
61
+ datasets_list.append(("hebrew-qa", ds_he1))
62
+ print(f" hebrew-qa: {len(ds_he1)}")
63
+ ds_he2 = load_dataset("itayl/hebrewQA-chatml", split="train")
64
+ ds_he2 = ds_he2.map(convert_hebrew_chatml_to_messages, remove_columns=ds_he2.column_names)
65
+ datasets_list.append(("hebrewQA-chatml", ds_he2))
66
+ print(f" hebrewQA-chatml: {len(ds_he2)}")
67
+ print("Loading English datasets...")
68
+ ds_en1 = load_dataset("HuggingFaceTB/OpenHermes-2.5-H4", split="train_sft")
69
+ ds_en1 = ds_en1.remove_columns([c for c in ds_en1.column_names if c != "messages"])
70
+ def filter_messages(example):
71
+ msgs = example.get("messages", [])
72
+ return all(m.get("role") in ["user", "assistant", "system"] for m in msgs)
73
+ ds_en1 = ds_en1.filter(filter_messages)
74
+ english_target = max_total - (len(ds_he1) + len(ds_he2))
75
+ if len(ds_en1) > english_target:
76
+ ds_en1 = ds_en1.shuffle(seed=seed).select(range(english_target))
77
+ datasets_list.append(("OpenHermes", ds_en1))
78
+ print(f" OpenHermes: {len(ds_en1)}")
79
+ all_datasets = [d for _, d in datasets_list]
80
+ combined = concatenate_datasets(all_datasets)
81
+ combined = combined.shuffle(seed=seed)
82
+ print(f"Final dataset: {len(combined)} samples")
83
+ return combined
84
+
85
+ model_id = os.environ.get("MODEL_ID", "google/gemma-4-E4B-it")
86
+ output_dir = os.environ.get("OUTPUT_DIR", "ssdataanalysis/gemma-4-E4B-hebrew-first")
87
+ print(f"=== Training {model_id} -> {output_dir} ===")
88
+
89
+ train_dataset = prepare_dataset(hebrew_ratio=0.5, max_total=120000)
90
+ # No eval dataset to avoid OOM during evaluation on A10G 24GB
91
+ # We will rely on training loss and periodic checkpointing
92
+
93
+ print("Loading tokenizer...")
94
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
95
+ print("Loading model with 4-bit quantization...")
96
+ bnb_config = BitsAndBytesConfig(
97
+ load_in_4bit=True,
98
+ bnb_4bit_quant_type="nf4",
99
+ bnb_4bit_compute_dtype=torch.bfloat16,
100
+ bnb_4bit_use_double_quant=True,
101
+ )
102
+ model = AutoModelForImageTextToText.from_pretrained(
103
+ model_id,
104
+ attn_implementation="sdpa",
105
+ quantization_config=bnb_config,
106
+ device_map="auto",
107
+ )
108
+
109
+ peft_config = LoraConfig(
110
+ r=64, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM",
111
+ target_modules="all-linear",
112
+ exclude_modules=["vision_tower", "multi_modal_projector"],
113
+ )
114
+
115
+ training_args = SFTConfig(
116
+ output_dir=output_dir,
117
+ num_train_epochs=3,
118
+ per_device_train_batch_size=1,
119
+ gradient_accumulation_steps=16,
120
+ learning_rate=2e-4,
121
+ lr_scheduler_type="cosine",
122
+ warmup_steps=500,
123
+ weight_decay=0.01,
124
+ max_length=2048,
125
+ packing=False,
126
+ bf16=True,
127
+ logging_strategy="steps",
128
+ logging_steps=10,
129
+ logging_first_step=True,
130
+ eval_strategy="no",
131
+ save_strategy="epoch",
132
+ save_total_limit=2,
133
+ push_to_hub=True,
134
+ hub_model_id=output_dir,
135
+ report_to="trackio",
136
+ run_name=output_dir,
137
+ remove_unused_columns=False,
138
+ disable_tqdm=True,
139
+ dataset_num_proc=4,
140
+ gradient_checkpointing=True,
141
+ )
142
+
143
+ trainer = SFTTrainer(
144
+ model=model, args=training_args, train_dataset=train_dataset,
145
+ peft_config=peft_config,
146
+ processing_class=tokenizer, callbacks=[TrackioAlertCallback()],
147
+ )
148
+ print("Starting training...")
149
+ trainer.train()
150
+ trainer.save_model(output_dir)
151
+ trainer.push_to_hub()
152
+ trackio.alert(title="Training Complete", text=f"Model {output_dir} training completed successfully", level="INFO")
153
+ print(f"Done! Model saved to {output_dir}")