|
|
--- |
|
|
library_name: trl |
|
|
pipeline_tag: text-generation |
|
|
datasets: |
|
|
- SRP-base-model-training/dataset_04_06_2025 |
|
|
language: |
|
|
- kk |
|
|
- en |
|
|
- ru |
|
|
|
|
|
--- |
|
|
|
|
|
## Base model_v2 gemma_3_800M_base_v2_multilingual_10B_data |
|
|
|
|
|
June 23 |
|
|
|
|
|
Base model trained on 10B kk,en,ru data. |
|
|
|
|
|
|
|
|
### Inference params |
|
|
```python |
|
|
import torch |
|
|
from transformers import AutoTokenizer, Gemma3ForCausalLM |
|
|
import os |
|
|
os.environ["CUDA_VISIBLE_DEVICE"] = "0,1" |
|
|
# Загрузка твоей обученной модели |
|
|
model_path = "SRP-base-model-training/gemma_3_800M_base_v2_multilingual_10B_data" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_path) |
|
|
model = Gemma3ForCausalLM.from_pretrained( |
|
|
model_path, |
|
|
torch_dtype=torch.bfloat16, |
|
|
device_map="auto" |
|
|
) |
|
|
|
|
|
# example = {"system": "Вы профессиональный переводчик. Переведите следующее предложение на қазақ язык.", "user": "<src=ru><tgt=kk>\nЗа один год с тех пор какие изменения произошли в Туркестане, какое дело доведено до конца?", "assistant": "Содан бергі бір жыл ішінде Түркістанда қандай өзгерістер болды, нендей іс тындырылды?"} |
|
|
# example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nСауда-саттықта салқынқандылық басым.", "assistant": "Composure prevails in trade."} |
|
|
example = {"system": "Сіз кәсіби аудармашысыз. Төмендегі сөйлемді English тіліне аударыңыз.", "user": "<src=kk><tgt=en>\nқала картасы", "assistant": "city map"} |
|
|
s = example["system"] |
|
|
u = example["user"] |
|
|
a = example["assistant"] |
|
|
|
|
|
tok = tokenizer |
|
|
# Промпт в формате чата |
|
|
prompt = ( |
|
|
(f"<start_of_turn>system\n{s}<end_of_turn>\n" |
|
|
f"<start_of_turn>user\n{u}<end_of_turn>\n" |
|
|
f"<start_of_turn>assistant")) |
|
|
|
|
|
model_inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
input_len = model_inputs["input_ids"].shape[-1] |
|
|
|
|
|
with torch.inference_mode(): |
|
|
generation = model.generate( |
|
|
**model_inputs, |
|
|
max_new_tokens=64, |
|
|
do_sample=True, |
|
|
top_p=0.9, |
|
|
#temperature=0.7, |
|
|
#repetition_penalty=1.2, |
|
|
eos_token_id=tok.convert_tokens_to_ids("<end_of_turn>"), |
|
|
pad_token_id=tok.eos_token_id, |
|
|
#min_new_tokens=5, |
|
|
) |
|
|
generation = generation[0][input_len:] |
|
|
|
|
|
decoded = tokenizer.decode(generation, skip_special_tokens=True) |
|
|
print(decoded) |
|
|
``` |
|
|
|
|
|
|
|
|
|
|
|
### Train |
|
|
|
|
|
Main script for training |
|
|
|
|
|
```python |
|
|
# gemma_pretrain_mix_cli.py – balance 50 % KK, 30 % RU, 20 % EN |
|
|
|
|
|
import os, math, json, argparse |
|
|
from pathlib import Path |
|
|
from datasets import (load_dataset, concatenate_datasets, |
|
|
disable_caching) |
|
|
from transformers import (AutoTokenizer, Gemma3TextConfig, |
|
|
Gemma3ForCausalLM, |
|
|
DataCollatorForLanguageModeling) |
|
|
from trl import SFTTrainer, SFTConfig |
|
|
|
|
|
disable_caching() |
|
|
|
|
|
# ────────── CLI ────────── |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--tokenizer_path", required=True) |
|
|
parser.add_argument("--meta_files", nargs=3, required=True, |
|
|
metavar=("META_KK", "META_RU", "META_EN"), |
|
|
help="пути к meta_*.json в порядке kk ru en") |
|
|
parser.add_argument("--output_dir", default="runs/gemma_mix_50_30_20") |
|
|
parser.add_argument("--model_path") |
|
|
parser.add_argument("--max_seq_length", type=int, default=2048) |
|
|
parser.add_argument("--per_device_batch_size", type=int, default=32) |
|
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=8) |
|
|
parser.add_argument("--learning_rate", type=float, default=3e-4) |
|
|
parser.add_argument("--wandb_project", default="gemma-pretrain") |
|
|
parser.add_argument("--wandb_run_name") |
|
|
args = parser.parse_args() |
|
|
|
|
|
cpu = os.cpu_count() |
|
|
os.environ["WANDB_PROJECT"] = args.wandb_project |
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "true" |
|
|
|
|
|
# ────────── Tokenizer / Model ────────── |
|
|
tok = AutoTokenizer.from_pretrained(args.tokenizer_path, use_fast=True) |
|
|
|
|
|
if args.model_path: |
|
|
model = Gemma3ForCausalLM.from_pretrained( |
|
|
args.model_path, torch_dtype="bfloat16", _attn_implementation="eager") |
|
|
else: |
|
|
# TODO WRONG |
|
|
# cfg = Gemma3TextConfig( |
|
|
# vocab_size=len(tok), |
|
|
# bos_token_id=tok.bos_token_id, eos_token_id=tok.eos_token_id, pad_token_id=tok.pad_token_id, |
|
|
# hidden_size=2304, num_hidden_layers=26, num_attention_heads=4, head_dim=256, |
|
|
# intermediate_size=9216, max_position_embeddings=32_768, |
|
|
# torch_dtype="bfloat16", _attn_implementation="eager") |
|
|
model = Gemma3ForCausalLM(cfg) |
|
|
model.resize_token_embeddings(len(tok)) |
|
|
|
|
|
# ────────── Load helper ────────── |
|
|
def load_meta(path: str): |
|
|
meta = json.load(open(path)) |
|
|
return concatenate_datasets( |
|
|
[load_dataset("json", data_files=i["path"], split="train") |
|
|
for i in meta.values()] |
|
|
) |
|
|
|
|
|
kk_ds, ru_ds, en_ds = [load_meta(p) for p in args.meta_files] |
|
|
print(f"Raw rows — KK={len(kk_ds):,}, RU={len(ru_ds):,}, EN={len(en_ds):,}") |
|
|
|
|
|
# ────────── Target sizes 50 / 30 / 20 ────────── |
|
|
target_total = int(len(kk_ds) / 0.50) # kk = 50 % |
|
|
need_ru = int(target_total * 0.30) |
|
|
need_en = int(target_total * 0.20) |
|
|
|
|
|
def resize(ds, need): |
|
|
if len(ds) >= need: # down-sample |
|
|
return ds.shuffle(seed=42).select(range(need)) |
|
|
reps = need // len(ds) + 1 # up-sample |
|
|
big = concatenate_datasets([ds] * reps).shuffle(seed=42) |
|
|
return big.select(range(need)) |
|
|
|
|
|
ru_ds = resize(ru_ds, need_ru) |
|
|
en_ds = resize(en_ds, need_en) |
|
|
print(f"Balanced rows — KK={len(kk_ds):,}, RU={len(ru_ds):,}, EN={len(en_ds):,}") |
|
|
|
|
|
# ────────── Merge & preprocess ────────── |
|
|
ds = concatenate_datasets([kk_ds, ru_ds, en_ds]).shuffle(seed=42) |
|
|
|
|
|
def add_bos_eos(ex): |
|
|
return {"text": f"{tok.bos_token}{ex['text']}{tok.eos_token}"} |
|
|
ds = ds.map(add_bos_eos, num_proc=cpu) |
|
|
|
|
|
# ────────── Training params ────────── |
|
|
world = int(os.getenv("WORLD_SIZE", 1)) |
|
|
eff_bs = args.per_device_batch_size * args.grad_acc * world |
|
|
max_st = math.ceil(len(ds) / eff_bs) |
|
|
print(f"Dataset={len(ds):,} eff_batch={eff_bs} max_steps={max_st}") |
|
|
|
|
|
collator = DataCollatorForLanguageModeling(tok, mlm=False) |
|
|
cfg_t = SFTConfig( |
|
|
output_dir=args.output_dir, |
|
|
max_seq_length=args.max_seq_length, |
|
|
packing=True, bf16=True, |
|
|
per_device_train_batch_size=args.per_device_batch_size, |
|
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
|
learning_rate=args.learning_rate, |
|
|
warmup_ratio=0.05, |
|
|
max_grad_norm=2.0, |
|
|
max_steps=max_st, |
|
|
lr_scheduler_type="cosine", |
|
|
optim="paged_adamw_8bit", |
|
|
save_steps=200, save_total_limit=20, |
|
|
logging_steps=1, |
|
|
deepspeed="ds_stage1.json", |
|
|
run_name=args.wandb_run_name, |
|
|
report_to="wandb", |
|
|
dataloader_num_workers=8, |
|
|
dataset_text_field="text", |
|
|
dataset_num_proc=cpu, |
|
|
) |
|
|
|
|
|
trainer = SFTTrainer(model=model, args=cfg_t, |
|
|
train_dataset=ds, data_collator=collator, |
|
|
processing_class=tok, formatting_func=None) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("🚀 Start pre-training 50/30/20") |
|
|
trainer.train() |
|
|
trainer.save_model(f"{args.output_dir}/checkpoint-final") |
|
|
tok.save_pretrained(f"{args.output_dir}/checkpoint-final") |
|
|
``` |
|
|
|
|
|
To run training please use similar bash |
|
|
|
|
|
```bash |
|
|
#bash |
|
|
|
|
|
export TRITON_CACHE_DIR=/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/utils/cache/.triton |
|
|
mkdir -p "$TRITON_CACHE_DIR" |
|
|
|
|
|
export WANDB_API_KEY="" |
|
|
|
|
|
OUTPUT_DIR='/scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling' |
|
|
WANDB_RUN_NAME='base-model-v1_gemma_1B_test_v2_with_kk_en_ru' |
|
|
if [ ! -d "$OUTPUT_DIR" ]; then |
|
|
mkdir -p "$OUTPUT_DIR" |
|
|
fi |
|
|
|
|
|
# --model_path "/scratch/vladimir_albrekht/projects/smollm/trl_italian_apporach/runs/my_experiment/checkpoint-final" \ |
|
|
|
|
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \ |
|
|
torchrun --standalone --nproc_per_node 8 base_train_v2_multi.py \ |
|
|
--tokenizer_path "/scratch/vladimir_albrekht/projects/smollm/models/tokenizers/tok_best_version_50_000_vocab_abai_20_june" \ |
|
|
--max_seq_length 2048 \ |
|
|
--meta_files \ |
|
|
/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_kk.json \ |
|
|
/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_ru.json \ |
|
|
/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/meta_en.json \ |
|
|
--per_device_batch_size 32 \ |
|
|
--gradient_accumulation_steps 8 \ |
|
|
--learning_rate 3e-4 \ |
|
|
--output_dir ${OUTPUT_DIR} \ |
|
|
--wandb_project "small_llm_SRP" \ |
|
|
--wandb_run_name ${WANDB_RUN_NAME} |
|
|
``` |
|
|
|
|
|
Meta in such format |
|
|
|
|
|
```json |
|
|
"train_en_news_cleaned_v2_splited_processed.jsonl": { |
|
|
"path": "/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/en_data/train.jsonl", |
|
|
"examples": 268890, |
|
|
"tokens": 92970273 |
|
|
}, |
|
|
"train_en_news_cleaned_v2_splited_processed_2.jsonl": { |
|
|
"path": "/scratch/vladimir_albrekht/projects/smollm/data/base_train_dataset_04_06_2025/en_data/train_2.jsonl", |
|
|
"examples": 268123, |
|
|
"tokens": 64523423 |
|
|
} |
|
|
``` |
|
|
|
|
|
>> Notes: path /scratch/vladimir_albrekht/projects/smollm/output_checkpoints/test_2_multiling/checkpoint-1978 |