File size: 6,615 Bytes
7c19d46
bae9038
7c19d46
bae9038
 
 
 
 
 
 
 
7c19d46
 
 
 
 
bae9038
7c19d46
bae9038
7c19d46
 
 
 
 
bae9038
7c19d46
 
 
 
bae9038
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c19d46
 
bae9038
7c19d46
bae9038
7c19d46
bae9038
7c19d46
bae9038
 
 
7c19d46
bae9038
7c19d46
bae9038
 
 
 
 
7c19d46
 
bae9038
7c19d46
 
 
 
 
 
bae9038
 
 
 
 
 
 
7c19d46
 
bae9038
7c19d46
 
 
 
bae9038
7c19d46
 
 
bae9038
 
 
7c19d46
bae9038
 
7c19d46
 
 
 
 
bae9038
7c19d46
 
 
 
 
 
 
 
 
36df1e5
7c19d46
bae9038
7c19d46
 
 
bae9038
 
 
 
7c19d46
 
 
 
 
 
 
 
 
 
 
 
bae9038
7c19d46
bae9038
7c19d46
bae9038
 
7c19d46
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# =============================================================================
# HuggingFace Fine-Tuning Script — vNext Production Training
# =============================================================================
# Based on: "LoRA Without Regret" (Schulman et al., 2025)
#   - LoRA matches full fine-tuning with correct configuration
#   - Key: all-linear targets + r=256 + LR 2e-4 + batch < 32
#
# Datasets (ranked by quality):
#   PRIMARY:  allenai/tulu-3-sft-mixture (940K examples, 19 sources)
#   REASONING: open-thoughts/OpenThoughts-114k (CoT traces)
#   FALLBACK: HuggingFaceH4/ultrachat_200k (200K multi-turn chat)
# =============================================================================

import os
import torch
from dataclasses import dataclass, field
from typing import Optional, List

from datasets import load_dataset, concatenate_datasets
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio


# ---------- Dataset Registry ----------
DATASET_REGISTRY = {
    "tulu-3-sft": {
        "name": "allenai/tulu-3-sft-mixture",
        "split": "train",
        "format": "messages",  # Already conversational
        "size": "~940K",
        "quality": "BEST — 19 curated sources (math, code, IF, safety, science)",
    },
    "openthoughts-114k": {
        "name": "open-thoughts/OpenThoughts-114k",
        "split": "train",
        "format": "conversations",  # Needs conversion
        "size": "~114K",
        "quality": "EXCELLENT — reasoning CoT traces",
    },
    "ultrachat-200k": {
        "name": "HuggingFaceH4/ultrachat_200k",
        "split": "train_sft",
        "format": "messages",
        "size": "~200K",
        "quality": "GOOD — multi-turn chat (baseline fallback)",
    },
}


def convert_openthoughts_to_messages(example):
    """Convert OpenThoughts conversations format to standard messages."""
    messages = []
    if example.get("system"):
        messages.append({"role": "system", "content": example["system"]})
    for turn in example["conversations"]:
        role = "user" if turn["from"] == "user" else "assistant"
        messages.append({"role": role, "content": turn["value"]})
    return {"messages": messages}


def load_and_prepare_dataset(dataset_key: str, max_samples: Optional[int] = None):
    """Load and format a dataset from the registry."""
    info = DATASET_REGISTRY[dataset_key]
    ds = load_dataset(info["name"], split=info["split"])

    if max_samples:
        ds = ds.select(range(min(max_samples, len(ds))))

    if dataset_key == "openthoughts-114k":
        remove_cols = [c for c in ds.column_names if c != "messages"]
        ds = ds.map(
            convert_openthoughts_to_messages,
            remove_columns=remove_cols,
        )

    return ds


@dataclass
class FinetuneConfig:
    """Fine-tuning hyperparameters — vNext (LoRA Without Regret config)."""
    model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
    dataset_key: str = "tulu-3-sft"  # Options: tulu-3-sft, openthoughts-114k, ultrachat-200k
    output_dir: str = "/output/models"
    hub_model_id: str = "devsecops/finetuned-llama-v2"

    # LoRA (LoRA Without Regret optimal config)
    lora_r: int = 256          # r=256 — sufficient capacity for SFT-scale datasets
    lora_alpha: int = 16       # alpha=16 — stable scaling
    lora_dropout: float = 0.05
    target_modules: str = "all-linear"  # ALL linear layers, not just attention

    # Training (LoRA Without Regret: batch < 32, LR = 2e-4)
    num_train_epochs: int = 1  # 1 epoch sufficient for 940K dataset
    per_device_train_batch_size: int = 2
    gradient_accumulation_steps: int = 8  # effective batch = 16 (< 32!)
    learning_rate: float = 2e-4  # 10x full FT rate
    max_seq_length: int = 2048
    warmup_ratio: float = 0.1
    lr_scheduler_type: str = "cosine"

    # Optimization
    bf16: bool = True
    gradient_checkpointing: bool = True
    optim: str = "adamw_torch"

    # Packing (LoRA Without Regret recommends packing=True)
    packing: bool = True
    packing_strategy: str = "bfd_split"  # Preserves all tokens

    # Loss
    assistant_only_loss: bool = True  # Only compute loss on assistant tokens


def finetune(config: FinetuneConfig):
    """Fine-tune a model with LoRA + SFT (vNext — LoRA Without Regret config)."""

    # --- Trackio monitoring ---
    trackio.init(
        project="devsecops-ml",
        name=f"sft-{config.model_name.split('/')[-1]}-{config.dataset_key}",
        config=vars(config),
    )

    # --- Dataset (best available) ---
    dataset = load_and_prepare_dataset(config.dataset_key)
    print(f"Dataset: {config.dataset_key} ({len(dataset)} examples)")

    # --- LoRA (LoRA Without Regret: all-linear, r=256) ---
    peft_config = LoraConfig(
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=config.target_modules,
    )

    # --- SFT Config ---
    sft_config = SFTConfig(
        output_dir=config.output_dir,
        num_train_epochs=config.num_train_epochs,
        per_device_train_batch_size=config.per_device_train_batch_size,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        learning_rate=config.learning_rate,
        max_length=config.max_seq_length,
        warmup_ratio=config.warmup_ratio,
        lr_scheduler_type=config.lr_scheduler_type,
        bf16=config.bf16,
        gradient_checkpointing=config.gradient_checkpointing,
        optim=config.optim,
        packing=config.packing,
        packing_strategy=config.packing_strategy,
        assistant_only_loss=config.assistant_only_loss,
        eos_token="<|eot_id|>",
        logging_strategy="steps",
        logging_steps=10,
        logging_first_step=True,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=3,
        push_to_hub=True,
        hub_model_id=config.hub_model_id,
        report_to="trackio",
        disable_tqdm=True,
    )

    # --- Trainer (SFTTrainer handles model loading + PEFT) ---
    trainer = SFTTrainer(
        model=config.model_name,
        train_dataset=dataset,
        peft_config=peft_config,
        args=sft_config,
    )

    # --- Train ---
    trainer.train()

    # --- Save ---
    trainer.push_to_hub()
    trackio.finish()
    print(f"Model pushed to: https://huggingface.co/{config.hub_model_id}")


if __name__ == "__main__":
    config = FinetuneConfig()
    finetune(config)