catninja123 commited on
Commit
899504c
·
verified ·
1 Parent(s): 5694ce7

Upload src/train_dpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/train_dpo.py +375 -0
src/train_dpo.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MASH Stage 3: DPO Alignment with GPTZero as Reward
3
+
4
+ 1. Use SFT model to generate paraphrases for training data
5
+ 2. Score each paraphrase with GPTZero API
6
+ 3. Construct preference pairs:
7
+ - chosen = human text (passes as human)
8
+ - rejected = model output that GPTZero detects as AI
9
+ 4. Train with DPO loss
10
+
11
+ GPTZero API is only called during data construction (~50-100 queries),
12
+ NOT during training itself.
13
+ """
14
+
15
+ import os
16
+ import sys
17
+ import json
18
+ import time
19
+ import argparse
20
+ import requests
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import DataLoader
24
+ from torch.optim import AdamW
25
+ from torch.optim.lr_scheduler import CosineAnnealingLR
26
+
27
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
28
+ from model import StyleBART
29
+ from dataset import MASHDPODataset, dpo_collate_fn
30
+
31
+
32
+ # ============================================================
33
+ # GPTZero API Integration
34
+ # ============================================================
35
+
36
+ GPTZERO_API_KEY = os.environ.get('GPTZERO_API_KEY', '')
37
+ GPTZERO_API_URL = 'https://api.gptzero.me/v2/predict/text'
38
+
39
+ def query_gptzero(text: str) -> dict:
40
+ """
41
+ Query GPTZero API for AI detection score.
42
+ Returns: {'ai_prob': float, 'human_prob': float, 'mixed_prob': float}
43
+ """
44
+ if not GPTZERO_API_KEY:
45
+ raise ValueError("GPTZERO_API_KEY not set")
46
+
47
+ headers = {
48
+ 'x-api-key': GPTZERO_API_KEY,
49
+ 'Content-Type': 'application/json',
50
+ }
51
+ payload = {
52
+ 'document': text,
53
+ 'version': '2024-04-04',
54
+ }
55
+
56
+ for attempt in range(3):
57
+ try:
58
+ resp = requests.post(GPTZERO_API_URL, json=payload, headers=headers, timeout=30)
59
+ resp.raise_for_status()
60
+ result = resp.json()
61
+ doc = result.get('documents', [{}])[0]
62
+ return {
63
+ 'ai_prob': doc.get('completely_generated_prob', 0),
64
+ 'human_prob': 1 - doc.get('completely_generated_prob', 0),
65
+ 'class': doc.get('predicted_class', 'unknown'),
66
+ }
67
+ except Exception as e:
68
+ if attempt < 2:
69
+ time.sleep(2 ** attempt)
70
+ else:
71
+ print(f"GPTZero API error: {e}")
72
+ return {'ai_prob': 0.5, 'human_prob': 0.5, 'class': 'error'}
73
+
74
+
75
+ # ============================================================
76
+ # DPO Data Construction
77
+ # ============================================================
78
+
79
+ def construct_dpo_data(sft_model_path: str, train_data_path: str,
80
+ output_path: str, device: str = 'cuda',
81
+ max_samples: int = 500, ai_threshold: float = 0.5):
82
+ """
83
+ Construct DPO preference pairs using SFT model + GPTZero.
84
+
85
+ For each sample:
86
+ 1. Generate paraphrase with SFT model (human style)
87
+ 2. Query GPTZero
88
+ 3. If detected as AI → use as rejected; human text = chosen
89
+ 4. If detected as human → skip (model already succeeds)
90
+ """
91
+ print(f"Loading SFT model from {sft_model_path}...")
92
+ model = StyleBART.load_pretrained(sft_model_path, device=device)
93
+ model = model.to(device)
94
+ model.eval()
95
+
96
+ # Load training data
97
+ raw_data = []
98
+ with open(train_data_path) as f:
99
+ for line in f:
100
+ raw_data.append(json.loads(line))
101
+
102
+ # Sample subset for DPO construction
103
+ import random
104
+ random.shuffle(raw_data)
105
+ raw_data = raw_data[:max_samples]
106
+
107
+ dpo_pairs = []
108
+ n_queried = 0
109
+ n_rejected = 0
110
+
111
+ print(f"Constructing DPO pairs from {len(raw_data)} samples...")
112
+
113
+ for i, d in enumerate(raw_data):
114
+ essay_type = d['type']
115
+ style_key = f'human_{essay_type}'
116
+
117
+ # Tokenize input
118
+ inputs = model.tokenizer(
119
+ d['input_text'],
120
+ max_length=512, truncation=True,
121
+ return_tensors='pt',
122
+ ).to(device)
123
+
124
+ # Generate with human style
125
+ with torch.no_grad():
126
+ generated = model.generate_text(
127
+ inputs['input_ids'],
128
+ inputs['attention_mask'],
129
+ style_keys=[style_key],
130
+ max_length=512, num_beams=4,
131
+ )
132
+
133
+ gen_text = model.tokenizer.decode(generated[0], skip_special_tokens=True)
134
+
135
+ # Query GPTZero
136
+ result = query_gptzero(gen_text)
137
+ n_queried += 1
138
+
139
+ if result['ai_prob'] > ai_threshold:
140
+ # Model failed to evade → good rejected sample
141
+ dpo_pairs.append({
142
+ 'input_text': d['input_text'],
143
+ 'chosen_text': d['human_text'],
144
+ 'rejected_text': gen_text,
145
+ 'style_key': style_key,
146
+ 'essay_type': essay_type,
147
+ 'gptzero_ai_prob': result['ai_prob'],
148
+ })
149
+ n_rejected += 1
150
+
151
+ if (i + 1) % 10 == 0:
152
+ print(f" [{i+1}/{len(raw_data)}] Queried: {n_queried}, "
153
+ f"Rejected (usable): {n_rejected}")
154
+
155
+ # Save DPO data
156
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
157
+ with open(output_path, 'w') as f:
158
+ for pair in dpo_pairs:
159
+ f.write(json.dumps(pair, ensure_ascii=False) + '\n')
160
+
161
+ print(f"\nDPO data construction complete:")
162
+ print(f" Total queried: {n_queried}")
163
+ print(f" Usable rejected pairs: {n_rejected}")
164
+ print(f" Rejection rate: {n_rejected/max(n_queried,1)*100:.1f}%")
165
+ print(f" Saved to: {output_path}")
166
+
167
+ return dpo_pairs
168
+
169
+
170
+ # ============================================================
171
+ # DPO Training
172
+ # ============================================================
173
+
174
+ def compute_dpo_loss(model, batch, device, beta=0.1, ref_model=None):
175
+ """
176
+ Compute DPO loss.
177
+
178
+ L_DPO = -E[log σ(β · (log π(y_w|x) - log π_ref(y_w|x))
179
+ - β · (log π(y_l|x) - log π_ref(y_l|x)))]
180
+ """
181
+ input_ids = batch['input_ids'].to(device)
182
+ attention_mask = batch['attention_mask'].to(device)
183
+ chosen_labels = batch['chosen_labels'].to(device)
184
+ rejected_labels = batch['rejected_labels'].to(device)
185
+ style_keys = batch['style_keys']
186
+
187
+ # Compute log probs for chosen
188
+ chosen_outputs = model(input_ids, attention_mask, chosen_labels, style_keys)
189
+ chosen_logits = chosen_outputs.logits
190
+ chosen_log_probs = compute_sequence_log_probs(chosen_logits, chosen_labels)
191
+
192
+ # Compute log probs for rejected
193
+ rejected_outputs = model(input_ids, attention_mask, rejected_labels, style_keys)
194
+ rejected_logits = rejected_outputs.logits
195
+ rejected_log_probs = compute_sequence_log_probs(rejected_logits, rejected_labels)
196
+
197
+ # Reference model log probs (frozen SFT model)
198
+ if ref_model is not None:
199
+ with torch.no_grad():
200
+ ref_chosen_outputs = ref_model(input_ids, attention_mask, chosen_labels, style_keys)
201
+ ref_chosen_log_probs = compute_sequence_log_probs(ref_chosen_outputs.logits, chosen_labels)
202
+
203
+ ref_rejected_outputs = ref_model(input_ids, attention_mask, rejected_labels, style_keys)
204
+ ref_rejected_log_probs = compute_sequence_log_probs(ref_rejected_outputs.logits, rejected_labels)
205
+ else:
206
+ ref_chosen_log_probs = chosen_log_probs.detach()
207
+ ref_rejected_log_probs = rejected_log_probs.detach()
208
+
209
+ # DPO loss
210
+ chosen_rewards = beta * (chosen_log_probs - ref_chosen_log_probs)
211
+ rejected_rewards = beta * (rejected_log_probs - ref_rejected_log_probs)
212
+
213
+ loss = -F.logsigmoid(chosen_rewards - rejected_rewards).mean()
214
+
215
+ # Metrics
216
+ with torch.no_grad():
217
+ reward_margin = (chosen_rewards - rejected_rewards).mean().item()
218
+ accuracy = ((chosen_rewards > rejected_rewards).float().mean().item())
219
+
220
+ return loss, {
221
+ 'loss': loss.item(),
222
+ 'reward_margin': reward_margin,
223
+ 'accuracy': accuracy,
224
+ }
225
+
226
+
227
+ def compute_sequence_log_probs(logits, labels):
228
+ """Compute per-sequence average log probability."""
229
+ shift_logits = logits[..., :-1, :].contiguous()
230
+ shift_labels = labels[..., 1:].contiguous()
231
+
232
+ log_probs = F.log_softmax(shift_logits, dim=-1)
233
+
234
+ # Gather log probs for actual tokens
235
+ token_log_probs = log_probs.gather(-1, shift_labels.clamp(min=0).unsqueeze(-1)).squeeze(-1)
236
+
237
+ # Mask padding (-100)
238
+ mask = (shift_labels != -100).float()
239
+
240
+ # Average log prob per sequence
241
+ seq_log_probs = (token_log_probs * mask).sum(dim=-1) / mask.sum(dim=-1).clamp(min=1)
242
+
243
+ return seq_log_probs
244
+
245
+
246
+ def train_dpo(model, ref_model, train_loader, optimizer, scheduler,
247
+ device, beta=0.1):
248
+ """Train one epoch of DPO."""
249
+ model.train()
250
+ total_metrics = {'loss': 0, 'reward_margin': 0, 'accuracy': 0}
251
+ n_batches = 0
252
+
253
+ for batch in train_loader:
254
+ loss, metrics = compute_dpo_loss(model, batch, device, beta, ref_model)
255
+
256
+ optimizer.zero_grad()
257
+ loss.backward()
258
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
259
+ optimizer.step()
260
+ scheduler.step()
261
+
262
+ for k in total_metrics:
263
+ total_metrics[k] += metrics[k]
264
+ n_batches += 1
265
+
266
+ return {k: v / n_batches for k, v in total_metrics.items()}
267
+
268
+
269
+ def main():
270
+ parser = argparse.ArgumentParser()
271
+ parser.add_argument('--mode', choices=['construct', 'train', 'both'], default='both')
272
+ parser.add_argument('--sft_model_path', default='checkpoints/sft/best')
273
+ parser.add_argument('--train_data', default='data/train.jsonl')
274
+ parser.add_argument('--dpo_data', default='data/dpo_pairs.jsonl')
275
+ parser.add_argument('--output_dir', default='checkpoints/dpo')
276
+ parser.add_argument('--batch_size', type=int, default=4)
277
+ parser.add_argument('--epochs', type=int, default=3)
278
+ parser.add_argument('--lr', type=float, default=1e-5)
279
+ parser.add_argument('--beta', type=float, default=0.1)
280
+ parser.add_argument('--max_dpo_samples', type=int, default=500)
281
+ parser.add_argument('--ai_threshold', type=float, default=0.5)
282
+ parser.add_argument('--seed', type=int, default=42)
283
+ args = parser.parse_args()
284
+
285
+ torch.manual_seed(args.seed)
286
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
287
+ print(f"Device: {device}")
288
+
289
+ # Stage 3a: Construct DPO data
290
+ if args.mode in ['construct', 'both']:
291
+ construct_dpo_data(
292
+ sft_model_path=args.sft_model_path,
293
+ train_data_path=args.train_data,
294
+ output_path=args.dpo_data,
295
+ device=str(device),
296
+ max_samples=args.max_dpo_samples,
297
+ ai_threshold=args.ai_threshold,
298
+ )
299
+
300
+ # Stage 3b: DPO Training
301
+ if args.mode in ['train', 'both']:
302
+ # Check DPO data exists
303
+ if not os.path.exists(args.dpo_data):
304
+ print(f"ERROR: DPO data not found at {args.dpo_data}")
305
+ print("Run with --mode construct first")
306
+ return
307
+
308
+ # Load DPO model (initialized from SFT)
309
+ print(f"\nLoading DPO model from {args.sft_model_path}...")
310
+ model = StyleBART.load_pretrained(args.sft_model_path, device=str(device))
311
+ model = model.to(device)
312
+
313
+ # Load reference model (frozen SFT)
314
+ print("Loading reference model (frozen)...")
315
+ ref_model = StyleBART.load_pretrained(args.sft_model_path, device=str(device))
316
+ ref_model = ref_model.to(device)
317
+ ref_model.eval()
318
+ for p in ref_model.parameters():
319
+ p.requires_grad = False
320
+
321
+ # Dataset
322
+ dpo_dataset = MASHDPODataset(
323
+ args.dpo_data, model.tokenizer,
324
+ max_input_len=512, max_target_len=512,
325
+ )
326
+ dpo_loader = DataLoader(
327
+ dpo_dataset, batch_size=args.batch_size,
328
+ shuffle=True, collate_fn=dpo_collate_fn,
329
+ )
330
+
331
+ print(f"DPO training pairs: {len(dpo_dataset)}")
332
+
333
+ # Optimizer
334
+ optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
335
+ total_steps = len(dpo_loader) * args.epochs
336
+ scheduler = CosineAnnealingLR(optimizer, T_max=total_steps, eta_min=1e-7)
337
+
338
+ # Training
339
+ os.makedirs(args.output_dir, exist_ok=True)
340
+ best_margin = -float('inf')
341
+
342
+ print(f"\n{'='*60}")
343
+ print(f"Starting DPO Training")
344
+ print(f" Epochs: {args.epochs}")
345
+ print(f" Beta: {args.beta}")
346
+ print(f" LR: {args.lr}")
347
+ print(f"{'='*60}\n")
348
+
349
+ for epoch in range(1, args.epochs + 1):
350
+ t0 = time.time()
351
+
352
+ metrics = train_dpo(
353
+ model, ref_model, dpo_loader, optimizer, scheduler,
354
+ device, beta=args.beta,
355
+ )
356
+
357
+ elapsed = time.time() - t0
358
+
359
+ print(f"Epoch {epoch}/{args.epochs} ({elapsed:.0f}s)")
360
+ print(f" Loss: {metrics['loss']:.4f}")
361
+ print(f" Reward margin: {metrics['reward_margin']:.4f}")
362
+ print(f" Accuracy: {metrics['accuracy']:.2%}")
363
+
364
+ if metrics['reward_margin'] > best_margin:
365
+ best_margin = metrics['reward_margin']
366
+ model.save_pretrained(os.path.join(args.output_dir, 'best'))
367
+ print(f" ★ New best model saved")
368
+
369
+ # Save final
370
+ model.save_pretrained(os.path.join(args.output_dir, 'final'))
371
+ print(f"\nDPO training complete! Models saved to {args.output_dir}/")
372
+
373
+
374
+ if __name__ == '__main__':
375
+ main()