Spaces:
Sleeping
Sleeping
Update utils.py
Browse files
utils.py
CHANGED
|
@@ -142,144 +142,97 @@ def extract_question_list(llm, query):
|
|
| 142 |
except:
|
| 143 |
return query
|
| 144 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
|
| 146 |
warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
|
| 147 |
|
| 148 |
intermediate_steps_KEY = "intermediate_steps"
|
| 149 |
|
| 150 |
-
|
| 151 |
-
def __init__(
|
| 152 |
-
self,
|
| 153 |
-
element: Table,
|
| 154 |
-
include_foreign_key_constraints = None,
|
| 155 |
-
if_not_exists: bool = False,
|
| 156 |
-
columns_to_ignore: List[str] = None,
|
| 157 |
-
):
|
| 158 |
-
if columns_to_ignore is None:
|
| 159 |
-
columns_to_ignore = []
|
| 160 |
-
element.columns = [col for col in element.columns if col.name not in columns_to_ignore]
|
| 161 |
-
super().__init__(element, if_not_exists=if_not_exists)
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
| 165 |
encoding = tiktoken.get_encoding(encoding_name)
|
| 166 |
num_tokens = len(encoding.encode(string))
|
| 167 |
return num_tokens
|
| 168 |
|
| 169 |
-
|
| 170 |
-
class PatchedSQLDatabase(SQLDatabase):
|
| 171 |
-
def __init__(self, *args, **kwargs):
|
| 172 |
-
super().__init__(*args, **kwargs)
|
| 173 |
-
|
| 174 |
-
def get_table_info(self, table_names: Optional[List[str]] = None) -> str:
|
| 175 |
-
"""Get information about specified tables.
|
| 176 |
-
|
| 177 |
-
Follows best practices as specified in: Rajkumar et al, 2022
|
| 178 |
-
(https://arxiv.org/abs/2204.00498)
|
| 179 |
-
|
| 180 |
-
If `sample_rows_in_table_info`, the specified number of sample rows will be
|
| 181 |
-
appended to each table description. This can increase performance as
|
| 182 |
-
demonstrated in the paper.
|
| 183 |
-
"""
|
| 184 |
-
all_table_names = self.get_usable_table_names()
|
| 185 |
-
if table_names is not None:
|
| 186 |
-
missing_tables = set(table_names).difference(all_table_names)
|
| 187 |
-
if missing_tables:
|
| 188 |
-
print('all_table_names', all_table_names)
|
| 189 |
-
raise ValueError(f"table_names {missing_tables} not found in database")
|
| 190 |
-
all_table_names = table_names
|
| 191 |
-
|
| 192 |
-
meta_tables = [
|
| 193 |
-
tbl
|
| 194 |
-
for tbl in self._metadata.sorted_tables
|
| 195 |
-
if tbl.name in set(all_table_names)
|
| 196 |
-
and not (self.dialect == "sqlite" and tbl.name.startswith("sqlite_"))
|
| 197 |
-
]
|
| 198 |
-
|
| 199 |
-
tables = []
|
| 200 |
-
for table in meta_tables:
|
| 201 |
-
if self._custom_table_info and table.name in self._custom_table_info:
|
| 202 |
-
tables.append(self._custom_table_info[table.name])
|
| 203 |
-
continue
|
| 204 |
-
|
| 205 |
-
# Ignore JSON datatyped columns
|
| 206 |
-
_cols = table.columns
|
| 207 |
-
if _cols is dict:
|
| 208 |
-
for k, v in _cols.items():
|
| 209 |
-
if type(v.type) is NullType:
|
| 210 |
-
table._columns.remove(v)
|
| 211 |
-
|
| 212 |
-
# add create table command
|
| 213 |
-
create_table = str(PatchedCreateTable(
|
| 214 |
-
table, columns_to_ignore=[]
|
| 215 |
-
).compile(self._engine))
|
| 216 |
-
table_info = ""
|
| 217 |
-
# table_info += f"{create_table.rstrip()}"
|
| 218 |
-
|
| 219 |
-
has_extra_info = (
|
| 220 |
-
self._indexes_in_table_info or self._sample_rows_in_table_info
|
| 221 |
-
)
|
| 222 |
-
if has_extra_info:
|
| 223 |
-
table_info += "\n\n/*"
|
| 224 |
-
if self._indexes_in_table_info:
|
| 225 |
-
table_info += f"\n{self._get_table_indexes(table)}\n"
|
| 226 |
-
if self._sample_rows_in_table_info:
|
| 227 |
-
table_info += f"\n{self._get_sample_rows(table)}\n"
|
| 228 |
-
if has_extra_info:
|
| 229 |
-
table_info += "*/"
|
| 230 |
-
table_info += self.get_columns_descriptions(table)
|
| 231 |
-
|
| 232 |
-
tables.append(table_info)
|
| 233 |
-
tables.sort()
|
| 234 |
-
final_str = "\n\n".join(tables)
|
| 235 |
-
return final_str
|
| 236 |
-
|
| 237 |
-
def _get_sample_rows(self, table: Table) -> str:
|
| 238 |
-
# build the select command
|
| 239 |
-
command = select(table).order_by(func.random()).limit(self._sample_rows_in_table_info)
|
| 240 |
-
|
| 241 |
-
# save the columns in string format
|
| 242 |
-
columns_str = ";".join([f'"{col.name}"' for col in table.columns])
|
| 243 |
-
|
| 244 |
-
try:
|
| 245 |
-
# get the sample rows
|
| 246 |
-
with self._engine.connect() as connection:
|
| 247 |
-
sample_rows_result = connection.execute(command) # type: ignore
|
| 248 |
-
# shorten values in the sample rows
|
| 249 |
-
sample_rows = list(
|
| 250 |
-
map(lambda ls: [str(i)[:100] for i in ls], sample_rows_result)
|
| 251 |
-
)
|
| 252 |
-
|
| 253 |
-
# save the sample rows in string format
|
| 254 |
-
sample_rows_str = "\n".join([";".join(row) for row in sample_rows])
|
| 255 |
-
|
| 256 |
-
# in some dialects when there are no rows in the table a
|
| 257 |
-
# 'ProgrammingError' is returned
|
| 258 |
-
except ProgrammingError:
|
| 259 |
-
sample_rows_str = ""
|
| 260 |
-
|
| 261 |
-
return (
|
| 262 |
-
f"{self._sample_rows_in_table_info} rows from {table.name} table:\n"
|
| 263 |
-
f"{columns_str}\n"
|
| 264 |
-
f"{sample_rows_str}"
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
def _execute(
|
| 268 |
-
self,
|
| 269 |
-
command: str,
|
| 270 |
-
fetch: str = "all",
|
| 271 |
-
):
|
| 272 |
-
with self._engine.begin() as connection:
|
| 273 |
-
cursor = connection.execute(text(command))
|
| 274 |
-
print(cursor.__dict__)
|
| 275 |
-
|
| 276 |
-
if cursor.returns_rows:
|
| 277 |
-
fields = list(cursor.keys())
|
| 278 |
-
result = [dict(zip(fields,row)) for row in cursor.fetchall()]
|
| 279 |
-
return result
|
| 280 |
-
return []
|
| 281 |
-
|
| 282 |
def clean_sql(s: str) -> str:
|
|
|
|
|
|
|
| 283 |
s = s.replace("```sql", "")
|
| 284 |
for symb in ["'", '"']:
|
| 285 |
if s.startswith(symb) and s.endswith(symb):
|
|
@@ -296,13 +249,19 @@ def clean_sql(s: str) -> str:
|
|
| 296 |
if s.endswith("TOP 1"):
|
| 297 |
s = s.replace("TOP 1","").strip()
|
| 298 |
s = s.replace("SELECT","SELECT TOP 1")
|
|
|
|
|
|
|
|
|
|
| 299 |
return s
|
| 300 |
|
| 301 |
def get_metadata_info(metadata_df, table_names):
|
| 302 |
str = ""
|
| 303 |
for table in table_names:
|
| 304 |
-
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
return str
|
| 308 |
|
|
@@ -356,7 +315,12 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
|
|
| 356 |
#print ("input key", self.input_key)
|
| 357 |
#print ("===============")
|
| 358 |
|
| 359 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 360 |
_run_manager.on_text(input_text, verbose=self.verbose)
|
| 361 |
# If not present, then defaults to None which is all tables.
|
| 362 |
table_names_to_use = inputs.get("table_names_to_use")
|
|
@@ -364,7 +328,7 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
|
|
| 364 |
table_info += get_metadata_info(metadata_df, table_names_to_use)
|
| 365 |
llm_inputs = {
|
| 366 |
"input": input_text,
|
| 367 |
-
"history":
|
| 368 |
"top_k": str(self.top_k),
|
| 369 |
"dialect": self.database.dialect,
|
| 370 |
"table_info": table_info,
|
|
@@ -426,6 +390,9 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
|
|
| 426 |
# self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
|
| 427 |
self.intermediate_steps['query_explanation'] = explanation
|
| 428 |
|
|
|
|
|
|
|
|
|
|
| 429 |
except:
|
| 430 |
try:
|
| 431 |
sql_data_new = sql_data[-20:] + sql_data[:20]
|
|
@@ -451,11 +418,15 @@ class SQLDatabaseChainPatched(SQLDatabaseChain):
|
|
| 451 |
).strip()
|
| 452 |
# self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
|
| 453 |
self.intermediate_steps['query_explanation'] = explanation
|
|
|
|
|
|
|
|
|
|
| 454 |
except Exception as exc:
|
| 455 |
# Append intermediate steps to exception, to aid in logging and later
|
| 456 |
# improvement of few shot prompt seeds
|
| 457 |
-
exc.intermediate_steps = self.intermediate_steps # type: ignore
|
| 458 |
-
raise exc
|
| 459 |
-
|
|
|
|
| 460 |
|
| 461 |
|
|
|
|
| 142 |
except:
|
| 143 |
return query
|
| 144 |
|
| 145 |
+
def translate_to_english(llm, user_query):
|
| 146 |
+
sys_prompt = """
|
| 147 |
+
You are an AI assistant that translates a text to English. \
|
| 148 |
+
Do not generate any irrelavant text, only return the translation."""
|
| 149 |
+
|
| 150 |
+
message1 = SystemMessage(content=sys_prompt)
|
| 151 |
+
|
| 152 |
+
message2 = HumanMessage(
|
| 153 |
+
content=user_query
|
| 154 |
+
)
|
| 155 |
+
message_log = [message1, message2]
|
| 156 |
+
|
| 157 |
+
output = llm.invoke(message_log)
|
| 158 |
+
pred = output.content
|
| 159 |
+
|
| 160 |
+
return pred
|
| 161 |
+
|
| 162 |
+
def translate(llm, user_query, to_translate):
|
| 163 |
+
sys_prompt = """
|
| 164 |
+
You are an AI assistant that determines the language given a user query - {} and translate the provided text in that target language. \
|
| 165 |
+
Do not generate any irrelavant text, only return the translation. \
|
| 166 |
+
If the user query is in English, then don't do anything just return the original text, no translation is required there.""".format(user_query)
|
| 167 |
+
|
| 168 |
+
message1 = SystemMessage(content=sys_prompt)
|
| 169 |
+
|
| 170 |
+
message2 = HumanMessage(
|
| 171 |
+
content=to_translate
|
| 172 |
+
)
|
| 173 |
+
message_log = [message1, message2]
|
| 174 |
+
|
| 175 |
+
output = llm.invoke(message_log)
|
| 176 |
+
pred = output.content
|
| 177 |
+
|
| 178 |
+
return pred
|
| 179 |
|
| 180 |
warnings.filterwarnings('ignore', message="pandas only supports SQLAlchemy connectable.*", category=UserWarning, module='chain')
|
| 181 |
|
| 182 |
intermediate_steps_KEY = "intermediate_steps"
|
| 183 |
|
| 184 |
+
template = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
You are a database expert. Given an input question, first create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
| 187 |
+
The final answer should be in a concise natural language.
|
| 188 |
+
|
| 189 |
+
Use the history if you can not understand the question.
|
| 190 |
+
|
| 191 |
+
If the question is in another language, translate it to English before proceeding.
|
| 192 |
+
|
| 193 |
+
Do not repeat the question while generating the SQL query.
|
| 194 |
+
|
| 195 |
+
Only generate a correct {dialect} query.
|
| 196 |
+
|
| 197 |
+
Once the SQLResult is available, generate the final answer in natural language format. Do not regenerate the question or SQL query in the final answer.
|
| 198 |
+
|
| 199 |
+
If the question asks any information for any particular number of days, use the lookback from the maximum date in the table, not from today's date.
|
| 200 |
+
|
| 201 |
+
Please note that MSSQL does not use LIMIT, but uses TOP clause.
|
| 202 |
+
|
| 203 |
+
You may also need to resolve the column name, as per the metadata. For instance, if the user asks about families and the column name is family, you should use family in the generated SQL.
|
| 204 |
+
|
| 205 |
+
Make sure that the column names are present in the table, by looking at the metadata.
|
| 206 |
+
|
| 207 |
+
If a question asks about availability over a period of time, you need to use SUM to calculate the total availability over that time period.
|
| 208 |
+
|
| 209 |
+
If a question mentions SKU, then use SKU column for filter, do not use any other column like comodity
|
| 210 |
+
|
| 211 |
+
If a question asks about AV of shortage, do not use AV in the SQL query as AV is not a valid column name. AV is the key in the Shortage column.
|
| 212 |
+
|
| 213 |
+
In the OpenOrderShotage table, the column Item should be used to extract the part ids, to answer questions related to shortage.
|
| 214 |
+
|
| 215 |
+
In the OpenOrderShotage table, Customer_Part_Name column is equivalent to SKU.
|
| 216 |
+
|
| 217 |
+
Use the following format:
|
| 218 |
+
Question: Question here
|
| 219 |
+
SQLQuery: SQL Query to run
|
| 220 |
+
SQLResult: Result of the SQLQuery
|
| 221 |
+
Answer: Final answer here.
|
| 222 |
+
|
| 223 |
+
Only use the following tables:
|
| 224 |
+
{table_info}
|
| 225 |
+
Question: {input}
|
| 226 |
+
"""
|
| 227 |
|
| 228 |
def num_tokens_from_string(string: str, encoding_name: str) -> int:
|
| 229 |
encoding = tiktoken.get_encoding(encoding_name)
|
| 230 |
num_tokens = len(encoding.encode(string))
|
| 231 |
return num_tokens
|
| 232 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
def clean_sql(s: str) -> str:
|
| 234 |
+
#s = s.replace("SQL:","").strip()
|
| 235 |
+
#s = s.replace("Let's execute these queries step-by-step to get the final answer.","").strip()
|
| 236 |
s = s.replace("```sql", "")
|
| 237 |
for symb in ["'", '"']:
|
| 238 |
if s.startswith(symb) and s.endswith(symb):
|
|
|
|
| 249 |
if s.endswith("TOP 1"):
|
| 250 |
s = s.replace("TOP 1","").strip()
|
| 251 |
s = s.replace("SELECT","SELECT TOP 1")
|
| 252 |
+
|
| 253 |
+
s = s.split("SQLQuery:")[-1].strip()
|
| 254 |
+
|
| 255 |
return s
|
| 256 |
|
| 257 |
def get_metadata_info(metadata_df, table_names):
|
| 258 |
str = ""
|
| 259 |
for table in table_names:
|
| 260 |
+
try:
|
| 261 |
+
str += "The following metadata is for the table " + table + "\n"
|
| 262 |
+
str += metadata_df[metadata_df.table == table].desc.iloc[0]
|
| 263 |
+
except:
|
| 264 |
+
pass
|
| 265 |
|
| 266 |
return str
|
| 267 |
|
|
|
|
| 315 |
#print ("input key", self.input_key)
|
| 316 |
#print ("===============")
|
| 317 |
|
| 318 |
+
orig_question = inputs[self.input_key]
|
| 319 |
+
history = inputs['history'].copy()
|
| 320 |
+
history.reverse()
|
| 321 |
+
|
| 322 |
+
inputs[self.input_key] = translate_to_english(self.llms['4k'], inputs[self.input_key])
|
| 323 |
+
input_text = f"{inputs[self.input_key]} \nHistory: {history} \nSQLQuery:"
|
| 324 |
_run_manager.on_text(input_text, verbose=self.verbose)
|
| 325 |
# If not present, then defaults to None which is all tables.
|
| 326 |
table_names_to_use = inputs.get("table_names_to_use")
|
|
|
|
| 328 |
table_info += get_metadata_info(metadata_df, table_names_to_use)
|
| 329 |
llm_inputs = {
|
| 330 |
"input": input_text,
|
| 331 |
+
"history": history,
|
| 332 |
"top_k": str(self.top_k),
|
| 333 |
"dialect": self.database.dialect,
|
| 334 |
"table_info": table_info,
|
|
|
|
| 390 |
# self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
|
| 391 |
self.intermediate_steps['query_explanation'] = explanation
|
| 392 |
|
| 393 |
+
if 'result' in self.intermediate_steps:
|
| 394 |
+
self.intermediate_steps['translated_result'] = translate(self.llms['4k'], orig_question, self.intermediate_steps['result'])
|
| 395 |
+
|
| 396 |
except:
|
| 397 |
try:
|
| 398 |
sql_data_new = sql_data[-20:] + sql_data[:20]
|
|
|
|
| 418 |
).strip()
|
| 419 |
# self.llm_chain = self.revert_to_small_model(chain=self.llm_chain)
|
| 420 |
self.intermediate_steps['query_explanation'] = explanation
|
| 421 |
+
if 'result' in self.intermediate_steps:
|
| 422 |
+
self.intermediate_steps['translated_result'] = translate(self.llms['4k'], orig_question, self.intermediate_steps['result'])
|
| 423 |
+
|
| 424 |
except Exception as exc:
|
| 425 |
# Append intermediate steps to exception, to aid in logging and later
|
| 426 |
# improvement of few shot prompt seeds
|
| 427 |
+
#exc.intermediate_steps = self.intermediate_steps # type: ignore
|
| 428 |
+
#raise exc
|
| 429 |
+
self.intermediate_steps['result'] = "I don't know the answer for this."
|
| 430 |
+
self.intermediate_steps['translated_result'] = "I don't know the answer for this."
|
| 431 |
|
| 432 |
|