|
|
import os |
|
|
import json |
|
|
from pathlib import Path |
|
|
import sympy |
|
|
from sympy.parsing.latex import parse_latex |
|
|
from sympy.parsing.sympy_parser import parse_expr |
|
|
from datasets import load_dataset |
|
|
import jsonlines |
|
|
from typing import Dict, List, Any |
|
|
import sys |
|
|
import psutil |
|
|
|
|
|
class MathDataPreparer: |
|
|
def __init__(self, output_dir: str = "processed_data"): |
|
|
self.output_dir = Path(output_dir) |
|
|
self.output_dir.mkdir(exist_ok=True) |
|
|
self.datasets = { |
|
|
"gsm8k": { |
|
|
"source": "gsm8k", |
|
|
"config": "main", |
|
|
"split": "train", |
|
|
"fields": ["question", "answer"] |
|
|
}, |
|
|
"proofnet": { |
|
|
"source": "hoskinson-center/proofnet", |
|
|
"split": "validation", |
|
|
"fields": ["problem", "solution", "proof_steps"] |
|
|
} |
|
|
} |
|
|
|
|
|
def normalize_equation(self, equation: str) -> str: |
|
|
"""Normalize mathematical equations using sympy""" |
|
|
try: |
|
|
|
|
|
if "\\" in equation: |
|
|
eq = parse_latex(equation) |
|
|
|
|
|
elif equation.startswith('$') and equation.endswith('$'): |
|
|
eq = parse_expr(equation[1:-1]) |
|
|
|
|
|
else: |
|
|
eq = parse_expr(equation) |
|
|
return str(eq) |
|
|
except Exception as e: |
|
|
print(f"Error normalizing equation: {equation}", file=sys.stderr) |
|
|
return equation |
|
|
|
|
|
def process_proof_steps(self, steps: List[str]) -> List[Dict[str, Any]]: |
|
|
"""Process and validate proof steps""" |
|
|
processed_steps = [] |
|
|
for step in steps: |
|
|
try: |
|
|
|
|
|
if not step.strip(): |
|
|
continue |
|
|
|
|
|
|
|
|
try: |
|
|
structured_step = json.loads(step) |
|
|
if isinstance(structured_step, dict): |
|
|
processed_steps.append(structured_step) |
|
|
continue |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
processed_steps.append({ |
|
|
"text": step.strip(), |
|
|
"valid": True |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"Error processing proof step: {step}", file=sys.stderr) |
|
|
processed_steps.append({ |
|
|
"text": step, |
|
|
"valid": False, |
|
|
"error": str(e) |
|
|
}) |
|
|
return processed_steps |
|
|
|
|
|
def process_gsm8k(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
|
"""Process GSM8K dataset""" |
|
|
processed = [] |
|
|
for example in dataset: |
|
|
try: |
|
|
processed_example = { |
|
|
"question": example["question"].strip(), |
|
|
"answer": example["answer"].strip() |
|
|
} |
|
|
|
|
|
|
|
|
if "=" in processed_example["question"]: |
|
|
processed_example["question"] = self.normalize_equation(processed_example["question"]) |
|
|
|
|
|
|
|
|
if "=" in processed_example["answer"]: |
|
|
processed_example["answer"] = self.normalize_equation(processed_example["answer"]) |
|
|
|
|
|
processed.append(processed_example) |
|
|
except Exception as e: |
|
|
print(f"Error processing GSM8K example: {str(e)}", file=sys.stderr) |
|
|
return processed |
|
|
|
|
|
def process_proofnet(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: |
|
|
"""Process ProofNet dataset""" |
|
|
processed = [] |
|
|
error_count = 0 |
|
|
|
|
|
|
|
|
print("\nProofNet dataset info:") |
|
|
print(f"Dataset type: {type(dataset)}") |
|
|
if hasattr(dataset, 'features'): |
|
|
print("\nDataset features:") |
|
|
for feature, dtype in dataset.features.items(): |
|
|
print(f"{feature}: {dtype}") |
|
|
|
|
|
|
|
|
if len(dataset) > 0: |
|
|
first_example = dataset[0] |
|
|
print("\nFirst example keys:", list(first_example.keys())) |
|
|
print("\nFirst example preview:") |
|
|
for key, value in first_example.items(): |
|
|
print(f"\n{key}:") |
|
|
print(f"Type: {type(value)}") |
|
|
if isinstance(value, str): |
|
|
print(f"Length: {len(value)}") |
|
|
elif isinstance(value, list): |
|
|
print(f"List length: {len(value)}") |
|
|
if len(value) > 0: |
|
|
print(f"First item type: {type(value[0])}") |
|
|
print("\n") |
|
|
|
|
|
for idx, example in enumerate(dataset): |
|
|
try: |
|
|
processed_example = { |
|
|
"problem": example.get("problem", "").strip(), |
|
|
"solution": example.get("solution", "").strip(), |
|
|
"proof_steps": [] |
|
|
} |
|
|
|
|
|
|
|
|
if "proof_steps" in example: |
|
|
steps = example["proof_steps"] |
|
|
print(f"\nExample {idx} proof steps info:") |
|
|
print(f"Type: {type(steps)}") |
|
|
if isinstance(steps, str): |
|
|
print(f"Length: {len(steps)}") |
|
|
|
|
|
steps = steps.split('\n') |
|
|
print(f"Split into {len(steps)} steps") |
|
|
elif isinstance(steps, list): |
|
|
print(f"List length: {len(steps)}") |
|
|
if len(steps) > 0: |
|
|
print(f"First item type: {type(steps[0])}") |
|
|
else: |
|
|
print(f"Warning: Unexpected proof steps type: {type(steps)}") |
|
|
steps = [] |
|
|
|
|
|
processed_example["proof_steps"] = self.process_proof_steps(steps) |
|
|
|
|
|
|
|
|
for field in ["problem", "solution"]: |
|
|
if "=" in processed_example[field]: |
|
|
try: |
|
|
processed_example[field] = self.normalize_equation(processed_example[field]) |
|
|
except Exception as e: |
|
|
print(f"Error normalizing {field} in ProofNet example {idx}: {str(e)}") |
|
|
|
|
|
processed.append(processed_example) |
|
|
except Exception as e: |
|
|
print(f"Error processing ProofNet example {idx}: {str(e)}") |
|
|
error_count += 1 |
|
|
|
|
|
print(f"\nProcessed {len(processed)} examples from ProofNet") |
|
|
print(f"Encountered {error_count} errors during processing") |
|
|
return processed |
|
|
|
|
|
def save_to_jsonl(self, data: List[Dict[str, Any]], filename: str): |
|
|
"""Save processed data to JSONL file""" |
|
|
filepath = self.output_dir / filename |
|
|
with jsonlines.open(filepath, mode='w') as writer: |
|
|
writer.write_all(data) |
|
|
return filepath |
|
|
|
|
|
def print_memory_usage(self): |
|
|
"""Print current memory usage""" |
|
|
process = psutil.Process() |
|
|
memory_info = process.memory_info() |
|
|
print(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f} MB") |
|
|
|
|
|
def print_sample(self, data: List[Dict[str, Any]], count: int = 3): |
|
|
"""Print sample of processed data""" |
|
|
print("\nSample data:") |
|
|
for i, example in enumerate(data[:count]): |
|
|
print(f"\nSample {i+1}:") |
|
|
if "proof_steps" in example: |
|
|
|
|
|
print(f"Problem: {example['problem']}") |
|
|
print(f"Solution: {example['solution']}") |
|
|
print("\nProof Steps:") |
|
|
for step in example["proof_steps"]: |
|
|
print(f"- {step['text']}") |
|
|
else: |
|
|
|
|
|
print(json.dumps(example, indent=2)) |
|
|
|
|
|
def main(): |
|
|
preparer = MathDataPreparer() |
|
|
|
|
|
|
|
|
print("\nProcessing GSM8K dataset...") |
|
|
gsm8k_dataset = load_dataset("gsm8k", "main", split="train") |
|
|
print(f"Loaded {len(gsm8k_dataset)} samples from GSM8K") |
|
|
|
|
|
processed_gsm8k = preparer.process_gsm8k(gsm8k_dataset) |
|
|
print(f"Processed {len(processed_gsm8k)} samples") |
|
|
|
|
|
preparer.print_sample(processed_gsm8k) |
|
|
|
|
|
|
|
|
gsm8k_path = preparer.save_to_jsonl(processed_gsm8k, "gsm8k_processed.jsonl") |
|
|
print(f"\nSaved GSM8K processed data to: {gsm8k_path}") |
|
|
|
|
|
|
|
|
print("\nProcessing ProofNet dataset...") |
|
|
try: |
|
|
proofnet_dataset = load_dataset("hoskinson-center/proofnet", split="validation") |
|
|
print(f"Loaded {len(proofnet_dataset)} samples from ProofNet") |
|
|
|
|
|
processed_proofnet = preparer.process_proofnet(proofnet_dataset) |
|
|
print(f"Processed {len(processed_proofnet)} samples") |
|
|
|
|
|
preparer.print_sample(processed_proofnet) |
|
|
|
|
|
|
|
|
proofnet_path = preparer.save_to_jsonl(processed_proofnet, "proofnet_processed.jsonl") |
|
|
print(f"\nSaved ProofNet processed data to: {proofnet_path}") |
|
|
except Exception as e: |
|
|
print(f"Error processing ProofNet dataset: {str(e)}") |
|
|
print("Continuing with GSM8K data only") |
|
|
|
|
|
|
|
|
preparer.print_memory_usage() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|