catninja123 commited on
Commit
5694ce7
·
verified ·
1 Parent(s): 69da0bb

Upload src/run_training.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/run_training.py +206 -0
src/run_training.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MASH Training Pipeline - Complete Entry Point for HuggingFace Space
3
+
4
+ Runs the full pipeline:
5
+ 1. Merge and prepare data
6
+ 2. Stage 2: Style-injection SFT
7
+ 3. Stage 3: DPO with GPTZero (optional, requires API key)
8
+ 4. Evaluate and save results
9
+
10
+ Usage on HF Space:
11
+ python run_training.py --stage sft # Run SFT only
12
+ python run_training.py --stage dpo # Run DPO (requires SFT checkpoint)
13
+ python run_training.py --stage all # Run full pipeline
14
+ python run_training.py --stage eval # Evaluate model
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import json
20
+ import argparse
21
+ import subprocess
22
+ import time
23
+
24
+ # Add src to path
25
+ SRC_DIR = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.insert(0, SRC_DIR)
27
+ BASE_DIR = os.path.dirname(SRC_DIR)
28
+
29
+
30
+ def run_merge():
31
+ """Merge Gemini + Grok pairs into training data."""
32
+ print("\n" + "="*60)
33
+ print("STEP 0: Merging training data")
34
+ print("="*60)
35
+ subprocess.run([sys.executable, os.path.join(SRC_DIR, 'merge_pairs.py')], check=True)
36
+
37
+
38
+ def run_sft(args):
39
+ """Run Style-injection SFT."""
40
+ print("\n" + "="*60)
41
+ print("STEP 1: Style-injection SFT")
42
+ print("="*60)
43
+
44
+ cmd = [
45
+ sys.executable, os.path.join(SRC_DIR, 'train_sft.py'),
46
+ '--train_data', os.path.join(BASE_DIR, 'data', 'train.jsonl'),
47
+ '--val_data', os.path.join(BASE_DIR, 'data', 'val.jsonl'),
48
+ '--output_dir', os.path.join(BASE_DIR, 'checkpoints', 'sft'),
49
+ '--model_name', args.model_name,
50
+ '--style_dim', str(args.style_dim),
51
+ '--batch_size', str(args.sft_batch_size),
52
+ '--epochs', str(args.sft_epochs),
53
+ '--lr', str(args.sft_lr),
54
+ '--lambda_recon', str(args.lambda_recon),
55
+ '--recon_ratio', str(args.recon_ratio),
56
+ '--max_input_len', str(args.max_len),
57
+ '--max_target_len', str(args.max_len),
58
+ ]
59
+ subprocess.run(cmd, check=True)
60
+
61
+
62
+ def run_dpo_construct(args):
63
+ """Construct DPO preference pairs using GPTZero."""
64
+ print("\n" + "="*60)
65
+ print("STEP 2a: Constructing DPO pairs with GPTZero")
66
+ print("="*60)
67
+
68
+ cmd = [
69
+ sys.executable, os.path.join(SRC_DIR, 'train_dpo.py'),
70
+ '--mode', 'construct',
71
+ '--sft_model_path', os.path.join(BASE_DIR, 'checkpoints', 'sft', 'best'),
72
+ '--train_data', os.path.join(BASE_DIR, 'data', 'train.jsonl'),
73
+ '--dpo_data', os.path.join(BASE_DIR, 'data', 'dpo_pairs.jsonl'),
74
+ '--max_dpo_samples', str(args.dpo_samples),
75
+ '--ai_threshold', str(args.ai_threshold),
76
+ ]
77
+ subprocess.run(cmd, check=True)
78
+
79
+
80
+ def run_dpo_train(args):
81
+ """Run DPO training."""
82
+ print("\n" + "="*60)
83
+ print("STEP 2b: DPO Training")
84
+ print("="*60)
85
+
86
+ cmd = [
87
+ sys.executable, os.path.join(SRC_DIR, 'train_dpo.py'),
88
+ '--mode', 'train',
89
+ '--sft_model_path', os.path.join(BASE_DIR, 'checkpoints', 'sft', 'best'),
90
+ '--dpo_data', os.path.join(BASE_DIR, 'data', 'dpo_pairs.jsonl'),
91
+ '--output_dir', os.path.join(BASE_DIR, 'checkpoints', 'dpo'),
92
+ '--batch_size', str(args.dpo_batch_size),
93
+ '--epochs', str(args.dpo_epochs),
94
+ '--lr', str(args.dpo_lr),
95
+ '--beta', str(args.dpo_beta),
96
+ ]
97
+ subprocess.run(cmd, check=True)
98
+
99
+
100
+ def run_eval(args):
101
+ """Evaluate model on validation set."""
102
+ print("\n" + "="*60)
103
+ print("STEP 3: Evaluation")
104
+ print("="*60)
105
+
106
+ # Determine which model to evaluate
107
+ dpo_path = os.path.join(BASE_DIR, 'checkpoints', 'dpo', 'best')
108
+ sft_path = os.path.join(BASE_DIR, 'checkpoints', 'sft', 'best')
109
+
110
+ if os.path.exists(dpo_path):
111
+ model_path = dpo_path
112
+ print(f"Evaluating DPO model: {model_path}")
113
+ elif os.path.exists(sft_path):
114
+ model_path = sft_path
115
+ print(f"Evaluating SFT model: {model_path}")
116
+ else:
117
+ print("ERROR: No trained model found")
118
+ return
119
+
120
+ cmd = [
121
+ sys.executable, os.path.join(SRC_DIR, 'inference.py'),
122
+ '--model_path', model_path,
123
+ '--input', os.path.join(BASE_DIR, 'data', 'val.jsonl'),
124
+ '--output', os.path.join(BASE_DIR, 'checkpoints', 'eval_results.jsonl'),
125
+ '--batch_size', str(args.eval_batch_size),
126
+ ]
127
+
128
+ if args.eval_gptzero:
129
+ cmd.append('--eval_gptzero')
130
+
131
+ subprocess.run(cmd, check=True)
132
+
133
+
134
+ def main():
135
+ parser = argparse.ArgumentParser(description='MASH Training Pipeline')
136
+
137
+ # Stage selection
138
+ parser.add_argument('--stage', default='all',
139
+ choices=['merge', 'sft', 'dpo_construct', 'dpo_train', 'dpo', 'all', 'eval'],
140
+ help='Which stage to run')
141
+
142
+ # Model config
143
+ parser.add_argument('--model_name', default='facebook/bart-base')
144
+ parser.add_argument('--style_dim', type=int, default=64)
145
+ parser.add_argument('--max_len', type=int, default=512)
146
+
147
+ # SFT config
148
+ parser.add_argument('--sft_batch_size', type=int, default=16)
149
+ parser.add_argument('--sft_epochs', type=int, default=5)
150
+ parser.add_argument('--sft_lr', type=float, default=3e-5)
151
+ parser.add_argument('--lambda_recon', type=float, default=0.3)
152
+ parser.add_argument('--recon_ratio', type=float, default=0.3)
153
+
154
+ # DPO config
155
+ parser.add_argument('--dpo_batch_size', type=int, default=4)
156
+ parser.add_argument('--dpo_epochs', type=int, default=3)
157
+ parser.add_argument('--dpo_lr', type=float, default=1e-5)
158
+ parser.add_argument('--dpo_beta', type=float, default=0.1)
159
+ parser.add_argument('--dpo_samples', type=int, default=500)
160
+ parser.add_argument('--ai_threshold', type=float, default=0.5)
161
+
162
+ # Eval config
163
+ parser.add_argument('--eval_batch_size', type=int, default=8)
164
+ parser.add_argument('--eval_gptzero', action='store_true')
165
+
166
+ args = parser.parse_args()
167
+
168
+ print("="*60)
169
+ print("MASH Training Pipeline")
170
+ print(f"Stage: {args.stage}")
171
+ print(f"Model: {args.model_name}")
172
+ print(f"Device: {'CUDA' if os.environ.get('CUDA_VISIBLE_DEVICES') or os.path.exists('/dev/nvidia0') else 'CPU'}")
173
+ print("="*60)
174
+
175
+ t0 = time.time()
176
+
177
+ if args.stage in ['merge', 'all']:
178
+ run_merge()
179
+
180
+ if args.stage in ['sft', 'all']:
181
+ run_sft(args)
182
+
183
+ if args.stage in ['dpo_construct', 'dpo', 'all']:
184
+ if os.environ.get('GPTZERO_API_KEY'):
185
+ run_dpo_construct(args)
186
+ else:
187
+ print("\nWARNING: GPTZERO_API_KEY not set, skipping DPO construction")
188
+
189
+ if args.stage in ['dpo_train', 'dpo', 'all']:
190
+ dpo_data = os.path.join(BASE_DIR, 'data', 'dpo_pairs.jsonl')
191
+ if os.path.exists(dpo_data):
192
+ run_dpo_train(args)
193
+ else:
194
+ print("\nWARNING: DPO data not found, skipping DPO training")
195
+
196
+ if args.stage in ['eval', 'all']:
197
+ run_eval(args)
198
+
199
+ elapsed = time.time() - t0
200
+ print(f"\n{'='*60}")
201
+ print(f"Pipeline complete in {elapsed/60:.1f} minutes")
202
+ print(f"{'='*60}")
203
+
204
+
205
+ if __name__ == '__main__':
206
+ main()