feat: sql tool
Browse files- app.py +154 -12
- requirements.txt +3 -1
app.py
CHANGED
|
@@ -4,6 +4,119 @@ import os
|
|
| 4 |
from huggingface_hub import login
|
| 5 |
login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
import gradio as gr
|
| 9 |
import torch
|
|
@@ -18,12 +131,6 @@ def get_device_type() -> str:
|
|
| 18 |
else:
|
| 19 |
return "cpu"
|
| 20 |
|
| 21 |
-
# # HuggingFaceTB/SmolLM2-135M-Instruct
|
| 22 |
-
# model_name="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" # 15G
|
| 23 |
-
|
| 24 |
-
# "meta-llama/Llama-3.2-3B does not work
|
| 25 |
-
model_name="meta-llama/Llama-3.2-3B-Instruct" # 6.5G
|
| 26 |
-
|
| 27 |
device = get_device_type()
|
| 28 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 29 |
model = AutoModelForCausalLM.from_pretrained(model_name).to(device, dtype=torch.float16)
|
|
@@ -81,6 +188,26 @@ examples = [
|
|
| 81 |
"I am having trouble with my transaction",
|
| 82 |
]
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
def predict(message, history):
|
| 85 |
# Always inject the user message into the system prompt's {input} placeholder
|
| 86 |
sys_prompt = system_prompt.replace("{input}", message)
|
|
@@ -105,17 +232,32 @@ def predict(message, history):
|
|
| 105 |
|
| 106 |
# Extract only the assistant's message (after the last user message)
|
| 107 |
if "<|start_header_id|>assistant<|end_header_id|>" in decoded:
|
| 108 |
-
|
| 109 |
-
|
| 110 |
elif "<|im_start|>assistant" in decoded:
|
| 111 |
# This works for most chat templates that append the assistant's reply at the end
|
| 112 |
-
|
| 113 |
-
|
| 114 |
else:
|
| 115 |
# Fallback: just return the decoded output
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
|
| 118 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
demo = gr.ChatInterface(predict, type="messages", examples=examples)
|
| 121 |
|
|
|
|
| 4 |
from huggingface_hub import login
|
| 5 |
login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))
|
| 6 |
|
| 7 |
+
|
| 8 |
+
# %pip install gradio smolagents sqlalchemy
|
| 9 |
+
|
| 10 |
+
# "meta-llama/Llama-3.2-3B does not work
|
| 11 |
+
model_name="meta-llama/Llama-3.2-3B-Instruct" # 6.5G
|
| 12 |
+
|
| 13 |
+
from sqlalchemy import (
|
| 14 |
+
create_engine,
|
| 15 |
+
MetaData,
|
| 16 |
+
Table,
|
| 17 |
+
Column,
|
| 18 |
+
DECIMAL,
|
| 19 |
+
TEXT,
|
| 20 |
+
insert,
|
| 21 |
+
inspect,
|
| 22 |
+
text,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def insert_rows_into_table(rows, table, engine):
|
| 26 |
+
for row in rows:
|
| 27 |
+
stmt = insert(table).values(**row)
|
| 28 |
+
with engine.begin() as connection:
|
| 29 |
+
connection.execute(stmt)
|
| 30 |
+
|
| 31 |
+
def prepare_payment_table(engine, metadata):
|
| 32 |
+
inspector = inspect(engine)
|
| 33 |
+
table_name = "payments"
|
| 34 |
+
if not inspector.has_table(table_name):
|
| 35 |
+
table = Table(
|
| 36 |
+
table_name,
|
| 37 |
+
metadata,
|
| 38 |
+
Column("id", TEXT, primary_key=True),
|
| 39 |
+
Column("amount", DECIMAL),
|
| 40 |
+
Column("created_at", TEXT),
|
| 41 |
+
)
|
| 42 |
+
metadata.create_all(engine)
|
| 43 |
+
else:
|
| 44 |
+
table = Table(table_name, metadata, autoload_with=engine)
|
| 45 |
+
|
| 46 |
+
rows = [
|
| 47 |
+
{"id": "payment-123", "amount": 100.0, "created_at": "2021-01-01 00:00:00"},
|
| 48 |
+
{"id": "payment-abc-12", "amount": 200.0, "created_at": "2021-01-02 00:00:00"},
|
| 49 |
+
]
|
| 50 |
+
insert_rows_into_table(rows, table, engine)
|
| 51 |
+
|
| 52 |
+
return table
|
| 53 |
+
|
| 54 |
+
def prepare_payout_table(engine, metadata):
|
| 55 |
+
inspector = inspect(engine)
|
| 56 |
+
table_name = "payouts"
|
| 57 |
+
if not inspector.has_table(table_name):
|
| 58 |
+
table = Table(
|
| 59 |
+
table_name,
|
| 60 |
+
metadata,
|
| 61 |
+
Column("id", TEXT, primary_key=True),
|
| 62 |
+
Column("amount", DECIMAL),
|
| 63 |
+
Column("created_at", TEXT),
|
| 64 |
+
)
|
| 65 |
+
metadata.create_all(engine)
|
| 66 |
+
else:
|
| 67 |
+
table = Table(table_name, metadata, autoload_with=engine)
|
| 68 |
+
|
| 69 |
+
rows = [
|
| 70 |
+
{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"},
|
| 71 |
+
{"id": "payout-b2c2", "amount": 100.0, "created_at": "2021-01-02 00:00:00"},
|
| 72 |
+
]
|
| 73 |
+
insert_rows_into_table(rows, table, engine)
|
| 74 |
+
|
| 75 |
+
return table
|
| 76 |
+
|
| 77 |
+
engine = create_engine("sqlite:///:memory:")
|
| 78 |
+
metadata = MetaData()
|
| 79 |
+
|
| 80 |
+
prepare_payment_table(engine, metadata)
|
| 81 |
+
prepare_payout_table(engine, metadata)
|
| 82 |
+
|
| 83 |
+
from smolagents import tool, CodeAgent, InferenceClientModel
|
| 84 |
+
|
| 85 |
+
@tool
|
| 86 |
+
def sql_engine(query: str) -> str:
|
| 87 |
+
"""
|
| 88 |
+
Allows you to perform SQL queries on the table. Returns a string representation of the result.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
query: The query to perform. This should be correct SQL.
|
| 92 |
+
"""
|
| 93 |
+
output = ""
|
| 94 |
+
with engine.connect() as con:
|
| 95 |
+
rows = con.execute(text(query))
|
| 96 |
+
for row in rows:
|
| 97 |
+
output += "\n" + str(row)
|
| 98 |
+
return output
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
agent = CodeAgent(
|
| 102 |
+
tools=[sql_engine],
|
| 103 |
+
model=InferenceClientModel(model_id=model_name),
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
tool_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
|
| 107 |
+
It can use the following tables:"""
|
| 108 |
+
|
| 109 |
+
inspector = inspect(engine)
|
| 110 |
+
for table in ["payments", "payouts"]:
|
| 111 |
+
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]
|
| 112 |
+
|
| 113 |
+
table_description = f"Table '{table}':\n"
|
| 114 |
+
|
| 115 |
+
table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
|
| 116 |
+
tool_description += "\n\n" + table_description
|
| 117 |
+
|
| 118 |
+
print("SQL tool description", tool_description)
|
| 119 |
+
|
| 120 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 121 |
import gradio as gr
|
| 122 |
import torch
|
|
|
|
| 131 |
else:
|
| 132 |
return "cpu"
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
device = get_device_type()
|
| 135 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 136 |
model = AutoModelForCausalLM.from_pretrained(model_name).to(device, dtype=torch.float16)
|
|
|
|
| 188 |
"I am having trouble with my transaction",
|
| 189 |
]
|
| 190 |
|
| 191 |
+
import json, re
|
| 192 |
+
|
| 193 |
+
def extract_transaction_info(response):
|
| 194 |
+
try:
|
| 195 |
+
json_match = re.search(r'\{.*\}', response, re.DOTALL)
|
| 196 |
+
if json_match:
|
| 197 |
+
resp_json = json.loads(json_match.group())
|
| 198 |
+
found = resp_json.get("found", False)
|
| 199 |
+
if found == False:
|
| 200 |
+
return None
|
| 201 |
+
|
| 202 |
+
transaction_id = resp_json.get("transaction_id")
|
| 203 |
+
if transaction_id:
|
| 204 |
+
# Ensure the transaction ID is a string
|
| 205 |
+
return str(transaction_id).strip()
|
| 206 |
+
else:
|
| 207 |
+
return None
|
| 208 |
+
except Exception as e:
|
| 209 |
+
return None
|
| 210 |
+
|
| 211 |
def predict(message, history):
|
| 212 |
# Always inject the user message into the system prompt's {input} placeholder
|
| 213 |
sys_prompt = system_prompt.replace("{input}", message)
|
|
|
|
| 232 |
|
| 233 |
# Extract only the assistant's message (after the last user message)
|
| 234 |
if "<|start_header_id|>assistant<|end_header_id|>" in decoded:
|
| 235 |
+
analysis_response = decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
|
| 236 |
+
analysis_response = analysis_response.replace("<|eot_id|>", "").strip()
|
| 237 |
elif "<|im_start|>assistant" in decoded:
|
| 238 |
# This works for most chat templates that append the assistant's reply at the end
|
| 239 |
+
analysis_response = decoded.split("<|im_start|>assistant")[-1]
|
| 240 |
+
analysis_response = analysis_response.replace("<|im_end|>", "").strip()
|
| 241 |
else:
|
| 242 |
# Fallback: just return the decoded output
|
| 243 |
+
analysis_response = decoded.strip()
|
| 244 |
+
|
| 245 |
+
transaction_id = extract_transaction_info(analysis_response)
|
| 246 |
+
if transaction_id == None:
|
| 247 |
+
return analysis_response
|
| 248 |
|
| 249 |
+
# If we successfully extracted a transaction ID, we can invoke the SQL tool
|
| 250 |
+
sql_response = agent.run(f"find a record with id {transaction_id} from correct table in the database")
|
| 251 |
+
|
| 252 |
+
print(f"SQL response: {sql_response}\n")
|
| 253 |
+
|
| 254 |
+
try:
|
| 255 |
+
analysis_json = json.loads(analysis_response)
|
| 256 |
+
analysis_json["database_result"] = sql_response
|
| 257 |
+
return json.dumps(analysis_json, ensure_ascii=False)
|
| 258 |
+
except Exception:
|
| 259 |
+
# If analysis_response is not valid JSON, return both as plain text
|
| 260 |
+
return analysis_response
|
| 261 |
|
| 262 |
demo = gr.ChatInterface(predict, type="messages", examples=examples)
|
| 263 |
|
requirements.txt
CHANGED
|
@@ -5,4 +5,6 @@ safetensors>=0.4.5
|
|
| 5 |
transformers==4.49.0
|
| 6 |
gradio>=5.23.0
|
| 7 |
datasets
|
| 8 |
-
bitsandbytes
|
|
|
|
|
|
|
|
|
| 5 |
transformers==4.49.0
|
| 6 |
gradio>=5.23.0
|
| 7 |
datasets
|
| 8 |
+
bitsandbytes
|
| 9 |
+
smolagents
|
| 10 |
+
sqlalchemy
|