victor7246 commited on
Commit
a837b64
·
verified ·
1 Parent(s): 96f2e73

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +103 -132
utils.py CHANGED
@@ -142,144 +142,97 @@ def extract_question_list(llm, query):
142
  except:
143
  return query
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
147
 
148
  intermediate_steps_KEY = "intermediate_steps"
149
 
150
- class PatchedCreateTable(CreateTable):
151
- def __init__(
152
- self,
153
- element: Table,
154
- include_foreign_key_constraints = None,
155
- if_not_exists: bool = False,
156
- columns_to_ignore: List[str] = None,
157
- ):
158
- if columns_to_ignore is None:
159
- columns_to_ignore = []
160
- element.columns = [col for col in element.columns if col.name not in columns_to_ignore]
161
- super().__init__(element, if_not_exists=if_not_exists)
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
 
164
  def num_tokens_from_string(string: str, encoding_name: str) -> int:
165
  encoding = tiktoken.get_encoding(encoding_name)
166
  num_tokens = len(encoding.encode(string))
167
  return num_tokens
168
 
169
-
170
- class PatchedSQLDatabase(SQLDatabase):
171
- def __init__(self, *args, **kwargs):
172
- super().__init__(*args, **kwargs)
173
-
174
- def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
175
- """Get information about specified tables.
176
-
177
- Follows best practices as specified in: Rajkumar et al, 2022
178
- (https://arxiv.org/abs/2204.00498)
179
-
180
- If `sample_rows_in_table_info`, the specified number of sample rows will be
181
- appended to each table description. This can increase performance as
182
- demonstrated in the paper.
183
- """
184
- all_table_names = self.get_usable_table_names()
185
- if table_names is not None:
186
- missing_tables = set(table_names).difference(all_table_names)
187
- if missing_tables:
188
- print('all_table_names', all_table_names)
189
- raise ValueError(f"table_names {missing_tables} not found in database")
190
- all_table_names = table_names
191
-
192
- meta_tables = [
193
- tbl
194
- for tbl in self._metadata.sorted_tables
195
- if tbl.name in set(all_table_names)
196
- and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
197
- ]
198
-
199
- tables = []
200
- for table in meta_tables:
201
- if self._custom_table_info and table.name in self._custom_table_info:
202
- tables.append(self._custom_table_info[table.name])
203
- continue
204
-
205
- # Ignore JSON datatyped columns
206
- _cols = table.columns
207
- if _cols is dict:
208
- for k, v in _cols.items():
209
- if type(v.type) is NullType:
210
- table._columns.remove(v)
211
-
212
- # add create table command
213
- create_table = str(PatchedCreateTable(
214
- table, columns_to_ignore=[]
215
- ).compile(self._engine))
216
- table_info = ""
217
- # table_info += f"{create_table.rstrip()}"
218
-
219
- has_extra_info = (
220
- self._indexes_in_table_info or self._sample_rows_in_table_info
221
- )
222
- if has_extra_info:
223
- table_info += "\n\n/*"
224
- if self._indexes_in_table_info:
225
- table_info += f"\n{self._get_table_indexes(table)}\n"
226
- if self._sample_rows_in_table_info:
227
- table_info += f"\n{self._get_sample_rows(table)}\n"
228
- if has_extra_info:
229
- table_info += "*/"
230
- table_info += self.get_columns_descriptions(table)
231
-
232
- tables.append(table_info)
233
- tables.sort()
234
- final_str = "\n\n".join(tables)
235
- return final_str
236
-
237
- def _get_sample_rows(self, table: Table) -> str:
238
- # build the select command
239
- command = select(table).order_by(func.random()).limit(self._sample_rows_in_table_info)
240
-
241
- # save the columns in string format
242
- columns_str = ";".join([f'"{col.name}"' for col in table.columns])
243
-
244
- try:
245
- # get the sample rows
246
- with self._engine.connect() as connection:
247
- sample_rows_result = connection.execute(command) # type: ignore
248
- # shorten values in the sample rows
249
- sample_rows = list(
250
- map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
251
- )
252
-
253
- # save the sample rows in string format
254
- sample_rows_str = "\n".join([";".join(row) for row in sample_rows])
255
-
256
- # in some dialects when there are no rows in the table a
257
- # 'ProgrammingError' is returned
258
- except ProgrammingError:
259
- sample_rows_str = ""
260
-
261
- return (
262
- f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
263
- f"{columns_str}\n"
264
- f"{sample_rows_str}"
265
- )
266
-
267
- def _execute(
268
- self,
269
- command: str,
270
- fetch: str = "all",
271
- ):
272
- with self._engine.begin() as connection:
273
- cursor = connection.execute(text(command))
274
- print(cursor.__dict__)
275
-
276
- if cursor.returns_rows:
277
- fields = list(cursor.keys())
278
- result = [dict(zip(fields,row)) for row in cursor.fetchall()]
279
- return result
280
- return []
281
-
282
  def clean_sql(s: str) -> str:
 
 
