File size: 2,501 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import json
import os
import sys
import re
from tqdm import tqdm

PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from data_factory.validator import SQLValidator

INPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
OUTPUT_FILE = "nl2sql_cleaned_ready_to_train.jsonl"

def main():
    if not os.path.exists(INPUT_FILE):
        print(f"Error: {INPUT_FILE} not found!")
        return

    print(f"Sweeping dataset to remove bad SQLs...")
    
    with open(INPUT_FILE, "r", encoding="utf-8") as f:
        lines = f.readlines()
        
    validators = {}
    cleaned_count = 0
    failed_count = 0
    
    with open(OUTPUT_FILE, "w", encoding="utf-8") as out_f:
        for line in tqdm(lines, desc="Filtering Garbage"):
            try:
                record = json.loads(line)
            except json.JSONDecodeError:
                failed_count += 1
                continue
                
            sql = record.get("sql", "").strip()
            metadata = record.get("metadata", {})
            domain = metadata.get("domain")
            
            # Fallback for domain extraction
            if not domain or domain == "unknown":
                content = record.get("prompt", [{}, {}])[1].get("content", "")
                match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", content)
                domain = match.group(1) if match else "unknown"
            
            if domain == "unknown":
                failed_count += 1
                continue
                
            if domain not in validators:
                validators[domain] = SQLValidator(domain, seed=42)
                
            try:
                val_result = validators[domain].validate(sql)
                # Keep ONLY if SQL is 100% perfect and returns data
                if val_result.passed and val_result.row_count > 0:
                    out_f.write(line)
                    cleaned_count += 1
                else:
                    failed_count += 1
            except Exception:
                failed_count += 1

    for v in validators.values():
        v.close()
        
    print("\n" + "="*50)
    print("DATASET CLEANUP COMPLETE")
    print("="*50)
    print(f"Original Rows : {len(lines)}")
    print(f"Cleaned Rows  : {cleaned_count} (100% Valid SQL)")
    print(f"Removed Rows  : {failed_count}")
    print(f"Saved To      : {OUTPUT_FILE}")
    print("="*50)

if __name__ == "__main__":
    main()