| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from datasets import load_dataset | |
| from peft import LoraConfig, TaskType | |
| from trl import SFTTrainer, SFTConfig | |
| import trackio | |
| model_name = "./SmolLM3-3B-Base/" | |
| dataset_path = "./MathInstruct/MathInstruct.json" | |
| output_dir = "./SmolLMathematician-3B" | |
| project_name = "SmolLMathematician-3B" | |
| MAX_SEQ_LENGTH = 4096 | |
| trackio.init(project=project_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| trust_remote_code=True, | |
| attn_implementation="flash_attention_2", | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model.config.pad_token_id = model.config.eos_token_id | |
| with open("chat_template.jinja", "r") as f: | |
| chat_template = f.read() | |
| tokenizer.chat_template = chat_template | |
| model.gradient_checkpointing_enable() | |
| dataset = load_dataset("json", data_files=dataset_path, split="train") | |
| def formatInstructionWithTemplate(example: dict) -> str: | |
| messages = [ | |
| {"role": "user", "content": example["instruction"]}, | |
| {"role": "assistant", "content": example["output"]}, | |
| ] | |
| return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False) | |
| def checkSequenceLength(example: dict) -> bool: | |
| formatted_text = formatInstructionWithTemplate(example) | |
| tokens = tokenizer(formatted_text) | |
| return len(tokens['input_ids']) <= MAX_SEQ_LENGTH | |
| original_size = len(dataset) | |
| train_dataset = dataset.filter(checkSequenceLength) | |
| new_size = len(train_dataset) | |
| print(f"Dataset: {original_size} → {new_size} samples (removed: {original_size - new_size})") | |
| torch.cuda.empty_cache() | |
| peft_config = LoraConfig( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| target_modules=['q_proj', 'v_proj'], | |
| bias="none", | |
| task_type=TaskType.CAUSAL_LM, | |
| ) | |
| training_args = SFTConfig( | |
| output_dir=output_dir, | |
| num_train_epochs=1, | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=8, | |
| optim="paged_adamw_8bit", | |
| learning_rate=2e-5, | |
| weight_decay=0.01, | |
| adam_epsilon=1e-6, | |
| max_grad_norm=1.0, | |
| lr_scheduler_type="cosine", | |
| warmup_ratio=0.1, | |
| logging_steps=8, | |
| eval_strategy="no", | |
| save_strategy="steps", | |
| save_steps=32, | |
| save_total_limit=4, | |
| resume_from_checkpoint=True, | |
| report_to="trackio", | |
| bf16=True, | |
| packing=True, | |
| max_length=MAX_SEQ_LENGTH, | |
| dataloader_pin_memory=False, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| ) | |
| trainer = SFTTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| peft_config=peft_config, | |
| formatting_func=formatInstructionWithTemplate, | |
| ) | |
| torch.cuda.empty_cache() | |
| trainer.train() | |
| torch.cuda.empty_cache() | |
| trainer.save_model(output_dir) | |
| print(f"LoRA adapter saved to {output_dir}") | |
| trackio.finish() | |