nl2sql-bench / clean_dataset.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
import json
import os
import sys
import re
from tqdm import tqdm
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from data_factory.validator import SQLValidator
INPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
OUTPUT_FILE = "nl2sql_cleaned_ready_to_train.jsonl"
def main():
if not os.path.exists(INPUT_FILE):
print(f"Error: {INPUT_FILE} not found!")
return
print(f"Sweeping dataset to remove bad SQLs...")
with open(INPUT_FILE, "r", encoding="utf-8") as f:
lines = f.readlines()
validators = {}
cleaned_count = 0
failed_count = 0
with open(OUTPUT_FILE, "w", encoding="utf-8") as out_f:
for line in tqdm(lines, desc="Filtering Garbage"):
try:
record = json.loads(line)
except json.JSONDecodeError:
failed_count += 1
continue
sql = record.get("sql", "").strip()
metadata = record.get("metadata", {})
domain = metadata.get("domain")
# Fallback for domain extraction
if not domain or domain == "unknown":
content = record.get("prompt", [{}, {}])[1].get("content", "")
match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", content)
domain = match.group(1) if match else "unknown"
if domain == "unknown":
failed_count += 1
continue
if domain not in validators:
validators[domain] = SQLValidator(domain, seed=42)
try:
val_result = validators[domain].validate(sql)
# Keep ONLY if SQL is 100% perfect and returns data
if val_result.passed and val_result.row_count > 0:
out_f.write(line)
cleaned_count += 1
else:
failed_count += 1
except Exception:
failed_count += 1
for v in validators.values():
v.close()
print("\n" + "="*50)
print("DATASET CLEANUP COMPLETE")
print("="*50)
print(f"Original Rows : {len(lines)}")
print(f"Cleaned Rows : {cleaned_count} (100% Valid SQL)")
print(f"Removed Rows : {failed_count}")
print(f"Saved To : {OUTPUT_FILE}")
print("="*50)
if __name__ == "__main__":
main()