SpyroSigma commited on
Commit
5d53755
·
verified ·
1 Parent(s): 2f7d72b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -2
app.py CHANGED
@@ -1,3 +1,105 @@
1
- import gradio as gr
 
2
 
3
- gr.load("models/defog/sqlcoder-7b-2").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ model_name = "defog/sqlcoder-7b-2"
5
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ if available_memory > 15e9:
7
+ # if you have atleast 15GB of GPU memory, run load the model in float16
8
+ model = AutoModelForCausalLM.from_pretrained(
9
+ model_name,
10
+ trust_remote_code=True,
11
+ torch_dtype=torch.float16,
12
+ device_map="auto",
13
+ use_cache=True,
14
+ )
15
+ else:
16
+ # else, load in 8 bits – this is a bit slower
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_name,
19
+ trust_remote_code=True,
20
+ # torch_dtype=torch.float16,
21
+ load_in_8bit=True,
22
+ device_map="auto",
23
+ use_cache=True,
24
+ )
25
+
26
+ prompt = """### Task
27
+ Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
28
+
29
+ ### Instructions
30
+ - If you cannot answer the question with the available database schema, return 'I do not know'
31
+ - Remember that revenue is price multiplied by quantity
32
+ - Remember that cost is supply_price multiplied by quantity
33
+
34
+ ### Database Schema
35
+ This query will run on a database whose schema is represented in this string:
36
+ CREATE TABLE products (
37
+ product_id INTEGER PRIMARY KEY, -- Unique ID for each product
38
+ name VARCHAR(50), -- Name of the product
39
+ price DECIMAL(10,2), -- Price of each unit of the product
40
+ quantity INTEGER -- Current quantity in stock
41
+ );
42
+
43
+ CREATE TABLE customers (
44
+ customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
45
+ name VARCHAR(50), -- Name of the customer
46
+ address VARCHAR(100) -- Mailing address of the customer
47
+ );
48
+
49
+ CREATE TABLE salespeople (
50
+ salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
51
+ name VARCHAR(50), -- Name of the salesperson
52
+ region VARCHAR(50) -- Geographic sales region
53
+ );
54
+
55
+ CREATE TABLE sales (
56
+ sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
57
+ product_id INTEGER, -- ID of product sold
58
+ customer_id INTEGER, -- ID of customer who made purchase
59
+ salesperson_id INTEGER, -- ID of salesperson who made the sale
60
+ sale_date DATE, -- Date the sale occurred
61
+ quantity INTEGER -- Quantity of product sold
62
+ );
63
+
64
+ CREATE TABLE product_suppliers (
65
+ supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
66
+ product_id INTEGER, -- Product ID supplied
67
+ supply_price DECIMAL(10,2) -- Unit price charged by supplier
68
+ );
69
+
70
+ -- sales.product_id can be joined with products.product_id
71
+ -- sales.customer_id can be joined with customers.customer_id
72
+ -- sales.salesperson_id can be joined with salespeople.salesperson_id
73
+ -- product_suppliers.product_id can be joined with products.product_id
74
+
75
+ ### Answer
76
+ Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
77
+ [SQL]
78
+ """
79
+
80
+ import sqlparse
81
+
82
+ def generate_query(question):
83
+ updated_prompt = prompt.format(question=question)
84
+ inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
85
+ generated_ids = model.generate(
86
+ **inputs,
87
+ num_return_sequences=1,
88
+ eos_token_id=tokenizer.eos_token_id,
89
+ pad_token_id=tokenizer.eos_token_id,
90
+ max_new_tokens=400,
91
+ do_sample=False,
92
+ num_beams=1,
93
+ )
94
+ outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
95
+
96
+ torch.cuda.empty_cache()
97
+ torch.cuda.synchronize()
98
+ # empty cache so that you do generate more results w/o memory crashing
99
+ # particularly important on Colab – memory management is much more straightforward
100
+ # when running on an inference service
101
+ return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)
102
+
103
+ question = "What was our revenue by product in the New York region last month?"
104
+ generated_sql = generate_query(question)
105
+ print(generated_sql)