from __future__ import annotations import re from typing import Any, Dict, List, Tuple, Optional __all__ = ["Planner"] # --------- Heuristic schema trimming (safe, mypy-clean) --------- def _tokenize_lower(s: str) -> List[str]: return re.findall(r"[a-z_]+", (s or "").lower()) def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]: """ Parse plain-text schema into [(table_name, lines)] blocks, supporting both 'Table: name' and 'CREATE TABLE name (' styles. """ blocks: List[Tuple[str, List[str]]] = [] cur_name: Optional[str] = None cur_lines: List[str] = [] def _flush() -> None: nonlocal cur_name, cur_lines if cur_name is not None and cur_lines: blocks.append((cur_name, cur_lines[:])) cur_name, cur_lines = None, [] for line in (schema_text or "").splitlines(): m = re.search(r"Table:\s*(\w+)", line, flags=re.IGNORECASE) m2 = re.search(r"CREATE\s+TABLE\s+(\w+)\s*\(", line, flags=re.IGNORECASE) started = False name: Optional[str] = None if m is not None: name = m.group(1) started = True elif m2 is not None: name = m2.group(1) started = True if started and name: _flush() cur_name = name cur_lines.append(line) else: if cur_name is not None: cur_lines.append(line) if cur_name is not None and line.strip().endswith(");"): _flush() _flush() return blocks def _pick_relevant_tables(schema_text: str, question: str, k: int = 3) -> str: """Keep up to k tables with highest lexical overlap with the question.""" try: blocks = _table_blocks(schema_text) if not blocks: return schema_text q_toks = set(_tokenize_lower(question)) scored: List[Tuple[int, str, List[str]]] = [] for name, lines in blocks: score = sum(1 for w in _tokenize_lower(name) if w in q_toks) cols_line = " ".join(lines) cols = re.findall(r"\b([A-Za-z_]\w*)\b", cols_line) score += min(2, sum(1 for c in cols if c.lower() in q_toks)) scored.append((score, name, lines)) scored.sort(key=lambda t: t[0], reverse=True) keep = [b for b in scored[: max(1, k)] if b[0] > 0] if not keep: keep = scored[: max(1, k)] out_lines: List[str] = [] for _, _, lines in keep: out_lines.extend(lines) if lines and lines[-1].strip() != "": out_lines.append("") trimmed = "\n".join(out_lines).strip() return trimmed if trimmed else schema_text except Exception: return schema_text # ------------------------------ Planner ------------------------------ class Planner: """Planner wrapper around the LLM provider.""" def __init__(self, *, llm, model_id: str | None = None) -> None: self.llm = llm # ensure model_id is always a str (for mypy) self.model_id: str = str(model_id or getattr(llm, "model", "unknown")) # in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost) self._plan_cache: dict[tuple[str, int, int], tuple[str, int, int, float]] = {} def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]: trimmed = _pick_relevant_tables(schema_preview or "", user_query or "", k=3) key: tuple[str, int, int] = ( self.model_id, hash(user_query or ""), hash(trimmed), ) if key in self._plan_cache: plan_text, pin, pout, cost = self._plan_cache[key] else: plan_text, pin, pout, cost = self.llm.plan( user_query=user_query, schema_preview=trimmed ) self._plan_cache[key] = (plan_text, pin, pout, cost) return { "plan": plan_text, "usage": { "prompt_tokens": pin, "completion_tokens": pout, "cost_usd": cost, }, }