Spaces:
Sleeping
Sleeping
File size: 4,855 Bytes
570f7bd e3e0ac5 8b2d603 570f7bd 8b2d603 c1bc4eb e3e0ac5 8b2d603 e3e0ac5 8b2d603 6f6e439 8b2d603 c1bc4eb e3e0ac5 570f7bd 8b2d603 570f7bd e3e0ac5 6f6e439 8b2d603 6f6e439 8b2d603 c1bc4eb 8b2d603 6f6e439 8b2d603 6f6e439 8b2d603 e3e0ac5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
from __future__ import annotations
import re
from typing import Any, Dict, List, Tuple, Optional
__all__ = ["Planner"]
# --------- Heuristic schema trimming (safe, mypy-clean) ---------
def _tokenize_lower(s: str) -> List[str]:
return re.findall(r"[a-z_]+", (s or "").lower())
def _table_blocks(schema_text: str) -> List[Tuple[str, List[str]]]:
"""
Parse plain-text schema into [(table_name, lines)] blocks,
supporting both 'Table: name' and 'CREATE TABLE name (' styles.
"""
blocks: List[Tuple[str, List[str]]] = []
cur_name: Optional[str] = None
cur_lines: List[str] = []
def _flush() -> None:
nonlocal cur_name, cur_lines
if cur_name is not None and cur_lines:
blocks.append((cur_name, cur_lines[:]))
cur_name, cur_lines = None, []
for line in (schema_text or "").splitlines():
m = re.search(r"Table:\s*(\w+)", line, flags=re.IGNORECASE)
m2 = re.search(r"CREATE\s+TABLE\s+(\w+)\s*\(", line, flags=re.IGNORECASE)
started = False
name: Optional[str] = None
if m is not None:
name = m.group(1)
started = True
elif m2 is not None:
name = m2.group(1)
started = True
if started and name:
_flush()
cur_name = name
cur_lines.append(line)
else:
if cur_name is not None:
cur_lines.append(line)
if cur_name is not None and line.strip().endswith(");"):
_flush()
_flush()
return blocks
def _pick_relevant_tables(schema_text: str, question: str, k: int = 3) -> str:
"""Keep up to k tables with highest lexical overlap with the question."""
try:
blocks = _table_blocks(schema_text)
if not blocks:
return schema_text
q_toks = set(_tokenize_lower(question))
scored: List[Tuple[int, str, List[str]]] = []
for name, lines in blocks:
score = sum(1 for w in _tokenize_lower(name) if w in q_toks)
cols_line = " ".join(lines)
cols = re.findall(r"\b([A-Za-z_]\w*)\b", cols_line)
score += min(2, sum(1 for c in cols if c.lower() in q_toks))
scored.append((score, name, lines))
scored.sort(key=lambda t: t[0], reverse=True)
keep = [b for b in scored[: max(1, k)] if b[0] > 0]
if not keep:
keep = scored[: max(1, k)]
out_lines: List[str] = []
for _, _, lines in keep:
out_lines.extend(lines)
if lines and lines[-1].strip() != "":
out_lines.append("")
trimmed = "\n".join(out_lines).strip()
return trimmed if trimmed else schema_text
except Exception:
return schema_text
# --------- Add schema size check ---------
def _trim_if_large(schema_text: str, max_chars: int = 8000) -> str:
"""Trim schema if it's too large to prevent timeout"""
if len(schema_text) <= max_chars:
return schema_text
# Keep first part of schema that fits
lines = schema_text[:max_chars].splitlines()
# Try to end at a complete line
return "\n".join(lines[:-1]) if len(lines) > 1 else lines[0]
# ------------------------------ Planner ------------------------------
class Planner:
"""Planner wrapper around the LLM provider."""
def __init__(self, *, llm, model_id: str | None = None) -> None:
self.llm = llm
# ensure model_id is always a str (for mypy)
self.model_id: str = str(model_id or getattr(llm, "model", "unknown"))
# in-memory cache: (model, hash(q), hash(trimmed)) → (plan, pin, pout, cost)
self._plan_cache: dict[tuple[str, int, int], tuple[str, int, int, float]] = {}
def run(self, *, user_query: str, schema_preview: str) -> Dict[str, Any]:
# First apply relevance filtering
trimmed = _pick_relevant_tables(schema_preview or "", user_query or "", k=3)
# Then apply size limit to prevent timeout
trimmed = _trim_if_large(trimmed, max_chars=8000)
key: tuple[str, int, int] = (
self.model_id,
hash(user_query or ""),
hash(trimmed),
)
if key in self._plan_cache:
plan_text, pin, pout, cost = self._plan_cache[key]
else:
# Call with increased timeout
plan_text, pin, pout, cost = self.llm.plan(
user_query=user_query,
schema_preview=trimmed,
timeout=120, # Increase timeout for large schemas
)
self._plan_cache[key] = (plan_text, pin, pout, cost)
return {
"plan": plan_text,
"usage": {
"prompt_tokens": pin,
"completion_tokens": pout,
"cost_usd": cost,
},
}
|