Spaces:
Running
Running
File size: 4,359 Bytes
a39d8ef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | import json
import os
import sys
import re
from collections import Counter
from tqdm import tqdm
# Add project root to path
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from data_factory.validator import SQLValidator
DATASET_FILE = "edge_cases.jsonl"
def main():
if not os.path.exists(DATASET_FILE):
print(f"Error: {DATASET_FILE} not found!")
return
print("Starting Dataset Quality & Sanity Check...\n")
total_rows = 0
corrupt_json = 0
sql_execution_failures = 0
empty_outputs = 0
missing_domains = 0
persona_counts = Counter()
unique_sqls = set()
unique_questions = set()
domain_counts = Counter()
validators = {}
with open(DATASET_FILE, "r", encoding="utf-8") as f:
lines = f.readlines()
for line in tqdm(lines, desc="Analyzing Rows"):
total_rows += 1
try:
record = json.loads(line)
except json.JSONDecodeError:
corrupt_json += 1
continue
prompt_block = record.get("prompt", [])
sql = record.get("sql", "").strip()
metadata = record.get("metadata", {})
if not prompt_block or len(prompt_block) < 2 or not sql:
empty_outputs += 1
continue
user_content = prompt_block[1].get("content", "")
question = user_content.split("QUESTION: ")[-1]
# Smart Domain Extraction: Try metadata first, fallback to prompt parsing
domain = metadata.get("domain")
if not domain:
match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", user_content)
domain = match.group(1) if match else "unknown"
persona = metadata.get("persona", "unknown")
persona_counts[persona] += 1
domain_counts[domain] += 1
unique_sqls.add(sql)
unique_questions.add(question)
# Skip validation if domain is completely unknown/corrupted
if domain == "unknown":
missing_domains += 1
continue
# Strict Execution Quality Check
try:
if domain not in validators:
validators[domain] = SQLValidator(domain, seed=42)
val_result = validators[domain].validate(sql)
if not val_result.passed or val_result.row_count == 0:
sql_execution_failures += 1
except Exception as e:
# If any schema error occurs, mark it as failure
missing_domains += 1
continue
# Cleanup validators
for v in validators.values():
v.close()
# --- REPORT GENERATION ---
print("\n" + "="*60)
print("DATASET HEALTH REPORT")
print("="*60)
print(f"Total Rows Parsed : {total_rows}")
print(f"Corrupt JSON Lines : {corrupt_json}")
print(f"Missing SQL/Domains : {empty_outputs + missing_domains}")
print("\nDIVERSITY METRICS:")
print(f"Unique SQL Queries : {len(unique_sqls)} (Base logic templates)")
print(f"Unique NL Questions : {len(unique_questions)}")
valid_total = total_rows - (corrupt_json + empty_outputs + missing_domains)
duplication_rate = (1 - (len(unique_questions) / valid_total)) * 100 if valid_total else 0
print(f"NL Duplication Rate : {duplication_rate:.2f}% (Should be low!)")
print("\nPERSONA DISTRIBUTION:")
for p, count in persona_counts.most_common():
print(f" - {p}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {p}: {count}")
print("\nDOMAIN DISTRIBUTION:")
for d, count in domain_counts.most_common():
print(f" - {d}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {d}: {count}")
print("\nCRITICAL QUALITY CHECK:")
fail_rate = (sql_execution_failures / valid_total) * 100 if valid_total else 0
print(f"SQL Execution Failures : {sql_execution_failures} ({fail_rate:.2f}%)")
if fail_rate > 5.0:
print("WARNING: Too many SQLs are failing. Dataset needs cleanup.")
elif fail_rate > 0:
print("GOOD: Very low failure rate. Safe to train after minor filtering.")
else:
print("PERFECT: Zero execution failures. Pure Gold Dataset!")
print("="*60)
if __name__ == "__main__":
main() |