File size: 2,714 Bytes
6963ef7
 
 
 
 
 
 
c4dd148
6963ef7
 
 
 
 
 
2638a07
6963ef7
 
 
 
 
2638a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6963ef7
2638a07
6963ef7
 
 
 
 
 
 
 
 
 
 
 
 
2638a07
 
6963ef7
 
2638a07
 
6963ef7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2638a07
 
 
 
6963ef7
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python3
# /// script
# dependencies = [
#     "trl>=0.12.0",
#     "peft>=0.7.0",
#     "transformers>=4.36.0",
#     "accelerate>=0.24.0",
#     "trackio",
# ]
# ///

from datasets import load_dataset
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
import trackio
import os

print("πŸš€ Starting TRL + Trackio Demo")
print("=" * 50)

# Initialize Trackio with Space sync for remote viewing
# Trackio will auto-create the Space if it doesn't exist
print("\nπŸ“Š Initializing Trackio...")
trackio.init(
    project="trl-demo",
    space_id="evalstate/trl-trackio-dashboard",  # Auto-creates if needed!
    config={
        "model": "Qwen/Qwen2.5-0.5B",
        "dataset": "trl-lib/Capybara",
        "max_steps": 50,  # Longer for better visualization
        "learning_rate": 2e-5,
    }
)
print("βœ… Trackio initialized! Dashboard: https://huggingface.co/spaces/evalstate/trl-trackio-dashboard")

# Load a small dataset (200 examples for better visualization)
print("\nπŸ“Š Loading dataset...")
dataset = load_dataset("trl-lib/Capybara", split="train[:200]")
print(f"βœ… Dataset loaded: {len(dataset)} examples")

# Get username for hub push
username = os.environ.get("HF_USERNAME", "evalstate")  # fallback to evalstate

# Training configuration with Trackio enabled
print("\nβš™οΈ  Configuring training...")
config = SFTConfig(
    # Output and Hub settings
    output_dir="trl-demo",
    push_to_hub=True,
    hub_model_id=f"{username}/trl-trackio-demo",

    # Training settings (longer for better metrics)
    max_steps=50,  # More steps for visualization
    per_device_train_batch_size=2,

    # Logging (log frequently for real-time monitoring)
    logging_steps=5,

    # Trackio monitoring - this is the key!
    report_to="trackio",

    # Learning rate
    learning_rate=2e-5,
)

# LoRA configuration (reduces memory usage)
print("πŸ”§ Setting up LoRA...")
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)

# Initialize trainer
print("\n🎯 Initializing trainer...")
trainer = SFTTrainer(
    model="Qwen/Qwen2.5-0.5B",
    train_dataset=dataset,
    args=config,
    peft_config=peft_config,
)

# Train!
print("\nπŸƒ Training started...")
print("πŸ“ˆ Trackio will track: loss, learning rate, GPU usage, memory, throughput")
print("-" * 50)
trainer.train()

# Save to Hub
print("\nπŸ’Ύ Pushing to Hub...")
trainer.push_to_hub()

# Finish Trackio logging
print("\nπŸ“Š Finalizing Trackio...")
trackio.finish()

print("\nβœ… Demo complete!")
print(f"πŸ“¦ Model saved to: https://huggingface.co/{username}/trl-trackio-demo")
print("πŸ“Š Check Trackio for training metrics and visualizations!")