catninja123 commited on
Commit
79ad8bc
·
verified ·
1 Parent(s): c7813a5

Upload src/generate_grok_pairs.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/generate_grok_pairs.py +156 -0
src/generate_grok_pairs.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate AI paraphrase pairs using Grok API (xAI).
3
+ Uses the cleaned human texts and generates Grok-style AI versions.
4
+ This complements the existing Gemini-generated pairs for multi-LLM training.
5
+ """
6
+
7
+ import json
8
+ import os
9
+ import time
10
+ import asyncio
11
+ import aiohttp
12
+
13
+ XAI_API_KEY = os.environ.get('XAI_API_KEY', '')
14
+ XAI_API_URL = 'https://api.x.ai/v1/chat/completions'
15
+
16
+ INPUT_FILE = '/home/ubuntu/mash_training/data/human_texts_clean.jsonl'
17
+ OUTPUT_FILE = '/home/ubuntu/mash_training/data/grok_pairs.jsonl'
18
+
19
+ SYSTEM_PROMPT = """You are a writing assistant. Your task is to paraphrase the given text while:
20
+ 1. Keeping ALL the same information and meaning
21
+ 2. Using a polished, professional AI writing style
22
+ 3. Making it sound like it was written by an AI language model
23
+ 4. Using smooth transitions, parallel structures, and sophisticated vocabulary
24
+ 5. Maintaining the same approximate length
25
+
26
+ Do NOT add new information. Do NOT remove any information. Just rephrase it in a polished AI style.
27
+ Output ONLY the paraphrased text, nothing else."""
28
+
29
+ CONCURRENCY = 15 # Increased for faster generation
30
+ semaphore = asyncio.Semaphore(CONCURRENCY)
31
+
32
+
33
+ async def paraphrase_one(session, essay_id, text, essay_type, retries=3):
34
+ """Paraphrase one text using Grok API."""
35
+ type_hint = "personal statement" if essay_type == "ps" else "college supplement essay"
36
+
37
+ headers = {
38
+ 'Authorization': f'Bearer {XAI_API_KEY}',
39
+ 'Content-Type': 'application/json',
40
+ }
41
+
42
+ payload = {
43
+ 'model': 'grok-3-mini-fast',
44
+ 'messages': [
45
+ {'role': 'system', 'content': SYSTEM_PROMPT},
46
+ {'role': 'user', 'content': f'Paraphrase this {type_hint} excerpt in AI style:\n\n{text}'},
47
+ ],
48
+ 'temperature': 0.7,
49
+ 'max_tokens': max(len(text.split()) * 3, 512),
50
+ }
51
+
52
+ async with semaphore:
53
+ for attempt in range(retries):
54
+ try:
55
+ async with session.post(XAI_API_URL, json=payload, headers=headers, timeout=aiohttp.ClientTimeout(total=60)) as resp:
56
+ if resp.status == 429:
57
+ # Rate limited - wait and retry
58
+ await asyncio.sleep(5 * (attempt + 1))
59
+ continue
60
+ resp.raise_for_status()
61
+ result = await resp.json()
62
+ ai_text = result['choices'][0]['message']['content'].strip()
63
+
64
+ # Basic validation
65
+ if len(ai_text.split()) >= len(text.split()) * 0.4:
66
+ return essay_id, ai_text
67
+ else:
68
+ continue
69
+ except Exception as e:
70
+ if attempt < retries - 1:
71
+ await asyncio.sleep(2 ** attempt)
72
+ else:
73
+ print(f" ERROR {essay_id}: {e}", flush=True)
74
+ return essay_id, None
75
+ return essay_id, None
76
+
77
+
78
+ async def main():
79
+ if not XAI_API_KEY:
80
+ print("ERROR: XAI_API_KEY not set")
81
+ return
82
+
83
+ # Load clean human texts
84
+ data = []
85
+ with open(INPUT_FILE) as f:
86
+ for line in f:
87
+ data.append(json.loads(line))
88
+ print(f"Loaded {len(data)} clean samples", flush=True)
89
+
90
+ # Load existing progress
91
+ done_ids = set()
92
+ if os.path.exists(OUTPUT_FILE):
93
+ with open(OUTPUT_FILE) as f:
94
+ for line in f:
95
+ d = json.loads(line)
96
+ done_ids.add(d['essay_id'])
97
+ print(f"Already done: {len(done_ids)}", flush=True)
98
+
99
+ # Filter remaining
100
+ remaining = [d for d in data if d['essay_id'] not in done_ids]
101
+ print(f"Remaining: {len(remaining)}", flush=True)
102
+
103
+ if not remaining:
104
+ print("All done!")
105
+ return
106
+
107
+ # Process in batches
108
+ batch_size = 30
109
+ total_done = 0
110
+ total_errors = 0
111
+ start_time = time.time()
112
+
113
+ async with aiohttp.ClientSession() as session:
114
+ with open(OUTPUT_FILE, 'a', encoding='utf-8') as out:
115
+ for batch_start in range(0, len(remaining), batch_size):
116
+ batch = remaining[batch_start:batch_start + batch_size]
117
+
118
+ coros = [
119
+ paraphrase_one(session, d['essay_id'], d['human_text'], d['type'])
120
+ for d in batch
121
+ ]
122
+ results = await asyncio.gather(*coros)
123
+
124
+ for (essay_id, ai_text), orig in zip(results, batch):
125
+ if ai_text:
126
+ pair = {
127
+ 'essay_id': orig['essay_id'],
128
+ 'type': orig['type'],
129
+ 'tier': orig.get('tier', 'unknown'),
130
+ 'year': orig.get('year', 'unknown'),
131
+ 'human_text': orig['human_text'],
132
+ 'ai_text': ai_text,
133
+ 'human_words': len(orig['human_text'].split()),
134
+ 'ai_words': len(ai_text.split()),
135
+ 'ai_model': 'grok-3-mini-fast',
136
+ }
137
+ out.write(json.dumps(pair, ensure_ascii=False) + '\n')
138
+ total_done += 1
139
+ else:
140
+ total_errors += 1
141
+
142
+ out.flush()
143
+ elapsed = time.time() - start_time
144
+ rate = (total_done + total_errors) / elapsed if elapsed > 0 else 0
145
+ remaining_count = len(remaining) - batch_start - len(batch)
146
+ eta = remaining_count / rate / 60 if rate > 0 else 0
147
+ print(f" Batch {batch_start//batch_size + 1}: "
148
+ f"{total_done} done, {total_errors} errors, "
149
+ f"{rate:.1f}/s, ETA {eta:.0f}min", flush=True)
150
+
151
+ elapsed = time.time() - start_time
152
+ print(f"\nDONE: {total_done} pairs, {total_errors} errors in {elapsed/60:.1f} min", flush=True)
153
+
154
+
155
+ if __name__ == '__main__':
156
+ asyncio.run(main())