File size: 2,991 Bytes
91e7690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from __future__ import annotations

import json
import re
from dataclasses import dataclass
from typing import Any


SAFE_SQL_RE = re.compile(r"^\s*(select|with)\b", re.IGNORECASE)
BLOCKED_SQL_RE = re.compile(r"\b(drop|truncate|delete|insert|update|alter|create)\b", re.IGNORECASE)


@dataclass
class PlanBundle:
    hypotheses: list[str]
    extra_queries: list[str]


def safe_query_filter(queries: list[str]) -> list[str]:
    out: list[str] = []
    seen: set[str] = set()
    for q in queries:
        s = (q or "").strip().rstrip(";")
        if not s:
            continue
        if not SAFE_SQL_RE.match(s):
            continue
        if BLOCKED_SQL_RE.search(s):
            continue
        key = re.sub(r"\s+", " ", s.lower())
        if key in seen:
            continue
        seen.add(key)
        out.append(s)
    return out


def parse_plan_json(raw: str) -> PlanBundle:
    try:
        payload = json.loads(raw)
        if not isinstance(payload, dict):
            return PlanBundle(hypotheses=[], extra_queries=[])
        hypotheses = payload.get("hypotheses", [])
        extra_queries = payload.get("extra_queries", [])
        return PlanBundle(
            hypotheses=[str(x) for x in hypotheses][:6],
            extra_queries=safe_query_filter([str(x) for x in extra_queries])[:3],
        )
    except Exception:
        return PlanBundle(hypotheses=[], extra_queries=[])


def build_plan_prompt(task_id: int, table_name: str, schema: dict[str, str], base_queries: list[str]) -> str:
    prompt = {
        "task_id": task_id,
        "table_name": table_name,
        "schema": schema,
        "base_queries": base_queries,
        "instruction": (
            "Propose short investigation hypotheses and at most 3 additional safe SELECT queries. "
            "Return JSON only with keys: hypotheses (list[str]) and extra_queries (list[str])."
        ),
    }
    return json.dumps(prompt)


def validate_and_repair_report(report: dict[str, Any]) -> dict[str, Any]:
    fixed = dict(report)
    fixed.setdefault("null_issues", {})
    fixed.setdefault("duplicate_row_count", 0)
    fixed.setdefault("schema_violations", [])
    fixed.setdefault("drifted_columns", [])
    fixed.setdefault("drift_details", {})
    fixed.setdefault("recommended_fixes", [])

    if not isinstance(fixed["null_issues"], dict):
        fixed["null_issues"] = {}
    if not isinstance(fixed["duplicate_row_count"], int):
        try:
            fixed["duplicate_row_count"] = int(fixed["duplicate_row_count"])
        except Exception:
            fixed["duplicate_row_count"] = 0
    if not isinstance(fixed["schema_violations"], list):
        fixed["schema_violations"] = []
    if not isinstance(fixed["drifted_columns"], list):
        fixed["drifted_columns"] = []
    if not isinstance(fixed["drift_details"], dict):
        fixed["drift_details"] = {}
    if not isinstance(fixed["recommended_fixes"], list):
        fixed["recommended_fixes"] = []

    return fixed