|
|
|
|
|
import os |
|
|
|
|
|
from huggingface_hub import login |
|
|
login(token=os.getenv("HUGGINGFACEHUB_API_KEY")) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name="meta-llama/Llama-3.2-3B-Instruct" |
|
|
|
|
|
from sqlalchemy import ( |
|
|
create_engine, |
|
|
MetaData, |
|
|
Table, |
|
|
Column, |
|
|
DECIMAL, |
|
|
TEXT, |
|
|
insert, |
|
|
inspect, |
|
|
text, |
|
|
) |
|
|
|
|
|
def insert_rows_into_table(rows, table, engine): |
|
|
for row in rows: |
|
|
stmt = insert(table).values(**row) |
|
|
with engine.begin() as connection: |
|
|
connection.execute(stmt) |
|
|
|
|
|
def prepare_payment_table(engine, metadata): |
|
|
inspector = inspect(engine) |
|
|
table_name = "payments" |
|
|
if not inspector.has_table(table_name): |
|
|
table = Table( |
|
|
table_name, |
|
|
metadata, |
|
|
Column("id", TEXT, primary_key=True), |
|
|
Column("amount", DECIMAL), |
|
|
Column("created_at", TEXT), |
|
|
) |
|
|
metadata.create_all(engine) |
|
|
else: |
|
|
table = Table(table_name, metadata, autoload_with=engine) |
|
|
|
|
|
rows = [ |
|
|
{"id": "payment-123", "amount": 100.0, "created_at": "2021-01-01 00:00:00"}, |
|
|
{"id": "payment-abc-12", "amount": 200.0, "created_at": "2021-01-02 00:00:00"}, |
|
|
] |
|
|
insert_rows_into_table(rows, table, engine) |
|
|
|
|
|
return table |
|
|
|
|
|
def prepare_payout_table(engine, metadata): |
|
|
inspector = inspect(engine) |
|
|
table_name = "payouts" |
|
|
if not inspector.has_table(table_name): |
|
|
table = Table( |
|
|
table_name, |
|
|
metadata, |
|
|
Column("id", TEXT, primary_key=True), |
|
|
Column("amount", DECIMAL), |
|
|
Column("created_at", TEXT), |
|
|
) |
|
|
metadata.create_all(engine) |
|
|
else: |
|
|
table = Table(table_name, metadata, autoload_with=engine) |
|
|
|
|
|
rows = [ |
|
|
{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}, |
|
|
{"id": "payout-b2c2", "amount": 100.0, "created_at": "2021-01-02 00:00:00"}, |
|
|
] |
|
|
insert_rows_into_table(rows, table, engine) |
|
|
|
|
|
return table |
|
|
|
|
|
engine = create_engine("sqlite:///:memory:") |
|
|
metadata = MetaData() |
|
|
|
|
|
payment_table=prepare_payment_table(engine, metadata) |
|
|
payout_table=prepare_payout_table(engine, metadata) |
|
|
|
|
|
from smolagents import tool, CodeAgent, InferenceClientModel |
|
|
|
|
|
@tool |
|
|
def sql_engine(query: str) -> str: |
|
|
""" |
|
|
Allows you to perform SQL queries on the table. Returns a string representation of the result. |
|
|
|
|
|
Args: |
|
|
query: The query to perform. This should be correct SQL. |
|
|
""" |
|
|
output = "" |
|
|
with engine.connect() as con: |
|
|
rows = con.execute(text(query)) |
|
|
for row in rows: |
|
|
output += "\n" + str(row) |
|
|
return output |
|
|
|
|
|
|
|
|
agent = CodeAgent( |
|
|
tools=[sql_engine], |
|
|
model=InferenceClientModel(model_id=model_name), |
|
|
) |
|
|
|
|
|
tool_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output. |
|
|
It can use the following tables:""" |
|
|
|
|
|
inspector = inspect(engine) |
|
|
for table in ["payments", "payouts"]: |
|
|
columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)] |
|
|
|
|
|
table_description = f"Table '{table}':\n" |
|
|
|
|
|
table_description += "Columns:\n" + "\n".join([f" - {name}: {col_type}" for name, col_type in columns_info]) |
|
|
tool_description += "\n\n" + table_description |
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import gradio as gr |
|
|
import torch |
|
|
|
|
|
def get_device_type() -> str: |
|
|
if torch.cuda.is_available(): |
|
|
return "cuda" |
|
|
|
|
|
elif torch.backends.mps.is_available(): |
|
|
return "mps" |
|
|
|
|
|
else: |
|
|
return "cpu" |
|
|
|
|
|
device = get_device_type() |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device, dtype=torch.float16) |
|
|
|
|
|
few_shot_examples = """ |
|
|
Example 1: |
|
|
USER INPUT: My transaction payment-a1c1 failed |
|
|
Response: |
|
|
{"found": true, "transaction_id": "payment-a1c1", "transaction_type": "payment", "justification": "The transaction ID starts with 'payment', is followed by a dash, and contains characters after the dash, matching all extraction rules."} |
|
|
|
|
|
Example 2: |
|
|
USER INPUT: Why is my withdrawal payout-b2c2 pending for 3 days |
|
|
Response: |
|
|
{"found": true, "transaction_id": "payout-b2c2", "transaction_type": "payout", "justification": "The transaction ID starts with 'payout', is followed by a dash, and contains characters after the dash, matching all extraction rules."} |
|
|
|
|
|
Example 3: |
|
|
USER INPUT: I am having trouble with my transaction |
|
|
Response: |
|
|
{"found": false, "justification": "No valid transaction ID matching the extraction rules was found in the user input."} |
|
|
""" |
|
|
|
|
|
system_prompt = f""" |
|
|
You are a customer support officer. Extract the transaction ID from the USER INPUT and determine its type. |
|
|
|
|
|
Extraction rules: |
|
|
- Look for words starting with 'payout' or 'payment'. |
|
|
- The next character must be a dash ('-'). |
|
|
- There must be at least one digit or character after the dash. |
|
|
- The transaction ID must appear exactly in the USER INPUT. |
|
|
- If found, set found to true; otherwise, set found to false. |
|
|
|
|
|
Type rules: |
|
|
- If the transaction ID starts with 'payout', type is payout. |
|
|
- If it starts with 'payment', type is payment. |
|
|
|
|
|
{few_shot_examples} |
|
|
|
|
|
===>USER INPUT BEGINS |
|
|
{{input}} |
|
|
<===USER INPUT ENDS |
|
|
|
|
|
Respond in valid JSON with these fields: |
|
|
found: (boolean) Whether a valid transaction ID was found. |
|
|
transaction_id: (string, if found) The extracted transaction ID. |
|
|
transaction_type: (string, if found) The transaction type in lowercase. |
|
|
justification: (string) Explain how you determined the transaction ID and type. If not found, do not fabricate an explanation. |
|
|
Return only valid JSON and nothing else. |
|
|
""" |
|
|
|
|
|
examples = [ |
|
|
"My transaction payment-123 failed", |
|
|
"Why is my withdrawal payout-abc456 pending for 3 days", |
|
|
"There is an issue with my transaction payout-87l2k3", |
|
|
"My deposit payment-abc-12 succeeded", |
|
|
"I am having trouble with my transaction", |
|
|
] |
|
|
|
|
|
import json, re |
|
|
|
|
|
def extract_transaction_info(response): |
|
|
try: |
|
|
json_match = re.search(r'\{.*\}', response, re.DOTALL) |
|
|
if json_match: |
|
|
resp_json = json.loads(json_match.group()) |
|
|
found = resp_json.get("found", False) |
|
|
if found == False: |
|
|
return None, None |
|
|
|
|
|
transaction_id = resp_json.get("transaction_id") |
|
|
transaction_type = resp_json.get("transaction_type") |
|
|
if transaction_id and transaction_type: |
|
|
return str(transaction_id).strip(), transaction_type.strip() |
|
|
else: |
|
|
return None, None |
|
|
except Exception as e: |
|
|
return None, None |
|
|
|
|
|
def predict(message, history): |
|
|
|
|
|
sys_prompt = system_prompt.replace("{input}", message) |
|
|
if not history or history[0].get("role") != "system": |
|
|
history = [{"role": "system", "content": sys_prompt}] + history |
|
|
else: |
|
|
history[0]["content"] = sys_prompt |
|
|
|
|
|
history.append({"role": "user", "content": message}) |
|
|
|
|
|
|
|
|
prompt = tokenizer.apply_chat_template(history, tokenize=False) |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
outputs = model.generate(**inputs, max_new_tokens=100) |
|
|
|
|
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "<|start_header_id|>assistant<|end_header_id|>" in decoded: |
|
|
analysis_response = decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1] |
|
|
analysis_response = analysis_response.replace("<|eot_id|>", "").strip() |
|
|
elif "<|im_start|>assistant" in decoded: |
|
|
|
|
|
analysis_response = decoded.split("<|im_start|>assistant")[-1] |
|
|
analysis_response = analysis_response.replace("<|im_end|>", "").strip() |
|
|
else: |
|
|
|
|
|
analysis_response = decoded.strip() |
|
|
|
|
|
transaction_id, transaction_type = extract_transaction_info(analysis_response) |
|
|
if transaction_id == None: |
|
|
return analysis_response |
|
|
|
|
|
|
|
|
sql_prompt = f""" |
|
|
Given the '{transaction_type}', find the corresponding table name in the database, the find a record with given id '{transaction_id}' from the table, return the record as JSON. If there is no record, return null. |
|
|
|
|
|
Example: |
|
|
Input: transaction_type: payout, transaction_id: payout-abc456 |
|
|
Record in payouts: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}} |
|
|
Response: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}} |
|
|
|
|
|
Input: transaction_type: payout, transaction_id: null |
|
|
Record in payouts: None |
|
|
Response: null |
|
|
""" |
|
|
|
|
|
try: |
|
|
sql_response = agent.run(sql_prompt) |
|
|
except Exception as e: |
|
|
print(f"An error occurred while running the SQL agent: {e}") |
|
|
try: |
|
|
analysis_json = json.loads(analysis_response) |
|
|
analysis_json["database_result"] = "Error running SQL agent: " + str(e) |
|
|
return json.dumps(analysis_json, ensure_ascii=False) |
|
|
except Exception: |
|
|
|
|
|
return analysis_response |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
analysis_json = json.loads(analysis_response) |
|
|
analysis_json["database_result"] = sql_response |
|
|
return json.dumps(analysis_json, ensure_ascii=False) |
|
|
except Exception: |
|
|
|
|
|
return analysis_response |
|
|
|
|
|
demo = gr.ChatInterface( |
|
|
predict, |
|
|
type="messages", |
|
|
examples=examples, |
|
|
title="💬 Customer Service Chatbot", |
|
|
description=( |
|
|
"This chatbot extracts transaction IDs from your message, determines their type (payment or payout), " |
|
|
"and retrieves the corresponding record from the database using natural language and SQL tools. " |
|
|
"Try asking about a transaction, e.g., 'Why is my withdrawal payout-b2c2 pending?'" |
|
|
), |
|
|
) |
|
|
|
|
|
demo.launch(share=True) |
|
|
|
|
|
|