File size: 10,624 Bytes
822c123
 
 
 
ee304fe
822c123
dd73159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e523149
 
dd73159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e523149
dd73159
822c123
 
42337a3
 
 
 
 
 
 
 
 
 
 
822c123
42337a3
822c123
582f696
1453054
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
822c123
a7de299
eab7e53
f3c5d0a
eab7e53
 
 
 
 
 
f3c5d0a
eab7e53
 
 
f3c5d0a
1453054
 
f3c5d0a
a7de299
f3c5d0a
 
eab7e53
f3c5d0a
 
 
eab7e53
f3c5d0a
 
822c123
 
1453054
 
822c123
1453054
822c123
 
 
dd73159
 
 
 
 
 
 
 
 
e523149
dd73159
 
e523149
 
 
dd73159
e523149
dd73159
e523149
dd73159
822c123
f9e55e7
 
b1f07de
f9e55e7
 
 
f3c5d0a
822c123
16541ad
b1f07de
 
 
 
 
16541ad
2dca75f
 
16541ad
2dca75f
 
cb75d9f
 
2dca75f
dd73159
 
2dca75f
 
dd73159
 
cb75d9f
 
dd73159
 
e523149
dd73159
 
2dca75f
e523149
6e4167c
 
dd73159
6e4167c
 
 
 
e523149
6e4167c
 
 
 
e523149
6e4167c
 
 
 
 
 
 
 
 
 
 
e523149
 
dd73159
 
 
 
 
 
 
 
822c123
e86fb23
 
 
 
 
 
 
 
 
 
 
822c123
f727d69
822c123
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299

import os

from huggingface_hub import login
login(token=os.getenv("HUGGINGFACEHUB_API_KEY"))


# %pip install gradio smolagents sqlalchemy

# "meta-llama/Llama-3.2-3B does not work
model_name="meta-llama/Llama-3.2-3B-Instruct"   # 6.5G

from sqlalchemy import (
    create_engine,
    MetaData,
    Table,
    Column,
    DECIMAL,
    TEXT,
    insert,
    inspect,
    text,
)

def insert_rows_into_table(rows, table, engine):
    for row in rows:
        stmt = insert(table).values(**row)
        with engine.begin() as connection:
            connection.execute(stmt)

def prepare_payment_table(engine, metadata):
    inspector = inspect(engine)
    table_name = "payments"
    if not inspector.has_table(table_name):
        table = Table(
            table_name,
            metadata,
            Column("id", TEXT, primary_key=True),
            Column("amount", DECIMAL),
            Column("created_at", TEXT),
        )
        metadata.create_all(engine)
    else:
        table = Table(table_name, metadata, autoload_with=engine)

    rows = [
        {"id": "payment-123", "amount": 100.0, "created_at": "2021-01-01 00:00:00"},
        {"id": "payment-abc-12", "amount": 200.0, "created_at": "2021-01-02 00:00:00"},
    ]
    insert_rows_into_table(rows, table, engine)

    return table

def prepare_payout_table(engine, metadata):
    inspector = inspect(engine)
    table_name = "payouts"
    if not inspector.has_table(table_name):
        table = Table(
            table_name,
            metadata,
            Column("id", TEXT, primary_key=True),
            Column("amount", DECIMAL),
            Column("created_at", TEXT),
        )
        metadata.create_all(engine)
    else:
        table = Table(table_name, metadata, autoload_with=engine)

    rows = [
        {"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"},
        {"id": "payout-b2c2", "amount": 100.0, "created_at": "2021-01-02 00:00:00"},
    ]
    insert_rows_into_table(rows, table, engine)

    return table

engine = create_engine("sqlite:///:memory:")
metadata = MetaData()

payment_table=prepare_payment_table(engine, metadata)
payout_table=prepare_payout_table(engine, metadata)

from smolagents import tool, CodeAgent, InferenceClientModel

@tool
def sql_engine(query: str) -> str:
    """
    Allows you to perform SQL queries on the table. Returns a string representation of the result.

    Args:
        query: The query to perform. This should be correct SQL.
    """
    output = ""
    with engine.connect() as con:
        rows = con.execute(text(query))
        for row in rows:
            output += "\n" + str(row)
    return output


agent = CodeAgent(
    tools=[sql_engine],
    model=InferenceClientModel(model_id=model_name),
)

tool_description = """Allows you to perform SQL queries on the table. Beware that this tool's output is a string representation of the execution output.
It can use the following tables:"""

inspector = inspect(engine)
for table in ["payments", "payouts"]:
    columns_info = [(col["name"], col["type"]) for col in inspector.get_columns(table)]

    table_description = f"Table '{table}':\n"

    table_description += "Columns:\n" + "\n".join([f"  - {name}: {col_type}" for name, col_type in columns_info])
    tool_description += "\n\n" + table_description

# print("SQL tool description", tool_description)

from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
import torch

def get_device_type() -> str:
    if torch.cuda.is_available():
        return "cuda"

    elif torch.backends.mps.is_available():
        return "mps"

    else:
        return "cpu"

device = get_device_type()
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device, dtype=torch.float16)