283
  s = s.replace("```sql", "")
284
  for symb in ["'", '"']:
285
  if s.startswith(symb) and s.endswith(symb):
@@ -296,13 +249,19 @@ def clean_sql(s: str) -> str:
296
  if s.endswith("TOP 1"):
297
  s = s.replace("TOP 1","").strip()
298
  s = s.replace("SELECT","SELECT TOP 1")
 
 
 
299
  return s
300
 
301
  def get_metadata_info(metadata_df, table_names):
302
  str = ""
303
  for table in table_names:
304
- str += "The following metadata is for the table " + table + "\n"
305
- str += metadata_df[metadata_df.table == table].desc.iloc[0]
 
 
 
306
 
307
  return str
308
 
@@ -356,7 +315,12 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
356
  #print ("input key", self.input_key)
357
  #print ("===============")
358
 
359
- input_text = f"{inputs[self.input_key]} \nHistory: {inputs['history']} \nSQLQuery:"
 
 
 
 
 
360
  _run_manager.on_text(input_text, verbose=self.verbose)
361
  # If not present, then defaults to None which is all tables.
362
  table_names_to_use = inputs.get("table_names_to_use")
@@ -364,7 +328,7 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
364
  table_info += get_metadata_info(metadata_df, table_names_to_use)
365
  llm_inputs = {
366
  "input": input_text,
367
- "history": inputs["history"],
368
  "top_k": str(self.top_k),
369
  "dialect": self.database.dialect,
370
  "table_info": table_info,
@@ -426,6 +390,9 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
426
  # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
427
  self.intermediate_steps['query_explanation'] = explanation
428
 
 
 
 
429
  except:
430
  try:
431
  sql_data_new = sql_data[-20:] + sql_data[:20]
@@ -451,11 +418,15 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
451
  ).strip()
452
  # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
453
  self.intermediate_steps['query_explanation'] = explanation
 
 
 
454
  except Exception as exc:
455
  # Append intermediate steps to exception, to aid in logging and later
456
  # improvement of few shot prompt seeds
457
- exc.intermediate_steps = self.intermediate_steps # type: ignore
458
- raise exc
459
-
 
460
 
461
 
 
142
  except:
143
  return query
144
 
145
+ def translate_to_english(llm, user_query):
146
+ sys_prompt = """
147
+ You are an AI assistant that translates a text to English. \
148
+ Do not generate any irrelavant text, only return the translation."""
149
+
150
+ message1 = SystemMessage(content=sys_prompt)
151
+
152
+ message2 = HumanMessage(
153
+ content=user_query
154
+ )
155
+ message_log = [message1, message2]
156
+
157
+ output = llm.invoke(message_log)
158
+ pred = output.content
159
+
160
+ return pred
161
+
162
+ def translate(llm, user_query, to_translate):
163
+ sys_prompt = """
164
+ You are an AI assistant that determines the language given a user query - {} and translate the provided text in that target language. \
165
+ Do not generate any irrelavant text, only return the translation. \
166
+ If the user query is in English, then don't do anything just return the original text, no translation is required there.""".format(user_query)
167
+
168
+ message1 = SystemMessage(content=sys_prompt)
169
+
170
+ message2 = HumanMessage(
171
+ content=to_translate
172
+ )
173
+ message_log = [message1, message2]
174
+
175
+ output = llm.invoke(message_log)
176
+ pred = output.content
177
+
178
+ return pred
179
 
180
  warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
181
 
182
  intermediate_steps_KEY = "intermediate_steps"
183
 
