File size: 3,989 Bytes
3aaeeb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
09a9397
3aaeeb3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Train Llama-3.1-8B-Instruct on open-thoughts/OpenThoughts-114k (reasoning CoT).

This dataset contains DeepSeek-R1 distilled reasoning traces.
Focuses on: math, code, science with chain-of-thought thinking.

Uses LoRA Without Regret config (r=256, all-linear).
Smaller dataset (114K) so uses higher LR and fewer epochs.

Usage:
  python train_openthoughts.py
  python train_openthoughts.py --max_steps 50  # quick test
"""

import argparse
import torch
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio


def convert_openthoughts(example):
    """Convert ShareGPT format to messages format."""
    messages = []
    if example.get("system"):
        messages.append({"role": "system", "content": example["system"]})
    for turn in example["conversations"]:
        role = "user" if turn["from"] == "user" else "assistant"
        messages.append({"role": role, "content": turn["value"]})
    return {"messages": messages}


def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-openthoughts-lora"):

    trackio.init(
        project="devsecops-ml",
        name="sft-llama3.1-8b-openthoughts",
        config={
            "model": "meta-llama/Llama-3.1-8B-Instruct",
            "dataset": "open-thoughts/OpenThoughts-114k",
            "lora_r": 256,
            "lora_alpha": 16,
            "target_modules": "all-linear",
            "learning_rate": 2e-4,
        },
    )

    # Load and convert
    print("Loading open-thoughts/OpenThoughts-114k...")
    dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
    print(f"Loaded {len(dataset)} examples (raw format)")

    remove_cols = [c for c in dataset.column_names if c != "messages"]
    dataset = dataset.map(convert_openthoughts, remove_columns=remove_cols)
    print(f"Converted to messages format: {len(dataset)} examples")

    # LoRA Without Regret
    peft_config = LoraConfig(
        r=256,
        lora_alpha=16,
        lora_dropout=0.05,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules="all-linear",
    )

    # Smaller dataset = higher LR + more epochs
    training_args = SFTConfig(
        output_dir="./output/llama3.1-8b-openthoughts-lora",
        push_to_hub=push_hub,
        hub_model_id=hub_model_id,
        model_init_kwargs={
            "torch_dtype": torch.bfloat16,
            "attn_implementation": "flash_attention_2",
        },
        learning_rate=2e-4,
        per_device_train_batch_size=2,
        gradient_accumulation_steps=8,  # effective batch = 16
        num_train_epochs=2,
        lr_scheduler_type="cosine",
        warmup_ratio=0.1,
        max_length=4096,
        packing=True,
        packing_strategy="bfd_split",
        gradient_checkpointing=True,
        bf16=True,
        assistant_only_loss=True,
        eos_token="<|eot_id|>",
        logging_strategy="steps",
        logging_steps=25,
        logging_first_step=True,
        report_to=["trackio"],
        disable_tqdm=True,
        save_strategy="steps",
        save_steps=500,
        save_total_limit=3,
        optim="adamw_torch",
    )

    if max_steps:
        training_args.max_steps = max_steps

    trainer = SFTTrainer(
        model="meta-llama/Llama-3.1-8B-Instruct",
        train_dataset=dataset,
        peft_config=peft_config,
        args=training_args,
    )

    trainer.train()

    if push_hub:
        trainer.push_to_hub()
        print(f"Model pushed to: https://huggingface.co/{hub_model_id}")

    trackio.finish()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--max_steps", type=int, default=None)
    parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-openthoughts-lora")
    parser.add_argument("--no_push", action="store_true")
    args = parser.parse_args()
    train(max_steps=args.max_steps, push_hub=not args.no_push, hub_model_id=args.hub_model_id)