victor7246 commited on
Commit
1cccdf6
·
verified ·
1 Parent(s): a29605a

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +69 -19
utils.py CHANGED
@@ -60,17 +60,23 @@ xls = pd.ExcelFile('SmartClever table explanations.xlsx')
60
  metadata_df = pd.DataFrame()
61
  i = 0
62
  sheet_to_df_map = {}
63
- for sheet_name in xls.sheet_names:
64
- sheet_to_df_map[sheet_name.strip()] = xls.parse(sheet_name, header=None)
65
- sheet_to_df_map[sheet_name.strip()].columns = sheet_to_df_map[sheet_name.strip()].iloc[1]
66
- sheet_to_df_map[sheet_name.strip()] = sheet_to_df_map[sheet_name.strip()].iloc[:1].fillna('')
67
- sheet_to_df_map[sheet_name.strip()]['metadata'] = sheet_to_df_map[sheet_name.strip()].apply(lambda x: \
68
- ". ".join([x[col] for col in sheet_to_df_map[sheet_name.strip()].columns]), axis=1)
69
-
70
- metadata_df.loc[i, "table"] = sheet_name.strip()
71
- metadata_df.loc[i, "desc"] = sheet_to_df_map[sheet_name.strip()]['metadata'].iloc[0]
72
-
73
- i += 1
 
 
 
 
 
 
74
 
75
  table_search = EmbeddingsSearch(metadata_df=metadata_df, emb_model=emb_model)
76
 
@@ -93,7 +99,24 @@ def extract_question_type(llm, query):
93
  return 'specific'
94
  else:
95
  return 'unknown'
96
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
98
 
99
  intermediate_steps_KEY = "intermediate_steps"
@@ -248,7 +271,7 @@ def clean_sql(s: str) -> str:
248
  s = s.replace("TOP 1","").strip()
249
  s = s.replace("SELECT","SELECT TOP 1")
250
  return s
251
-
252
  class SQLDatabaseChainPatched(SQLDatabaseChain):
253
  intermediate_steps: List[Any] = Field(default_factory=list)
254
  llms: Dict[str, Any] = Field(default_factory=dict)
@@ -270,6 +293,7 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
270
  # get number of tokens in the input prompt
271
  selected_inputs = {k: inputs[k] for k in chain.prompt.input_variables}
272
  prompt = chain.prompt.format_prompt(**selected_inputs)
 
273
  # https://stackoverflow.com/questions/75804599/openai-api-how-do-i-count-tokens-before-i-send-an-api-request
274
  n_tokens = num_tokens_from_string(string=prompt.text, encoding_name='cl100k_base')
275
  print(f"N tokens in input: {n_tokens}")
@@ -297,6 +321,7 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
297
  # If not present, then defaults to None which is all tables.
298
  table_names_to_use = inputs.get("table_names_to_use")
299
  table_info = self.database.get_table_info(table_names=table_names_to_use)
 