184
+ template = """
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ You are a database expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
187
+ The final answer should be in a concise natural language.
188
+
189
+ Use the history if you can not understand the question.
190
+
191
+ If the question is in another language, translate it to English before proceeding.
192
+
193
+ Do not repeat the question while generating the SQL query.
194
+
195
+ Only generate a correct {dialect} query.
196
+
197
+ Once the SQLResult is available, generate the final answer in natural language format. Do not regenerate the question or SQL query in the final answer.
198
+
199
+ If the question asks any information for any particular number of days, use the lookback from the maximum date in the table, not from today's date.
200
+
201
+ Please note that MSSQL does not use LIMIT, but uses TOP clause.
202
+
203
+ You may also need to resolve the column name, as per the metadata. For instance, if the user asks about families and the column name is family, you should use family in the generated SQL.
204
+
205
+ Make sure that the column names are present in the table, by looking at the metadata.
206
+
207
+ If a question asks about availability over a period of time, you need to use SUM to calculate the total availability over that time period.
208
+
209
+ If a question mentions SKU, then use SKU column for filter, do not use any other column like comodity
210
+
211
+ If a question asks about AV of shortage, do not use AV in the SQL query as AV is not a valid column name. AV is the key in the Shortage column.
212
+
213
+ In the OpenOrderShotage table, the column Item should be used to extract the part ids, to answer questions related to shortage.
214
+
215
+ In the OpenOrderShotage table, Customer_Part_Name column is equivalent to SKU.
216
+
217
+ Use the following format:
218
+ Question: Question here
219
+ SQLQuery: SQL Query to run
220
+ SQLResult: Result of the SQLQuery
221
+ Answer: Final answer here.
222
+
223
+ Only use the following tables:
224
+ {table_info}
225
+ Question: {input}
226
+ """
227
 
228
  def num_tokens_from_string(string: str, encoding_name: str) -> int:
229
  encoding = tiktoken.get_encoding(encoding_name)
230
  num_tokens = len(encoding.encode(string))
231
  return num_tokens
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  def clean_sql(s: str) -> str:
234
+ #s = s.replace("SQL:","").strip()
235
+ #s = s.replace("Let's execute these queries step-by-step to get the final answer.","").strip()
236
  s = s.replace("```sql", "")
237
  for symb in ["'", '"']:
238
  if s.startswith(symb) and s.endswith(symb):
 
249
  if s.endswith("TOP 1"):
250
  s = s.replace("TOP 1","").strip()
251
  s = s.replace("SELECT","SELECT TOP 1")
252
+
253
+ s = s.split("SQLQuery:")[-1].strip()
254
+
255
  return s
256
 
257
  def get_metadata_info(metadata_df, table_names):
258
  str = ""
259
  for table in table_names:
260
+ try:
261
+ str += "The following metadata is for the table " + table + "\n"
262
+ str += metadata_df[metadata_df.table == table].desc.iloc[0]
263
+ except:
264
+ pass
265
 
266
  return str
267
 
 
315
  #print ("input key", self.input_key)
316
  #print ("===============")
317
 
318
+ orig_question = inputs[self.input_key]
319
+ history = inputs['history'].copy()
320
+ history.reverse()
321
+
322
+ inputs[self.input_key] = translate_to_english(self.llms['4k'], inputs[self.input_key])
323
+ input_text = f"{inputs[self.input_key]} \nHistory: {history} \nSQLQuery:"
324
  _run_manager.on_text(input_text, verbose=self.verbose)
325
  # If not present, then defaults to None which is all tables.
326
  table_names_to_use = inputs.get("table_names_to_use")
 
328
  table_info += get_metadata_info(metadata_df, table_names_to_use)
329
  llm_inputs = {
330
  "input": input_text,
331
+ "history": history,
332
  "top_k": str(self.top_k),
333
  "dialect": self.database.dialect,
334
  "table_info": table_info,
 
390
  # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
391
  self.intermediate_steps['query_explanation'] = explanation
392
 
393
+ if 'result' in self.intermediate_steps:
394
+ self.intermediate_steps['translated_result'] = translate(self.llms['4k'], orig_question, self.intermediate_steps['result'])
395
+
396
  except:
397
  try:
398
  sql_data_new = sql_data[-20:] + sql_data[:20]
 
418
  ).strip()
419
  # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
420
  self.intermediate_steps['query_explanation'] = explanation
421
+ if 'result' in self.intermediate_steps:
422
+ self.intermediate_steps['translated_result'] = translate(self.llms['4k'], orig_question, self.intermediate_steps['result'])
423
+
424
  except Exception as exc:
425
  # Append intermediate steps to exception, to aid in logging and later
426
  # improvement of few shot prompt seeds
427
+ #exc.intermediate_steps = self.intermediate_steps # type: ignore
428
+ #raise exc
429
+ self.intermediate_steps['result'] = "I don't know the answer for this."
430
+ self.intermediate_steps['translated_result'] = "I don't know the answer for this."
431
 
432