Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -60,17 +60,23 @@ xls = pd.ExcelFile('SmartClever table explanations.xlsx')
|
|
| 60 |
metadata_df = pd.DataFrame()
|
| 61 |
i = 0
|
| 62 |
sheet_to_df_map = {}
|
| 63 |
-
for sheet_name in xls.sheet_names:
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
table_search = EmbeddingsSearch(metadata_df=metadata_df, emb_model=emb_model)
|
| 76 |
|
|
@@ -93,7 +99,24 @@ def extract_question_type(llm, query):
|
|
| 93 |
return 'specific'
|
| 94 |
else:
|
| 95 |
return 'unknown'
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
|
| 98 |
|
| 99 |
intermediate_steps_KEY = "intermediate_steps"
|
|
@@ -248,7 +271,7 @@ def clean_sql(s: str) -> str:
|
|
| 248 |
s = s.replace("TOP 1","").strip()
|
| 249 |
s = s.replace("SELECT","SELECT TOP 1")
|
| 250 |
return s
|
| 251 |
-
|
| 252 |
class SQLDatabaseChainPatched(SQLDatabaseChain):
|
| 253 |
intermediate_steps: List[Any] = Field(default_factory=list)
|
| 254 |
llms: Dict[str, Any] = Field(default_factory=dict)
|
|
@@ -270,6 +293,7 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
|
|
| 270 |
# get number of tokens in the input prompt
|
| 271 |
selected_inputs = {k: inputs[k] for k in chain.prompt.input_variables}
|
| 272 |
prompt = chain.prompt.format_prompt(**selected_inputs)
|
|
|
|
| 273 |
# https://stackoverflow.com/questions/75804599/openai-api-how-do-i-count-tokens-before-i-send-an-api-request
|
| 274 |
n_tokens = num_tokens_from_string(string=prompt.text, encoding_name='cl100k_base')
|
| 275 |
print(f"N tokens in input: {n_tokens}")
|
|
@@ -297,6 +321,7 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
|
|
| 297 |
# If not present, then defaults to None which is all tables.
|
| 298 |
table_names_to_use = inputs.get("table_names_to_use")
|
| 299 |
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
|
|
|
| 300 |
llm_inputs = {
|
| 301 |
"input": input_text,
|
| 302 |
"history": inputs["history"],
|
|
@@ -319,6 +344,7 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
|
|
| 319 |
|
| 320 |
# list to store estimated num of tokens
|
| 321 |
self.intermediate_steps['n_tokens_list'] = []
|
|
|
|
| 322 |
try:
|
| 323 |
# get sql
|
| 324 |
self.llm_chain, n_tokens1 = self.prepare_llm(llm_inputs, chain=self.llm_chain)
|
|
@@ -360,10 +386,34 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
|
|
| 360 |
# self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
|
| 361 |
self.intermediate_steps['query_explanation'] = explanation
|
| 362 |
|
| 363 |
-
except
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
|
|
|
| 60 |
metadata_df = pd.DataFrame()
|
| 61 |
i = 0
|
| 62 |
sheet_to_df_map = {}
|
| 63 |
+
for k, sheet_name in enumerate(xls.sheet_names):
|
| 64 |
+
if k > 0:
|
| 65 |
+
sheet_to_df_map[sheet_name.strip()] = xls.parse(sheet_name, header=None)
|
| 66 |
+
sheet_to_df_map[sheet_name.strip()].columns = sheet_to_df_map[sheet_name.strip()].iloc[1]
|
| 67 |
+
sheet_to_df_map[sheet_name.strip()] = sheet_to_df_map[sheet_name.strip()].iloc[:1].fillna('')
|
| 68 |
+
sheet_to_df_map[sheet_name.strip()]['metadata'] = sheet_to_df_map[sheet_name.strip()].apply(lambda x: \
|
| 69 |
+
". ".join([x[col] for col in sheet_to_df_map[sheet_name.strip()].columns]), axis=1)
|
| 70 |
+
|
| 71 |
+
metadata_df.loc[i, "table"] = sheet_name.strip()
|
| 72 |
+
metadata_df.loc[i, "desc"] = sheet_to_df_map[sheet_name.strip()]['metadata'].iloc[0]
|
| 73 |
+
|
| 74 |
+
i += 1
|
| 75 |
+
|
| 76 |
+
metadata_df2 = xls.parse('Table explanations',header=1).dropna(axis=0,how='all').dropna(axis=1,how='all')
|
| 77 |
+
metadata_df2.columns = ['table','metadata']
|
| 78 |
+
metadata_df2.table = metadata_df2.table.apply(lambda x: x.strip())
|
| 79 |
+
metadata_df = pd.merge(metadata_df, metadata_df2, how='inner')
|
| 80 |
|
| 81 |
table_search = EmbeddingsSearch(metadata_df=metadata_df, emb_model=emb_model)
|
| 82 |
|
|
|
|
| 99 |
return 'specific'
|
| 100 |
else:
|
| 101 |
return 'unknown'
|
| 102 |
+
|
| 103 |
+
def extract_table_name(query):
|
| 104 |
+
messages = [
|
| 105 |
+
(
|
| 106 |
+
"system",
|
| 107 |
+
"""
|
| 108 |
+
You are an AI assistant that determines the most relevant table name given a user query. Following is the metadata information you need to use to determine the most relevant table.\
|
| 109 |
+
{}.""".format(metadata_df[['table','metadata']].to_string()),
|
| 110 |
+
),
|
| 111 |
+
("human", query),
|
| 112 |
+
]
|
| 113 |
+
output = llm.invoke(messages)
|
| 114 |
+
pred = output.content
|
| 115 |
+
|
| 116 |
+
for table in metadata_df.table.unique():
|
| 117 |
+
if table in pred:
|
| 118 |
+
return table
|
| 119 |
+
|
| 120 |
warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
|
| 121 |
|
| 122 |
intermediate_steps_KEY = "intermediate_steps"
|
|
|
|
| 271 |
s = s.replace("TOP 1","").strip()
|
| 272 |
s = s.replace("SELECT","SELECT TOP 1")
|
| 273 |
return s
|
| 274 |
+
|
| 275 |
class SQLDatabaseChainPatched(SQLDatabaseChain):
|
| 276 |
intermediate_steps: List[Any] = Field(default_factory=list)
|
| 277 |
llms: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
| 293 |
# get number of tokens in the input prompt
|
| 294 |
selected_inputs = {k: inputs[k] for k in chain.prompt.input_variables}
|
| 295 |
prompt = chain.prompt.format_prompt(**selected_inputs)
|
| 296 |
+
#print (prompt)
|
| 297 |
# https://stackoverflow.com/questions/75804599/openai-api-how-do-i-count-tokens-before-i-send-an-api-request
|
| 298 |
n_tokens = num_tokens_from_string(string=prompt.text, encoding_name='cl100k_base')
|
| 299 |
print(f"N tokens in input: {n_tokens}")
|
|
|
|
| 321 |
# If not present, then defaults to None which is all tables.
|
| 322 |
table_names_to_use = inputs.get("table_names_to_use")
|
| 323 |
table_info = self.database.get_table_info(table_names=table_names_to_use)
|
| 324 |
+
table_info += get_metadata_info(metadata_df, table_names_to_use)
|
| 325 |
llm_inputs = {
|
| 326 |
"input": input_text,
|
| 327 |
"history": inputs["history"],
|
|
|
|
| 344 |
|
| 345 |
# list to store estimated num of tokens
|
| 346 |
self.intermediate_steps['n_tokens_list'] = []
|
| 347 |
+
input_text_bkp = input_text
|
| 348 |
try:
|
| 349 |
# get sql
|
| 350 |
self.llm_chain, n_tokens1 = self.prepare_llm(llm_inputs, chain=self.llm_chain)
|
|
|
|
| 386 |
# self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
|
| 387 |
self.intermediate_steps['query_explanation'] = explanation
|
| 388 |
|
| 389 |
+
except:
|
| 390 |
+
try:
|
| 391 |
+
sql_data_new = sql_data[-20:] + sql_data[:20]
|
| 392 |
+
input_text = input_text_bkp + f"{sql_cmd}\nSQLResult: {str(sql_data_new)}\nAnswer:"
|
| 393 |
+
llm_inputs["input"] = input_text
|
| 394 |
+
self.llm_chain, n_tokens3 = self.prepare_llm(llm_inputs, chain=self.llm_chain)
|
| 395 |
+
# self.intermediate_steps['n_tokens_list'].append(n_tokens3)
|
| 396 |
+
final_result = self.llm_chain.predict(
|
| 397 |
+
callbacks=_run_manager.get_child(),
|
| 398 |
+
**llm_inputs,
|
| 399 |
+
).strip()
|
| 400 |
+
# self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
|
| 401 |
+
self.intermediate_steps['result'] = final_result
|
| 402 |
+
|
| 403 |
+
# provide explanation
|
| 404 |
+
input_text += f"{final_result}\nExplanation:"
|
| 405 |
+
llm_inputs["input"] = input_text
|
| 406 |
+
self.llm_chain, n_tokens4 = self.prepare_llm(llm_inputs, chain=self.llm_chain)
|
| 407 |
+
# self.intermediate_steps['n_tokens_list'].append(n_tokens3)
|
| 408 |
+
explanation = self.llm_chain.predict(
|
| 409 |
+
callbacks=_run_manager.get_child(),
|
| 410 |
+
**llm_inputs,
|
| 411 |
+
).strip()
|
| 412 |
+
# self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
|
| 413 |
+
self.intermediate_steps['query_explanation'] = explanation
|
| 414 |
+
except Exception as exc:
|
| 415 |
+
# Append intermediate steps to exception, to aid in logging and later
|
| 416 |
+
# improvement of few shot prompt seeds
|
| 417 |
+
exc.intermediate_steps = self.intermediate_steps # type: ignore
|
| 418 |
+
raise exc
|
| 419 |
|