ivxivx commited on
Commit
dd73159
·
unverified ·
1 Parent(s): a7de299

feat: sql tool

Browse files
Files changed (2) hide show
  1. app.py +154 -12
  2. requirements.txt +3 -1
app.py CHANGED
@@ -4,6 +4,119 @@ import os
4
  from huggingface_hub import login
5
  login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  import gradio as gr
9
  import torch
@@ -18,12 +131,6 @@ def get_device_type() -> str:
18
  else:
19
  return "cpu"
20
 
21
- # # HuggingFaceTB/SmolLM2-135M-Instruct
22
- # model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" # 15G
23
-
24
- # "meta-llama/Llama-3.2-3B does not work
25
- model_name="meta-llama/Llama-3.2-3B-Instruct" # 6.5G
26
-
27
  device = get_device_type()
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
  model = AutoModelForCausalLM.from_pretrained(model_name).to(device, dtype=torch.float16)
@@ -81,6 +188,26 @@ examples = [
81
  "I am having trouble with my transaction",
82
  ]
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  def predict(message, history):
85
  # Always inject the user message into the system prompt's {input} placeholder
86
  sys_prompt = system_prompt.replace("{input}", message)
@@ -105,17 +232,32 @@ def predict(message, history):
105
 
106
  # Extract only the assistant's message (after the last user message)
107
  if "<|start_header_id|>assistant<|end_header_id|>" in decoded:
108
- response = decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
109
- response = response.replace("<|eot_id|>", "").strip()
110
  elif "<|im_start|>assistant" in decoded:
111
  # This works for most chat templates that append the assistant's reply at the end
112
- response = decoded.split("<|im_start|>assistant")[-1]
113
- response = response.replace("<|im_end|>", "").strip()
114
  else:
115
  # Fallback: just return the decoded output
116
- response = decoded.strip()
 
 
 
 
117
 
118
- return response
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  demo = gr.ChatInterface(predict, type="messages", examples=examples)
121
 
 
4
  from huggingface_hub import login
5
  login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))
6
 
