Instructions to use ssdataanalysis/gemma-4-E4B-hebrew-first with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ssdataanalysis/gemma-4-E4B-hebrew-first with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ssdataanalysis/gemma-4-E4B-hebrew-first", dtype="auto") - Notebooks
- Google Colab
- Kaggle
File size: 6,723 Bytes
779c4ca b88a92b 779c4ca b88a92b 779c4ca b88a92b 779c4ca b88a92b 779c4ca b88a92b 779c4ca b88a92b 779c4ca 38eb889 779c4ca b88a92b 779c4ca b88a92b 76bf694 779c4ca 38eb889 779c4ca 76bf694 779c4ca 12dd270 779c4ca 12dd270 779c4ca 12dd270 779c4ca | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 | import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
import random
from datasets import load_dataset, concatenate_datasets
from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer
import trackio
import torch
from transformers import TrainerCallback
trackio.init(
project="hebrew-gemma4",
space_id="ssdataanalysis/mlintern-heb4",
)
class TrackioAlertCallback(TrainerCallback):
def on_log(self, args, state, control, logs=None, **kwargs):
if logs and "loss" in logs:
loss = logs["loss"]
step = state.global_step
if loss > 5.0 and step > 50:
trackio.alert(title="High Loss Warning", text=f"loss={loss:.3f} at step {step} lr too high", level="WARN")
elif step % 100 == 0:
trackio.alert(title="Training Progress", text=f"loss={loss:.3f} at step {step}", level="INFO")
def on_evaluate(self, args, state, control, metrics=None, **kwargs):
if metrics and "eval_loss" in metrics:
trackio.alert(title="Eval Complete", text=f"eval_loss={metrics['eval_loss']:.3f} at step {state.global_step}", level="INFO")
def convert_alpaca_to_messages(example):
instruction = example.get("instruction", "")
input_text = example.get("input", "")
output = example.get("output", "")
user_content = instruction
if input_text and str(input_text).strip():
user_content += "\n" + str(input_text)
return {"messages": [{"role": "user", "content": user_content}, {"role": "assistant", "content": output}]}
def prepare_dataset(hebrew_ratio=0.5, max_total=120000, seed=42):
random.seed(seed)
datasets_list = []
print("Loading high-quality Hebrew datasets...")
ds_he1 = load_dataset("ashercn97/hebrew_alpaca_gpt4", split="train")
ds_he1 = ds_he1.map(convert_alpaca_to_messages, remove_columns=ds_he1.column_names)
datasets_list.append(("hebrew_alpaca_gpt4", ds_he1))
print(f" hebrew_alpaca_gpt4: {len(ds_he1)}")
ds_he2 = load_dataset("saillab/alpaca-hebrew-cleaned", split="train")
ds_he2 = ds_he2.map(convert_alpaca_to_messages, remove_columns=ds_he2.column_names)
datasets_list.append(("alpaca-hebrew-cleaned", ds_he2))
print(f" alpaca-hebrew-cleaned: {len(ds_he2)}")
print("Loading English datasets...")
ds_en = load_dataset("HuggingFaceTB/OpenHermes-2.5-H4", split="train_sft")
ds_en = ds_en.remove_columns([c for c in ds_en.column_names if c != "messages"])
def filter_messages(example):
msgs = example.get("messages", [])
return all(m.get("role") in ["user", "assistant", "system"] for m in msgs)
ds_en = ds_en.filter(filter_messages)
hebrew_count = len(ds_he1) + len(ds_he2)
english_target = max_total - hebrew_count
if english_target < 10000:
english_target = max_total // 2
hebrew_cap = max_total - english_target
he1_cap = int(hebrew_cap * len(ds_he1) / hebrew_count)
he2_cap = hebrew_cap - he1_cap
ds_he1 = ds_he1.shuffle(seed=seed).select(range(min(len(ds_he1), he1_cap)))
ds_he2 = ds_he2.shuffle(seed=seed).select(range(min(len(ds_he2), he2_cap)))
hebrew_count = len(ds_he1) + len(ds_he2)
english_target = max_total - hebrew_count
if len(ds_en) > english_target:
ds_en = ds_en.shuffle(seed=seed).select(range(english_target))
datasets_list.append(("OpenHermes", ds_en))
print(f" OpenHermes: {len(ds_en)}")
all_datasets = [d for _, d in datasets_list]
combined = concatenate_datasets(all_datasets)
combined = combined.shuffle(seed=seed)
print(f"Final dataset: {len(combined)} samples ({hebrew_count} Hebrew, {len(ds_en)} English)")
return combined
model_id = os.environ.get("MODEL_ID", "google/gemma-4-E4B-it")
output_dir = os.environ.get("OUTPUT_DIR", "ssdataanalysis/gemma-4-E4B-hebrew-first")
print(f"=== Training {model_id} -> {output_dir} ===")
train_dataset = prepare_dataset(hebrew_ratio=0.5, max_total=120000)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print("Loading model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForImageTextToText.from_pretrained(
model_id,
attn_implementation="sdpa",
quantization_config=bnb_config,
device_map="auto",
)
peft_config = LoraConfig(
r=64, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM",
target_modules="all-linear",
exclude_modules=["vision_tower", "multi_modal_projector"],
)
training_args = SFTConfig(
output_dir=output_dir,
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=2e-4,
lr_scheduler_type="constant",
warmup_steps=500,
weight_decay=0.01,
max_length=2048,
packing=False,
assistant_only_loss=False, # CRITICAL FIX: prevent all labels being masked to -100
bf16=True,
logging_strategy="steps",
logging_steps=10,
logging_first_step=True,
eval_strategy="no",
save_strategy="steps",
save_steps=500,
save_total_limit=3,
push_to_hub=True,
hub_model_id=output_dir,
report_to="trackio",
run_name=output_dir.replace("/", "-") + "-optimal",
remove_unused_columns=False,
disable_tqdm=True,
dataset_num_proc=4,
gradient_checkpointing=True,
)
# Check for previously saved adapter on hub to resume from
previous_adapter = None
try:
from huggingface_hub import HfApi
api = HfApi()
files = api.list_repo_files(output_dir, repo_type="model")
if "adapter_model.safetensors" in files:
print("Found previously saved adapter on hub. Loading it as starting point...")
from peft import PeftModel
model = PeftModel.from_pretrained(model, output_dir)
print("Loaded previous adapter weights. Training will continue from this state.")
previous_adapter = True
except Exception as e:
print(f"No previous adapter found ({e}), starting from scratch")
trainer = SFTTrainer(
model=model, args=training_args, train_dataset=train_dataset,
peft_config=peft_config if not previous_adapter else None,
processing_class=tokenizer, callbacks=[TrackioAlertCallback()],
)
print("Starting training...")
trainer.train()
trainer.save_model(output_dir)
trainer.push_to_hub()
trackio.alert(title="Training Complete", text=f"Model {output_dir} training completed successfully", level="INFO")
print(f"Done! Model saved to {output_dir}")
|