AndreasThinks commited on
Commit
dc96d09
·
verified ·
1 Parent(s): fcb8a43

Upload train_ministral.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_ministral.py +104 -0
train_ministral.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # dependencies = ["trl>=0.12.0", "peft>=0.7.0", "trackio", "torch>=2.0.0", "transformers>=4.40.0"]
3
+ # ///
4
+
5
+ """Fine-tune Ministral-3-3B-Instruct on NATO doctrine dataset."""
6
+
7
+ from datasets import load_dataset
8
+ from peft import LoraConfig
9
+ from trl import SFTTrainer, SFTConfig
10
+ import trackio
11
+
12
+ # Load dataset from HF Hub
13
+ print("Loading NATO doctrine dataset...")
14
+ dataset = load_dataset("AndreasThinks/nato-doctrine-sft", split="train")
15
+ dataset_test = load_dataset("AndreasThinks/nato-doctrine-sft", split="test")
16
+
17
+ print(f"✓ Train set: {len(dataset)} examples")
18
+ print(f"✓ Test set: {len(dataset_test)} examples")
19
+
20
+ # Configure LoRA for efficient fine-tuning
21
+ peft_config = LoraConfig(
22
+ r=16,
23
+ lora_alpha=32,
24
+ lora_dropout=0.05,
25
+ bias="none",
26
+ task_type="CAUSAL_LM",
27
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
28
+ )
29
+
30
+ # Training configuration
31
+ training_args = SFTConfig(
32
+ output_dir="nato-ministral-3b",
33
+
34
+ # Model saving
35
+ push_to_hub=True,
36
+ hub_model_id="AndreasThinks/ministral-3b-nato-doctrine",
37
+ hub_strategy="every_save",
38
+ hub_private_repo=False,
39
+
40
+ # Training parameters
41
+ num_train_epochs=3,
42
+ per_device_train_batch_size=2,
43
+ per_device_eval_batch_size=2,
44
+ gradient_accumulation_steps=8, # Effective batch size = 16
45
+ gradient_checkpointing=True,
46
+
47
+ # Learning rate
48
+ learning_rate=2e-4,
49
+ lr_scheduler_type="cosine",
50
+ warmup_ratio=0.1,
51
+
52
+ # Optimization
53
+ optim="adamw_torch",
54
+ weight_decay=0.01,
55
+ max_grad_norm=1.0,
56
+
57
+ # Evaluation
58
+ eval_strategy="steps",
59
+ eval_steps=50,
60
+
61
+ # Logging and saving
62
+ logging_steps=10,
63
+ save_strategy="steps",
64
+ save_steps=100,
65
+ save_total_limit=3,
66
+
67
+ # Monitoring with Trackio
68
+ report_to="trackio",
69
+ run_name="nato-ministral-3b-v1",
70
+ project="nato-doctrine-training",
71
+
72
+ # Other
73
+ bf16=True, # Use bfloat16 for better stability
74
+ seed=42,
75
+ )
76
+
77
+ # Initialize trainer
78
+ print("\n✓ Initializing SFT trainer...")
79
+ trainer = SFTTrainer(
80
+ model="mistralai/Ministral-3-3B-Instruct-2512",
81
+ train_dataset=dataset,
82
+ eval_dataset=dataset_test,
83
+ peft_config=peft_config,
84
+ args=training_args,
85
+ )
86
+
87
+ # Start training
88
+ print("\n✓ Starting training...")
89
+ print(f" Model: mistralai/Ministral-3-3B-Instruct-2512")
90
+ print(f" Training examples: {len(dataset)}")
91
+ print(f" Test examples: {len(dataset_test)}")
92
+ print(f" Epochs: 3")
93
+ print(f" LoRA rank: 16")
94
+ print(f" Output: AndreasThinks/ministral-3b-nato-doctrine\n")
95
+
96
+ trainer.train()
97
+
98
+ # Save final model
99
+ print("\n✓ Training complete! Saving final model...")
100
+ trainer.push_to_hub()
101
+
102
+ print("\n✅ Fine-tuning complete!")
103
+ print(f" Model: https://huggingface.co/AndreasThinks/ministral-3b-nato-doctrine")
104
+ print(f" Trackio: Check your dashboard for metrics")