fernando-machina commited on
Commit
e9aeabf
·
verified ·
1 Parent(s): 4d0b701

Upload train_gemma_sportingbot.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_gemma_sportingbot.py +88 -0
train_gemma_sportingbot.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "transformers>=4.36.0", "datasets>=2.16.0"]
3
+ # ///
4
+
5
+ from datasets import load_dataset
6
+ from peft import LoraConfig
7
+ from trl import SFTTrainer, SFTConfig
8
+ import trackio
9
+
10
+ # Load dataset
11
+ dataset = load_dataset("machina-sports/sportingbot-classification", split="train")
12
+
13
+ # Create train/eval split (10% eval)
14
+ dataset_split = dataset.train_test_split(test_size=0.1, seed=42)
15
+
16
+ print(f"✅ Dataset loaded: {len(dataset_split['train'])} train, {len(dataset_split['test'])} eval")
17
+
18
+ # Configure LoRA
19
+ peft_config = LoraConfig(
20
+ r=32,
21
+ lora_alpha=64,
22
+ target_modules="all-linear",
23
+ lora_dropout=0.05,
24
+ bias="none",
25
+ task_type="CAUSAL_LM"
26
+ )
27
+
28
+ # Configure training
29
+ training_args = SFTConfig(
30
+ output_dir="sportingbot-gemma-classifier",
31
+
32
+ # Hub settings (CRITICAL - saves results)
33
+ push_to_hub=True,
34
+ hub_model_id="fernando-machina/sportingbot-gemma-classifier",
35
+ hub_strategy="every_save",
36
+ hub_private_repo=False,
37
+
38
+ # Training hyperparameters (from user's config)
39
+ num_train_epochs=5,
40
+ per_device_train_batch_size=2,
41
+ gradient_accumulation_steps=4,
42
+ learning_rate=0.0001,
43
+
44
+ # Optimization (bf16 for Gemma)
45
+ bf16=True,
46
+ gradient_checkpointing=True,
47
+
48
+ # Evaluation
49
+ eval_strategy="steps",
50
+ eval_steps=10,
51
+
52
+ # Checkpointing
53
+ save_strategy="steps",
54
+ save_steps=50,
55
+ save_total_limit=3,
56
+
57
+ # Logging
58
+ logging_steps=5,
59
+ report_to="trackio",
60
+
61
+ # Trackio monitoring
62
+ project="sportingbot-classification",
63
+ run_name="gemma-2-2b-it-v1",
64
+
65
+ # Sequence length
66
+ max_seq_length=512,
67
+ )
68
+
69
+ print("🚀 Starting training with Gemma 2-2B-it...")
70
+
71
+ # Create trainer
72
+ trainer = SFTTrainer(
73
+ model="google/gemma-2-2b-it",
74
+ train_dataset=dataset_split["train"],
75
+ eval_dataset=dataset_split["test"],
76
+ peft_config=peft_config,
77
+ args=training_args,
78
+ )
79
+
80
+ # Train
81
+ trainer.train()
82
+
83
+ print("✅ Training complete! Pushing to Hub...")
84
+
85
+ # Push final model
86
+ trainer.push_to_hub()
87
+
88
+ print(f"🎉 Model saved to: https://huggingface.co/fernando-machina/sportingbot-gemma-classifier")