simpleLLM / math_expert /data_processor.py
hollywoodfrancis's picture
Upload 11 files
b8ab4a2 verified
import json
import yaml
import sympy
from sympy.parsing.latex import parse_latex
from huggingface_hub import hf_hub_download
from pathlib import Path
import jsonlines
from typing import Dict, List, Any
from config import DATASETS, DATA_PROCESSING
class MathDataProcessor:
def __init__(self):
self.processed_data = []
self.dataset_paths = {}
self.math_operations = {
"differentiation": self._process_differentiation,
"integration": self._process_integration,
"limits": self._process_limits,
"simplification": self._process_simplification,
"matrix": self._process_matrix,
"probability": self._process_probability,
"statistics": self._process_statistics
}
def download_dataset(self, dataset_name: str) -> Path:
"""Download dataset from Hugging Face"""
if dataset_name not in DATASETS:
raise ValueError(f"Dataset {dataset_name} not defined in configuration")
dataset_config = DATASETS[dataset_name]
dataset_path = Path(f"data/{dataset_name}")
# Download from Hugging Face
hf_hub_download(
repo_id=dataset_config["dataset_name"],
filename=f"{dataset_config['split']}.jsonl",
local_dir=dataset_path
)
self.dataset_paths[dataset_name] = dataset_path
return dataset_path
def normalize_equation(self, equation: str) -> str:
"""Normalize mathematical equations using sympy"""
try:
# Try to parse LaTeX first
if "\\" in equation:
eq = parse_latex(equation)
else:
eq = sympy.sympify(equation)
return str(eq)
except:
return equation
def process_proof_steps(self, steps: List[str]) -> List[Dict[str, str]]:
"""Process proof steps into structured format"""
processed_steps = []
for step in steps:
try:
# Try to parse as YAML if it contains structured data
structured_step = yaml.safe_load(step)
if isinstance(structured_step, dict):
processed_steps.append(structured_step)
else:
processed_steps.append({"step": step})
except:
processed_steps.append({"step": step})
return processed_steps
def _process_differentiation(self, expression: str) -> str:
"""Process and validate differentiation operations"""
x = sympy.Symbol('x')
try:
expr = sympy.sympify(expression)
derivative = sympy.diff(expr, x)
return str(derivative)
except:
return expression
def _process_integration(self, expression: str) -> str:
"""Process and validate integration operations"""
x = sympy.Symbol('x')
try:
expr = sympy.sympify(expression)
integral = sympy.integrate(expr, x)
return str(integral)
except:
return expression
def _process_limits(self, expression: str) -> str:
"""Process and validate limit operations"""
x = sympy.Symbol('x')
try:
expr = sympy.sympify(expression)
limit = sympy.limit(expr, x, sympy.oo)
return str(limit)
except:
return expression
def _process_simplification(self, expression: str) -> str:
"""Process and validate expression simplification"""
try:
expr = sympy.sympify(expression)
simplified = sympy.simplify(expr)
return str(simplified)
except:
return expression
def _process_matrix(self, matrix_str: str) -> str:
"""Process and validate matrix operations"""
try:
matrix = sympy.Matrix([[float(n) for n in row.split()]
for row in matrix_str.split(';')])
return str(matrix)
except:
return matrix_str
def _process_probability(self, problem: str) -> Dict:
"""Process probability problems and extract key parameters"""
try:
# Basic parsing for probability problems
if "probability" in problem.lower():
return {
"type": "probability",
"parameters": self._extract_parameters(problem),
"distribution": self._identify_distribution(problem)
}
return {"type": "unknown"}
except:
return {"type": "unknown"}
def _process_statistics(self, data: str) -> Dict:
"""Process statistical data and extract key metrics"""
try:
# Basic statistical processing
if "," in data:
numbers = [float(n) for n in data.split(',')]
return {
"mean": sum(numbers) / len(numbers),
"median": sorted(numbers)[len(numbers)//2],
"std_dev": self._calculate_std_dev(numbers)
}
return {"error": "Invalid data format"}
except:
return {"error": "Processing failed"}
def _extract_parameters(self, text: str) -> Dict:
"""Extract parameters from mathematical text"""
parameters = {}
# Basic parameter extraction logic
if "=" in text:
parts = text.split("=")
parameters["equation"] = parts[0].strip()
parameters["value"] = parts[1].strip()
return parameters
def _identify_distribution(self, text: str) -> str:
"""Identify probability distribution from text"""
distributions = {
"binomial": ["binomial", "bernoulli"],
"normal": ["normal", "gaussian"],
"poisson": ["poisson"],
"exponential": ["exponential"]
}
text_lower = text.lower()
for dist, keywords in distributions.items():
if any(keyword in text_lower for keyword in keywords):
return dist
return "unknown"
def _calculate_std_dev(self, numbers: List[float]) -> float:
"""Calculate standard deviation"""
mean = sum(numbers) / len(numbers)
variance = sum((x - mean) ** 2 for x in numbers) / len(numbers)
return variance ** 0.5
def process_math_operation(self, operation_type: str, content: str) -> Any:
"""Process a specific mathematical operation"""
if operation_type in self.math_operations:
return self.math_operations[operation_type](content)
return content
def validate_entry(self, entry: Dict[str, Any]) -> bool:
"""Enhanced validation with mathematical checks"""
steps = entry.get("steps", [])
text = entry.get("question", "") + entry.get("answer", "")
# Basic validation
if len(steps) < DATA_PROCESSING["validation"]["min_steps"]:
return False
if len(text) < DATA_PROCESSING["validation"]["min_length"]:
return False
# Mathematical validation
try:
# Check if equations are valid
if "equation" in entry:
sympy.sympify(entry["equation"])
# Check if steps follow logical progression
if len(steps) > 1:
for i in range(len(steps) - 1):
if not self._check_step_continuity(steps[i], steps[i+1]):
return False
# Check for circular logic in proofs
if "proof" in entry:
if not self._check_proof_validity(entry["proof"]):
return False
return True
except:
return False
def _check_step_continuity(self, step1: str, step2: str) -> bool:
"""Check if mathematical steps are logically connected"""
try:
# Basic check for logical progression
if "=" in step1 and "=" in step2:
s1 = step1.split("=")[1].strip()
s2 = step2.split("=")[0].strip()
return s1 == s2
return True
except:
return False
def _check_proof_validity(self, proof: str) -> bool:
"""Check if a proof is logically valid"""
# Basic proof validation
if "assume" in proof.lower() and "therefore" not in proof.lower():
return False
if "contradiction" in proof.lower() and "false" not in proof.lower():
return False
return True
def process_dataset(self, dataset_name: str):
"""Process a specific dataset according to its configuration"""
dataset_path = self.download_dataset(dataset_name)
dataset_config = DATASETS[dataset_name]
with jsonlines.open(dataset_path / f"{dataset_config['split']}.jsonl") as reader:
for entry in reader:
processed_entry = {}
# Process each field
for field in dataset_config["use_fields"]:
value = entry.get(field)
if value:
if field == "equation":
processed_entry[field] = self.normalize_equation(value)
elif field == "proof_steps":
processed_entry[field] = self.process_proof_steps(value)
else:
processed_entry[field] = value
# Validate and add if valid
if self.validate_entry(processed_entry):
self.processed_data.append(processed_entry)
def save_processed_data(self, output_path: str):
"""Save processed data to JSONL format"""
with jsonlines.open(output_path, mode='w') as writer:
writer.write_all(self.processed_data)
if __name__ == "__main__":
processor = MathDataProcessor()
# Process all defined datasets
for dataset in DATASETS.keys():
processor.process_dataset(dataset)
# Save processed data
output_path = "processed_data/math_expert_data.jsonl"
processor.save_processed_data(output_path)