Spaces:
Runtime error
Runtime error
File size: 3,330 Bytes
5d53755 2f7d72b 5d53755 491b679 5d53755 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 |
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "defog/sqlcoder-7b-2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
# torch_dtype=torch.float16,
load_in_8bit=True,
device_map="auto",
use_cache=True,
)
prompt = """### Task
Generate a SQL query to answer [QUESTION]{question}[/QUESTION]
### Instructions
- If you cannot answer the question with the available database schema, return 'I do not know'
- Remember that revenue is price multiplied by quantity
- Remember that cost is supply_price multiplied by quantity
### Database Schema
This query will run on a database whose schema is represented in this string:
CREATE TABLE products (
product_id INTEGER PRIMARY KEY, -- Unique ID for each product
name VARCHAR(50), -- Name of the product
price DECIMAL(10,2), -- Price of each unit of the product
quantity INTEGER -- Current quantity in stock
);
CREATE TABLE customers (
customer_id INTEGER PRIMARY KEY, -- Unique ID for each customer
name VARCHAR(50), -- Name of the customer
address VARCHAR(100) -- Mailing address of the customer
);
CREATE TABLE salespeople (
salesperson_id INTEGER PRIMARY KEY, -- Unique ID for each salesperson
name VARCHAR(50), -- Name of the salesperson
region VARCHAR(50) -- Geographic sales region
);
CREATE TABLE sales (
sale_id INTEGER PRIMARY KEY, -- Unique ID for each sale
product_id INTEGER, -- ID of product sold
customer_id INTEGER, -- ID of customer who made purchase
salesperson_id INTEGER, -- ID of salesperson who made the sale
sale_date DATE, -- Date the sale occurred
quantity INTEGER -- Quantity of product sold
);
CREATE TABLE product_suppliers (
supplier_id INTEGER PRIMARY KEY, -- Unique ID for each supplier
product_id INTEGER, -- Product ID supplied
supply_price DECIMAL(10,2) -- Unit price charged by supplier
);
-- sales.product_id can be joined with products.product_id
-- sales.customer_id can be joined with customers.customer_id
-- sales.salesperson_id can be joined with salespeople.salesperson_id
-- product_suppliers.product_id can be joined with products.product_id
### Answer
Given the database schema, here is the SQL query that answers [QUESTION]{question}[/QUESTION]
[SQL]
"""
import sqlparse
def generate_query(question):
updated_prompt = prompt.format(question=question)
inputs = tokenizer(updated_prompt, return_tensors="pt").to("cuda")
generated_ids = model.generate(
**inputs,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=400,
do_sample=False,
num_beams=1,
)
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
torch.cuda.empty_cache()
torch.cuda.synchronize()
# empty cache so that you do generate more results w/o memory crashing
# particularly important on Colab – memory management is much more straightforward
# when running on an inference service
return sqlparse.format(outputs[0].split("[SQL]")[-1], reindent=True)
question = "What was our revenue by product in the New York region last month?"
generated_sql = generate_query(question)
print(generated_sql) |