import json import yaml import sympy from sympy.parsing.latex import parse_latex from huggingface_hub import hf_hub_download from pathlib import Path import jsonlines from typing import Dict, List, Any from config import DATASETS, DATA_PROCESSING class MathDataProcessor: def __init__(self): self.processed_data = [] self.dataset_paths = {} self.math_operations = { "differentiation": self._process_differentiation, "integration": self._process_integration, "limits": self._process_limits, "simplification": self._process_simplification, "matrix": self._process_matrix, "probability": self._process_probability, "statistics": self._process_statistics } def download_dataset(self, dataset_name: str) -> Path: """Download dataset from Hugging Face""" if dataset_name not in DATASETS: raise ValueError(f"Dataset {dataset_name} not defined in configuration") dataset_config = DATASETS[dataset_name] dataset_path = Path(f"data/{dataset_name}") # Download from Hugging Face hf_hub_download( repo_id=dataset_config["dataset_name"], filename=f"{dataset_config['split']}.jsonl", local_dir=dataset_path ) self.dataset_paths[dataset_name] = dataset_path return dataset_path def normalize_equation(self, equation: str) -> str: """Normalize mathematical equations using sympy""" try: # Try to parse LaTeX first if "\\" in equation: eq = parse_latex(equation) else: eq = sympy.sympify(equation) return str(eq) except: return equation def process_proof_steps(self, steps: List[str]) -> List[Dict[str, str]]: """Process proof steps into structured format""" processed_steps = [] for step in steps: try: # Try to parse as YAML if it contains structured data structured_step = yaml.safe_load(step) if isinstance(structured_step, dict): processed_steps.append(structured_step) else: processed_steps.append({"step": step}) except: processed_steps.append({"step": step}) return processed_steps def _process_differentiation(self, expression: str) -> str: """Process and validate differentiation operations""" x = sympy.Symbol('x') try: expr = sympy.sympify(expression) derivative = sympy.diff(expr, x) return str(derivative) except: return expression def _process_integration(self, expression: str) -> str: """Process and validate integration operations""" x = sympy.Symbol('x') try: expr = sympy.sympify(expression) integral = sympy.integrate(expr, x) return str(integral) except: return expression def _process_limits(self, expression: str) -> str: """Process and validate limit operations""" x = sympy.Symbol('x') try: expr = sympy.sympify(expression) limit = sympy.limit(expr, x, sympy.oo) return str(limit) except: return expression def _process_simplification(self, expression: str) -> str: """Process and validate expression simplification""" try: expr = sympy.sympify(expression) simplified = sympy.simplify(expr) return str(simplified) except: return expression def _process_matrix(self, matrix_str: str) -> str: """Process and validate matrix operations""" try: matrix = sympy.Matrix([[float(n) for n in row.split()] for row in matrix_str.split(';')]) return str(matrix) except: return matrix_str def _process_probability(self, problem: str) -> Dict: """Process probability problems and extract key parameters""" try: # Basic parsing for probability problems if "probability" in problem.lower(): return { "type": "probability", "parameters": self._extract_parameters(problem), "distribution": self._identify_distribution(problem) } return {"type": "unknown"} except: return {"type": "unknown"} def _process_statistics(self, data: str) -> Dict: """Process statistical data and extract key metrics""" try: # Basic statistical processing if "," in data: numbers = [float(n) for n in data.split(',')] return { "mean": sum(numbers) / len(numbers), "median": sorted(numbers)[len(numbers)//2], "std_dev": self._calculate_std_dev(numbers) } return {"error": "Invalid data format"} except: return {"error": "Processing failed"} def _extract_parameters(self, text: str) -> Dict: """Extract parameters from mathematical text""" parameters = {} # Basic parameter extraction logic if "=" in text: parts = text.split("=") parameters["equation"] = parts[0].strip() parameters["value"] = parts[1].strip() return parameters def _identify_distribution(self, text: str) -> str: """Identify probability distribution from text""" distributions = { "binomial": ["binomial", "bernoulli"], "normal": ["normal", "gaussian"], "poisson": ["poisson"], "exponential": ["exponential"] } text_lower = text.lower() for dist, keywords in distributions.items(): if any(keyword in text_lower for keyword in keywords): return dist return "unknown" def _calculate_std_dev(self, numbers: List[float]) -> float: """Calculate standard deviation""" mean = sum(numbers) / len(numbers) variance = sum((x - mean) ** 2 for x in numbers) / len(numbers) return variance ** 0.5 def process_math_operation(self, operation_type: str, content: str) -> Any: """Process a specific mathematical operation""" if operation_type in self.math_operations: return self.math_operations[operation_type](content) return content def validate_entry(self, entry: Dict[str, Any]) -> bool: """Enhanced validation with mathematical checks""" steps = entry.get("steps", []) text = entry.get("question", "") + entry.get("answer", "") # Basic validation if len(steps) < DATA_PROCESSING["validation"]["min_steps"]: return False if len(text) < DATA_PROCESSING["validation"]["min_length"]: return False # Mathematical validation try: # Check if equations are valid if "equation" in entry: sympy.sympify(entry["equation"]) # Check if steps follow logical progression if len(steps) > 1: for i in range(len(steps) - 1): if not self._check_step_continuity(steps[i], steps[i+1]): return False # Check for circular logic in proofs if "proof" in entry: if not self._check_proof_validity(entry["proof"]): return False return True except: return False def _check_step_continuity(self, step1: str, step2: str) -> bool: """Check if mathematical steps are logically connected""" try: # Basic check for logical progression if "=" in step1 and "=" in step2: s1 = step1.split("=")[1].strip() s2 = step2.split("=")[0].strip() return s1 == s2 return True except: return False def _check_proof_validity(self, proof: str) -> bool: """Check if a proof is logically valid""" # Basic proof validation if "assume" in proof.lower() and "therefore" not in proof.lower(): return False if "contradiction" in proof.lower() and "false" not in proof.lower(): return False return True def process_dataset(self, dataset_name: str): """Process a specific dataset according to its configuration""" dataset_path = self.download_dataset(dataset_name) dataset_config = DATASETS[dataset_name] with jsonlines.open(dataset_path / f"{dataset_config['split']}.jsonl") as reader: for entry in reader: processed_entry = {} # Process each field for field in dataset_config["use_fields"]: value = entry.get(field) if value: if field == "equation": processed_entry[field] = self.normalize_equation(value) elif field == "proof_steps": processed_entry[field] = self.process_proof_steps(value) else: processed_entry[field] = value # Validate and add if valid if self.validate_entry(processed_entry): self.processed_data.append(processed_entry) def save_processed_data(self, output_path: str): """Save processed data to JSONL format""" with jsonlines.open(output_path, mode='w') as writer: writer.write_all(self.processed_data) if __name__ == "__main__": processor = MathDataProcessor() # Process all defined datasets for dataset in DATASETS.keys(): processor.process_dataset(dataset) # Save processed data output_path = "processed_data/math_expert_data.jsonl" processor.save_processed_data(output_path)