simpleLLM / math_expert /prepare_data.py
hollywoodfrancis's picture
Upload 11 files
b8ab4a2 verified
import os
import json
from pathlib import Path
import sympy
from sympy.parsing.latex import parse_latex
from sympy.parsing.sympy_parser import parse_expr
from datasets import load_dataset
import jsonlines
from typing import Dict, List, Any
import sys
import psutil
class MathDataPreparer:
def __init__(self, output_dir: str = "processed_data"):
self.output_dir = Path(output_dir)
self.output_dir.mkdir(exist_ok=True)
self.datasets = {
"gsm8k": {
"source": "gsm8k",
"config": "main",
"split": "train",
"fields": ["question", "answer"]
},
"proofnet": {
"source": "hoskinson-center/proofnet",
"split": "validation",
"fields": ["problem", "solution", "proof_steps"]
}
}
def normalize_equation(self, equation: str) -> str:
"""Normalize mathematical equations using sympy"""
try:
# Try LaTeX first
if "\\" in equation:
eq = parse_latex(equation)
# Then try markdown math
elif equation.startswith('$') and equation.endswith('$'):
eq = parse_expr(equation[1:-1])
# Then try regular expression
else:
eq = parse_expr(equation)
return str(eq)
except Exception as e:
print(f"Error normalizing equation: {equation}", file=sys.stderr)
return equation
def process_proof_steps(self, steps: List[str]) -> List[Dict[str, Any]]:
"""Process and validate proof steps"""
processed_steps = []
for step in steps:
try:
# Basic validation
if not step.strip():
continue
# Try to parse as structured data
try:
structured_step = json.loads(step)
if isinstance(structured_step, dict):
processed_steps.append(structured_step)
continue
except json.JSONDecodeError:
pass
# Process as plain text
processed_steps.append({
"text": step.strip(),
"valid": True
})
except Exception as e:
print(f"Error processing proof step: {step}", file=sys.stderr)
processed_steps.append({
"text": step,
"valid": False,
"error": str(e)
})
return processed_steps
def process_gsm8k(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process GSM8K dataset"""
processed = []
for example in dataset:
try:
processed_example = {
"question": example["question"].strip(),
"answer": example["answer"].strip()
}
# Normalize equations in question
if "=" in processed_example["question"]:
processed_example["question"] = self.normalize_equation(processed_example["question"])
# Normalize equations in answer
if "=" in processed_example["answer"]:
processed_example["answer"] = self.normalize_equation(processed_example["answer"])
processed.append(processed_example)
except Exception as e:
print(f"Error processing GSM8K example: {str(e)}", file=sys.stderr)
return processed
def process_proofnet(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Process ProofNet dataset"""
processed = []
error_count = 0
# First, let's print some dataset info
print("\nProofNet dataset info:")
print(f"Dataset type: {type(dataset)}")
if hasattr(dataset, 'features'):
print("\nDataset features:")
for feature, dtype in dataset.features.items():
print(f"{feature}: {dtype}")
# Print first example structure
if len(dataset) > 0:
first_example = dataset[0]
print("\nFirst example keys:", list(first_example.keys()))
print("\nFirst example preview:")
for key, value in first_example.items():
print(f"\n{key}:")
print(f"Type: {type(value)}")
if isinstance(value, str):
print(f"Length: {len(value)}")
elif isinstance(value, list):
print(f"List length: {len(value)}")
if len(value) > 0:
print(f"First item type: {type(value[0])}")
print("\n")
for idx, example in enumerate(dataset):
try:
processed_example = {
"problem": example.get("problem", "").strip(),
"solution": example.get("solution", "").strip(),
"proof_steps": []
}
# Handle proof steps
if "proof_steps" in example:
steps = example["proof_steps"]
print(f"\nExample {idx} proof steps info:")
print(f"Type: {type(steps)}")
if isinstance(steps, str):
print(f"Length: {len(steps)}")
# Try to split string into steps
steps = steps.split('\n')
print(f"Split into {len(steps)} steps")
elif isinstance(steps, list):
print(f"List length: {len(steps)}")
if len(steps) > 0:
print(f"First item type: {type(steps[0])}")
else:
print(f"Warning: Unexpected proof steps type: {type(steps)}")
steps = []
processed_example["proof_steps"] = self.process_proof_steps(steps)
# Normalize equations
for field in ["problem", "solution"]:
if "=" in processed_example[field]:
try:
processed_example[field] = self.normalize_equation(processed_example[field])
except Exception as e:
print(f"Error normalizing {field} in ProofNet example {idx}: {str(e)}")
processed.append(processed_example)
except Exception as e:
print(f"Error processing ProofNet example {idx}: {str(e)}")
error_count += 1
print(f"\nProcessed {len(processed)} examples from ProofNet")
print(f"Encountered {error_count} errors during processing")
return processed
def save_to_jsonl(self, data: List[Dict[str, Any]], filename: str):
"""Save processed data to JSONL file"""
filepath = self.output_dir / filename
with jsonlines.open(filepath, mode='w') as writer:
writer.write_all(data)
return filepath
def print_memory_usage(self):
"""Print current memory usage"""
process = psutil.Process()
memory_info = process.memory_info()
print(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f} MB")
def print_sample(self, data: List[Dict[str, Any]], count: int = 3):
"""Print sample of processed data"""
print("\nSample data:")
for i, example in enumerate(data[:count]):
print(f"\nSample {i+1}:")
if "proof_steps" in example:
# For ProofNet samples, show proof steps
print(f"Problem: {example['problem']}")
print(f"Solution: {example['solution']}")
print("\nProof Steps:")
for step in example["proof_steps"]:
print(f"- {step['text']}")
else:
# For GSM8K samples
print(json.dumps(example, indent=2))
def main():
preparer = MathDataPreparer()
# Load and process GSM8K
print("\nProcessing GSM8K dataset...")
gsm8k_dataset = load_dataset("gsm8k", "main", split="train")
print(f"Loaded {len(gsm8k_dataset)} samples from GSM8K")
processed_gsm8k = preparer.process_gsm8k(gsm8k_dataset)
print(f"Processed {len(processed_gsm8k)} samples")
preparer.print_sample(processed_gsm8k)
# Save GSM8K
gsm8k_path = preparer.save_to_jsonl(processed_gsm8k, "gsm8k_processed.jsonl")
print(f"\nSaved GSM8K processed data to: {gsm8k_path}")
# Load and process ProofNet
print("\nProcessing ProofNet dataset...")
try:
proofnet_dataset = load_dataset("hoskinson-center/proofnet", split="validation")
print(f"Loaded {len(proofnet_dataset)} samples from ProofNet")
processed_proofnet = preparer.process_proofnet(proofnet_dataset)
print(f"Processed {len(processed_proofnet)} samples")
preparer.print_sample(processed_proofnet)
# Save ProofNet
proofnet_path = preparer.save_to_jsonl(processed_proofnet, "proofnet_processed.jsonl")
print(f"\nSaved ProofNet processed data to: {proofnet_path}")
except Exception as e:
print(f"Error processing ProofNet dataset: {str(e)}")
print("Continuing with GSM8K data only")
# Print memory usage
preparer.print_memory_usage()
if __name__ == "__main__":
main()