rajaykumar12959 commited on
Commit
4b9a558
·
verified ·
1 Parent(s): 331e226

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +87 -17
README.md CHANGED
@@ -96,37 +96,91 @@ model, tokenizer = FastLanguageModel.from_pretrained(
96
  FastLanguageModel.for_inference(model) # Enable faster inference
97
  ```
98
 
99
- ### Generating SQL Queries
100
 
101
  ```python
102
- def generate_sql(schema, question):
103
- gemma_prompt = """<start_of_turn>user
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.
105
 
106
  ### Schema:
107
- {}
108
 
109
  ### Question:
110
- {}<end_of_turn>
111
  <start_of_turn>model
112
  """
113
 
114
- input_prompt = gemma_prompt.format(schema, question)
115
  inputs = tokenizer([input_prompt], return_tensors="pt").to("cuda")
116
 
117
- outputs = model.generate(**inputs, max_new_tokens=300, use_cache=True)
 
 
 
 
 
 
 
 
 
 
 
 
118
  result = tokenizer.batch_decode(outputs)[0]
 
119
 
120
- # Extract the generated SQL
121
- sql_result = result.split("<start_of_turn>model")[-1].replace("<end_of_turn>", "").strip()
122
- return sql_result
123
  ```
124
 
125
- ### Example: Complex Multi-Table Query
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  ```python
128
  # E-commerce Database Schema
129
- test_sql_context = """
130
  CREATE TABLE users (
131
  user_id INT PRIMARY KEY,
132
  username TEXT,
@@ -157,15 +211,14 @@ CREATE TABLE order_items (
157
  );
158
  """
159
 
160
- # Complex Question
161
- test_question = """
162
  List the usernames and emails of users who have spent more than $500 in total on products
163
  in the 'Electronics' category.
164
  """
165
 
166
- # Generate SQL
167
- sql_query = generate_sql(test_sql_context, test_question)
168
- print(sql_query)
169
  ```
170
 
171
  **Expected Output:**
@@ -180,6 +233,23 @@ GROUP BY u.user_id, u.username, u.email
180
  HAVING SUM(oi.quantity * p.price) > 500;
181
  ```
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  ## Training Details
184
 
185
  ### Dataset
 
96
  FastLanguageModel.for_inference(model) # Enable faster inference
97
  ```
98
 
99
+ ### Inference Function
100
 
101
  ```python
102
+ def inference_text_to_sql(model, tokenizer, schema, question, max_new_tokens=300):
103
+ """
104
+ Perform inference to generate SQL from natural language question and database schema.
105
+
106
+ Args:
107
+ model: Fine-tuned Gemma model
108
+ tokenizer: Model tokenizer
109
+ schema: Database schema as string
110
+ question: Natural language question
111
+ max_new_tokens: Maximum tokens to generate
112
+
113
+ Returns:
114
+ Generated SQL query as string
115
+ """
116
+ # Format the input prompt
117
+ input_prompt = f"""<start_of_turn>user
118
  You are a powerful text-to-SQL model. Your job is to answer questions about a database. You are given a question and context regarding one or more tables.
119
 
120
  ### Schema:
121
+ {schema}
122
 
123
  ### Question:
124
+ {question}<end_of_turn>
125
  <start_of_turn>model
126
  """
127
 
128
+ # Tokenize input
129
  inputs = tokenizer([input_prompt], return_tensors="pt").to("cuda")
130
 
131
+ # Generate output
132
+ with torch.no_grad():
133
+ outputs = model.generate(
134
+ **inputs,
135
+ max_new_tokens=max_new_tokens,
136
+ use_cache=True,
137
+ do_sample=True,
138
+ temperature=0.1, # Low temperature for more deterministic output
139
+ top_p=0.9,
140
+ pad_token_id=tokenizer.eos_token_id
141
+ )
142
+
143
+ # Decode and clean the result
144
  result = tokenizer.batch_decode(outputs)[0]
145
+ sql_query = result.split("<start_of_turn>model")[-1].replace("<end_of_turn>", "").strip()
146
 
147
+ return sql_query
 
 
148
  ```
149
 
150
+ ### Example Usage
151
+
152
+ #### Example 1: Simple Single-Table Query
153
+
154
+ ```python
155
+ # Simple employee database
156
+ simple_schema = """
157
+ CREATE TABLE employees (
158
+ employee_id INT PRIMARY KEY,
159
+ name TEXT,
160
+ department TEXT,
161
+ salary DECIMAL,
162
+ hire_date DATE
163
+ );
164
+ """
165
+
166
+ simple_question = "Find all employees in the 'Engineering' department with salary greater than 75000"
167
+
168
+ sql_result = inference_text_to_sql(model, tokenizer, simple_schema, simple_question)
169
+ print(f"Generated SQL:\n{sql_result}")
170
+ ```
171
+
172
+ **Expected Output:**
173
+ ```sql
174
+ SELECT * FROM employees
175
+ WHERE department = 'Engineering'
176
+ AND salary > 75000;
177
+ ```
178
+
179
+ #### Example 2: Multi-Table JOIN Query
180
 
181
  ```python
182
  # E-commerce Database Schema
183
+ complex_schema = """
184
  CREATE TABLE users (
185
  user_id INT PRIMARY KEY,
186
  username TEXT,
 
211
  );
212
  """
213
 
214
+ # Complex Question requiring 4-table JOIN
215
+ complex_question = """
216
  List the usernames and emails of users who have spent more than $500 in total on products
217
  in the 'Electronics' category.
218
  """
219
 
220
+ sql_result = inference_text_to_sql(model, tokenizer, complex_schema, complex_question)
221
+ print(f"Generated SQL:\n{sql_result}")
 
222
  ```
223
 
224
  **Expected Output:**
 
233
  HAVING SUM(oi.quantity * p.price) > 500;
234
  ```
235
 
236
+ #### Example 3: Aggregation with GROUP BY
237
+
238
+ ```python
239
+ agg_question = "Find the average salary by department for departments with more than 5 employees"
240
+
241
+ sql_result = inference_text_to_sql(model, tokenizer, simple_schema, agg_question)
242
+ print(f"Generated SQL:\n{sql_result}")
243
+ ```
244
+
245
+ **Expected Output:**
246
+ ```sql
247
+ SELECT department, AVG(salary) as avg_salary
248
+ FROM employees
249
+ GROUP BY department
250
+ HAVING COUNT(*) > 5;
251
+ ```
252
+
253
  ## Training Details
254
 
255
  ### Dataset