catninja123 commited on
Commit
69da0bb
·
verified ·
1 Parent(s): 963219b

Upload src/prepare_data.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/prepare_data.py +132 -0
src/prepare_data.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MASH Stage 1: Data Preparation
3
+ - Load raw training pairs (human PS/Supp + AI paraphrased versions)
4
+ - Filter for quality (word count, length ratio)
5
+ - Split into train/val sets
6
+ - Save in format ready for Style-SFT and DPO
7
+ """
8
+
9
+ import json
10
+ import random
11
+ import os
12
+ from pathlib import Path
13
+
14
+ def load_raw_data(path: str) -> list:
15
+ data = []
16
+ with open(path) as f:
17
+ for line in f:
18
+ d = json.loads(line)
19
+ data.append(d)
20
+ return data
21
+
22
+ def filter_data(data: list, min_words: int = 50, max_words: int = 800) -> list:
23
+ """Filter for quality samples."""
24
+ filtered = []
25
+ for d in data:
26
+ hw = d['human_words']
27
+ aw = d['ai_words']
28
+ # Both texts should be within reasonable range
29
+ if hw < min_words or aw < min_words:
30
+ continue
31
+ if hw > max_words or aw > max_words:
32
+ continue
33
+ # Length ratio should be reasonable (AI version shouldn't be too different)
34
+ ratio = aw / hw if hw > 0 else 0
35
+ if ratio < 0.5 or ratio > 2.0:
36
+ continue
37
+ # Text should not be empty or too short
38
+ if len(d['human_text'].strip()) < 100 or len(d['ai_text'].strip()) < 100:
39
+ continue
40
+ filtered.append(d)
41
+ return filtered
42
+
43
+ def prepare_sft_data(data: list) -> list:
44
+ """
45
+ Prepare data for Style-injection SFT.
46
+ Each sample has:
47
+ - input: AI text
48
+ - target_human: human text (for style transfer task)
49
+ - target_ai: AI text (for reconstruction task)
50
+ """
51
+ sft_data = []
52
+ for d in data:
53
+ sft_data.append({
54
+ 'id': d['essay_id'],
55
+ 'type': d['type'],
56
+ 'tier': d.get('tier', 'unknown'),
57
+ 'input_text': d['ai_text'],
58
+ 'human_text': d['human_text'],
59
+ 'ai_text': d['ai_text'],
60
+ })
61
+ return sft_data
62
+
63
+ def split_data(data: list, val_ratio: float = 0.1, seed: int = 42) -> tuple:
64
+ """Split into train and validation sets, stratified by type."""
65
+ random.seed(seed)
66
+
67
+ # Separate by type
68
+ ps_data = [d for d in data if d['type'] == 'ps']
69
+ supp_data = [d for d in data if d['type'] == 'supp']
70
+
71
+ random.shuffle(ps_data)
72
+ random.shuffle(supp_data)
73
+
74
+ ps_val_size = max(1, int(len(ps_data) * val_ratio))
75
+ supp_val_size = max(1, int(len(supp_data) * val_ratio))
76
+
77
+ val_data = ps_data[:ps_val_size] + supp_data[:supp_val_size]
78
+ train_data = ps_data[ps_val_size:] + supp_data[supp_val_size:]
79
+
80
+ random.shuffle(train_data)
81
+ random.shuffle(val_data)
82
+
83
+ return train_data, val_data
84
+
85
+ def save_jsonl(data: list, path: str):
86
+ with open(path, 'w') as f:
87
+ for d in data:
88
+ f.write(json.dumps(d, ensure_ascii=False) + '\n')
89
+
90
+ def main():
91
+ raw_path = '/home/ubuntu/experiment/training_pairs_v3_final.jsonl'
92
+ output_dir = '/home/ubuntu/mash_training/data'
93
+ os.makedirs(output_dir, exist_ok=True)
94
+
95
+ # Load and filter
96
+ print("Loading raw data...")
97
+ raw_data = load_raw_data(raw_path)
98
+ print(f" Raw samples: {len(raw_data)}")
99
+
100
+ print("Filtering data...")
101
+ filtered = filter_data(raw_data)
102
+ print(f" After filtering: {len(filtered)}")
103
+
104
+ # Prepare SFT format
105
+ print("Preparing SFT data...")
106
+ sft_data = prepare_sft_data(filtered)
107
+
108
+ # Split
109
+ print("Splitting into train/val...")
110
+ train_data, val_data = split_data(sft_data)
111
+ print(f" Train: {len(train_data)}")
112
+ print(f" Val: {len(val_data)}")
113
+
114
+ # Type distribution
115
+ from collections import Counter
116
+ train_types = Counter(d['type'] for d in train_data)
117
+ val_types = Counter(d['type'] for d in val_data)
118
+ print(f" Train types: {dict(train_types)}")
119
+ print(f" Val types: {dict(val_types)}")
120
+
121
+ # Save
122
+ save_jsonl(train_data, os.path.join(output_dir, 'train.jsonl'))
123
+ save_jsonl(val_data, os.path.join(output_dir, 'val.jsonl'))
124
+ save_jsonl(sft_data, os.path.join(output_dir, 'all.jsonl'))
125
+
126
+ print(f"\nData saved to {output_dir}/")
127
+ print(" train.jsonl - for training")
128
+ print(" val.jsonl - for validation")
129
+ print(" all.jsonl - complete dataset")
130
+
131
+ if __name__ == '__main__':
132
+ main()