Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| from ast import literal_eval | |
| from sentence_transformers import SentenceTransformer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| import numpy as np | |
| import warnings | |
| from typing import Any, Dict, List, Optional | |
| from pydantic import BaseModel, Field | |
| import os | |
| import pprint | |
| import tiktoken | |
| from tqdm import tqdm | |
| from langchain_experimental.sql import SQLDatabaseChain | |
| from langchain.callbacks.manager import CallbackManagerForChainRun | |
| from langchain.chains.llm import LLMChain | |
| from langchain.prompts.prompt import PromptTemplate | |
| from langchain.tools.sql_database.prompt import QUERY_CHECKER | |
| import pandas as pd | |
| from sqlalchemy.schema import CreateTable, CreateColumn | |
| from sqlalchemy.types import NullType | |
| from sqlalchemy import MetaData, Table, create_engine, inspect, select, text | |
| from sqlalchemy.sql.expression import func, select | |
| from langchain.callbacks.manager import CallbackManagerForChainRun | |
| from langchain.chains.base import Chain | |
| from langchain.chains.llm import LLMChain | |
| from langchain.chains.sql_database.prompt import DECIDER_PROMPT, PROMPT, SQL_PROMPTS | |
| from langchain.prompts.prompt import PromptTemplate | |
| from langchain.schema import BasePromptTemplate | |
| from langchain.schema.language_model import BaseLanguageModel | |
| from langchain.utilities.sql_database import SQLDatabase | |
| from langchain_experimental.pydantic_v1 import Extra, Field, root_validator | |
| emb_model = SentenceTransformer("all-MiniLM-L6-v2") | |
| class EmbeddingsSearch: | |
| def __init__(self, metadata_df, emb_model): | |
| self.model = emb_model | |
| self.metadata_df = metadata_df | |
| self.embeddings = self.model.encode(self.metadata_df['final_metadata'].tolist()) | |
| def __call__(self, text: str, topk: int = 5): | |
| q_emb = self.model.encode([text]) | |
| distances = cosine_similarity(q_emb, self.embeddings) | |
| idx = np.flip(distances.argsort())[0] | |
| distances.sort() | |
| distances = np.flip(distances)[0] | |
| results = pd.DataFrame() | |
| results['idx'] = idx.tolist()[:topk] | |
| results['distances'] = distances.tolist()[:topk] | |
| results['table'] = [ | |
| self.metadata_df.loc[i, "table"] for i in results['idx'] | |
| ] | |
| return results | |
| #xls = pd.ExcelFile('SmartClever table explanations updated.xlsx') | |
| #metadata_df = pd.DataFrame() | |
| #i = 0 | |
| #sheet_to_df_map = {} | |
| #for k, sheet_name in enumerate(xls.sheet_names): | |
| # if k > 0: | |
| # sheet_to_df_map[sheet_name.strip()] = xls.parse(sheet_name, header=None) | |
| # sheet_to_df_map[sheet_name.strip()].columns = sheet_to_df_map[sheet_name.strip()].iloc[1] | |
| # sheet_to_df_map[sheet_name.strip()] = sheet_to_df_map[sheet_name.strip()].iloc[:1].fillna('') | |
| # sheet_to_df_map[sheet_name.strip()]['metadata'] = sheet_to_df_map[sheet_name.strip()].apply(lambda x: \ | |
| # ". ".join([x[col] for col in sheet_to_df_map[sheet_name.strip()].columns]), axis=1) | |
| # metadata_df.loc[i, "table"] = sheet_name.strip() | |
| # metadata_df.loc[i, "desc"] = sheet_to_df_map[sheet_name.strip()]['metadata'].iloc[0] | |
| # | |
| # i += 1 | |
| #metadata_df2 = xls.parse('Table explanations',header=1).dropna(axis=0,how='all').dropna(axis=1,how='all') | |
| #metadata_df2.columns = ['table','metadata'] | |
| #metadata_df2.table = metadata_df2.table.apply(lambda x: x.strip()) | |
| #metadata_df = pd.merge(metadata_df, metadata_df2, how='inner') | |
| xls = pd.ExcelFile('SmartClever table explanations_V5.xlsx') | |
| metadata_df = pd.DataFrame() | |
| i = 0 | |
| sheet_to_df_map = {} | |
| for k, sheet_name in enumerate(xls.sheet_names): | |
| if k > 0: | |
| sheet_to_df_map[sheet_name.strip()] = xls.parse(sheet_name, header=None) | |
| sheet_to_df_map[sheet_name.strip()].columns = sheet_to_df_map[sheet_name.strip()].iloc[1] | |
| sheet_to_df_map[sheet_name.strip()] = sheet_to_df_map[sheet_name.strip()].iloc[:1].fillna('') | |
| sheet_to_df_map[sheet_name.strip()]['metadata'] = sheet_to_df_map[sheet_name.strip()].apply(lambda x: \ | |
| ". ".join([x[col] for col in sheet_to_df_map[sheet_name.strip()].columns]), axis=1) | |
| metadata_df.loc[i, "table"] = sheet_name.strip() | |
| metadata_df.loc[i, "desc"] = sheet_to_df_map[sheet_name.strip()]['metadata'].iloc[0] | |
| i += 1 | |
| metadata_df2 = xls.parse('Table explanations',header=1).dropna(axis=0,how='all').dropna(axis=1,how='all') | |
| metadata_df2.columns = ['table','nickname','metadata'] | |
| metadata_df2.table = metadata_df2.table.apply(lambda x: x.strip()) | |
| metadata_df = pd.merge(metadata_df, metadata_df2, how='inner') | |
| table_desc = pd.read_csv("table_desc.csv", lineterminator='\n') | |
| table_desc.columns = ['table','desc'] | |
| metadata_df = metadata_df.drop(['desc'], axis=1) | |
| metadata_df = pd.merge(metadata_df, table_desc, how='inner') | |
| metadata_df['final_metadata'] = metadata_df.apply(lambda x: x["desc"] + "\n" + x['metadata'], axis=1) | |
| #metadata_df.loc[metadata_df.table == 'History_All_Skus_Availability', 'table'] = 'TBL_History_All_Skus_Availability' | |
| #metadata_df.loc[metadata_df.table == 'daily_inventory', 'table'] = 'TBL_DAILY_INVENTORY' | |
| #metadata_df.loc[metadata_df.table == 'HISTORY_OpenOrderShortage', 'table'] = 'TBL_HISTORY_OpenOrderShortage' | |
| metadata_df.loc[metadata_df.table == 'daily_inventory', 'table'] = 'DAILY_INVENTORY' | |
| table_search = EmbeddingsSearch(metadata_df=metadata_df, emb_model=emb_model) | |
| def extract_question_type(llm, query): | |
| sys_prompt = """ | |
| You are an AI assistant that determines if a user provided question can be answered from the given tables. | |
| The metadata of the tables are provided here - {}. \ | |
| If the question can be answered return yes. | |
| If the question is a generic one and cannot be answered using these tables, return no. | |
| Note that any question specific to families, commodities, products, forecasts, SKUs can be related to the tables, so return yes.""".format(metadata_df[['table','metadata']].to_string()) | |
| messages = [ | |
| ("system", sys_prompt), | |
| ("human", query), | |
| ] | |
| output = llm.invoke(messages) | |
| pred = output.content | |
| return pred | |
| def extract_table_name(llm, query): | |
| messages = [ | |
| ( | |
| "system", | |
| """ | |
| You are an AI assistant that determines the most relevant table name given a user query. Following is the metadata information you need to use to determine the most relevant table.\ | |
| {}.""".format(metadata_df[['table','metadata']].to_string()), | |
| ), | |
| ("human", query), | |
| ] | |
| output = llm.invoke(messages) | |
| pred = output.content | |
| tables = [] | |
| for table in metadata_df.table.unique(): | |
| if table in pred: | |
| tables.append(table) | |
| return tables | |
| def extract_question_list(llm, query): | |
| sys_prompt = """You are an AI assistant that determines if multiple questions are stacked in a single question and split the question into sub questions and return a list of them. Make sure the response is a valid Python list. | |
| If the question is a single question and return the original question. | |
| Please do not add any additional text, only return the final response.""" | |
| messages = [ | |
| ( | |
| "system", | |
| sys_prompt, | |
| ), | |
| ("human", query), | |
| ] | |
| output = llm.invoke(messages) | |
| pred = output.content | |
| try: | |
| return literal_eval(pred) | |
| except: | |
| return query | |
| def translate_to_english(llm, user_query): | |
| sys_prompt = """ | |
| You are an AI assistant that translates a text to English. \ | |
| Do not generate any irrelavant text, only return the translation.""" | |
| messages = [ | |
| ( | |
| "system", | |
| sys_prompt, | |
| ), | |
| ("human", user_query), | |
| ] | |
| output = llm.invoke(messages) | |
| pred = output.content | |
| return pred | |
| def translate(llm, user_query, to_translate): | |
| sys_prompt = """ | |
| You are an AI assistant that determines the language of the following text - {} and translate the user provided text in that language. \ | |
| Do not generate any irrelavant text, only return the translation.""".format(user_query) | |
| messages = [ | |
| ( | |
| "system", | |
| sys_prompt, | |
| ), | |
| ("human", to_translate), | |
| ] | |
| output = llm.invoke(messages) | |
| pred = output.content | |
| return pred | |
| def num_tokens_from_string(string: str, encoding_name: str) -> int: | |
| encoding = tiktoken.get_encoding(encoding_name) | |
| num_tokens = len(encoding.encode(string)) | |
| return num_tokens | |
| def clean_sql(s: str) -> str: | |
| #s = s.replace("SQL:","").strip() | |
| #s = s.replace("Let's execute these queries step-by-step to get the final answer.","").strip() | |
| s = s.replace("```sql", "") | |
| for symb in ["'", '"']: | |
| if s.startswith(symb) and s.endswith(symb): | |
| s = s[1:-1] | |
| for symb in [";"]: | |
| if s.endswith(symb): | |
| s = s[:-1] | |
| s = s.replace("```", "") | |
| s = s.replace("\n", " ") | |
| s = s.replace("\t", " ") | |
| if "LIMIT 1" in s: | |
| s = s.replace("LIMIT 1","").strip() | |
| s = s.replace("SELECT","SELECT TOP 1") | |
| if s.endswith("TOP 1"): | |
| s = s.replace("TOP 1","").strip() | |
| s = s.replace("SELECT","SELECT TOP 1") | |
| s = s.split("SQLQuery:")[-1].strip() | |
| return s | |
| def get_metadata_info(metadata_df, table_names): | |
| str = "" | |
| for table in table_names: | |
| try: | |
| str += "The following metadata is for the table " + table + "\n" | |
| #str += metadata_df[metadata_df.table == table].final_metadata.iloc[0] | |
| str += metadata_df[metadata_df.table == table].desc.iloc[0] | |
| except: | |
| pass | |
| return str | |
| class SQLDatabaseChainPatched(SQLDatabaseChain): | |
| intermediate_steps: List[Any] = Field(default_factory=list) | |
| llms: Dict[str, Any] = Field(default_factory=dict) | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.return_intermediate_steps = True | |
| self.intermediate_steps = [] | |
| def set_llms(self, llms): | |
| self.llms = llms | |
| print("Set llms") | |
| print(self.llms) | |
| def prepare_llm(self, inputs, chain, replace_llm: bool = False): | |
| # this function is used to monkey path llm in case num tokens is above max num tokens for a small | |
| # 4k model | |
| # after llm call we need to call `revert_to_small_model` function to revert to small 4k model | |
| # get number of tokens in the input prompt | |
| selected_inputs = {k: inputs[k] for k in chain.prompt.input_variables} | |
| prompt = chain.prompt.format_prompt(**selected_inputs) | |
| #print ("==================================") | |
| #print (prompt) | |
| #print ("==================================") | |
| # https://stackoverflow.com/questions/75804599/openai-api-how-do-i-count-tokens-before-i-send-an-api-request | |
| n_tokens = num_tokens_from_string(string=prompt.text, encoding_name='cl100k_base') | |
| print(f"N tokens in input: {n_tokens}") | |
| if replace_llm: | |
| max_tokens_small_model = 8000 | |
| if n_tokens > max_tokens_small_model * 0.9: | |
| chain.llm = self.llms['16k'] | |
| print("Using large model") | |
| return chain, n_tokens | |
| def revert_to_small_model(self, chain): | |
| chain.llm = self.llms['4k'] | |
| print("Reverted model to 4k") | |
| return chain | |
| def _call( | |
| self, | |
| inputs: Dict[str, Any], | |
| run_manager: Optional[CallbackManagerForChainRun] = None, | |
| ) -> Dict[str, Any]: | |
| _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() | |
| #print ("===============") | |
| #print ("input key", self.input_key) | |
| #print ("===============") | |
| orig_question = inputs[self.input_key] | |
| history = inputs['history'].copy() | |
| history.reverse() | |
| #inputs[self.input_key] = translate_to_english(self.llms['4k'], inputs[self.input_key]) | |
| input_text = f"{inputs[self.input_key]} \nHistory: {history} \nSQLQuery:" | |
| _run_manager.on_text(input_text, verbose=self.verbose) | |
| # If not present, then defaults to None which is all tables. | |
| table_names_to_use = inputs.get("table_names_to_use") | |
| table_info = self.database.get_table_info(table_names=table_names_to_use) | |
| table_info += get_metadata_info(metadata_df, table_names_to_use) | |
| llm_inputs = { | |
| "input": input_text, | |
| "history": history, | |
| "top_k": str(self.top_k), | |
| "dialect": self.database.dialect, | |
| "table_info": table_info, | |
| "stop": ["\nSQLResult:"], | |
| } | |
| if self.memory is not None: | |
| for k in self.memory.memory_variables: | |
| llm_inputs[k] = inputs[k] | |
| self.intermediate_steps = {} | |
| # remove table info due to large size | |
| self.intermediate_steps['llm_inputs'] = {} | |
| for k, v in llm_inputs.items(): | |
| if k not in ['table_info']: | |
| self.intermediate_steps['llm_inputs'][k] = v | |
| # list to store estimated num of tokens | |
| self.intermediate_steps['n_tokens_list'] = [] | |
| input_text_bkp = input_text | |
| try: | |
| # get sql | |
| self.llm_chain, n_tokens1 = self.prepare_llm(llm_inputs, chain=self.llm_chain) | |
| # self.intermediate_steps['n_tokens_list'].append(n_tokens1) | |
| sql_cmd = self.llm_chain.predict( | |
| callbacks=_run_manager.get_child(), | |
| **llm_inputs, | |
| ).strip() | |
| # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain) | |
| self.intermediate_steps['sql_cmd_unchecked'] = sql_cmd | |
| self.intermediate_steps['sql_cmd'] = clean_sql(sql_cmd) | |
| # run sql | |
| sql_data = self.database._execute(self.intermediate_steps['sql_cmd'], fetch='all') | |
| self.intermediate_steps['sql_data'] = sql_data | |
| # provide human answer | |
| input_text += f"{sql_cmd}\nSQLResult: {str(sql_data)}\nAnswer:" | |
| llm_inputs["input"] = input_text | |
| self.llm_chain, n_tokens3 = self.prepare_llm(llm_inputs, chain=self.llm_chain) | |
| # self.intermediate_steps['n_tokens_list'].append(n_tokens3) | |
| final_result = self.llm_chain.predict( | |
| callbacks=_run_manager.get_child(), | |
| **llm_inputs, | |
| ).strip() | |
| # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain) | |
| self.intermediate_steps['result'] = final_result | |
| # provide explanation | |
| input_text += f"{final_result}\nExplanation:" | |
| llm_inputs["input"] = input_text | |
| self.llm_chain, n_tokens4 = self.prepare_llm(llm_inputs, chain=self.llm_chain) | |
| # self.intermediate_steps['n_tokens_list'].append(n_tokens3) | |
| explanation = self.llm_chain.predict( | |
| callbacks=_run_manager.get_child(), | |
| **llm_inputs, | |
| ).strip() | |
| # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain) | |
| self.intermediate_steps['query_explanation'] = explanation | |
| #if 'result' in self.intermediate_steps: | |
| # self.intermediate_steps['translated_result'] = translate(self.llms['4k'], orig_question, self.intermediate_steps['result']) | |
| except: | |
| try: | |
| sql_data_new = sql_data[-20:] + sql_data[:20] | |
| input_text = input_text_bkp + f"{sql_cmd}\nSQLResult: {str(sql_data_new)}\nAnswer:" | |
| llm_inputs["input"] = input_text | |
| self.llm_chain, n_tokens3 = self.prepare_llm(llm_inputs, chain=self.llm_chain) | |
| # self.intermediate_steps['n_tokens_list'].append(n_tokens3) | |
| final_result = self.llm_chain.predict( | |
| callbacks=_run_manager.get_child(), | |
| **llm_inputs, | |
| ).strip() | |
| # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain) | |
| self.intermediate_steps['result'] = final_result | |
| # provide explanation | |
| input_text += f"{final_result}\nExplanation:" | |
| llm_inputs["input"] = input_text | |
| self.llm_chain, n_tokens4 = self.prepare_llm(llm_inputs, chain=self.llm_chain) | |
| # self.intermediate_steps['n_tokens_list'].append(n_tokens3) | |
| explanation = self.llm_chain.predict( | |
| callbacks=_run_manager.get_child(), | |
| **llm_inputs, | |
| ).strip() | |
| # self.llm_chain = self.revert_to_small_model(chain=self.llm_chain) | |
| self.intermediate_steps['query_explanation'] = explanation | |
| #if 'result' in self.intermediate_steps: | |
| # self.intermediate_steps['translated_result'] = translate(self.llms['4k'], orig_question, self.intermediate_steps['result']) | |
| except Exception as exc: | |
| # Append intermediate steps to exception, to aid in logging and later | |
| # improvement of few shot prompt seeds | |
| #exc.intermediate_steps = self.intermediate_steps # type: ignore | |
| #raise exc | |
| self.intermediate_steps['result'] = "I don't know the answer for this." | |
| #self.intermediate_steps['translated_result'] = "I don't know the answer for this." | |