import os import json from pathlib import Path import sympy from sympy.parsing.latex import parse_latex from sympy.parsing.sympy_parser import parse_expr from datasets import load_dataset import jsonlines from typing import Dict, List, Any import sys import psutil class MathDataPreparer: def __init__(self, output_dir: str = "processed_data"): self.output_dir = Path(output_dir) self.output_dir.mkdir(exist_ok=True) self.datasets = { "gsm8k": { "source": "gsm8k", "config": "main", "split": "train", "fields": ["question", "answer"] }, "proofnet": { "source": "hoskinson-center/proofnet", "split": "validation", "fields": ["problem", "solution", "proof_steps"] } } def normalize_equation(self, equation: str) -> str: """Normalize mathematical equations using sympy""" try: # Try LaTeX first if "\\" in equation: eq = parse_latex(equation) # Then try markdown math elif equation.startswith('$') and equation.endswith('$'): eq = parse_expr(equation[1:-1]) # Then try regular expression else: eq = parse_expr(equation) return str(eq) except Exception as e: print(f"Error normalizing equation: {equation}", file=sys.stderr) return equation def process_proof_steps(self, steps: List[str]) -> List[Dict[str, Any]]: """Process and validate proof steps""" processed_steps = [] for step in steps: try: # Basic validation if not step.strip(): continue # Try to parse as structured data try: structured_step = json.loads(step) if isinstance(structured_step, dict): processed_steps.append(structured_step) continue except json.JSONDecodeError: pass # Process as plain text processed_steps.append({ "text": step.strip(), "valid": True }) except Exception as e: print(f"Error processing proof step: {step}", file=sys.stderr) processed_steps.append({ "text": step, "valid": False, "error": str(e) }) return processed_steps def process_gsm8k(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Process GSM8K dataset""" processed = [] for example in dataset: try: processed_example = { "question": example["question"].strip(), "answer": example["answer"].strip() } # Normalize equations in question if "=" in processed_example["question"]: processed_example["question"] = self.normalize_equation(processed_example["question"]) # Normalize equations in answer if "=" in processed_example["answer"]: processed_example["answer"] = self.normalize_equation(processed_example["answer"]) processed.append(processed_example) except Exception as e: print(f"Error processing GSM8K example: {str(e)}", file=sys.stderr) return processed def process_proofnet(self, dataset: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Process ProofNet dataset""" processed = [] error_count = 0 # First, let's print some dataset info print("\nProofNet dataset info:") print(f"Dataset type: {type(dataset)}") if hasattr(dataset, 'features'): print("\nDataset features:") for feature, dtype in dataset.features.items(): print(f"{feature}: {dtype}") # Print first example structure if len(dataset) > 0: first_example = dataset[0] print("\nFirst example keys:", list(first_example.keys())) print("\nFirst example preview:") for key, value in first_example.items(): print(f"\n{key}:") print(f"Type: {type(value)}") if isinstance(value, str): print(f"Length: {len(value)}") elif isinstance(value, list): print(f"List length: {len(value)}") if len(value) > 0: print(f"First item type: {type(value[0])}") print("\n") for idx, example in enumerate(dataset): try: processed_example = { "problem": example.get("problem", "").strip(), "solution": example.get("solution", "").strip(), "proof_steps": [] } # Handle proof steps if "proof_steps" in example: steps = example["proof_steps"] print(f"\nExample {idx} proof steps info:") print(f"Type: {type(steps)}") if isinstance(steps, str): print(f"Length: {len(steps)}") # Try to split string into steps steps = steps.split('\n') print(f"Split into {len(steps)} steps") elif isinstance(steps, list): print(f"List length: {len(steps)}") if len(steps) > 0: print(f"First item type: {type(steps[0])}") else: print(f"Warning: Unexpected proof steps type: {type(steps)}") steps = [] processed_example["proof_steps"] = self.process_proof_steps(steps) # Normalize equations for field in ["problem", "solution"]: if "=" in processed_example[field]: try: processed_example[field] = self.normalize_equation(processed_example[field]) except Exception as e: print(f"Error normalizing {field} in ProofNet example {idx}: {str(e)}") processed.append(processed_example) except Exception as e: print(f"Error processing ProofNet example {idx}: {str(e)}") error_count += 1 print(f"\nProcessed {len(processed)} examples from ProofNet") print(f"Encountered {error_count} errors during processing") return processed def save_to_jsonl(self, data: List[Dict[str, Any]], filename: str): """Save processed data to JSONL file""" filepath = self.output_dir / filename with jsonlines.open(filepath, mode='w') as writer: writer.write_all(data) return filepath def print_memory_usage(self): """Print current memory usage""" process = psutil.Process() memory_info = process.memory_info() print(f"Current memory usage: {memory_info.rss / 1024 / 1024:.2f} MB") def print_sample(self, data: List[Dict[str, Any]], count: int = 3): """Print sample of processed data""" print("\nSample data:") for i, example in enumerate(data[:count]): print(f"\nSample {i+1}:") if "proof_steps" in example: # For ProofNet samples, show proof steps print(f"Problem: {example['problem']}") print(f"Solution: {example['solution']}") print("\nProof Steps:") for step in example["proof_steps"]: print(f"- {step['text']}") else: # For GSM8K samples print(json.dumps(example, indent=2)) def main(): preparer = MathDataPreparer() # Load and process GSM8K print("\nProcessing GSM8K dataset...") gsm8k_dataset = load_dataset("gsm8k", "main", split="train") print(f"Loaded {len(gsm8k_dataset)} samples from GSM8K") processed_gsm8k = preparer.process_gsm8k(gsm8k_dataset) print(f"Processed {len(processed_gsm8k)} samples") preparer.print_sample(processed_gsm8k) # Save GSM8K gsm8k_path = preparer.save_to_jsonl(processed_gsm8k, "gsm8k_processed.jsonl") print(f"\nSaved GSM8K processed data to: {gsm8k_path}") # Load and process ProofNet print("\nProcessing ProofNet dataset...") try: proofnet_dataset = load_dataset("hoskinson-center/proofnet", split="validation") print(f"Loaded {len(proofnet_dataset)} samples from ProofNet") processed_proofnet = preparer.process_proofnet(proofnet_dataset) print(f"Processed {len(processed_proofnet)} samples") preparer.print_sample(processed_proofnet) # Save ProofNet proofnet_path = preparer.save_to_jsonl(processed_proofnet, "proofnet_processed.jsonl") print(f"\nSaved ProofNet processed data to: {proofnet_path}") except Exception as e: print(f"Error processing ProofNet dataset: {str(e)}") print("Continuing with GSM8K data only") # Print memory usage preparer.print_memory_usage() if __name__ == "__main__": main()