devsecops-platform / model /train_openthoughts.py
shaikhsalman's picture
refactor: merged structure - model at center, DevSecOps wrapped around it
9d4d5c7 verified
"""
Train Llama-3.1-8B-Instruct on open-thoughts/OpenThoughts-114k (reasoning CoT).
This dataset contains DeepSeek-R1 distilled reasoning traces.
Focuses on: math, code, science with chain-of-thought thinking.
Uses LoRA Without Regret config (r=256, all-linear).
Smaller dataset (114K) so uses higher LR and fewer epochs.
Usage:
python train_openthoughts.py
python train_openthoughts.py --max_steps 50 # quick test
"""
import argparse
import torch
from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
def convert_openthoughts(example):
"""Convert ShareGPT format to messages format."""
messages = []
if example.get("system"):
messages.append({"role": "system", "content": example["system"]})
for turn in example["conversations"]:
role = "user" if turn["from"] == "user" else "assistant"
messages.append({"role": role, "content": turn["value"]})
return {"messages": messages}
def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-openthoughts-lora"):
trackio.init(
project="devsecops-ml",
name="sft-llama3.1-8b-openthoughts",
config={
"model": "meta-llama/Llama-3.1-8B-Instruct",
"dataset": "open-thoughts/OpenThoughts-114k",
"lora_r": 256,
"lora_alpha": 16,
"target_modules": "all-linear",
"learning_rate": 2e-4,
},
)
# Load and convert
print("Loading open-thoughts/OpenThoughts-114k...")
dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train")
print(f"Loaded {len(dataset)} examples (raw format)")
remove_cols = [c for c in dataset.column_names if c != "messages"]
dataset = dataset.map(convert_openthoughts, remove_columns=remove_cols)
print(f"Converted to messages format: {len(dataset)} examples")
# LoRA Without Regret
peft_config = LoraConfig(
r=256,
lora_alpha=16,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
target_modules="all-linear",
)
# Smaller dataset = higher LR + more epochs
training_args = SFTConfig(
output_dir="./output/llama3.1-8b-openthoughts-lora",
push_to_hub=push_hub,
hub_model_id=hub_model_id,
model_init_kwargs={
"torch_dtype": torch.bfloat16,
"attn_implementation": "flash_attention_2",
},
learning_rate=2e-4,
per_device_train_batch_size=2,
gradient_accumulation_steps=8, # effective batch = 16
num_train_epochs=2,
lr_scheduler_type="cosine",
warmup_ratio=0.1,
max_length=4096,
packing=True,
packing_strategy="bfd_split",
gradient_checkpointing=True,
bf16=True,
assistant_only_loss=True,
eos_token="<|eot_id|>",
logging_strategy="steps",
logging_steps=25,
logging_first_step=True,
report_to=["trackio"],
disable_tqdm=True,
save_strategy="steps",
save_steps=500,
save_total_limit=3,
optim="adamw_torch",
)
if max_steps:
training_args.max_steps = max_steps
trainer = SFTTrainer(
model="meta-llama/Llama-3.1-8B-Instruct",
train_dataset=dataset,
peft_config=peft_config,
args=training_args,
)
trainer.train()
if push_hub:
trainer.push_to_hub()
print(f"Model pushed to: https://huggingface.co/{hub_model_id}")
trackio.finish()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max_steps", type=int, default=None)
parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-openthoughts-lora")
parser.add_argument("--no_push", action="store_true")
args = parser.parse_args()
train(max_steps=args.max_steps, push_hub=not args.no_push, hub_model_id=args.hub_model_id)