walidsobhie-code Claude Opus 4.6 commited on
Commit
444c0e7
·
1 Parent(s): fd566cb

feat: add local training script for Mac MPS

Browse files

- Simple one-command training on Mac
- Downloads model if not present
- Uses existing training data
- Configured for MPS (Apple Silicon)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. train_local.py +115 -0
train_local.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Stack 2.9 Local Training Script for Mac (MPS)
4
+ Run this on your Mac to train the model locally.
5
+ """
6
+
7
+ import os
8
+ import sys
9
+
10
+ # Add the training module to path
11
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'stack/training'))
12
+
13
+ # Set environment for MPS
14
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
15
+
16
+ def main():
17
+ print("=" * 60)
18
+ print("Stack 2.9 Local Training (Mac MPS)")
19
+ print("=" * 60)
20
+
21
+ # Check MPS availability
22
+ try:
23
+ import torch
24
+ print(f"PyTorch version: {torch.__version__}")
25
+ print(f"MPS available: {torch.backends.mps.is_available()}")
26
+ if torch.backends.mps.is_available():
27
+ print(f"MPS built: {torch.backends.mps.is_built()}")
28
+ except Exception as e:
29
+ print(f"⚠️ PyTorch/MPS check error: {e}")
30
+
31
+ # Check paths
32
+ base_model = "./base_model_qwen7b"
33
+ data_path = "./data/final/train.jsonl"
34
+ output_dir = "./training_output"
35
+ model_name = "Qwen/Qwen2.5-Coder-7B"
36
+
37
+ print(f"\n📁 Checking paths...")
38
+ print(f" Base model: {base_model} - {'✅ exists' if os.path.exists(base_model) else '❌ not found'}")
39
+ print(f" Data: {data_path} - {'✅ exists' if os.path.exists(data_path) else '❌ not found'}")
40
+
41
+ # Download model if not exists
42
+ if not os.path.exists(base_model):
43
+ print(f"\n⬇️ Downloading model ({model_name})...")
44
+ print(" This takes ~10-15 minutes...")
45
+ from transformers import AutoModelForCausalLM, AutoTokenizer
46
+
47
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
48
+ tokenizer.save_pretrained(base_model)
49
+ model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
50
+ model.save_pretrained(base_model)
51
+ print(f" ✅ Model saved to {base_model}")
52
+ else:
53
+ print(f" ✅ Model already exists!")
54
+
55
+ if not os.path.exists(data_path):
56
+ print("\n❌ Training data not found!")
57
+ print(" Expected: ./data/final/train.jsonl")
58
+ print(" Available data files:")
59
+ for root, dirs, files in os.walk("./data"):
60
+ for f in files:
61
+ if f.endswith(".jsonl"):
62
+ print(f" - {os.path.join(root, f)}")
63
+ return
64
+
65
+ # Create output directory
66
+ os.makedirs(output_dir, exist_ok=True)
67
+
68
+ # Load and update config
69
+ import yaml
70
+
71
+ config_path = "stack/training/train_config_local.yaml"
72
+ if os.path.exists(config_path):
73
+ with open(config_path, 'r') as f:
74
+ config = yaml.safe_load(f)
75
+ else:
76
+ print(f"⚠️ Config not found at {config_path}, using defaults")
77
+ config = {
78
+ 'model': {'name': base_model, 'trust_remote_code': True},
79
+ 'data': {'input_path': data_path, 'max_length': 2048},
80
+ 'lora': {'r': 16, 'alpha': 32, 'target_modules': ['q_proj', 'k_proj', 'v_proj', 'o_proj']},
81
+ 'training': {'num_epochs': 1, 'batch_size': 1, 'learning_rate': 2e-4},
82
+ 'output': {'lora_dir': f'{output_dir}/lora', 'merged_dir': f'{output_dir}/merged'},
83
+ 'hardware': {'device': 'mps'}
84
+ }
85
+
86
+ # Update config with local paths
87
+ config['model']['name'] = base_model
88
+ config['data']['input_path'] = data_path
89
+ config['output']['lora_dir'] = f"{output_dir}/lora"
90
+ config['output']['merged_dir'] = f"{output_dir}/merged"
91
+ config['hardware']['device'] = "mps"
92
+
93
+ # Save updated config
94
+ updated_config = f"{output_dir}/train_config.yaml"
95
+ with open(updated_config, 'w') as f:
96
+ yaml.dump(config, f)
97
+
98
+ print(f"\n✅ Config saved to: {updated_config}")
99
+ print(f"\n🚀 Starting training...")
100
+ print(f" Output will be at: {output_dir}/")
101
+ print("=" * 60)
102
+
103
+ # Run training
104
+ from train_lora import train_lora
105
+ trainer = train_lora(updated_config)
106
+
107
+ print("=" * 60)
108
+ print("✅ TRAINING COMPLETED!")
109
+ print(f" LoRA adapter: {output_dir}/lora/")
110
+ print(f" Merged model: {output_dir}/merged/")
111
+ print("=" * 60)
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()