""" Train Llama-3.1-8B-Instruct on open-thoughts/OpenThoughts-114k (reasoning CoT). This dataset contains DeepSeek-R1 distilled reasoning traces. Focuses on: math, code, science with chain-of-thought thinking. Uses LoRA Without Regret config (r=256, all-linear). Smaller dataset (114K) so uses higher LR and fewer epochs. Usage: python train_openthoughts.py python train_openthoughts.py --max_steps 50 # quick test """ import argparse import torch from datasets import load_dataset from peft import LoraConfig from trl import SFTTrainer, SFTConfig import trackio def convert_openthoughts(example): """Convert ShareGPT format to messages format.""" messages = [] if example.get("system"): messages.append({"role": "system", "content": example["system"]}) for turn in example["conversations"]: role = "user" if turn["from"] == "user" else "assistant" messages.append({"role": role, "content": turn["value"]}) return {"messages": messages} def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-openthoughts-lora"): trackio.init( project="devsecops-ml", name="sft-llama3.1-8b-openthoughts", config={ "model": "meta-llama/Llama-3.1-8B-Instruct", "dataset": "open-thoughts/OpenThoughts-114k", "lora_r": 256, "lora_alpha": 16, "target_modules": "all-linear", "learning_rate": 2e-4, }, ) # Load and convert print("Loading open-thoughts/OpenThoughts-114k...") dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train") print(f"Loaded {len(dataset)} examples (raw format)") remove_cols = [c for c in dataset.column_names if c != "messages"] dataset = dataset.map(convert_openthoughts, remove_columns=remove_cols) print(f"Converted to messages format: {len(dataset)} examples") # LoRA Without Regret peft_config = LoraConfig( r=256, lora_alpha=16, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules="all-linear", ) # Smaller dataset = higher LR + more epochs training_args = SFTConfig( output_dir="./output/llama3.1-8b-openthoughts-lora", push_to_hub=push_hub, hub_model_id=hub_model_id, model_init_kwargs={ "torch_dtype": torch.bfloat16, "attn_implementation": "flash_attention_2", }, learning_rate=2e-4, per_device_train_batch_size=2, gradient_accumulation_steps=8, # effective batch = 16 num_train_epochs=2, lr_scheduler_type="cosine", warmup_ratio=0.1, max_length=4096, packing=True, packing_strategy="bfd_split", gradient_checkpointing=True, bf16=True, assistant_only_loss=True, eos_token="<|eot_id|>", logging_strategy="steps", logging_steps=25, logging_first_step=True, report_to=["trackio"], disable_tqdm=True, save_strategy="steps", save_steps=500, save_total_limit=3, optim="adamw_torch", ) if max_steps: training_args.max_steps = max_steps trainer = SFTTrainer( model="meta-llama/Llama-3.1-8B-Instruct", train_dataset=dataset, peft_config=peft_config, args=training_args, ) trainer.train() if push_hub: trainer.push_to_hub() print(f"Model pushed to: https://huggingface.co/{hub_model_id}") trackio.finish() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--max_steps", type=int, default=None) parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-openthoughts-lora") parser.add_argument("--no_push", action="store_true") args = parser.parse_args() train(max_steps=args.max_steps, push_hub=not args.no_push, hub_model_id=args.hub_model_id)