File size: 4,610 Bytes
91e7690
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from __future__ import annotations

import itertools
import re
from dataclasses import dataclass
from typing import Iterable


@dataclass(frozen=True)
class AlgoConfig:
    w_coverage: float
    w_stat: float
    w_risk: float
    w_novelty: float
    limit_bonus: float
    repeat_penalty: float


def _query_features(sql: str) -> dict[str, float]:
    s = (sql or "").lower()
    return {
        "coverage": float(any(k in s for k in ["count(", "sum(", "avg(", "group by", "distinct"])),
        "stat": float(any(k in s for k in ["avg(", "stddev", "variance", "percentile", "try_cast", "strptime"])),
        "risk": float(any(k in s for k in ["drop", "truncate", "delete", "insert", "update", "alter", "create"])),
        "novelty": float(any(k in s for k in ["left join", "except", "not in", "having", "case when"])),
        "has_limit": float("limit" in s),
    }


def _task_keywords(task_id: int) -> list[str]:
    if task_id == 1:
        return ["null", "email", "customer_id", "duplicate", "group by"]
    if task_id == 2:
        return ["quantity", "amount", "n/a", "try_cast", "order_date"]
    return ["transactions_baseline", "transactions_current", "category", "user_id", "avg(amount)"]


def _task_relevance(task_id: int, sql: str) -> float:
    s = (sql or "").lower()
    keys = _task_keywords(task_id)
    hits = sum(1 for k in keys if k in s)
    return hits / max(1, len(keys))


def _sql_shape_penalty(sql: str) -> float:
    # Penalize very long and likely redundant SQL in a constrained step budget.
    length = len(sql or "")
    if length < 120:
        return 0.0
    if length < 300:
        return 0.02
    return 0.05


def algorithm_config_stream() -> Iterable[AlgoConfig]:
    # 11^4 * 7^2 = 717,409 total algorithm configurations.
    grid_a = [i / 10 for i in range(0, 11)]
    grid_b = [i / 20 for i in range(0, 7)]
    for a, b, c, d, e, f in itertools.product(grid_a, grid_a, grid_a, grid_a, grid_b, grid_b):
        yield AlgoConfig(
            w_coverage=a,
            w_stat=b,
            w_risk=c,
            w_novelty=d,
            limit_bonus=e,
            repeat_penalty=f,
        )


def _config_query_score(task_id: int, sql: str, cfg: AlgoConfig, q_prior: float) -> float:
    f = _query_features(sql)
    relevance = _task_relevance(task_id, sql)
    penalty_len = _sql_shape_penalty(sql)
    score = (
        cfg.w_coverage * f["coverage"]
        + cfg.w_stat * f["stat"]
        + cfg.w_novelty * f["novelty"]
        + cfg.limit_bonus * f["has_limit"]
        + 0.6 * relevance
        + 0.4 * q_prior
        - cfg.w_risk * f["risk"]
        - penalty_len
    )
    return score


def _ranking_for_config(task_id: int, queries: list[str], cfg: AlgoConfig, priors: list[float]) -> list[int]:
    pairs = []
    for i, q in enumerate(queries):
        pairs.append((i, _config_query_score(task_id, q, cfg, priors[i])))
    pairs.sort(key=lambda x: x[1], reverse=True)
    return [i for i, _ in pairs]


def select_best_config(task_id: int, queries: list[str], priors: list[float], max_configs: int = 100_000) -> AlgoConfig:
    best_cfg = None
    best_obj = -10**9

    for idx, cfg in enumerate(algorithm_config_stream()):
        if idx >= max_configs:
            break
        ranking = _ranking_for_config(task_id, queries, cfg, priors)

        # Objective: prioritize top-2 quality and diversity in SQL intent.
        top = ranking[:2]
        top_score = sum(_config_query_score(task_id, queries[i], cfg, priors[i]) for i in top)

        intents = set()
        for i in top:
            s = queries[i].lower()
            intent = "join" if any(k in s for k in ["join", "except", "not in"]) else "agg"
            intents.add(intent)
        diversity_bonus = 0.05 if len(intents) > 1 else 0.0

        obj = top_score + diversity_bonus
        if obj > best_obj:
            best_obj = obj
            best_cfg = cfg

    return best_cfg if best_cfg is not None else AlgoConfig(0.5, 0.5, 1.0, 0.5, 0.0, 0.0)


def ensemble_order(task_id: int, queries: list[str], priors: list[float], max_configs: int = 100_000) -> list[str]:
    cfg = select_best_config(task_id, queries, priors, max_configs=max_configs)
    ranking = _ranking_for_config(task_id, queries, cfg, priors)

    # De-prioritize unsafe SQL just in case external user-provided probes are included.
    safe = []
    unsafe = []
    for i in ranking:
        if re.search(r"\b(drop|truncate|delete|insert|update|alter|create)\b", queries[i], re.IGNORECASE):
            unsafe.append(queries[i])
        else:
            safe.append(queries[i])
    return safe + unsafe