evalstate HF Staff commited on
Commit
6963ef7
ยท
verified ยท
1 Parent(s): 9ffa969

Upload demo_train.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. demo_train.py +80 -0
demo_train.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # /// script
3
+ # dependencies = [
4
+ # "trl>=0.12.0",
5
+ # "peft>=0.7.0",
6
+ # "transformers>=4.36.0",
7
+ # "accelerate>=0.24.0",
8
+ # ]
9
+ # ///
10
+
11
+ from datasets import load_dataset
12
+ from peft import LoraConfig
13
+ from trl import SFTTrainer, SFTConfig
14
+ import os
15
+
16
+ print("๐Ÿš€ Starting TRL + Trackio Demo")
17
+ print("=" * 50)
18
+
19
+ # Load a tiny dataset (just 50 examples for demo)
20
+ print("\n๐Ÿ“Š Loading dataset...")
21
+ dataset = load_dataset("trl-lib/Capybara", split="train[:50]")
22
+ print(f"โœ… Dataset loaded: {len(dataset)} examples")
23
+
24
+ # Get username for hub push
25
+ username = os.environ.get("HF_USERNAME", "evalstate") # fallback to evalstate
26
+
27
+ # Training configuration with Trackio enabled
28
+ print("\nโš™๏ธ Configuring training...")
29
+ config = SFTConfig(
30
+ # Output and Hub settings
31
+ output_dir="trl-demo",
32
+ push_to_hub=True,
33
+ hub_model_id=f"{username}/trl-trackio-demo",
34
+
35
+ # Quick demo settings
36
+ max_steps=10, # Very short for demo
37
+ per_device_train_batch_size=2,
38
+
39
+ # Logging
40
+ logging_steps=2,
41
+
42
+ # Trackio monitoring - this is the key!
43
+ report_to="trackio",
44
+
45
+ # Learning rate
46
+ learning_rate=2e-5,
47
+ )
48
+
49
+ # LoRA configuration (reduces memory usage)
50
+ print("๐Ÿ”ง Setting up LoRA...")
51
+ peft_config = LoraConfig(
52
+ r=8,
53
+ lora_alpha=16,
54
+ lora_dropout=0.05,
55
+ bias="none",
56
+ task_type="CAUSAL_LM",
57
+ )
58
+
59
+ # Initialize trainer
60
+ print("\n๐ŸŽฏ Initializing trainer...")
61
+ trainer = SFTTrainer(
62
+ model="Qwen/Qwen2.5-0.5B",
63
+ train_dataset=dataset,
64
+ args=config,
65
+ peft_config=peft_config,
66
+ )
67
+
68
+ # Train!
69
+ print("\n๐Ÿƒ Training started...")
70
+ print("๐Ÿ“ˆ Trackio will track: loss, learning rate, GPU usage, memory, throughput")
71
+ print("-" * 50)
72
+ trainer.train()
73
+
74
+ # Save to Hub
75
+ print("\n๐Ÿ’พ Pushing to Hub...")
76
+ trainer.push_to_hub()
77
+
78
+ print("\nโœ… Demo complete!")
79
+ print(f"๐Ÿ“ฆ Model saved to: https://huggingface.co/{username}/trl-trackio-demo")
80
+ print("๐Ÿ“Š Check Trackio for training metrics and visualizations!")