catninja123 commited on
Commit
720d8ac
·
verified ·
1 Parent(s): 10427af

Upload src/dataset_dpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/dataset_dpo.py +135 -0
src/dataset_dpo.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MASH DPO Dataset - Preference pairs for Direct Preference Optimization
3
+
4
+ Preference pairs:
5
+ - prompt: instruction + AI text
6
+ - chosen: human-written text (naturally passes GPTZero)
7
+ - rejected: SFT model output (detected as AI by GPTZero)
8
+
9
+ For the initial DPO round, we use the training data's human texts as "chosen"
10
+ and the AI texts (which the SFT model was trained to approximate) as a proxy
11
+ for "rejected". This works because:
12
+ 1. The SFT model's outputs are stylistically similar to the AI inputs
13
+ 2. The human texts represent the target distribution we want to reach
14
+ 3. DPO will push the model away from AI-like patterns toward human-like ones
15
+ """
16
+
17
+ import json
18
+ import random
19
+ import torch
20
+ from torch.utils.data import Dataset
21
+
22
+
23
+ # Same instruction templates as SFT v4
24
+ INSTRUCTIONS = [
25
+ "Rewrite the following AI-generated {type} essay in a natural, authentic human voice. Preserve the original meaning and key details while making the writing sound genuinely human-written:\n\n{text}",
26
+ "Transform this AI-written {type} essay into natural human writing. Keep the same ideas and details but make it sound like a real person wrote it:\n\n{text}",
27
+ "Convert the following machine-generated {type} essay to sound authentically human. Maintain the core content while adopting a genuine, personal writing style:\n\n{text}",
28
+ "Rewrite this {type} essay to remove all traces of AI writing. The output should read as if written by a real student, preserving the original meaning:\n\n{text}",
29
+ "Make the following AI-generated {type} essay sound human-written. Keep the same content and structure but use natural, authentic language:\n\n{text}",
30
+ ]
31
+
32
+ TYPE_NAMES = {
33
+ 'ps': 'personal statement',
34
+ 'supp': 'supplemental',
35
+ }
36
+
37
+
38
+ class DPODataset(Dataset):
39
+ """
40
+ Dataset for DPO training.
41
+
42
+ Each sample contains:
43
+ - prompt_ids: tokenized instruction + AI text (encoder input)
44
+ - chosen_ids: tokenized human text (preferred output)
45
+ - rejected_ids: tokenized AI text (dispreferred output)
46
+ """
47
+
48
+ def __init__(self, data_path: str, tokenizer,
49
+ max_input_len: int = 512, max_target_len: int = 512):
50
+ self.tokenizer = tokenizer
51
+ self.max_input_len = max_input_len
52
+ self.max_target_len = max_target_len
53
+
54
+ self.examples = []
55
+ with open(data_path) as f:
56
+ for line in f:
57
+ d = json.loads(line)
58
+ essay_type = d.get('type', 'supp')
59
+ type_name = TYPE_NAMES.get(essay_type, essay_type)
60
+
61
+ # For DPO: chosen=human, rejected=AI (the original AI text)
62
+ # The AI text serves as a proxy for what the SFT model would generate
63
+ self.examples.append({
64
+ 'ai_text': d.get('ai_text', d.get('input_text', '')),
65
+ 'human_text': d['human_text'],
66
+ 'type_name': type_name,
67
+ })
68
+
69
+ def __len__(self):
70
+ return len(self.examples)
71
+
72
+ def __getitem__(self, idx):
73
+ ex = self.examples[idx]
74
+
75
+ # Build instruction prompt (same as SFT)
76
+ template = random.choice(INSTRUCTIONS)
77
+ prompt_text = template.format(
78
+ type=ex['type_name'],
79
+ text=ex['ai_text'],
80
+ )
81
+
82
+ # Tokenize prompt (encoder input)
83
+ prompt_enc = self.tokenizer(
84
+ prompt_text,
85
+ max_length=self.max_input_len,
86
+ truncation=True,
87
+ padding='max_length',
88
+ return_tensors='pt',
89
+ )
90
+
91
+ # Tokenize chosen (human text)
92
+ chosen_enc = self.tokenizer(
93
+ text_target=ex['human_text'],
94
+ max_length=self.max_target_len,
95
+ truncation=True,
96
+ padding='max_length',
97
+ return_tensors='pt',
98
+ )
99
+
100
+ # Tokenize rejected (AI text — the original, not the instruction)
101
+ rejected_enc = self.tokenizer(
102
+ text_target=ex['ai_text'],
103
+ max_length=self.max_target_len,
104
+ truncation=True,
105
+ padding='max_length',
106
+ return_tensors='pt',
107
+ )
108
+
109
+ # Build labels (replace pad with -100)
110
+ chosen_labels = chosen_enc['input_ids'].squeeze().clone()
111
+ chosen_labels[chosen_labels == self.tokenizer.pad_token_id] = -100
112
+
113
+ rejected_labels = rejected_enc['input_ids'].squeeze().clone()
114
+ rejected_labels[rejected_labels == self.tokenizer.pad_token_id] = -100
115
+
116
+ return {
117
+ 'input_ids': prompt_enc['input_ids'].squeeze(),
118
+ 'attention_mask': prompt_enc['attention_mask'].squeeze(),
119
+ 'chosen_labels': chosen_labels,
120
+ 'rejected_labels': rejected_labels,
121
+ 'chosen_attention_mask': (chosen_enc['input_ids'].squeeze() != self.tokenizer.pad_token_id).long(),
122
+ 'rejected_attention_mask': (rejected_enc['input_ids'].squeeze() != self.tokenizer.pad_token_id).long(),
123
+ }
124
+
125
+
126
+ def dpo_collate_fn(batch):
127
+ """Collate function for DPO dataset."""
128
+ return {
129
+ 'input_ids': torch.stack([b['input_ids'] for b in batch]),
130
+ 'attention_mask': torch.stack([b['attention_mask'] for b in batch]),
131
+ 'chosen_labels': torch.stack([b['chosen_labels'] for b in batch]),
132
+ 'rejected_labels': torch.stack([b['rejected_labels'] for b in batch]),
133
+ 'chosen_attention_mask': torch.stack([b['chosen_attention_mask'] for b in batch]),
134
+ 'rejected_attention_mask': torch.stack([b['rejected_attention_mask'] for b in batch]),
135
+ }