Spaces:
Running
Running
| import json | |
| import os | |
| import sys | |
| import re | |
| from tqdm import tqdm | |
| PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) | |
| if PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, PROJECT_ROOT) | |
| from data_factory.validator import SQLValidator | |
| INPUT_FILE = "nl2sql_50k_elite_dataset.jsonl" | |
| OUTPUT_FILE = "nl2sql_cleaned_ready_to_train.jsonl" | |
| def main(): | |
| if not os.path.exists(INPUT_FILE): | |
| print(f"Error: {INPUT_FILE} not found!") | |
| return | |
| print(f"Sweeping dataset to remove bad SQLs...") | |
| with open(INPUT_FILE, "r", encoding="utf-8") as f: | |
| lines = f.readlines() | |
| validators = {} | |
| cleaned_count = 0 | |
| failed_count = 0 | |
| with open(OUTPUT_FILE, "w", encoding="utf-8") as out_f: | |
| for line in tqdm(lines, desc="Filtering Garbage"): | |
| try: | |
| record = json.loads(line) | |
| except json.JSONDecodeError: | |
| failed_count += 1 | |
| continue | |
| sql = record.get("sql", "").strip() | |
| metadata = record.get("metadata", {}) | |
| domain = metadata.get("domain") | |
| # Fallback for domain extraction | |
| if not domain or domain == "unknown": | |
| content = record.get("prompt", [{}, {}])[1].get("content", "") | |
| match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", content) | |
| domain = match.group(1) if match else "unknown" | |
| if domain == "unknown": | |
| failed_count += 1 | |
| continue | |
| if domain not in validators: | |
| validators[domain] = SQLValidator(domain, seed=42) | |
| try: | |
| val_result = validators[domain].validate(sql) | |
| # Keep ONLY if SQL is 100% perfect and returns data | |
| if val_result.passed and val_result.row_count > 0: | |
| out_f.write(line) | |
| cleaned_count += 1 | |
| else: | |
| failed_count += 1 | |
| except Exception: | |
| failed_count += 1 | |
| for v in validators.values(): | |
| v.close() | |
| print("\n" + "="*50) | |
| print("DATASET CLEANUP COMPLETE") | |
| print("="*50) | |
| print(f"Original Rows : {len(lines)}") | |
| print(f"Cleaned Rows : {cleaned_count} (100% Valid SQL)") | |
| print(f"Removed Rows : {failed_count}") | |
| print(f"Saved To : {OUTPUT_FILE}") | |
| print("="*50) | |
| if __name__ == "__main__": | |
| main() |