| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import torch |
| from dataclasses import dataclass, field |
| from typing import Optional, List |
|
|
| from datasets import load_dataset, concatenate_datasets |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| BitsAndBytesConfig, |
| ) |
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig |
| import trackio |
|
|
|
|
| |
| DATASET_REGISTRY = { |
| "tulu-3-sft": { |
| "name": "allenai/tulu-3-sft-mixture", |
| "split": "train", |
| "format": "messages", |
| "size": "~940K", |
| "quality": "BEST β 19 curated sources (math, code, IF, safety, science)", |
| }, |
| "openthoughts-114k": { |
| "name": "open-thoughts/OpenThoughts-114k", |
| "split": "train", |
| "format": "conversations", |
| "size": "~114K", |
| "quality": "EXCELLENT β reasoning CoT traces", |
| }, |
| "ultrachat-200k": { |
| "name": "HuggingFaceH4/ultrachat_200k", |
| "split": "train_sft", |
| "format": "messages", |
| "size": "~200K", |
| "quality": "GOOD β multi-turn chat (baseline fallback)", |
| }, |
| } |
|
|
|
|
| def convert_openthoughts_to_messages(example): |
| """Convert OpenThoughts conversations format to standard messages.""" |
| messages = [] |
| if example.get("system"): |
| messages.append({"role": "system", "content": example["system"]}) |
| for turn in example["conversations"]: |
| role = "user" if turn["from"] == "user" else "assistant" |
| messages.append({"role": role, "content": turn["value"]}) |
| return {"messages": messages} |
|
|
|
|
| def load_and_prepare_dataset(dataset_key: str, max_samples: Optional[int] = None): |
| """Load and format a dataset from the registry.""" |
| info = DATASET_REGISTRY[dataset_key] |
| ds = load_dataset(info["name"], split=info["split"]) |
|
|
| if max_samples: |
| ds = ds.select(range(min(max_samples, len(ds)))) |
|
|
| if dataset_key == "openthoughts-114k": |
| remove_cols = [c for c in ds.column_names if c != "messages"] |
| ds = ds.map( |
| convert_openthoughts_to_messages, |
| remove_columns=remove_cols, |
| ) |
|
|
| return ds |
|
|
|
|
| @dataclass |
| class FinetuneConfig: |
| """Fine-tuning hyperparameters β vNext (LoRA Without Regret config).""" |
| model_name: str = "meta-llama/Llama-3.1-8B-Instruct" |
| dataset_key: str = "tulu-3-sft" |
| output_dir: str = "/output/models" |
| hub_model_id: str = "devsecops/finetuned-llama-v2" |
|
|
| |
| lora_r: int = 256 |
| lora_alpha: int = 16 |
| lora_dropout: float = 0.05 |
| target_modules: str = "all-linear" |
|
|
| |
| num_train_epochs: int = 1 |
| per_device_train_batch_size: int = 2 |
| gradient_accumulation_steps: int = 8 |
| learning_rate: float = 2e-4 |
| max_seq_length: int = 2048 |
| warmup_ratio: float = 0.1 |
| lr_scheduler_type: str = "cosine" |
|
|
| |
| bf16: bool = True |
| gradient_checkpointing: bool = True |
| optim: str = "adamw_torch" |
|
|
| |
| packing: bool = True |
| packing_strategy: str = "bfd_split" |
|
|
| |
| assistant_only_loss: bool = True |
|
|
|
|
| def finetune(config: FinetuneConfig): |
| """Fine-tune a model with LoRA + SFT (vNext β LoRA Without Regret config).""" |
|
|
| |
| trackio.init( |
| project="devsecops-ml", |
| name=f"sft-{config.model_name.split('/')[-1]}-{config.dataset_key}", |
| config=vars(config), |
| ) |
|
|
| |
| dataset = load_and_prepare_dataset(config.dataset_key) |
| print(f"Dataset: {config.dataset_key} ({len(dataset)} examples)") |
|
|
| |
| peft_config = LoraConfig( |
| r=config.lora_r, |
| lora_alpha=config.lora_alpha, |
| lora_dropout=config.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules=config.target_modules, |
| ) |
|
|
| |
| sft_config = SFTConfig( |
| output_dir=config.output_dir, |
| num_train_epochs=config.num_train_epochs, |
| per_device_train_batch_size=config.per_device_train_batch_size, |
| gradient_accumulation_steps=config.gradient_accumulation_steps, |
| learning_rate=config.learning_rate, |
| max_length=config.max_seq_length, |
| warmup_ratio=config.warmup_ratio, |
| lr_scheduler_type=config.lr_scheduler_type, |
| bf16=config.bf16, |
| gradient_checkpointing=config.gradient_checkpointing, |
| optim=config.optim, |
| packing=config.packing, |
| packing_strategy=config.packing_strategy, |
| assistant_only_loss=config.assistant_only_loss, |
| eos_token="<|eot_id|>", |
| logging_strategy="steps", |
| logging_steps=10, |
| logging_first_step=True, |
| save_strategy="steps", |
| save_steps=500, |
| save_total_limit=3, |
| push_to_hub=True, |
| hub_model_id=config.hub_model_id, |
| report_to="trackio", |
| disable_tqdm=True, |
| ) |
|
|
| |
| trainer = SFTTrainer( |
| model=config.model_name, |
| train_dataset=dataset, |
| peft_config=peft_config, |
| args=sft_config, |
| ) |
|
|
| |
| trainer.train() |
|
|
| |
| trainer.push_to_hub() |
| trackio.finish() |
| print(f"Model pushed to: https://huggingface.co/{config.hub_model_id}") |
|
|
|
|
| if __name__ == "__main__": |
| config = FinetuneConfig() |
| finetune(config) |
|
|