Spaces:
Sleeping
Sleeping
| """ | |
| Data Processor for RAG System | |
| Processes WikiSQL dataset and prepares data for the RAG system. | |
| """ | |
| import json | |
| import os | |
| from typing import List, Dict, Any, Optional, Tuple | |
| from pathlib import Path | |
| import pandas as pd | |
| from datasets import load_dataset | |
| from loguru import logger | |
| class DataProcessor: | |
| """Processes WikiSQL dataset for RAG system.""" | |
| def __init__(self, data_dir: str = "./data"): | |
| """ | |
| Initialize the data processor. | |
| Args: | |
| data_dir: Directory to store processed data | |
| """ | |
| self.data_dir = Path(data_dir) | |
| self.data_dir.mkdir(parents=True, exist_ok=True) | |
| # File paths | |
| self.processed_data_path = self.data_dir / "processed_examples.json" | |
| self.vector_store_data_path = self.data_dir / "vector_store_data.json" | |
| self.statistics_path = self.data_dir / "data_statistics.json" | |
| logger.info(f"Data processor initialized at {self.data_dir}") | |
| def process_wikisql_dataset(self, | |
| max_examples: Optional[int] = None, | |
| split: str = "train") -> List[Dict[str, Any]]: | |
| """ | |
| Process WikiSQL dataset and prepare examples for RAG system. | |
| Args: | |
| max_examples: Maximum number of examples to process (None for all) | |
| split: Dataset split to use ('train', 'validation', 'test') | |
| Returns: | |
| List of processed examples | |
| """ | |
| try: | |
| logger.info(f"Loading WikiSQL {split} dataset...") | |
| # Load dataset | |
| dataset = load_dataset("wikisql", split=split) | |
| if max_examples: | |
| dataset = dataset.select(range(min(max_examples, len(dataset)))) | |
| logger.info(f"Processing {len(dataset)} examples...") | |
| # Process examples | |
| processed_examples = [] | |
| for i, example in enumerate(dataset): | |
| processed_example = self._process_single_example(example, i) | |
| if processed_example: | |
| processed_examples.append(processed_example) | |
| # Progress logging | |
| if (i + 1) % 1000 == 0: | |
| logger.info(f"Processed {i + 1}/{len(dataset)} examples") | |
| # Save processed data | |
| self._save_processed_data(processed_examples) | |
| # Generate statistics | |
| stats = self._generate_statistics(processed_examples) | |
| self._save_statistics(stats) | |
| logger.info(f"Successfully processed {len(processed_examples)} examples") | |
| return processed_examples | |
| except Exception as e: | |
| logger.error(f"Error processing WikiSQL dataset: {e}") | |
| raise | |
| def _process_single_example(self, example: Dict[str, Any], index: int) -> Optional[Dict[str, Any]]: | |
| """ | |
| Process a single WikiSQL example. | |
| Args: | |
| example: Raw example from WikiSQL dataset | |
| index: Example index | |
| Returns: | |
| Processed example or None if invalid | |
| """ | |
| try: | |
| # Extract basic information | |
| question = example.get("question", "").strip() | |
| table_headers = example.get("table", {}).get("header", []) | |
| sql_query = example.get("sql", {}).get("human_readable", "") | |
| # Validate example | |
| if not question or not table_headers or not sql_query: | |
| return None | |
| # Clean and normalize | |
| question = self._clean_text(question) | |
| table_headers = [self._clean_text(h) for h in table_headers] | |
| sql_query = self._clean_sql(sql_query) | |
| # Analyze complexity and categorize | |
| complexity = self._assess_example_complexity(question, sql_query) | |
| category = self._categorize_example(question, sql_query) | |
| # Create processed example | |
| processed_example = { | |
| "example_id": f"wikisql_{index}", | |
| "question": question, | |
| "table_headers": table_headers, | |
| "sql": sql_query, | |
| "difficulty": complexity, | |
| "category": category, | |
| "metadata": { | |
| "source": "wikisql", | |
| "split": "train", | |
| "original_index": index, | |
| "table_name": example.get("table", {}).get("name", "unknown"), | |
| "question_type": self._classify_question_type(question), | |
| "sql_features": self._extract_sql_features(sql_query) | |
| } | |
| } | |
| return processed_example | |
| except Exception as e: | |
| logger.warning(f"Error processing example {index}: {e}") | |
| return None | |
| def _clean_text(self, text: str) -> str: | |
| """Clean and normalize text.""" | |
| if not text: | |
| return "" | |
| # Remove extra whitespace | |
| text = " ".join(text.split()) | |
| # Remove special characters that might cause issues | |
| text = text.replace('"', "'").replace('"', "'") | |
| return text.strip() | |
| def _clean_sql(self, sql: str) -> str: | |
| """Clean and normalize SQL query.""" | |
| if not sql: | |
| return "" | |
| # Remove extra whitespace | |
| sql = " ".join(sql.split()) | |
| # Ensure proper SQL formatting | |
| sql = sql.replace(" ,", ",").replace(", ", ",") | |
| sql = sql.replace(" (", "(").replace("( ", "(") | |
| sql = sql.replace(" )", ")").replace(") ", ")") | |
| # Add semicolon if missing | |
| if not sql.endswith(';'): | |
| sql += ';' | |
| return sql.strip() | |
| def _assess_example_complexity(self, question: str, sql: str) -> str: | |
| """Assess the complexity of an example.""" | |
| complexity_score = 0 | |
| # Question complexity | |
| if len(question.split()) > 15: | |
| complexity_score += 2 | |
| elif len(question.split()) > 10: | |
| complexity_score += 1 | |
| # SQL complexity | |
| sql_lower = sql.lower() | |
| if 'join' in sql_lower: | |
| complexity_score += 2 | |
| if 'group by' in sql_lower: | |
| complexity_score += 2 | |
| if 'having' in sql_lower: | |
| complexity_score += 2 | |
| if 'subquery' in sql_lower or '(' in sql_lower and ')' in sql_lower: | |
| complexity_score += 2 | |
| if 'union' in sql_lower or 'intersect' in sql_lower: | |
| complexity_score += 3 | |
| # Determine difficulty level | |
| if complexity_score >= 6: | |
| return "hard" | |
| elif complexity_score >= 3: | |
| return "medium" | |
| else: | |
| return "easy" | |
| def _categorize_example(self, question: str, sql: str) -> str: | |
| """Categorize the example based on question and SQL.""" | |
| question_lower = question.lower() | |
| sql_lower = sql.lower() | |
| # Aggregation queries | |
| if any(word in question_lower for word in ['count', 'how many', 'number of']): | |
| return "aggregation" | |
| elif any(word in question_lower for word in ['average', 'mean', 'sum', 'total']): | |
| return "aggregation" | |
| # Grouping queries | |
| elif any(word in question_lower for word in ['group by', 'grouped', 'by department', 'by category']): | |
| return "grouping" | |
| # Join queries | |
| elif any(word in question_lower for word in ['join', 'combine', 'merge', 'connect']): | |
| return "join" | |
| # Sorting queries | |
| elif any(word in question_lower for word in ['order by', 'sort', 'rank', 'top', 'highest', 'lowest']): | |
| return "sorting" | |
| # Filtering queries | |
| elif any(word in question_lower for word in ['where', 'filter', 'condition']): | |
| return "filtering" | |
| # Simple queries | |
| else: | |
| return "simple" | |
| def _classify_question_type(self, question: str) -> str: | |
| """Classify the type of question.""" | |
| question_lower = question.lower() | |
| if '?' in question_lower: | |
| return "interrogative" | |
| elif any(word in question_lower for word in ['show', 'display', 'list']): | |
| return "display" | |
| elif any(word in question_lower for word in ['find', 'get', 'retrieve']): | |
| return "retrieval" | |
| else: | |
| return "statement" | |
| def _extract_sql_features(self, sql: str) -> List[str]: | |
| """Extract SQL features from the query.""" | |
| features = [] | |
| sql_lower = sql.lower() | |
| if 'select' in sql_lower: | |
| features.append("select") | |
| if 'from' in sql_lower: | |
| features.append("from") | |
| if 'where' in sql_lower: | |
| features.append("where") | |
| if 'join' in sql_lower: | |
| features.append("join") | |
| if 'group by' in sql_lower: | |
| features.append("group_by") | |
| if 'having' in sql_lower: | |
| features.append("having") | |
| if 'order by' in sql_lower: | |
| features.append("order_by") | |
| if 'limit' in sql_lower: | |
| features.append("limit") | |
| if 'distinct' in sql_lower: | |
| features.append("distinct") | |
| if 'count(' in sql_lower: | |
| features.append("count_aggregation") | |
| if 'avg(' in sql_lower: | |
| features.append("avg_aggregation") | |
| if 'sum(' in sql_lower: | |
| features.append("sum_aggregation") | |
| return features | |
| def _save_processed_data(self, examples: List[Dict[str, Any]]) -> None: | |
| """Save processed examples to file.""" | |
| try: | |
| with open(self.processed_data_path, 'w', encoding='utf-8') as f: | |
| json.dump(examples, f, indent=2, ensure_ascii=False) | |
| logger.info(f"Saved {len(examples)} processed examples to {self.processed_data_path}") | |
| except Exception as e: | |
| logger.error(f"Error saving processed data: {e}") | |
| def _save_statistics(self, stats: Dict[str, Any]) -> None: | |
| """Save data statistics to file.""" | |
| try: | |
| with open(self.statistics_path, 'w', encoding='utf-8') as f: | |
| json.dump(stats, f, indent=2, ensure_ascii=False) | |
| logger.info(f"Saved statistics to {self.statistics_path}") | |
| except Exception as e: | |
| logger.error(f"Error saving statistics: {e}") | |
| def _generate_statistics(self, examples: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| """Generate comprehensive statistics about the processed data.""" | |
| if not examples: | |
| return {"error": "No examples to analyze"} | |
| # Basic counts | |
| total_examples = len(examples) | |
| # Difficulty distribution | |
| difficulty_counts = {} | |
| for example in examples: | |
| difficulty = example.get("difficulty", "unknown") | |
| difficulty_counts[difficulty] = difficulty_counts.get(difficulty, 0) + 1 | |
| # Category distribution | |
| category_counts = {} | |
| for example in examples: | |
| category = example.get("category", "unknown") | |
| category_counts[category] = category_counts.get(category, 0) + 1 | |
| # Question type distribution | |
| question_type_counts = {} | |
| for example in examples: | |
| question_type = example.get("metadata", {}).get("question_type", "unknown") | |
| question_type_counts[question_type] = question_type_counts.get(question_type, 0) + 1 | |
| # SQL features distribution | |
| sql_features_counts = {} | |
| for example in examples: | |
| features = example.get("metadata", {}).get("sql_features", []) | |
| for feature in features: | |
| sql_features_counts[feature] = sql_features_counts.get(feature, 0) + 1 | |
| # Table schema statistics | |
| table_sizes = [] | |
| for example in examples: | |
| headers = example.get("table_headers", []) | |
| table_sizes.append(len(headers)) | |
| avg_table_size = sum(table_sizes) / len(table_sizes) if table_sizes else 0 | |
| return { | |
| "total_examples": total_examples, | |
| "difficulty_distribution": difficulty_counts, | |
| "category_distribution": category_counts, | |
| "question_type_distribution": question_type_counts, | |
| "sql_features_distribution": sql_features_counts, | |
| "table_schema_stats": { | |
| "average_columns": avg_table_size, | |
| "min_columns": min(table_sizes) if table_sizes else 0, | |
| "max_columns": max(table_sizes) if table_sizes else 0 | |
| }, | |
| "data_quality": { | |
| "examples_with_questions": sum(1 for e in examples if e.get("question")), | |
| "examples_with_sql": sum(1 for e in examples if e.get("sql")), | |
| "examples_with_headers": sum(1 for e in examples if e.get("table_headers")) | |
| } | |
| } | |
| def load_processed_data(self) -> List[Dict[str, Any]]: | |
| """Load previously processed data.""" | |
| try: | |
| if self.processed_data_path.exists(): | |
| with open(self.processed_data_path, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| logger.info(f"Loaded {len(data)} processed examples") | |
| return data | |
| else: | |
| logger.warning("No processed data found") | |
| return [] | |
| except Exception as e: | |
| logger.error(f"Error loading processed data: {e}") | |
| return [] | |
| def get_data_statistics(self) -> Dict[str, Any]: | |
| """Get current data statistics.""" | |
| try: | |
| if self.statistics_path.exists(): | |
| with open(self.statistics_path, 'r', encoding='utf-8') as f: | |
| stats = json.load(f) | |
| return stats | |
| else: | |
| return {"error": "No statistics available"} | |
| except Exception as e: | |
| logger.error(f"Error loading statistics: {e}") | |
| return {"error": str(e)} | |
| def create_sample_dataset(self, num_examples: int = 100) -> List[Dict[str, Any]]: | |
| """Create a small sample dataset for testing.""" | |
| sample_examples = [ | |
| { | |
| "example_id": "sample_1", | |
| "question": "How many employees are older than 30?", | |
| "table_headers": ["id", "name", "age", "department", "salary"], | |
| "sql": "SELECT COUNT(*) FROM employees WHERE age > 30;", | |
| "difficulty": "easy", | |
| "category": "aggregation", | |
| "metadata": { | |
| "source": "sample", | |
| "question_type": "interrogative", | |
| "sql_features": ["select", "count_aggregation", "where"] | |
| } | |
| }, | |
| { | |
| "example_id": "sample_2", | |
| "question": "Show all employees in IT department", | |
| "table_headers": ["id", "name", "age", "department", "salary"], | |
| "sql": "SELECT * FROM employees WHERE department = 'IT';", | |
| "difficulty": "easy", | |
| "category": "filtering", | |
| "metadata": { | |
| "source": "sample", | |
| "question_type": "display", | |
| "sql_features": ["select", "where"] | |
| } | |
| }, | |
| { | |
| "example_id": "sample_3", | |
| "question": "What is the average salary by department?", | |
| "table_headers": ["id", "name", "age", "department", "salary"], | |
| "sql": "SELECT department, AVG(salary) FROM employees GROUP BY department;", | |
| "difficulty": "medium", | |
| "category": "grouping", | |
| "metadata": { | |
| "source": "sample", | |
| "question_type": "interrogative", | |
| "sql_features": ["select", "avg_aggregation", "group_by"] | |
| } | |
| } | |
| ] | |
| # Add more examples if requested | |
| while len(sample_examples) < num_examples: | |
| base_example = sample_examples[len(sample_examples) % 3] | |
| new_example = base_example.copy() | |
| new_example["example_id"] = f"sample_{len(sample_examples) + 1}" | |
| sample_examples.append(new_example) | |
| return sample_examples[:num_examples] | |