Spaces:
Build error
Build error
| import streamlit as st | |
| from langchain_community.utilities.sql_database import SQLDatabase | |
| from langchain.chains import create_sql_query_chain | |
| from langchain_openai import ChatOpenAI | |
| from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool | |
| from langchain.memory import ChatMessageHistory | |
| from operator import itemgetter | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.runnables import RunnablePassthrough | |
| from table_details import create_table_chain | |
| from prompts import create_prompts | |
| def get_db_uri(credentials): | |
| return f"postgresql+psycopg2://{credentials['user']}:{credentials['password']}@{credentials['host']}:{credentials['port']}/{credentials['database']}" | |
| def get_chain(_db_uri, api_key): | |
| """Create the langchain with the provided credentials""" | |
| try: | |
| db = SQLDatabase.from_uri(_db_uri) | |
| llm = ChatOpenAI(temperature=0.7, model="gpt-3.5-turbo", api_key=api_key) | |
| # Get the table chain and prompts | |
| table_chain = create_table_chain(api_key) | |
| final_prompt, answer_prompt = create_prompts(api_key) | |
| generate_query = create_sql_query_chain(llm, db, final_prompt) | |
| execute_query = QuerySQLDataBaseTool(db=db) | |
| rephrase_answer = answer_prompt | llm | StrOutputParser() | |
| chain = ( | |
| RunnablePassthrough.assign(table_names_to_use=table_chain) | | |
| RunnablePassthrough.assign(query=generate_query).assign( | |
| result=itemgetter("query") | execute_query | |
| ) | rephrase_answer | |
| ) | |
| return chain | |
| except Exception as e: | |
| st.error(f"Error creating chain: {str(e)}") | |
| return None | |
| def create_history(messages): | |
| history = ChatMessageHistory() | |
| for message in messages: | |
| if message["role"] == "user": | |
| history.add_user_message(message["content"]) | |
| else: | |
| history.add_ai_message(message["content"]) | |
| return history | |
| def invoke_chain(question, messages, db_credentials, api_key): | |
| try: | |
| db_uri = get_db_uri(db_credentials) | |
| chain = get_chain(db_uri, api_key) | |
| if chain is None: | |
| return "Sorry, I couldn't connect to the database. Please check your credentials." | |
| history = create_history(messages) | |
| response = chain.invoke({ | |
| "question": question, | |
| "top_k": 100, | |
| "messages": history.messages | |
| }) | |
| history.add_user_message(question) | |
| history.add_ai_message(response) | |
| return response | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}" |