ivxivx's picture
chore: title and description
e86fb23 unverified
import os
from huggingface_hub import login
login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))
# %pip install gradio smolagents sqlalchemy
# "meta-llama/Llama-3.2-3B does not work
model_name="meta-llama/Llama-3.2-3B-Instruct" # 6.5G
from sqlalchemy import (
create_engine,
MetaData,
Table,
Column,
DECIMAL,
TEXT,
insert,
inspect,
text,
)
def insert_rows_into_table(rows, table, engine):
for row in rows:
stmt = insert(table).values(**row)
with engine.begin() as connection:
connection.execute(stmt)
def prepare_payment_table(engine, metadata):
inspector = inspect(engine)
table_name = "payments"
if not inspector.has_table(table_name):
table = Table(
table_name,
metadata,
Column("id", TEXT, primary_key=True),
Column("amount", DECIMAL),
Column("created_at", TEXT),
)
metadata.create_all(engine)
else:
table = Table(table_name, metadata, autoload_with=engine)
rows = [
{"id": "payment-123", "amount": 100.0, "created_at": "2021-01-01 00:00:00"},
{"id": "payment-abc-12", "amount": 200.0, "created_at": "2021-01-02 00:00:00"},
]
insert_rows_into_table(rows, table, engine)
return table
def prepare_payout_table(engine, metadata):
inspector = inspect(engine)
table_name = "payouts"
if not inspector.has_table(table_name):
table = Table(
table_name,
metadata,
Column("id", TEXT, primary_key=True),
Column("amount", DECIMAL),
Column("created_at", TEXT),
)
metadata.create_all(engine)
else:
table = Table(table_name, metadata, autoload_with=engine)
rows = [
{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"},
{"id": "payout-b2c2", "amount": 100.0, "created_at": "2021-01-02 00:00:00"},
]
insert_rows_into_table(rows, table, engine)
return table
engine = create_engine("sqlite:///:memory:")
metadata = MetaData()
payment_table=prepare_payment_table(engine, metadata)
payout_table=prepare_payout_table(engine, metadata)
from smolagents import tool, CodeAgent, InferenceClientModel
@tool
def sql_engine(query: str) -> str:
"""
Allows you to perform SQL queries on the table. Returns a string representation of the result.
Args:
query: The query to perform. This should be correct SQL.
"""
output = ""
with engine.connect() as con:
rows = con.execute(text(query))
for row in rows:
output += "\n" + str(row)
return output
agent = CodeAgent(
tools=[sql_engine],
model=InferenceClientModel(model_id=model_name),
)
tool_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:"""
inspector = inspect(engine)
for table in ["payments", "payouts"]:
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]
table_description = f"Table '{table}':\n"
table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info])
tool_description += "\n\n" + table_description
# print("SQL tool description", tool_description)
from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch
def get_device_type() -> str:
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
device = get_device_type()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device, dtype=torch.float16)
few_shot_examples = """
Example 1:
USER INPUT: My transaction payment-a1c1 failed
Response:
{"found": true, "transaction_id": "payment-a1c1", "transaction_type": "payment", "justification": "The transaction ID starts with 'payment', is followed by a dash, and contains characters after the dash, matching all extraction rules."}
Example 2:
USER INPUT: Why is my withdrawal payout-b2c2 pending for 3 days
Response:
{"found": true, "transaction_id": "payout-b2c2", "transaction_type": "payout", "justification": "The transaction ID starts with 'payout', is followed by a dash, and contains characters after the dash, matching all extraction rules."}
Example 3:
USER INPUT: I am having trouble with my transaction
Response:
{"found": false, "justification": "No valid transaction ID matching the extraction rules was found in the user input."}
"""
system_prompt = f"""
You are a customer support officer. Extract the transaction ID from the USER INPUT and determine its type.
Extraction rules:
- Look for words starting with 'payout' or 'payment'.
- The next character must be a dash ('-').
- There must be at least one digit or character after the dash.
- The transaction ID must appear exactly in the USER INPUT.
- If found, set found to true; otherwise, set found to false.
Type rules:
- If the transaction ID starts with 'payout', type is payout.
- If it starts with 'payment', type is payment.
{few_shot_examples}
===>USER INPUT BEGINS
{{input}}
<===USER INPUT ENDS
Respond in valid JSON with these fields:
found: (boolean) Whether a valid transaction ID was found.
transaction_id: (string, if found) The extracted transaction ID.
transaction_type: (string, if found) The transaction type in lowercase.
justification: (string) Explain how you determined the transaction ID and type. If not found, do not fabricate an explanation.
Return only valid JSON and nothing else.
"""
examples = [
"My transaction payment-123 failed",
"Why is my withdrawal payout-abc456 pending for 3 days",
"There is an issue with my transaction payout-87l2k3",
"My deposit payment-abc-12 succeeded",
"I am having trouble with my transaction",
]
import json, re
def extract_transaction_info(response):
try:
json_match = re.search(r'\{.*\}', response, re.DOTALL)
if json_match:
resp_json = json.loads(json_match.group())
found = resp_json.get("found", False)
if found == False:
return None, None
transaction_id = resp_json.get("transaction_id")
transaction_type = resp_json.get("transaction_type")
if transaction_id and transaction_type:
return str(transaction_id).strip(), transaction_type.strip()
else:
return None, None
except Exception as e:
return None, None
def predict(message, history):
# Always inject the user message into the system prompt's {input} placeholder
sys_prompt = system_prompt.replace("{input}", message)
if not history or history[0].get("role") != "system":
history = [{"role": "system", "content": sys_prompt}] + history
else:
history[0]["content"] = sys_prompt
history.append({"role": "user", "content": message})
# 1. Build prompt from history using chat template
prompt = tokenizer.apply_chat_template(history, tokenize=False)
# 2. Tokenize prompt for model input
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# 3. Generate response
outputs = model.generate(**inputs, max_new_tokens=100)
# skip_special_tokens=False: we want to keep special tokens for easier parsing
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
# print(f"decoded: {decoded}\n")
# print(f"outputs: {outputs}\n")
# Extract only the assistant's message (after the last user message)
if "<|start_header_id|>assistant<|end_header_id|>" in decoded:
analysis_response = decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
analysis_response = analysis_response.replace("<|eot_id|>", "").strip()
elif "<|im_start|>assistant" in decoded:
# This works for most chat templates that append the assistant's reply at the end
analysis_response = decoded.split("<|im_start|>assistant")[-1]
analysis_response = analysis_response.replace("<|im_end|>", "").strip()
else:
# Fallback: just return the decoded output
analysis_response = decoded.strip()
transaction_id, transaction_type = extract_transaction_info(analysis_response)
if transaction_id == None:
return analysis_response
# payment is required to use the SQL agent, error: Subscribe to PRO to get 20x more monthly
sql_prompt = f"""
Given the '{transaction_type}', find the corresponding table name in the database, the find a record with given id '{transaction_id}' from the table, return the record as JSON. If there is no record, return null.
Example:
Input: transaction_type: payout, transaction_id: payout-abc456
Record in payouts: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}}
Response: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}}
Input: transaction_type: payout, transaction_id: null
Record in payouts: None
Response: null
"""
try:
sql_response = agent.run(sql_prompt)
except Exception as e:
print(f"An error occurred while running the SQL agent: {e}")
try:
analysis_json = json.loads(analysis_response)
analysis_json["database_result"] = "Error running SQL agent: " + str(e)
return json.dumps(analysis_json, ensure_ascii=False)
except Exception:
# If analysis_response is not valid JSON, return both as plain text
return analysis_response
# print(f"SQL response: {sql_response}\n")
try:
analysis_json = json.loads(analysis_response)
analysis_json["database_result"] = sql_response
return json.dumps(analysis_json, ensure_ascii=False)
except Exception:
# If analysis_response is not valid JSON, return both as plain text
return analysis_response
demo = gr.ChatInterface(
predict,
type="messages",
examples=examples,
title="💬 Customer Service Chatbot",
description=(
"This chatbot extracts transaction IDs from your message, determines their type (payment or payout), "
"and retrieves the corresponding record from the database using natural language and SQL tools. "
"Try asking about a transaction, e.g., 'Why is my withdrawal payout-b2c2 pending?'"
),
)
demo.launch(share=True)