feat: sql agent
Browse files
app.py
CHANGED
|
@@ -77,8 +77,8 @@ def prepare_payout_table(engine, metadata):
|
|
| 77 |
engine = create_engine("sqlite:///:memory:")
|
| 78 |
metadata = MetaData()
|
| 79 |
|
| 80 |
-
prepare_payment_table(engine, metadata)
|
| 81 |
-
prepare_payout_table(engine, metadata)
|
| 82 |
|
| 83 |
from smolagents import tool, CodeAgent, InferenceClientModel
|
| 84 |
|
|
@@ -115,7 +115,7 @@ for table in ["payments", "payouts"]:
|
|
| 115 |
table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
|
| 116 |
tool_description += "\n\n" + table_description
|
| 117 |
|
| 118 |
-
print("SQL tool description", tool_description)
|
| 119 |
|
| 120 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 121 |
import gradio as gr
|
|
@@ -197,16 +197,16 @@ def extract_transaction_info(response):
|
|
| 197 |
resp_json = json.loads(json_match.group())
|
| 198 |
found = resp_json.get("found", False)
|
| 199 |
if found == False:
|
| 200 |
-
return None
|
| 201 |
|
| 202 |
transaction_id = resp_json.get("transaction_id")
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
return str(transaction_id).strip()
|
| 206 |
else:
|
| 207 |
-
return None
|
| 208 |
except Exception as e:
|
| 209 |
-
return None
|
| 210 |
|
| 211 |
def predict(message, history):
|
| 212 |
# Always inject the user message into the system prompt's {input} placeholder
|
|
@@ -242,14 +242,31 @@ def predict(message, history):
|
|
| 242 |
# Fallback: just return the decoded output
|
| 243 |
analysis_response = decoded.strip()
|
| 244 |
|
| 245 |
-
transaction_id = extract_transaction_info(analysis_response)
|
| 246 |
if transaction_id == None:
|
| 247 |
return analysis_response
|
| 248 |
|
| 249 |
-
#
|
| 250 |
-
|
|
|
|
| 251 |
|
| 252 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
try:
|
| 255 |
analysis_json = json.loads(analysis_response)
|
|
|
|
| 77 |
engine = create_engine("sqlite:///:memory:")
|
| 78 |
metadata = MetaData()
|
| 79 |
|
| 80 |
+
payment_table=prepare_payment_table(engine, metadata)
|
| 81 |
+
payout_table=prepare_payout_table(engine, metadata)
|
| 82 |
|
| 83 |
from smolagents import tool, CodeAgent, InferenceClientModel
|
| 84 |
|
|
|
|
| 115 |
table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
|
| 116 |
tool_description += "\n\n" + table_description
|
| 117 |
|
| 118 |
+
# print("SQL tool description", tool_description)
|
| 119 |
|
| 120 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 121 |
import gradio as gr
|
|
|
|
| 197 |
resp_json = json.loads(json_match.group())
|
| 198 |
found = resp_json.get("found", False)
|
| 199 |
if found == False:
|
| 200 |
+
return None, None
|
| 201 |
|
| 202 |
transaction_id = resp_json.get("transaction_id")
|
| 203 |
+
transaction_type = resp_json.get("transaction_type")
|
| 204 |
+
if transaction_id and transaction_type:
|
| 205 |
+
return str(transaction_id).strip(), transaction_type.strip()
|
| 206 |
else:
|
| 207 |
+
return None, None
|
| 208 |
except Exception as e:
|
| 209 |
+
return None, None
|
| 210 |
|
| 211 |
def predict(message, history):
|
| 212 |
# Always inject the user message into the system prompt's {input} placeholder
|
|
|
|
| 242 |
# Fallback: just return the decoded output
|
| 243 |
analysis_response = decoded.strip()
|
| 244 |
|
| 245 |
+
transaction_id, transaction_type = extract_transaction_info(analysis_response)
|
| 246 |
if transaction_id == None:
|
| 247 |
return analysis_response
|
| 248 |
|
| 249 |
+
# payment is required to use the SQL agent, error: Subscribe to PRO to get 20x more monthly
|
| 250 |
+
# sql_prompt = f"""
|
| 251 |
+
# Given the '{transaction_type}', find the corresponding table name in the database, the find a record with given id '{transaction_id}' from the table, return the record as JSON. If there is no record, return null.
|
| 252 |
|
| 253 |
+
# Example:
|
| 254 |
+
# Input: transaction_type: payout, transaction_id: payout-abc456
|
| 255 |
+
# Record in payouts: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}}
|
| 256 |
+
# Response: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}}
|
| 257 |
+
|
| 258 |
+
# Input: transaction_type: payout, transaction_id: null
|
| 259 |
+
# Record in payouts: None
|
| 260 |
+
# Response: null
|
| 261 |
+
# """
|
| 262 |
+
|
| 263 |
+
# sql_response = agent.run(sql_prompt)
|
| 264 |
+
|
| 265 |
+
# Directly use the sql_engine tool without CodeAgent to avoid payment
|
| 266 |
+
table = payment_table if transaction_type == "payment" else payout_table
|
| 267 |
+
sql_response = table.select().where(table.c.id == '{transaction_id}')
|
| 268 |
+
|
| 269 |
+
# print(f"SQL response: {sql_response}\n")
|
| 270 |
|
| 271 |
try:
|
| 272 |
analysis_json = json.loads(analysis_response)
|