7
+
8
+ # %pip install gradio smolagents sqlalchemy
9
+
10
+ # "meta-llama/Llama-3.2-3B does not work
11
+ model_name="meta-llama/Llama-3.2-3B-Instruct" # 6.5G
12
+
13
+ from sqlalchemy import (
14
+ create_engine,
15
+ MetaData,
16
+ Table,
17
+ Column,
18
+ DECIMAL,
19
+ TEXT,
20
+ insert,
21
+ inspect,
22
+ text,
23
+ )
24
+
25
+ def insert_rows_into_table(rows, table, engine):
26
+ for row in rows:
27
+ stmt = insert(table).values(**row)
28
+ with engine.begin() as connection:
29
+ connection.execute(stmt)
30
+
31
+ def prepare_payment_table(engine, metadata):
32
+ inspector = inspect(engine)
33
+ table_name = "payments"
34
+ if not inspector.has_table(table_name):
35
+ table = Table(
36
+ table_name,
37
+ metadata,
38
+ Column("id", TEXT, primary_key=True),
39
+ Column("amount", DECIMAL),
40
+ Column("created_at", TEXT),
41
+ )
42
+ metadata.create_all(engine)
43
+ else:
44
+ table = Table(table_name, metadata, autoload_with=engine)
45
+
46
+ rows = [
47
+ {"id": "payment-123", "amount": 100.0, "created_at": "2021-01-01 00:00:00"},
48
+ {"id": "payment-abc-12", "amount": 200.0, "created_at": "2021-01-02 00:00:00"},
49
+ ]
50
+ insert_rows_into_table(rows, table, engine)
51
+
52
+ return table
53
+
54
+ def prepare_payout_table(engine, metadata):
55
+ inspector = inspect(engine)
56
+ table_name = "payouts"
57
+ if not inspector.has_table(table_name):
58
+ table = Table(
59
+ table_name,
60
+ metadata,
61
+ Column("id", TEXT, primary_key=True),
62
+ Column("amount", DECIMAL),
63
+ Column("created_at", TEXT),
64
+ )
65
+ metadata.create_all(engine)
66
+ else:
67
+ table = Table(table_name, metadata, autoload_with=engine)
68
+
69
+ rows = [
70
+ {"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"},
71
+ {"id": "payout-b2c2", "amount": 100.0, "created_at": "2021-01-02 00:00:00"},
72
+ ]
73
+ insert_rows_into_table(rows, table, engine)
74
+
75
+ return table
76
+
77
+ engine = create_engine("sqlite:///:memory:")
78
+ metadata = MetaData()
79
+
80
+ prepare_payment_table(engine, metadata)
81
+ prepare_payout_table(engine, metadata)
82
+
83
+ from smolagents import tool, CodeAgent, InferenceClientModel
84
+
85
+ @tool
86
+ def sql_engine(query: str) -> str:
87
+ """
88
+ Allows you to perform SQL queries on the table. Returns a string representation of the result.
89
+
90
+ Args:
91
+ query: The query to perform. This should be correct SQL.
92
+ """
93
+ output = ""
94
+ with engine.connect() as con:
95
+ rows = con.execute(text(query))
96
+ for row in rows:
97
+ output += "\n" + str(row)
98
+ return output
99
+
100
+
101
+ agent = CodeAgent(
102
+ tools=[sql_engine],
103
+ model=InferenceClientModel(model_id=model_name),
104
+ )
105
+
106
+ tool_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
107
+ It can use the following tables:"""
108
+
109
+ inspector = inspect(engine)
110
+ for table in ["payments", "payouts"]:
111
+ columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]
112
+
113
+ table_description = f"Table '{table}':\n"
114
+
115
+ table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
116
+ tool_description += "\n\n" + table_description
117
+
118
+ print("SQL tool description", tool_description)
119
+
120
  from transformers import AutoModelForCausalLM, AutoTokenizer
121
  import gradio as gr
122
  import torch
 
131
  else:
132
  return "cpu"
133
 
 
 
 
 
 
 
134
  device = get_device_type()
135
  tokenizer = AutoTokenizer.from_pretrained(model_name)
136
  model = AutoModelForCausalLM.from_pretrained(model_name).to(device, dtype=torch.float16)
 
188
  "I am having trouble with my transaction",
189
  ]
190
 
191
+ import json, re
192
+
193
+ def extract_transaction_info(response):
194
+ try:
195
+ json_match = re.search(r'\{.*\}', response, re.DOTALL)
196
+ if json_match:
197
+ resp_json = json.loads(json_match.group())
198
+ found = resp_json.get("found", False)
199
+ if found == False:
200
+ return None
201
+
202
+ transaction_id = resp_json.get("transaction_id")
203
+ if transaction_id:
204
+ # Ensure the transaction ID is a string
205
+ return str(transaction_id).strip()
206
+ else:
207
+ return None
208
+ except Exception as e:
209
+ return None
210
+
211
  def predict(message, history):
212
  # Always inject the user message into the system prompt's {input} placeholder
213
  sys_prompt = system_prompt.replace("{input}", message)
 
232
 
233
  # Extract only the assistant's message (after the last user message)
234
  if "<|start_header_id|>assistant<|end_header_id|>" in decoded:
235
+ analysis_response = decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
236
+ analysis_response = analysis_response.replace("<|eot_id|>", "").strip()
237
  elif "<|im_start|>assistant" in decoded:
238
  # This works for most chat templates that append the assistant's reply at the end
239
+ analysis_response = decoded.split("<|im_start|>assistant")[-1]
240
+ analysis_response = analysis_response.replace("<|im_end|>", "").strip()
241
  else:
242
  # Fallback: just return the decoded output
243
+ analysis_response = decoded.strip()
244
+
245
+ transaction_id = extract_transaction_info(analysis_response)
246
+ if transaction_id == None:
247
+ return analysis_response
248
 
249
+ # If we successfully extracted a transaction ID, we can invoke the SQL tool
250
+ sql_response = agent.run(f"find a record with id {transaction_id} from correct table in the database")
251
+
252
+ print(f"SQL response: {sql_response}\n")
253
+
254
+ try:
255
+ analysis_json = json.loads(analysis_response)
256
+ analysis_json["database_result"] = sql_response
257
+ return json.dumps(analysis_json, ensure_ascii=False)
258
+ except Exception:
259
+ # If analysis_response is not valid JSON, return both as plain text
260
+ return analysis_response
261
 
262
  demo = gr.ChatInterface(predict, type="messages", examples=examples)
263
 
requirements.txt CHANGED
@@ -5,4 +5,6 @@ safetensors>=0.4.5
5
  transformers==4.49.0
6
  gradio>=5.23.0
7
  datasets
8
- bitsandbytes
 
 
 
5
  transformers==4.49.0
6
  gradio>=5.23.0
7
  datasets
8
+ bitsandbytes
9
+ smolagents
10
+ sqlalchemy