Instructions to use ssdataanalysis/gemma-4-E2B-hebrew-first with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ssdataanalysis/gemma-4-E2B-hebrew-first with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ssdataanalysis/gemma-4-E2B-hebrew-first", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" | |
| import random | |
| from datasets import load_dataset, concatenate_datasets | |
| from transformers import AutoModelForImageTextToText, AutoTokenizer, BitsAndBytesConfig | |
| from peft import LoraConfig | |
| from trl import SFTConfig, SFTTrainer | |
| import trackio | |
| import torch | |
| from transformers import TrainerCallback | |
| trackio.init( | |
| project="hebrew-gemma4", | |
| space_id="ssdataanalysis/mlintern-heb4", | |
| ) | |
| class TrackioAlertCallback(TrainerCallback): | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| if logs and "loss" in logs: | |
| loss = logs["loss"] | |
| step = state.global_step | |
| if loss > 5.0 and step > 50: | |
| trackio.alert(title="High Loss Warning", text=f"loss={loss:.3f} at step {step} lr too high", level="WARN") | |
| elif step % 100 == 0: | |
| trackio.alert(title="Training Progress", text=f"loss={loss:.3f} at step {step}", level="INFO") | |
| def on_evaluate(self, args, state, control, metrics=None, **kwargs): | |
| if metrics and "eval_loss" in metrics: | |
| trackio.alert(title="Eval Complete", text=f"eval_loss={metrics['eval_loss']:.3f} at step {state.global_step}", level="INFO") | |
| def convert_alpaca_to_messages(example): | |
| instruction = example.get("instruction", "") | |
| input_text = example.get("input", "") | |
| output = example.get("output", "") | |
| user_content = instruction | |
| if input_text and str(input_text).strip(): | |
| user_content += "\n" + str(input_text) | |
| return {"messages": [{"role": "user", "content": user_content}, {"role": "assistant", "content": output}]} | |
| def prepare_dataset(hebrew_ratio=0.5, max_total=120000, seed=42): | |
| random.seed(seed) | |
| datasets_list = [] | |
| print("Loading high-quality Hebrew datasets...") | |
| ds_he1 = load_dataset("ashercn97/hebrew_alpaca_gpt4", split="train") | |
| ds_he1 = ds_he1.map(convert_alpaca_to_messages, remove_columns=ds_he1.column_names) | |
| datasets_list.append(("hebrew_alpaca_gpt4", ds_he1)) | |
| print(f" hebrew_alpaca_gpt4: {len(ds_he1)}") | |
| ds_he2 = load_dataset("saillab/alpaca-hebrew-cleaned", split="train") | |
| ds_he2 = ds_he2.map(convert_alpaca_to_messages, remove_columns=ds_he2.column_names) | |
| datasets_list.append(("alpaca-hebrew-cleaned", ds_he2)) | |
| print(f" alpaca-hebrew-cleaned: {len(ds_he2)}") | |
| print("Loading English datasets...") | |
| ds_en = load_dataset("HuggingFaceTB/OpenHermes-2.5-H4", split="train_sft") | |
| ds_en = ds_en.remove_columns([c for c in ds_en.column_names if c != "messages"]) | |
| def filter_messages(example): | |
| msgs = example.get("messages", []) | |
| return all(m.get("role") in ["user", "assistant", "system"] for m in msgs) | |
| ds_en = ds_en.filter(filter_messages) | |
| hebrew_count = len(ds_he1) + len(ds_he2) | |
| english_target = max_total - hebrew_count | |
| if english_target < 10000: | |
| english_target = max_total // 2 | |
| hebrew_cap = max_total - english_target | |
| he1_cap = int(hebrew_cap * len(ds_he1) / hebrew_count) | |
| he2_cap = hebrew_cap - he1_cap | |
| ds_he1 = ds_he1.shuffle(seed=seed).select(range(min(len(ds_he1), he1_cap))) | |
| ds_he2 = ds_he2.shuffle(seed=seed).select(range(min(len(ds_he2), he2_cap))) | |
| hebrew_count = len(ds_he1) + len(ds_he2) | |
| english_target = max_total - hebrew_count | |
| if len(ds_en) > english_target: | |
| ds_en = ds_en.shuffle(seed=seed).select(range(english_target)) | |
| datasets_list.append(("OpenHermes", ds_en)) | |
| print(f" OpenHermes: {len(ds_en)}") | |
| all_datasets = [d for _, d in datasets_list] | |
| combined = concatenate_datasets(all_datasets) | |
| combined = combined.shuffle(seed=seed) | |
| print(f"Final dataset: {len(combined)} samples ({hebrew_count} Hebrew, {len(ds_en)} English)") | |
| return combined | |
| model_id = os.environ.get("MODEL_ID", "google/gemma-4-E2B-it") | |
| output_dir = os.environ.get("OUTPUT_DIR", "ssdataanalysis/gemma-4-E2B-hebrew-first") | |
| print(f"=== Training {model_id} -> {output_dir} ===") | |
| train_dataset = prepare_dataset(hebrew_ratio=0.5, max_total=120000) | |
| print("Loading tokenizer...") | |
| tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) | |
| print("Loading model with 4-bit quantization...") | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| model_id, | |
| attn_implementation="sdpa", | |
| quantization_config=bnb_config, | |
| device_map="auto", | |
| ) | |
| peft_config = LoraConfig( | |
| r=64, lora_alpha=16, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", | |
| target_modules="all-linear", | |
| exclude_modules=["vision_tower", "multi_modal_projector"], | |
| ) | |
| training_args = SFTConfig( | |
| output_dir=output_dir, | |
| num_train_epochs=3, | |
| per_device_train_batch_size=4, | |
| gradient_accumulation_steps=4, | |
| learning_rate=2e-4, | |
| lr_scheduler_type="constant", | |
| warmup_steps=500, | |
| weight_decay=0.01, | |
| max_length=2048, | |
| packing=False, | |
| assistant_only_loss=False, # CRITICAL FIX: prevent all labels being masked to -100 | |
| bf16=True, | |
| logging_strategy="steps", | |
| logging_steps=10, | |
| logging_first_step=True, | |
| eval_strategy="no", | |
| save_strategy="steps", | |
| save_steps=500, | |
| save_total_limit=3, | |
| push_to_hub=True, | |
| hub_model_id=output_dir, | |
| report_to="trackio", | |
| run_name=output_dir.replace("/", "-") + "-optimal", | |
| remove_unused_columns=False, | |
| disable_tqdm=True, | |
| dataset_num_proc=4, | |
| gradient_checkpointing=True, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, args=training_args, train_dataset=train_dataset, | |
| peft_config=peft_config, | |
| processing_class=tokenizer, callbacks=[TrackioAlertCallback()], | |
| ) | |
| print("Starting training...") | |
| trainer.train() | |
| trainer.save_model(output_dir) | |
| trainer.push_to_hub() | |
| trackio.alert(title="Training Complete", text=f"Model {output_dir} training completed successfully", level="INFO") | |
| print(f"Done! Model saved to {output_dir}") | |