|
|
|
|
|
""" |
|
|
UltraData-Math L3 - Data Synthesis Script |
|
|
|
|
|
OpenAI API-based data synthesis tool, supporting: |
|
|
- Q&A synthesis |
|
|
- Multi-turn conversation synthesis |
|
|
- Multi-style rewriting |
|
|
- Knowledge extraction and textbook exercise generation |
|
|
|
|
|
Usage: |
|
|
python run_synthesis.py \ |
|
|
--input data.jsonl \ |
|
|
--output output.jsonl \ |
|
|
--task qa \ |
|
|
--level high_school \ |
|
|
--model gpt-4o \ |
|
|
--workers 10 |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import asyncio |
|
|
import json |
|
|
import os |
|
|
import re |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
from openai import AsyncOpenAI |
|
|
|
|
|
|
|
|
from qa_synthesis import QA_PROMPTS, get_qa_prompt |
|
|
from conversation_synthesis import CONVERSATION_PROMPTS, get_conversation_prompt |
|
|
from multistyle_rewrite import MULTISTYLE_PROMPTS, get_multistyle_prompt |
|
|
from knowledge_textbook import ( |
|
|
get_knowledge_extraction_prompt, |
|
|
get_textbook_exercise_prompt, |
|
|
TEXTBOOK_EXERCISE_PROMPTS, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_MODEL = "gpt-4o" |
|
|
DEFAULT_TEMPERATURE = 0.7 |
|
|
DEFAULT_MAX_TOKENS = 4096 |
|
|
DEFAULT_WORKERS = 10 |
|
|
DEFAULT_MAX_RETRIES = 3 |
|
|
DEFAULT_RETRY_DELAY = 1.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_qa_output(response: str) -> dict: |
|
|
"""Parse Q&A synthesis output""" |
|
|
result = {"raw": response} |
|
|
|
|
|
|
|
|
problem_match = re.search(r"<problem>(.*?)</problem>", response, re.DOTALL) |
|
|
solution_match = re.search(r"<solution>(.*?)</solution>", response, re.DOTALL) |
|
|
|
|
|
if problem_match: |
|
|
result["problem"] = problem_match.group(1).strip() |
|
|
if solution_match: |
|
|
result["solution"] = solution_match.group(1).strip() |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def parse_conversation_output(response: str) -> dict: |
|
|
"""Parse conversation synthesis output""" |
|
|
result = {"raw": response} |
|
|
|
|
|
|
|
|
for tag in ["discussions", "conversation", "interaction"]: |
|
|
match = re.search(rf"<{tag}>(.*?)</{tag}>", response, re.DOTALL) |
|
|
if match: |
|
|
result["content"] = match.group(1).strip() |
|
|
result["type"] = tag |
|
|
break |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def parse_rewrite_output(response: str) -> dict: |
|
|
"""Parse multi-style rewrite output""" |
|
|
result = {"raw": response} |
|
|
|
|
|
match = re.search(r"<rewritten content>(.*?)</rewritten content>", response, re.DOTALL) |
|
|
if match: |
|
|
result["rewritten"] = match.group(1).strip() |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def parse_knowledge_output(response: str) -> dict: |
|
|
"""Parse knowledge extraction output""" |
|
|
result = {"raw": response} |
|
|
|
|
|
if "no result" in response.lower(): |
|
|
result["knowledge_points"] = [] |
|
|
return result |
|
|
|
|
|
|
|
|
pattern = r"<mathematical knowledge point\d*>(.*?)</mathematical knowledge point\d*>" |
|
|
matches = re.findall(pattern, response, re.DOTALL) |
|
|
result["knowledge_points"] = [m.strip() for m in matches] |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def parse_textbook_output(response: str) -> dict: |
|
|
"""Parse textbook exercise output""" |
|
|
result = {"raw": response} |
|
|
|
|
|
match = re.search(r"<material>(.*?)</material>", response, re.DOTALL) |
|
|
if match: |
|
|
result["material"] = match.group(1).strip() |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
OUTPUT_PARSERS = { |
|
|
"qa": parse_qa_output, |
|
|
"conversation": parse_conversation_output, |
|
|
"rewrite": parse_rewrite_output, |
|
|
"knowledge": parse_knowledge_output, |
|
|
"textbook": parse_textbook_output, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SynthesisClient: |
|
|
"""Data synthesis client""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
api_key: Optional[str] = None, |
|
|
base_url: Optional[str] = None, |
|
|
model: str = DEFAULT_MODEL, |
|
|
temperature: float = DEFAULT_TEMPERATURE, |
|
|
max_tokens: int = DEFAULT_MAX_TOKENS, |
|
|
max_retries: int = DEFAULT_MAX_RETRIES, |
|
|
retry_delay: float = DEFAULT_RETRY_DELAY, |
|
|
): |
|
|
self.client = AsyncOpenAI( |
|
|
api_key=api_key or os.getenv("OPENAI_API_KEY"), |
|
|
base_url=base_url or os.getenv("OPENAI_BASE_URL"), |
|
|
) |
|
|
self.model = model |
|
|
self.temperature = temperature |
|
|
self.max_tokens = max_tokens |
|
|
self.max_retries = max_retries |
|
|
self.retry_delay = retry_delay |
|
|
|
|
|
async def generate(self, prompt: str) -> str: |
|
|
"""Call API to generate content""" |
|
|
for attempt in range(self.max_retries): |
|
|
try: |
|
|
response = await self.client.chat.completions.create( |
|
|
model=self.model, |
|
|
messages=[{"role": "user", "content": prompt}], |
|
|
temperature=self.temperature, |
|
|
max_tokens=self.max_tokens, |
|
|
) |
|
|
return response.choices[0].message.content |
|
|
except Exception as e: |
|
|
if attempt < self.max_retries - 1: |
|
|
await asyncio.sleep(self.retry_delay * (2 ** attempt)) |
|
|
else: |
|
|
raise e |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SynthesisTask: |
|
|
"""Base class for synthesis tasks""" |
|
|
|
|
|
def __init__(self, client: SynthesisClient, text_field: str = "text"): |
|
|
self.client = client |
|
|
self.text_field = text_field |
|
|
|
|
|
def get_prompt(self, sample: dict) -> str: |
|
|
raise NotImplementedError |
|
|
|
|
|
def parse_output(self, response: str) -> dict: |
|
|
raise NotImplementedError |
|
|
|
|
|
async def process(self, sample: dict) -> dict: |
|
|
"""Process a single sample""" |
|
|
prompt = self.get_prompt(sample) |
|
|
response = await self.client.generate(prompt) |
|
|
parsed = self.parse_output(response) |
|
|
return {**sample, "synthesis_result": parsed} |
|
|
|
|
|
|
|
|
class QASynthesisTask(SynthesisTask): |
|
|
"""Q&A synthesis task""" |
|
|
|
|
|
def __init__(self, client: SynthesisClient, level: str, text_field: str = "text"): |
|
|
super().__init__(client, text_field) |
|
|
self.level = level |
|
|
self.prompt_template = get_qa_prompt(level) |
|
|
|
|
|
def get_prompt(self, sample: dict) -> str: |
|
|
text = sample.get(self.text_field, "") |
|
|
return self.prompt_template.format(text=text) |
|
|
|
|
|
def parse_output(self, response: str) -> dict: |
|
|
return parse_qa_output(response) |
|
|
|
|
|
|
|
|
class ConversationSynthesisTask(SynthesisTask): |
|
|
"""Conversation synthesis task""" |
|
|
|
|
|
def __init__(self, client: SynthesisClient, style: str, text_field: str = "text"): |
|
|
super().__init__(client, text_field) |
|
|
self.style = style |
|
|
self.prompt_template = get_conversation_prompt(style) |
|
|
|
|
|
def get_prompt(self, sample: dict) -> str: |
|
|
text = sample.get(self.text_field, "") |
|
|
return self.prompt_template.format(text=text) |
|
|
|
|
|
def parse_output(self, response: str) -> dict: |
|
|
return parse_conversation_output(response) |
|
|
|
|
|
|
|
|
class RewriteSynthesisTask(SynthesisTask): |
|
|
"""Multi-style rewrite task""" |
|
|
|
|
|
def __init__(self, client: SynthesisClient, style: str, text_field: str = "text"): |
|
|
super().__init__(client, text_field) |
|
|
self.style = style |
|
|
self.prompt_template = get_multistyle_prompt(style) |
|
|
|
|
|
def get_prompt(self, sample: dict) -> str: |
|
|
text = sample.get(self.text_field, "") |
|
|
return self.prompt_template.format(text=text) |
|
|
|
|
|
def parse_output(self, response: str) -> dict: |
|
|
return parse_rewrite_output(response) |
|
|
|
|
|
|
|
|
class KnowledgeExtractionTask(SynthesisTask): |
|
|
"""Knowledge extraction task""" |
|
|
|
|
|
def __init__(self, client: SynthesisClient, text_field: str = "text"): |
|
|
super().__init__(client, text_field) |
|
|
self.prompt_template = get_knowledge_extraction_prompt() |
|
|
|
|
|
def get_prompt(self, sample: dict) -> str: |
|
|
text = sample.get(self.text_field, "") |
|
|
return self.prompt_template.format(text=text) |
|
|
|
|
|
def parse_output(self, response: str) -> dict: |
|
|
return parse_knowledge_output(response) |
|
|
|
|
|
|
|
|
class TextbookExerciseTask(SynthesisTask): |
|
|
"""Textbook exercise generation task""" |
|
|
|
|
|
def __init__(self, client: SynthesisClient, difficulty: str, knowledge_field: str = "knowledge_point"): |
|
|
super().__init__(client) |
|
|
self.difficulty = difficulty |
|
|
self.knowledge_field = knowledge_field |
|
|
self.prompt_template = get_textbook_exercise_prompt(difficulty) |
|
|
|
|
|
def get_prompt(self, sample: dict) -> str: |
|
|
knowledge = sample.get(self.knowledge_field, "") |
|
|
return self.prompt_template.format(mathematical_knowledge_point=knowledge) |
|
|
|
|
|
def parse_output(self, response: str) -> dict: |
|
|
return parse_textbook_output(response) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def process_batch( |
|
|
task: SynthesisTask, |
|
|
samples: list[dict], |
|
|
workers: int, |
|
|
progress_callback=None, |
|
|
) -> list[dict]: |
|
|
"""Process batch data concurrently""" |
|
|
semaphore = asyncio.Semaphore(workers) |
|
|
results = [] |
|
|
completed = 0 |
|
|
|
|
|
async def process_with_semaphore(sample: dict, idx: int): |
|
|
nonlocal completed |
|
|
async with semaphore: |
|
|
try: |
|
|
result = await task.process(sample) |
|
|
result["_status"] = "success" |
|
|
except Exception as e: |
|
|
result = {**sample, "_status": "error", "_error": str(e)} |
|
|
|
|
|
completed += 1 |
|
|
if progress_callback: |
|
|
progress_callback(completed, len(samples)) |
|
|
|
|
|
return idx, result |
|
|
|
|
|
tasks = [process_with_semaphore(sample, i) for i, sample in enumerate(samples)] |
|
|
task_results = await asyncio.gather(*tasks) |
|
|
|
|
|
|
|
|
task_results.sort(key=lambda x: x[0]) |
|
|
results = [r[1] for r in task_results] |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def load_jsonl(filepath: str) -> list[dict]: |
|
|
"""Load JSONL file""" |
|
|
data = [] |
|
|
with open(filepath, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
line = line.strip() |
|
|
if line: |
|
|
data.append(json.loads(line)) |
|
|
return data |
|
|
|
|
|
|
|
|
def save_jsonl(data: list[dict], filepath: str): |
|
|
"""Save JSONL file""" |
|
|
with open(filepath, "w", encoding="utf-8") as f: |
|
|
for item in data: |
|
|
f.write(json.dumps(item, ensure_ascii=False) + "\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_task(args, client: SynthesisClient) -> SynthesisTask: |
|
|
"""Create synthesis task based on arguments""" |
|
|
task_type = args.task |
|
|
|
|
|
if task_type == "qa": |
|
|
level = args.level or "high_school" |
|
|
if level not in QA_PROMPTS: |
|
|
raise ValueError(f"Invalid QA level: {level}. Available: {list(QA_PROMPTS.keys())}") |
|
|
return QASynthesisTask(client, level, args.text_field) |
|
|
|
|
|
elif task_type == "conversation": |
|
|
style = args.style or "teacher_student" |
|
|
if style not in CONVERSATION_PROMPTS: |
|
|
raise ValueError(f"Invalid conversation style: {style}. Available: {list(CONVERSATION_PROMPTS.keys())}") |
|
|
return ConversationSynthesisTask(client, style, args.text_field) |
|
|
|
|
|
elif task_type == "rewrite": |
|
|
style = args.style or "textbook" |
|
|
if style not in MULTISTYLE_PROMPTS: |
|
|
raise ValueError(f"Invalid rewrite style: {style}. Available: {list(MULTISTYLE_PROMPTS.keys())}") |
|
|
return RewriteSynthesisTask(client, style, args.text_field) |
|
|
|
|
|
elif task_type == "knowledge": |
|
|
return KnowledgeExtractionTask(client, args.text_field) |
|
|
|
|
|
elif task_type == "textbook": |
|
|
difficulty = args.difficulty or "easy" |
|
|
if difficulty not in TEXTBOOK_EXERCISE_PROMPTS: |
|
|
raise ValueError(f"Invalid difficulty: {difficulty}. Available: {list(TEXTBOOK_EXERCISE_PROMPTS.keys())}") |
|
|
return TextbookExerciseTask(client, difficulty, args.knowledge_field) |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unknown task type: {task_type}") |
|
|
|
|
|
|
|
|
def print_progress(completed: int, total: int): |
|
|
"""Print progress""" |
|
|
percent = completed / total * 100 |
|
|
print(f"\rProgress: {completed}/{total} ({percent:.1f}%)", end="", flush=True) |
|
|
|
|
|
|
|
|
async def main_async(args): |
|
|
"""Async main function""" |
|
|
|
|
|
client = SynthesisClient( |
|
|
api_key=args.api_key, |
|
|
base_url=args.base_url, |
|
|
model=args.model, |
|
|
temperature=args.temperature, |
|
|
max_tokens=args.max_tokens, |
|
|
max_retries=args.max_retries, |
|
|
) |
|
|
|
|
|
|
|
|
task = create_task(args, client) |
|
|
|
|
|
|
|
|
print(f"Loading data from {args.input}...") |
|
|
samples = load_jsonl(args.input) |
|
|
|
|
|
|
|
|
if args.limit: |
|
|
samples = samples[:args.limit] |
|
|
|
|
|
print(f"Processing {len(samples)} samples with {args.workers} workers...") |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
results = await process_batch( |
|
|
task, |
|
|
samples, |
|
|
args.workers, |
|
|
progress_callback=print_progress if not args.quiet else None, |
|
|
) |
|
|
|
|
|
elapsed = time.time() - start_time |
|
|
print(f"\nCompleted in {elapsed:.2f}s ({len(samples)/elapsed:.1f} samples/s)") |
|
|
|
|
|
|
|
|
success_count = sum(1 for r in results if r.get("_status") == "success") |
|
|
error_count = len(results) - success_count |
|
|
print(f"Success: {success_count}, Error: {error_count}") |
|
|
|
|
|
|
|
|
save_jsonl(results, args.output) |
|
|
print(f"Results saved to {args.output}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="UltraData-Math L3 Data Synthesis Tool", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Q&A synthesis (high school level) |
|
|
python run_synthesis.py -i data.jsonl -o qa_output.jsonl -t qa --level high_school |
|
|
|
|
|
# Multi-turn conversation synthesis (teacher-student) |
|
|
python run_synthesis.py -i data.jsonl -o conv_output.jsonl -t conversation --style teacher_student |
|
|
|
|
|
# Multi-style rewrite (textbook style) |
|
|
python run_synthesis.py -i data.jsonl -o rewrite_output.jsonl -t rewrite --style textbook |
|
|
|
|
|
# Knowledge extraction |
|
|
python run_synthesis.py -i data.jsonl -o knowledge_output.jsonl -t knowledge |
|
|
|
|
|
# Textbook exercise generation (medium difficulty) |
|
|
python run_synthesis.py -i knowledge.jsonl -o textbook_output.jsonl -t textbook --difficulty medium |
|
|
|
|
|
Task Types: |
|
|
qa Q&A synthesis |
|
|
--level: grade_school, middle_school, high_school, college |
|
|
|
|
|
conversation Multi-turn conversation synthesis |
|
|
--style: two_professors, teacher_student, two_students, |
|
|
interview, problem_solving, layman_expert, debate |
|
|
|
|
|
rewrite Multi-style rewrite |
|
|
--style: wikipedia, textbook, blog, popular_science, |
|
|
academic_paper, learning_note, lecture_note |
|
|
|
|
|
knowledge Knowledge extraction |
|
|
|
|
|
textbook Textbook exercise generation |
|
|
--difficulty: easy, medium, hard |
|
|
""" |
|
|
) |
|
|
|
|
|
|
|
|
parser.add_argument("-i", "--input", required=True, help="Input JSONL file path") |
|
|
parser.add_argument("-o", "--output", required=True, help="Output JSONL file path") |
|
|
|
|
|
|
|
|
parser.add_argument("-t", "--task", required=True, |
|
|
choices=["qa", "conversation", "rewrite", "knowledge", "textbook"], |
|
|
help="Synthesis task type") |
|
|
parser.add_argument("--level", help="Q&A difficulty level") |
|
|
parser.add_argument("--style", help="Conversation/rewrite style") |
|
|
parser.add_argument("--difficulty", help="Textbook exercise difficulty") |
|
|
|
|
|
|
|
|
parser.add_argument("--text-field", default="text", help="Input text field name (default: text)") |
|
|
parser.add_argument("--knowledge-field", default="knowledge_point", help="Knowledge point field name (default: knowledge_point)") |
|
|
|
|
|
|
|
|
parser.add_argument("--api-key", help="OpenAI API Key (or set OPENAI_API_KEY env var)") |
|
|
parser.add_argument("--base-url", help="API Base URL (or set OPENAI_BASE_URL env var)") |
|
|
parser.add_argument("--model", default=DEFAULT_MODEL, help=f"Model name (default: {DEFAULT_MODEL})") |
|
|
parser.add_argument("--temperature", type=float, default=DEFAULT_TEMPERATURE, help=f"Sampling temperature (default: {DEFAULT_TEMPERATURE})") |
|
|
parser.add_argument("--max-tokens", type=int, default=DEFAULT_MAX_TOKENS, help=f"Max tokens to generate (default: {DEFAULT_MAX_TOKENS})") |
|
|
|
|
|
|
|
|
parser.add_argument("-w", "--workers", type=int, default=DEFAULT_WORKERS, help=f"Concurrency (default: {DEFAULT_WORKERS})") |
|
|
parser.add_argument("--max-retries", type=int, default=DEFAULT_MAX_RETRIES, help=f"Max retries (default: {DEFAULT_MAX_RETRIES})") |
|
|
parser.add_argument("--limit", type=int, help="Limit number of samples to process") |
|
|
parser.add_argument("-q", "--quiet", action="store_true", help="Quiet mode") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
asyncio.run(main_async(args)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|