File size: 4,855 Bytes
570f7bd
e3e0ac5
8b2d603
 
570f7bd
8b2d603
 
 
 
 
 
c1bc4eb
e3e0ac5
8b2d603
e3e0ac5
8b2d603
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f6e439
 
 
 
 
 
 
 
 
 
 
 
8b2d603
 
 
c1bc4eb
e3e0ac5
570f7bd
8b2d603
 
 
 
570f7bd
e3e0ac5
6f6e439
8b2d603
6f6e439
 
8b2d603
 
 
 
 
c1bc4eb
8b2d603
 
 
6f6e439
8b2d603
6f6e439
 
 
8b2d603
 
 
e3e0ac5
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
from __future__ import annotations

import re
from typing import Any, Dict, List, Tuple, Optional

__all__ = ["Planner"]


# --------- Heuristic schema trimming (safe, mypy-clean) ---------
def _tokenize_lower(s: str) -> List[str]:
    return re.findall(r"[a-z_]+", (s or "").lower())


def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]:
    """
    Parse plain-text schema into [(table_name, lines)] blocks,
    supporting both 'Table: name' and 'CREATE TABLE name (' styles.
    """
    blocks: List[Tuple[str, List[str]]] = []
    cur_name: Optional[str] = None
    cur_lines: List[str] = []

    def _flush() -> None:
        nonlocal cur_name, cur_lines
        if cur_name is not None and cur_lines:
            blocks.append((cur_name, cur_lines[:]))
        cur_name, cur_lines = None, []

    for line in (schema_text or "").splitlines():
        m = re.search(r"Table:\s*(\w+)", line, flags=re.IGNORECASE)
        m2 = re.search(r"CREATE\s+TABLE\s+(\w+)\s*\(", line, flags=re.IGNORECASE)

        started = False
        name: Optional[str] = None
        if m is not None:
            name = m.group(1)
            started = True
        elif m2 is not None:
            name = m2.group(1)
            started = True

        if started and name:
            _flush()
            cur_name = name
            cur_lines.append(line)
        else:
            if cur_name is not None:
                cur_lines.append(line)

        if cur_name is not None and line.strip().endswith(");"):
            _flush()

    _flush()
    return blocks


def _pick_relevant_tables(schema_text: str, question: str, k: int = 3) -> str:
    """Keep up to k tables with highest lexical overlap with the question."""
    try:
        blocks = _table_blocks(schema_text)
        if not blocks:
            return schema_text

        q_toks = set(_tokenize_lower(question))
        scored: List[Tuple[int, str, List[str]]] = []
        for name, lines in blocks:
            score = sum(1 for w in _tokenize_lower(name) if w in q_toks)
            cols_line = " ".join(lines)
            cols = re.findall(r"\b([A-Za-z_]\w*)\b", cols_line)
            score += min(2, sum(1 for c in cols if c.lower() in q_toks))
            scored.append((score, name, lines))

        scored.sort(key=lambda t: t[0], reverse=True)
        keep = [b for b in scored[: max(1, k)] if b[0] > 0]
        if not keep:
            keep = scored[: max(1, k)]

        out_lines: List[str] = []
        for _, _, lines in keep:
            out_lines.extend(lines)
            if lines and lines[-1].strip() != "":
                out_lines.append("")
        trimmed = "\n".join(out_lines).strip()
        return trimmed if trimmed else schema_text
    except Exception:
        return schema_text


# --------- Add schema size check ---------
def _trim_if_large(schema_text: str, max_chars: int = 8000) -> str:
    """Trim schema if it's too large to prevent timeout"""
    if len(schema_text) <= max_chars:
        return schema_text

    # Keep first part of schema that fits
    lines = schema_text[:max_chars].splitlines()
    # Try to end at a complete line
    return "\n".join(lines[:-1]) if len(lines) > 1 else lines[0]


# ------------------------------ Planner ------------------------------
class Planner:
    """Planner wrapper around the LLM provider."""

    def __init__(self, *, llm, model_id: str | None = None) -> None:
        self.llm = llm
        # ensure model_id is always a str (for mypy)
        self.model_id: str = str(model_id or getattr(llm, "model", "unknown"))
        # in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost)
        self._plan_cache: dict[tuple[str, int, int], tuple[str, int, int, float]] = {}

    def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
        # First apply relevance filtering
        trimmed = _pick_relevant_tables(schema_preview or "", user_query or "", k=3)
        # Then apply size limit to prevent timeout
        trimmed = _trim_if_large(trimmed, max_chars=8000)

        key: tuple[str, int, int] = (
            self.model_id,
            hash(user_query or ""),
            hash(trimmed),
        )
        if key in self._plan_cache:
            plan_text, pin, pout, cost = self._plan_cache[key]
        else:
            # Call with increased timeout
            plan_text, pin, pout, cost = self.llm.plan(
                user_query=user_query,
                schema_preview=trimmed,
                timeout=120,  # Increase timeout for large schemas
            )
            self._plan_cache[key] = (plan_text, pin, pout, cost)

        return {
            "plan": plan_text,
            "usage": {
                "prompt_tokens": pin,
                "completion_tokens": pout,
                "cost_usd": cost,
            },
        }