Instructions to use ssdataanalysis/gemma-4-E2B-hebrew-first with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ssdataanalysis/gemma-4-E2B-hebrew-first with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ssdataanalysis/gemma-4-E2B-hebrew-first", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 6,037 Bytes
c572108 1f5ea7f c572108 1f5ea7f c572108 1f5ea7f c572108 1f5ea7f c572108 1f5ea7f c572108 1f5ea7f c572108 960c757 c572108 1f5ea7f c572108 960c757 c572108 960c757 c572108 960c757 c572108 960c757 c572108 1f5ea7f c572108 1f5ea7f 960c757 1f5ea7f c572108 1f5ea7f 522d2a2 1f5ea7f 960c757 8792aab 1f5ea7f 8792aab 960c757 61b62c0 1f5ea7f f0b6c36 522d2a2 1f5ea7f 960c757 1f5ea7f 8792aab 1f5ea7f c572108 1f5ea7f f0b6c36 1f5ea7f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import random
from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
import trackio
import torch
from transformers import TrainerCallback
trackio.init(
project="hebrew-gemma4",
space_id="ssdataanalysis/mlintern-heb4",
)
class TrackioAlertCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and "loss" in logs:
loss = logs["loss"]
step = state.global_step
if loss > 5.0 and step > 50:
trackio.alert(title="High Loss Warning", text=f"loss={loss:.3f} at step {step} lr too high", level="WARN")
elif step % 100 == 0:
trackio.alert(title="Training Progress", text=f"loss={loss:.3f} at step {step}", level="INFO")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics and "eval_loss" in metrics:
trackio.alert(title="Eval Complete", text=f"eval_loss={metrics['eval_loss']:.3f} at step {state.global_step}", level="INFO")
def convert_alpaca_to_messages(example):
instruction = example.get("instruction", "")
input_text = example.get("input", "")
output = example.get("output", "")
user_content = instruction
if input_text and str(input_text).strip():
user_content += "\n" + str(input_text)
return {"messages": [{"role": "user", "content": user_content}, {"role": "assistant", "content": output}]}
def prepare_dataset(hebrew_ratio=0.5, max_total=120000, seed=42):
random.seed(seed)
datasets_list = []
print("Loading high-quality Hebrew datasets...")
ds_he1 = load_dataset("ashercn97/hebrew_alpaca_gpt4", split="train")
ds_he1 = ds_he1.map(convert_alpaca_to_messages, remove_columns=ds_he1.column_names)
datasets_list.append(("hebrew_alpaca_gpt4", ds_he1))
print(f" hebrew_alpaca_gpt4: {len(ds_he1)}")
ds_he2 = load_dataset("saillab/alpaca-hebrew-cleaned", split="train")
ds_he2 = ds_he2.map(convert_alpaca_to_messages, remove_columns=ds_he2.column_names)
datasets_list.append(("alpaca-hebrew-cleaned", ds_he2))
print(f" alpaca-hebrew-cleaned: {len(ds_he2)}")
print("Loading English datasets...")
ds_en = load_dataset("HuggingFaceTB/OpenHermes-2.5-H4", split="train_sft")
ds_en = ds_en.remove_columns([c for c in ds_en.column_names if c != "messages"])
def filter_messages(example):
msgs = example.get("messages", [])
return all(m.get("role") in ["user", "assistant", "system"] for m in msgs)
ds_en = ds_en.filter(filter_messages)
hebrew_count = len(ds_he1) + len(ds_he2)
english_target = max_total - hebrew_count
if english_target < 10000:
english_target = max_total // 2
hebrew_cap = max_total - english_target
he1_cap = int(hebrew_cap * len(ds_he1) / hebrew_count)
he2_cap = hebrew_cap - he1_cap
ds_he1 = ds_he1.shuffle(seed=seed).select(range(min(len(ds_he1), he1_cap)))
ds_he2 = ds_he2.shuffle(seed=seed).select(range(min(len(ds_he2), he2_cap)))
hebrew_count = len(ds_he1) + len(ds_he2)
english_target = max_total - hebrew_count
if len(ds_en) > english_target:
ds_en = ds_en.shuffle(seed=seed).select(range(english_target))
datasets_list.append(("OpenHermes", ds_en))
print(f" OpenHermes: {len(ds_en)}")
all_datasets = [d for _, d in datasets_list]
combined = concatenate_datasets(all_datasets)
combined = combined.shuffle(seed=seed)
print(f"Final dataset: {len(combined)} samples ({hebrew_count} Hebrew, {len(ds_en)} English)")
return combined
model_id = os.environ.get("MODEL_ID", "google/gemma-4-E2B-it")
output_dir = os.environ.get("OUTPUT_DIR", "ssdataanalysis/gemma-4-E2B-hebrew-first")
print(f"=== Training {model_id} -> {output_dir} ===")
train_dataset = prepare_dataset(hebrew_ratio=0.5, max_total=120000)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print("Loading model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
attn_implementation="sdpa",
quantization_config=bnb_config,
device_map="auto",
)
peft_config = LoraConfig(
r=64, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM",
target_modules="all-linear",
exclude_modules=["vision_tower", "multi_modal_projector"],
)
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
lr_scheduler_type="constant",
warmup_steps=500,
weight_decay=0.01,
max_length=2048,
packing=False,
assistant_only_loss=False, # CRITICAL FIX: prevent all labels being masked to -100
bf16=True,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
eval_strategy="no",
save_strategy="steps",
save_steps=500,
save_total_limit=3,
push_to_hub=True,
hub_model_id=output_dir,
report_to="trackio",
run_name=output_dir.replace("/", "-") + "-optimal",
remove_unused_columns=False,
disable_tqdm=True,
dataset_num_proc=4,
gradient_checkpointing=True,
)
trainer = SFTTrainer(
model=model, args=training_args, train_dataset=train_dataset,
peft_config=peft_config,
processing_class=tokenizer, callbacks=[TrackioAlertCallback()],
)
print("Starting training...")
trainer.train()
trainer.save_model(output_dir)
trainer.push_to_hub()
trackio.alert(title="Training Complete", text=f"Model {output_dir} training completed successfully", level="INFO")
print(f"Done! Model saved to {output_dir}")
|