300
  llm_inputs = {
301
  "input": input_text,
302
  "history": inputs["history"],
@@ -319,6 +344,7 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
319
 
320
  # list to store estimated num of tokens
321
  self.intermediate_steps['n_tokens_list'] = []
 
322
  try:
323
  # get sql
324
  self.llm_chain, n_tokens1 = self.prepare_llm(llm_inputs, chain=self.llm_chain)
@@ -360,10 +386,34 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
360
  # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
361
  self.intermediate_steps['query_explanation'] = explanation
362
 
363
- except Exception as exc:
364
- # Append intermediate steps to exception, to aid in logging and later
365
- # improvement of few shot prompt seeds
366
- exc.intermediate_steps = self.intermediate_steps # type: ignore
367
- raise exc
368
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
 
 
60
  metadata_df = pd.DataFrame()
61
  i = 0
62
  sheet_to_df_map = {}
63
+ for k, sheet_name in enumerate(xls.sheet_names):
64
+ if k > 0:
65
+ sheet_to_df_map[sheet_name.strip()] = xls.parse(sheet_name, header=None)
66
+ sheet_to_df_map[sheet_name.strip()].columns = sheet_to_df_map[sheet_name.strip()].iloc[1]
67
+ sheet_to_df_map[sheet_name.strip()] = sheet_to_df_map[sheet_name.strip()].iloc[:1].fillna('')
68
+ sheet_to_df_map[sheet_name.strip()]['metadata'] = sheet_to_df_map[sheet_name.strip()].apply(lambda x: \
69
+ ". ".join([x[col] for col in sheet_to_df_map[sheet_name.strip()].columns]), axis=1)
70
+
71
+ metadata_df.loc[i, "table"] = sheet_name.strip()
72
+ metadata_df.loc[i, "desc"] = sheet_to_df_map[sheet_name.strip()]['metadata'].iloc[0]
73
+
74
+ i += 1
75
+
76
+ metadata_df2 = xls.parse('Table explanations',header=1).dropna(axis=0,how='all').dropna(axis=1,how='all')
77
+ metadata_df2.columns = ['table','metadata']
78
+ metadata_df2.table = metadata_df2.table.apply(lambda x: x.strip())
79
+ metadata_df = pd.merge(metadata_df, metadata_df2, how='inner')
80
 
81
  table_search = EmbeddingsSearch(metadata_df=metadata_df, emb_model=emb_model)
82
 
 
99
  return 'specific'
100
  else:
101
  return 'unknown'
102
+
103
+ def extract_table_name(query):
104
+ messages = [
105
+ (
106
+ "system",
107
+ """
108
+ You are an AI assistant that determines the most relevant table name given a user query. Following is the metadata information you need to use to determine the most relevant table.\
109
+ {}.""".format(metadata_df[['table','metadata']].to_string()),
110
+ ),
111
+ ("human", query),
112
+ ]
113
+ output = llm.invoke(messages)
114
+ pred = output.content
115
+
116
+ for table in metadata_df.table.unique():
117
+ if table in pred:
118
+ return table
119
+
120
  warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
121
 
122
  intermediate_steps_KEY = "intermediate_steps"
 
271
  s = s.replace("TOP 1","").strip()
272
  s = s.replace("SELECT","SELECT TOP 1")
273
  return s
274
+
275
  class SQLDatabaseChainPatched(SQLDatabaseChain):
276
  intermediate_steps: List[Any] = Field(default_factory=list)
277
  llms: Dict[str, Any] = Field(default_factory=dict)
 
293
  # get number of tokens in the input prompt
294
  selected_inputs = {k: inputs[k] for k in chain.prompt.input_variables}
295
  prompt = chain.prompt.format_prompt(**selected_inputs)
296
+ #print (prompt)
297
  # https://stackoverflow.com/questions/75804599/openai-api-how-do-i-count-tokens-before-i-send-an-api-request
298
  n_tokens = num_tokens_from_string(string=prompt.text, encoding_name='cl100k_base')
299
  print(f"N tokens in input: {n_tokens}")
 
321
  # If not present, then defaults to None which is all tables.
322
  table_names_to_use = inputs.get("table_names_to_use")
323
  table_info = self.database.get_table_info(table_names=table_names_to_use)
324
+ table_info += get_metadata_info(metadata_df, table_names_to_use)
325
  llm_inputs = {
326
  "input": input_text,
327
  "history": inputs["history"],
 
344
 
345
  # list to store estimated num of tokens
346
  self.intermediate_steps['n_tokens_list'] = []
347
+ input_text_bkp = input_text
348
  try:
349
  # get sql
350
  self.llm_chain, n_tokens1 = self.prepare_llm(llm_inputs, chain=self.llm_chain)
 
386
  # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
387
  self.intermediate_steps['query_explanation'] = explanation
388
 
389
+ except:
390
+ try:
391
+ sql_data_new = sql_data[-20:] + sql_data[:20]
392
+ input_text = input_text_bkp + f"{sql_cmd}\nSQLResult: {str(sql_data_new)}\nAnswer:"
393
+ llm_inputs["input"] = input_text
394
+ self.llm_chain, n_tokens3 = self.prepare_llm(llm_inputs, chain=self.llm_chain)
395
+ # self.intermediate_steps['n_tokens_list'].append(n_tokens3)
396
+ final_result = self.llm_chain.predict(
397
+ callbacks=_run_manager.get_child(),
398
+ **llm_inputs,
399
+ ).strip()
400
+ # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
401
+ self.intermediate_steps['result'] = final_result
402
+
403
+ # provide explanation
404
+ input_text += f"{final_result}\nExplanation:"
405
+ llm_inputs["input"] = input_text
406
+ self.llm_chain, n_tokens4 = self.prepare_llm(llm_inputs, chain=self.llm_chain)
407
+ # self.intermediate_steps['n_tokens_list'].append(n_tokens3)
408
+ explanation = self.llm_chain.predict(
409
+ callbacks=_run_manager.get_child(),
410
+ **llm_inputs,
411
+ ).strip()
412
+ # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
413
+ self.intermediate_steps['query_explanation'] = explanation
414
+ except Exception as exc:
415
+ # Append intermediate steps to exception, to aid in logging and later
416
+ # improvement of few shot prompt seeds
417
+ exc.intermediate_steps = self.intermediate_steps # type: ignore
418
+ raise exc
419