few_shot_examples = """
Example 1:
USER INPUT: My transaction payment-a1c1 failed
Response:
{"found": true, "transaction_id": "payment-a1c1", "transaction_type": "payment", "justification": "The transaction ID starts with 'payment', is followed by a dash, and contains characters after the dash, matching all extraction rules."}

Example 2:
USER INPUT: Why is my withdrawal payout-b2c2 pending for 3 days
Response:
{"found": true, "transaction_id": "payout-b2c2", "transaction_type": "payout", "justification": "The transaction ID starts with 'payout', is followed by a dash, and contains characters after the dash, matching all extraction rules."}

Example 3:
USER INPUT: I am having trouble with my transaction
Response:
{"found": false, "justification": "No valid transaction ID matching the extraction rules was found in the user input."}
"""

system_prompt = f"""
You are a customer support officer. Extract the transaction ID from the USER INPUT and determine its type.

Extraction rules:
- Look for words starting with 'payout' or 'payment'.
- The next character must be a dash ('-').
- There must be at least one digit or character after the dash.
- The transaction ID must appear exactly in the USER INPUT.
- If found, set found to true; otherwise, set found to false.

Type rules:
- If the transaction ID starts with 'payout', type is payout.
- If it starts with 'payment', type is payment.

{few_shot_examples}

===>USER INPUT BEGINS
{{input}}
<===USER INPUT ENDS

Respond in valid JSON with these fields:
found: (boolean) Whether a valid transaction ID was found.
transaction_id: (string, if found) The extracted transaction ID.
transaction_type: (string, if found) The transaction type in lowercase.
justification: (string) Explain how you determined the transaction ID and type. If not found, do not fabricate an explanation.
Return only valid JSON and nothing else.
"""

examples = [
    "My transaction payment-123 failed", 
    "Why is my withdrawal payout-abc456 pending for 3 days", 
    "There is an issue with my transaction payout-87l2k3",
    "My deposit payment-abc-12 succeeded", 
    "I am having trouble with my transaction",
]

import json, re

def extract_transaction_info(response):
    try:
        json_match = re.search(r'\{.*\}', response, re.DOTALL)
        if json_match:
            resp_json = json.loads(json_match.group())
            found = resp_json.get("found", False)
            if found == False:
                return None, None
            
            transaction_id = resp_json.get("transaction_id")
            transaction_type = resp_json.get("transaction_type")
            if transaction_id and transaction_type:
                return str(transaction_id).strip(), transaction_type.strip()
            else:
                return None, None
    except Exception as e:
        return None, None

def predict(message, history):
    # Always inject the user message into the system prompt's {input} placeholder
    sys_prompt = system_prompt.replace("{input}", message)
    if not history or history[0].get("role") != "system":
        history = [{"role": "system", "content": sys_prompt}] + history
    else:
        history[0]["content"] = sys_prompt

    history.append({"role": "user", "content": message})

    # 1. Build prompt from history using chat template
    prompt = tokenizer.apply_chat_template(history, tokenize=False)
    # 2. Tokenize prompt for model input
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    # 3. Generate response
    outputs = model.generate(**inputs, max_new_tokens=100)
    # skip_special_tokens=False: we want to keep special tokens for easier parsing
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)

    # print(f"decoded: {decoded}\n")
    # print(f"outputs: {outputs}\n")
    
    # Extract only the assistant's message (after the last user message)
    if "<|start_header_id|>assistant<|end_header_id|>" in decoded:
        analysis_response = decoded.split("<|start_header_id|>assistant<|end_header_id|>")[-1]
        analysis_response = analysis_response.replace("<|eot_id|>", "").strip()
    elif "<|im_start|>assistant" in decoded:
        # This works for most chat templates that append the assistant's reply at the end
        analysis_response = decoded.split("<|im_start|>assistant")[-1]
        analysis_response = analysis_response.replace("<|im_end|>", "").strip()
    else:
        # Fallback: just return the decoded output
        analysis_response = decoded.strip()

    transaction_id, transaction_type = extract_transaction_info(analysis_response)
    if transaction_id == None:
        return analysis_response
        
    # payment is required to use the SQL agent, error: Subscribe to PRO to get 20x more monthly
    sql_prompt = f"""
        Given the '{transaction_type}', find the corresponding table name in the database, the find a record with given id '{transaction_id}' from the table, return the record as JSON. If there is no record, return null.

        Example:
        Input: transaction_type: payout, transaction_id: payout-abc456
        Record in payouts: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}}
        Response: {{"id": "payout-abc456", "amount": 50.0, "created_at": "2021-01-01 00:00:00"}}

        Input: transaction_type: payout, transaction_id: null
        Record in payouts: None
        Response: null
    """

    try:
        sql_response = agent.run(sql_prompt)
    except Exception as e:
        print(f"An error occurred while running the SQL agent: {e}")
        try:
            analysis_json = json.loads(analysis_response)
            analysis_json["database_result"] = "Error running SQL agent: " + str(e)
            return json.dumps(analysis_json, ensure_ascii=False)
        except Exception:
            # If analysis_response is not valid JSON, return both as plain text
            return analysis_response

    # print(f"SQL response: {sql_response}\n")

    try:
        analysis_json = json.loads(analysis_response)
        analysis_json["database_result"] = sql_response
        return json.dumps(analysis_json, ensure_ascii=False)
    except Exception:
        # If analysis_response is not valid JSON, return both as plain text
        return analysis_response

demo = gr.ChatInterface(
    predict,
    type="messages",
    examples=examples,
    title="💬 Customer Service Chatbot",
    description=(
        "This chatbot extracts transaction IDs from your message, determines their type (payment or payout), "
        "and retrieves the corresponding record from the database using natural language and SQL tools. "
        "Try asking about a transaction, e.g., 'Why is my withdrawal payout-b2c2 pending?'"
    ),
)

demo.launch(share=True)