File size: 4,359 Bytes
a39d8ef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import json
import os
import sys
import re
from collections import Counter
from tqdm import tqdm

# Add project root to path
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

from data_factory.validator import SQLValidator

DATASET_FILE = "edge_cases.jsonl"

def main():
    if not os.path.exists(DATASET_FILE):
        print(f"Error: {DATASET_FILE} not found!")
        return

    print("Starting Dataset Quality & Sanity Check...\n")

    total_rows = 0
    corrupt_json = 0
    sql_execution_failures = 0
    empty_outputs = 0
    missing_domains = 0
    
    persona_counts = Counter()
    unique_sqls = set()
    unique_questions = set()
    domain_counts = Counter()

    validators = {}

    with open(DATASET_FILE, "r", encoding="utf-8") as f:
        lines = f.readlines()
        
    for line in tqdm(lines, desc="Analyzing Rows"):
        total_rows += 1
        try:
            record = json.loads(line)
        except json.JSONDecodeError:
            corrupt_json += 1
            continue

        prompt_block = record.get("prompt", [])
        sql = record.get("sql", "").strip()
        metadata = record.get("metadata", {})
        
        if not prompt_block or len(prompt_block) < 2 or not sql:
            empty_outputs += 1
            continue

        user_content = prompt_block[1].get("content", "")
        question = user_content.split("QUESTION: ")[-1]
        
        # Smart Domain Extraction: Try metadata first, fallback to prompt parsing
        domain = metadata.get("domain")
        if not domain:
            match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", user_content)
            domain = match.group(1) if match else "unknown"

        persona = metadata.get("persona", "unknown")

        persona_counts[persona] += 1
        domain_counts[domain] += 1
        unique_sqls.add(sql)
        unique_questions.add(question)

        # Skip validation if domain is completely unknown/corrupted
        if domain == "unknown":
            missing_domains += 1
            continue

        # Strict Execution Quality Check
        try:
            if domain not in validators:
                validators[domain] = SQLValidator(domain, seed=42)
                
            val_result = validators[domain].validate(sql)
            if not val_result.passed or val_result.row_count == 0:
                sql_execution_failures += 1
        except Exception as e:
            # If any schema error occurs, mark it as failure
            missing_domains += 1
            continue

    # Cleanup validators
    for v in validators.values():
        v.close()

    # --- REPORT GENERATION ---
    print("\n" + "="*60)
    print("DATASET HEALTH REPORT")
    print("="*60)
    print(f"Total Rows Parsed       : {total_rows}")
    print(f"Corrupt JSON Lines      : {corrupt_json}")
    print(f"Missing SQL/Domains     : {empty_outputs + missing_domains}")
    
    print("\nDIVERSITY METRICS:")
    print(f"Unique SQL Queries      : {len(unique_sqls)} (Base logic templates)")
    print(f"Unique NL Questions     : {len(unique_questions)}")
    
    valid_total = total_rows - (corrupt_json + empty_outputs + missing_domains)
    duplication_rate = (1 - (len(unique_questions) / valid_total)) * 100 if valid_total else 0
    print(f"NL Duplication Rate     : {duplication_rate:.2f}% (Should be low!)")

    print("\nPERSONA DISTRIBUTION:")
    for p, count in persona_counts.most_common():
        print(f" - {p}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {p}: {count}")

    print("\nDOMAIN DISTRIBUTION:")
    for d, count in domain_counts.most_common():
        print(f" - {d}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {d}: {count}")

    print("\nCRITICAL QUALITY CHECK:")
    fail_rate = (sql_execution_failures / valid_total) * 100 if valid_total else 0
    print(f"SQL Execution Failures  : {sql_execution_failures} ({fail_rate:.2f}%)")
    
    if fail_rate > 5.0:
        print("WARNING: Too many SQLs are failing. Dataset needs cleanup.")
    elif fail_rate > 0:
        print("GOOD: Very low failure rate. Safe to train after minor filtering.")
    else:
        print("PERFECT: Zero execution failures. Pure Gold Dataset!")
    print("="*60)

if __name__ == "__main__":
    main()