| """ |
| Train Llama-3.1-8B-Instruct on open-thoughts/OpenThoughts-114k (reasoning CoT). |
| |
| This dataset contains DeepSeek-R1 distilled reasoning traces. |
| Focuses on: math, code, science with chain-of-thought thinking. |
| |
| Uses LoRA Without Regret config (r=256, all-linear). |
| Smaller dataset (114K) so uses higher LR and fewer epochs. |
| |
| Usage: |
| python train_openthoughts.py |
| python train_openthoughts.py --max_steps 50 # quick test |
| """ |
|
|
| import argparse |
| import torch |
| from datasets import load_dataset |
| from peft import LoraConfig |
| from trl import SFTTrainer, SFTConfig |
| import trackio |
|
|
|
|
| def convert_openthoughts(example): |
| """Convert ShareGPT format to messages format.""" |
| messages = [] |
| if example.get("system"): |
| messages.append({"role": "system", "content": example["system"]}) |
| for turn in example["conversations"]: |
| role = "user" if turn["from"] == "user" else "assistant" |
| messages.append({"role": role, "content": turn["value"]}) |
| return {"messages": messages} |
|
|
|
|
| def train(max_steps=None, push_hub=True, hub_model_id="shaikhsalman/llama-3.1-8b-openthoughts-lora"): |
|
|
| trackio.init( |
| project="devsecops-ml", |
| name="sft-llama3.1-8b-openthoughts", |
| config={ |
| "model": "meta-llama/Llama-3.1-8B-Instruct", |
| "dataset": "open-thoughts/OpenThoughts-114k", |
| "lora_r": 256, |
| "lora_alpha": 16, |
| "target_modules": "all-linear", |
| "learning_rate": 2e-4, |
| }, |
| ) |
|
|
| |
| print("Loading open-thoughts/OpenThoughts-114k...") |
| dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train") |
| print(f"Loaded {len(dataset)} examples (raw format)") |
|
|
| remove_cols = [c for c in dataset.column_names if c != "messages"] |
| dataset = dataset.map(convert_openthoughts, remove_columns=remove_cols) |
| print(f"Converted to messages format: {len(dataset)} examples") |
|
|
| |
| peft_config = LoraConfig( |
| r=256, |
| lora_alpha=16, |
| lora_dropout=0.05, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules="all-linear", |
| ) |
|
|
| |
| training_args = SFTConfig( |
| output_dir="./output/llama3.1-8b-openthoughts-lora", |
| push_to_hub=push_hub, |
| hub_model_id=hub_model_id, |
| model_init_kwargs={ |
| "torch_dtype": torch.bfloat16, |
| "attn_implementation": "flash_attention_2", |
| }, |
| learning_rate=2e-4, |
| per_device_train_batch_size=2, |
| gradient_accumulation_steps=8, |
| num_train_epochs=2, |
| lr_scheduler_type="cosine", |
| warmup_ratio=0.1, |
| max_length=4096, |
| packing=True, |
| packing_strategy="bfd_split", |
| gradient_checkpointing=True, |
| bf16=True, |
| assistant_only_loss=True, |
| eos_token="<|eot_id|>", |
| logging_strategy="steps", |
| logging_steps=25, |
| logging_first_step=True, |
| report_to=["trackio"], |
| disable_tqdm=True, |
| save_strategy="steps", |
| save_steps=500, |
| save_total_limit=3, |
| optim="adamw_torch", |
| ) |
|
|
| if max_steps: |
| training_args.max_steps = max_steps |
|
|
| trainer = SFTTrainer( |
| model="meta-llama/Llama-3.1-8B-Instruct", |
| train_dataset=dataset, |
| peft_config=peft_config, |
| args=training_args, |
| ) |
|
|
| trainer.train() |
|
|
| if push_hub: |
| trainer.push_to_hub() |
| print(f"Model pushed to: https://huggingface.co/{hub_model_id}") |
|
|
| trackio.finish() |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--max_steps", type=int, default=None) |
| parser.add_argument("--hub_model_id", type=str, default="shaikhsalman/llama-3.1-8b-openthoughts-lora") |
| parser.add_argument("--no_push", action="store_true") |
| args = parser.parse_args() |
| train(max_steps=args.max_steps, push_hub=not args.no_push, hub_model_id=args.hub_model_id) |
|
|