ivxivx commited on
Commit
e523149
·
unverified ·
1 Parent(s): dd73159

feat: sql agent

Browse files
Files changed (1) hide show
  1. app.py +30 -13
app.py CHANGED
@@ -77,8 +77,8 @@ def prepare_payout_table(engine, metadata):
77
  engine = create_engine("sqlite:///:memory:")
78
  metadata = MetaData()
79
 
80
- prepare_payment_table(engine, metadata)
81
- prepare_payout_table(engine, metadata)
82
 
83
  from smolagents import tool, CodeAgent, InferenceClientModel
84
 
@@ -115,7 +115,7 @@ for table in ["payments", "payouts"]:
115
  table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
116
  tool_description += "\n\n" + table_description
117
 
118
- print("SQL tool description", tool_description)
119
 
120
  from transformers import AutoModelForCausalLM, AutoTokenizer
121
  import gradio as gr
@@ -197,16 +197,16 @@ def extract_transaction_info(response):
197
  resp_json = json.loads(json_match.group())
198
  found = resp_json.get("found", False)
199
  if found == False:
200
- return None
201
 
202
  transaction_id = resp_json.get("transaction_id")
203
- if transaction_id:
204
- # Ensure the transaction ID is a string
205
- return str(transaction_id).strip()
206
  else:
207
- return None
208
  except Exception as e:
209
- return None
210
 
211
  def predict(message, history):
212
  # Always inject the user message into the system prompt's {input} placeholder
@@ -242,14 +242,31 @@ def predict(message, history):
242
  # Fallback: just return the decoded output
243
  analysis_response = decoded.strip()
244
 
245
- transaction_id = extract_transaction_info(analysis_response)
246
  if transaction_id == None:
247
  return analysis_response
248
 
249
- # If we successfully extracted a transaction ID, we can invoke the SQL tool
250
- sql_response = agent.run(f"find a record with id {transaction_id} from correct table in the database")
 
251
 
252
- print(f"SQL response: {sql_response}\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  try:
255
  analysis_json = json.loads(analysis_response)
 
77
  engine = create_engine("sqlite:///:memory:")
78
  metadata = MetaData()
79
 
80
+ payment_table=prepare_payment_table(engine, metadata)
81
+ payout_table=prepare_payout_table(engine, metadata)
82
 
83
  from smolagents import tool, CodeAgent, InferenceClientModel
84
 
 
115
  table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
116
  tool_description += "\n\n" + table_description
117
 
118
+ # print("SQL tool description", tool_description)
119
 
120
  from transformers import AutoModelForCausalLM, AutoTokenizer
121
  import gradio as gr
 
197
  resp_json = json.loads(json_match.group())
198
  found = resp_json.get("found", False)
199
  if found == False:
200
+ return None, None
201
 
202
  transaction_id = resp_json.get("transaction_id")
203
+ transaction_type = resp_json.get("transaction_type")
204
+ if transaction_id and transaction_type:
205
+ return str(transaction_id).strip(), transaction_type.strip()
206
  else:
207
+ return None, None
208
  except Exception as e:
209
+ return None, None
210
 
211
  def predict(message, history):
212
  # Always inject the user message into the system prompt's {input} placeholder
 
242
  # Fallback: just return the decoded output
243
  analysis_response = decoded.strip()
244
 
245
+ transaction_id, transaction_type = extract_transaction_info(analysis_response)
246
  if transaction_id == None:
247
  return analysis_response
248
 
249
+ # payment is required to use the SQL agent, error: Subscribe to PRO to get 20x more monthly
250
+ # sql_prompt = f"""
251
+ # Given the '{transaction_type}', find the corresponding table name in the database, the find a record with given id '{transaction_id}' from the table, return the record as JSON. If there is no record, return null.
252
 
253
+ # Example:
254
+ # Input: transaction_type: payout, transaction_id: payout-abc456
255
+ # Record in payouts: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}}
256
+ # Response: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}}
257
+
258
+ # Input: transaction_type: payout, transaction_id: null
259
+ # Record in payouts: None
260
+ # Response: null
261
+ # """
262
+
263
+ # sql_response = agent.run(sql_prompt)
264
+
265
+ # Directly use the sql_engine tool without CodeAgent to avoid payment
266
+ table = payment_table if transaction_type == "payment" else payout_table
267
+ sql_response = table.select().where(table.c.id == '{transaction_id}')
268
+
269
+ # print(f"SQL response: {sql_response}\n")
270
 
271
  try:
272
  analysis_json = json.loads(analysis_response)