ssdataanalysis's picture
Fix: add assistant_only_loss=False to prevent all labels being masked to -100
76bf694 verified
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import random
from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
import trackio
import torch
from transformers import TrainerCallback
trackio.init(
project="hebrew-gemma4",
space_id="ssdataanalysis/mlintern-heb4",
)
class TrackioAlertCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and "loss" in logs:
loss = logs["loss"]
step = state.global_step
if loss > 5.0 and step > 50:
trackio.alert(title="High Loss Warning", text=f"loss={loss:.3f} at step {step} lr too high", level="WARN")
elif step % 100 == 0:
trackio.alert(title="Training Progress", text=f"loss={loss:.3f} at step {step}", level="INFO")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics and "eval_loss" in metrics:
trackio.alert(title="Eval Complete", text=f"eval_loss={metrics['eval_loss']:.3f} at step {state.global_step}", level="INFO")
def convert_alpaca_to_messages(example):
instruction = example.get("instruction", "")
input_text = example.get("input", "")
output = example.get("output", "")
user_content = instruction
if input_text and str(input_text).strip():
user_content += "\n" + str(input_text)
return {"messages": [{"role": "user", "content": user_content}, {"role": "assistant", "content": output}]}
def prepare_dataset(hebrew_ratio=0.5, max_total=120000, seed=42):
random.seed(seed)
datasets_list = []
print("Loading high-quality Hebrew datasets...")
ds_he1 = load_dataset("ashercn97/hebrew_alpaca_gpt4", split="train")
ds_he1 = ds_he1.map(convert_alpaca_to_messages, remove_columns=ds_he1.column_names)
datasets_list.append(("hebrew_alpaca_gpt4", ds_he1))
print(f" hebrew_alpaca_gpt4: {len(ds_he1)}")
ds_he2 = load_dataset("saillab/alpaca-hebrew-cleaned", split="train")
ds_he2 = ds_he2.map(convert_alpaca_to_messages, remove_columns=ds_he2.column_names)
datasets_list.append(("alpaca-hebrew-cleaned", ds_he2))
print(f" alpaca-hebrew-cleaned: {len(ds_he2)}")
print("Loading English datasets...")
ds_en = load_dataset("HuggingFaceTB/OpenHermes-2.5-H4", split="train_sft")
ds_en = ds_en.remove_columns([c for c in ds_en.column_names if c != "messages"])
def filter_messages(example):
msgs = example.get("messages", [])
return all(m.get("role") in ["user", "assistant", "system"] for m in msgs)
ds_en = ds_en.filter(filter_messages)
hebrew_count = len(ds_he1) + len(ds_he2)
english_target = max_total - hebrew_count
if english_target < 10000:
english_target = max_total // 2
hebrew_cap = max_total - english_target
he1_cap = int(hebrew_cap * len(ds_he1) / hebrew_count)
he2_cap = hebrew_cap - he1_cap
ds_he1 = ds_he1.shuffle(seed=seed).select(range(min(len(ds_he1), he1_cap)))
ds_he2 = ds_he2.shuffle(seed=seed).select(range(min(len(ds_he2), he2_cap)))
hebrew_count = len(ds_he1) + len(ds_he2)
english_target = max_total - hebrew_count
if len(ds_en) > english_target:
ds_en = ds_en.shuffle(seed=seed).select(range(english_target))
datasets_list.append(("OpenHermes", ds_en))
print(f" OpenHermes: {len(ds_en)}")
all_datasets = [d for _, d in datasets_list]
combined = concatenate_datasets(all_datasets)
combined = combined.shuffle(seed=seed)
print(f"Final dataset: {len(combined)} samples ({hebrew_count} Hebrew, {len(ds_en)} English)")
return combined
model_id = os.environ.get("MODEL_ID", "google/gemma-4-E4B-it")
output_dir = os.environ.get("OUTPUT_DIR", "ssdataanalysis/gemma-4-E4B-hebrew-first")
print(f"=== Training {model_id} -> {output_dir} ===")
train_dataset = prepare_dataset(hebrew_ratio=0.5, max_total=120000)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print("Loading model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
attn_implementation="sdpa",
quantization_config=bnb_config,
device_map="auto",
)
peft_config = LoraConfig(
r=64, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM",
target_modules="all-linear",
exclude_modules=["vision_tower", "multi_modal_projector"],
)
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
lr_scheduler_type="constant",
warmup_steps=500,
weight_decay=0.01,
max_length=2048,
packing=False,
assistant_only_loss=False, # CRITICAL FIX: prevent all labels being masked to -100
bf16=True,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
eval_strategy="no",
save_strategy="steps",
save_steps=500,
save_total_limit=3,
push_to_hub=True,
hub_model_id=output_dir,
report_to="trackio",
run_name=output_dir.replace("/", "-") + "-optimal",
remove_unused_columns=False,
disable_tqdm=True,
dataset_num_proc=4,
gradient_checkpointing=True,
)
# Check for previously saved adapter on hub to resume from
previous_adapter = None
try:
from huggingface_hub import HfApi
api = HfApi()
files = api.list_repo_files(output_dir, repo_type="model")
if "adapter_model.safetensors" in files:
print("Found previously saved adapter on hub. Loading it as starting point...")
from peft import PeftModel
model = PeftModel.from_pretrained(model, output_dir)
print("Loaded previous adapter weights. Training will continue from this state.")
previous_adapter = True
except Exception as e:
print(f"No previous adapter found ({e}), starting from scratch")
trainer = SFTTrainer(
model=model, args=training_args, train_dataset=train_dataset,
peft_config=peft_config if not previous_adapter else None,
processing_class=tokenizer, callbacks=[TrackioAlertCallback()],
)
print("Starting training...")
trainer.train()
trainer.save_model(output_dir)
trainer.push_to_hub()
trackio.alert(title="Training Complete", text=f"Model {output_dir} training completed successfully", level="INFO")
print(f"Done! Model saved to {output_dir}")