File size: 7,666 Bytes
b69a231
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
from __future__ import annotations

from pathlib import Path
from typing import Optional, Dict, Any
import sqlite3

from sqlalchemy import create_engine

from orchestrator.settings import Settings

from langchain_groq import ChatGroq
from langchain_community.utilities.sql_database import SQLDatabase
from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
from langchain_community.agent_toolkits.sql.base import create_sql_agent


def _resolve_sqlite_path(settings: Settings, db_path: Optional[str] = None) -> Path:
    p = Path(db_path or settings.sqlite_path)
    if not p.is_absolute():
        # project root = parent of orchestrator/
        p = (Path(__file__).resolve().parents[1] / p).resolve()
    return p


def _make_sql_db_readonly(sqlite_path: Path) -> SQLDatabase:
    if not sqlite_path.exists():
        raise FileNotFoundError(
            f"SQLite DB not found at: {sqlite_path}\n"
            f"Fix: put student.db at project root OR set SQLITE_PATH to an absolute path."
        )

    def _connect():
        return sqlite3.connect(f"file:{sqlite_path.as_posix()}?mode=ro", uri=True)

    engine = create_engine("sqlite:///", creator=_connect)
    return SQLDatabase(engine)


def _make_llm(settings: Settings):
    # ChatGroq param names differ across versions; support both.
    try:
        return ChatGroq(
            api_key=settings.groq_api_key,
            model=settings.llm_model,
            temperature=0,
        )
    except TypeError:
        return ChatGroq(
            groq_api_key=settings.groq_api_key,
            model_name=settings.llm_model,
            temperature=0,
        )


def make_sql_agent(settings: Settings, *, db_path: Optional[str] = None):
    llm = _make_llm(settings)

    sqlite_path = _resolve_sqlite_path(settings, db_path=db_path)
    db = _make_sql_db_readonly(sqlite_path)

    toolkit = SQLDatabaseToolkit(db=db, llm=llm)

    # This is the key difference vs your b version:
    # Force the tool-calling SQL agent (most reliable on LC 1.2.x).
    agent = create_sql_agent(
        llm=llm,
        toolkit=toolkit,
        agent_type="tool-calling",
        handle_parsing_errors=True,
        max_iterations=30,
        max_execution_time=60,
        verbose=bool(settings.debug),
        return_intermediate_steps=bool(settings.debug),
    )

    return agent, db, str(sqlite_path)


def sql_answer(settings: Settings, question: str, *, db_path: Optional[str] = None) -> Dict[str, Any]:
    agent, db, sqlite_path = make_sql_agent(settings, db_path=db_path)

    q = (question or "").strip().lower()

    # Keep your deterministic shortcut (nice UX)
    if any(s in q for s in ["list the tables", "list tables", "show tables", "what tables"]):
        tables = db.get_usable_table_names()
        return {"answer": "Tables: " + ", ".join(tables), "db_path": sqlite_path}

    # Run agent
    out = agent.invoke({"input": question})

    # Normalize output
    answer = out.get("output") if isinstance(out, dict) else str(out)

    result = {"answer": str(answer), "db_path": sqlite_path, "agent": "sql"}

    # If debug enabled, surface intermediate steps in Streamlit expander
    if isinstance(out, dict) and "intermediate_steps" in out:
        result["intermediate_steps"] = out["intermediate_steps"]

    return result







# from __future__ import annotations

# from pathlib import Path
# from typing import Optional, Dict, Any
# import sqlite3

# from sqlalchemy import create_engine

# from orchestrator.settings import Settings
# from orchestrator.factories import get_llm

# # --- Imports that vary across LangChain versions ---
# try:
#     # langchain >= 1.x
#     from langchain.sql_database import SQLDatabase
# except Exception:
#     # older / community
#     from langchain_community.utilities import SQLDatabase

# try:
#     from langchain_community.agent_toolkits.sql.toolkit import SQLDatabaseToolkit
# except Exception:
#     # older path (rare)
#     from langchain.agents.agent_toolkits import SQLDatabaseToolkit

# try:
#     from langchain.agents import create_sql_agent
# except Exception:
#     from langchain_community.agent_toolkits.sql.base import create_sql_agent


# def _resolve_sqlite_path(settings: Settings) -> Path:
#     """
#     Resolve SQLITE_PATH relative to project root (parent of orchestrator/),
#     so Streamlit's current working directory does not break DB loading.
#     """
#     p = Path(settings.sqlite_path)
#     if not p.is_absolute():
#         p = (Path(__file__).resolve().parents[1] / p).resolve()
#     return p


# def _make_sql_db_readonly(sqlite_path: Path) -> SQLDatabase:
#     """
#     Open SQLite in READ-ONLY mode so a wrong path does NOT create an empty DB file.
#     """
#     if not sqlite_path.exists():
#         raise FileNotFoundError(
#             f"SQLite DB not found at: {sqlite_path}\n"
#             f"Fix: put student.db at the project root OR set SQLITE_PATH to an absolute path."
#         )

#     def _connect():
#         return sqlite3.connect(f"file:{sqlite_path.as_posix()}?mode=ro", uri=True)

#     engine = create_engine("sqlite:///", creator=_connect)
#     return SQLDatabase(engine)


# def _create_agent(llm, toolkit, verbose: bool):
#     """
#     Create SQL agent WITHOUT passing kwargs that frequently clash with defaults
#     in langchain-classic AgentExecutor.
#     """
#     # Keep only the safest option; many builds already set other defaults internally.
#     agent_exec_kwargs = {"handle_parsing_errors": True}

#     # Some versions accept max_iterations/max_execution_time top-level.
#     # Some accept neither.
#     # We try progressively.
#     try:
#         return create_sql_agent(
#             llm=llm,
#             toolkit=toolkit,
#             verbose=verbose,
#             max_iterations=25,
#             max_execution_time=60,
#             agent_executor_kwargs=agent_exec_kwargs,
#         )
#     except TypeError:
#         # Try without time/iteration controls to avoid duplicate kwargs.
#         return create_sql_agent(
#             llm=llm,
#             toolkit=toolkit,
#             verbose=verbose,
#             agent_executor_kwargs=agent_exec_kwargs,
#         )


# def make_sql_agent(settings: Settings, *, db_path: Optional[str] = None):
#     llm = get_llm(settings, temperature=0)

#     sqlite_path = Path(db_path).expanduser().resolve() if db_path else _resolve_sqlite_path(settings)
#     db = _make_sql_db_readonly(sqlite_path)
#     toolkit = SQLDatabaseToolkit(db=db, llm=llm)

#     agent = _create_agent(llm, toolkit, verbose=getattr(settings, "debug", False))
#     return agent, db, str(sqlite_path)


# def sql_answer(settings: Settings, question: str, *, db_path: Optional[str] = None) -> Dict[str, Any]:
#     agent, db, sqlite_path = make_sql_agent(settings, db_path=db_path)

#     # Deterministic shortcut so this never loops.
#     q = (question or "").strip().lower()
#     if any(s in q for s in ["list the tables", "list tables", "show tables", "what tables"]):
#         try:
#             tables = db.get_usable_table_names()
#         except Exception:
#             # fallback for older SQLDatabase implementations
#             tables = []
#         return {
#             "answer": "Tables: " + (", ".join(tables) if tables else "(none found)"),
#             "db_path": sqlite_path,
#         }

#     # Run agent
#     out = agent.invoke({"input": question})

#     # Normalize output
#     if isinstance(out, dict):
#         answer = out.get("output") or out.get("answer") or str(out)
#     else:
#         answer = str(out)

#     return {"answer": answer, "db_path": sqlite_path}