Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import sqlite3 | |
| import pandas as pd | |
| import openai | |
| import os | |
| from langchain.embeddings.openai import OpenAIEmbeddings | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain.text_splitter import CharacterTextSplitter | |
| from langchain.vectorstores import Chroma | |
| os.environ["OPENAI_API_KEY"] = os.getenv("SECRET_KEY") | |
| def init_database(): | |
| conn = sqlite3.connect('GPTPromptTemplates.db') | |
| cursor = conn.cursor() | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS USERS ( | |
| USER_ID INTEGER PRIMARY KEY, | |
| User_Name VARCHAR(255) | |
| ) | |
| ''') | |
| cursor.execute(''' | |
| CREATE TABLE IF NOT EXISTS TEMPLATES ( | |
| TEMPLATE_ID INTEGER PRIMARY KEY, | |
| USER_ID INTEGER, | |
| Prompt_Name VARCHAR(255), | |
| Prompt_Text TEXT | |
| ) | |
| ''') | |
| cursor.execute(''' | |
| CREATE UNIQUE INDEX IF NOT EXISTS idx_templates_prompt_name ON TEMPLATES (USER_ID, Prompt_Name) | |
| ''') | |
| conn.commit() | |
| conn.close() | |
| def insert_prompt_template(user_id, prompt_name, prompt_text): | |
| conn = sqlite3.connect('GPTPromptTemplates.db') | |
| cursor = conn.cursor() | |
| cursor.execute('INSERT OR REPLACE INTO TEMPLATES (USER_ID, Prompt_Name, Prompt_Text) VALUES (?, ?, ?)', (user_id, prompt_name, prompt_text)) | |
| conn.commit() | |
| conn.close() | |
| def delete_prompt_template(user_id, prompt_name): | |
| conn = sqlite3.connect('GPTPromptTemplates.db') | |
| cursor = conn.cursor() | |
| cursor.execute('DELETE FROM TEMPLATES WHERE USER_ID = ? AND prompt_name = ?', (user_id, prompt_name)) | |
| conn.commit() | |
| conn.close() | |
| def get_prompt(user_id, prompt_name): | |
| conn = sqlite3.connect('GPTPromptTemplates.db') | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT Prompt_Name, Prompt_Text FROM TEMPLATES WHERE Prompt_Name = ? AND USER_ID = ?', (prompt_name, user_id)) | |
| template = cursor.fetchone() | |
| conn.close() | |
| if template == None: | |
| return '','' | |
| else: | |
| return template[0], template[1] | |
| def get_default_prompt(user_id): | |
| conn = sqlite3.connect('GPTPromptTemplates.db') | |
| cursor = conn.cursor() | |
| cursor.execute('SELECT Prompt_Name, Prompt_Text FROM TEMPLATES WHERE USER_ID = ? ORDER BY Prompt_Name ASC LIMIT 1', (user_id, )) | |
| template = cursor.fetchone() | |
| conn.close() | |
| if template == None: | |
| return '','' | |
| else: | |
| return template[0], template[1] | |
| def get_prompt_list(user_id): | |
| conn = sqlite3.connect('GPTPromptTemplates.db') | |
| templates = pd.read_sql_query('SELECT DISTINCT Prompt_Name FROM TEMPLATES WHERE USER_ID = {} ORDER BY Prompt_Name ASC'.format(user_id), conn) | |
| conn.commit() | |
| conn.close() | |
| return templates | |
| def template_change_value(): | |
| name, prompt = get_prompt(st.session_state.user_id, st.session_state.template_select) | |
| st.session_state.name = name | |
| st.session_state.prompt = prompt | |
| def template_return_value(template_name): | |
| st.session_state.template_select = template_name | |
| name, prompt = get_prompt(st.session_state.user_id, st.session_state.template_select) | |
| st.session_state.name = name | |
| st.session_state.prompt = prompt | |
| def main(): | |
| st.title("Working with Chat GPT with templates") | |
| init_database() | |
| col1, col2, col3 = st.columns([1,1,1]) | |
| user_id = 1 | |
| name, prompt = get_default_prompt(user_id) | |
| prompt_list = get_prompt_list(user_id) | |
| model_names = ['gpt-4','gpt-3.5-turbo','gpt-3.5-turbo-16k'] | |
| if not "initialized" in st.session_state: | |
| st.session_state.user_id = user_id | |
| st.session_state.name = name | |
| st.session_state.prompt = prompt | |
| st.session_state.prompt_list = prompt_list | |
| st.session_state.template_select = name | |
| st.session_state.output = '' | |
| st.session_state.model_name = 'gpt-4' | |
| st.session_state.initialized = True | |
| with col1: | |
| input_text = st.text_area('Please insert data for transforming', '', key="input_data", height=450) | |
| if st.button("Apply"): | |
| query = prompt | |
| with st.spinner('In progress...'): | |
| # st.write("in progress") | |
| # text_splitter = CharacterTextSplitter(chunk_size=4096, chunk_overlap=0) | |
| # texts = text_splitter.split_text(input_text) | |
| # embeddings = OpenAIEmbeddings() | |
| # docsearch = Chroma.from_texts(texts, embeddings, metadatas=[{"source": str(i)} for i in range(len(texts))]).as_retriever() | |
| # docs = docsearch.get_relevant_documents(query) | |
| if st.session_state.model_name == 'gpt-4': | |
| max_tkns=5500 | |
| else : | |
| max_tkns=3000 | |
| openai.api_key = os.environ["OPENAI_API_KEY"] | |
| response = openai.ChatCompletion.create( | |
| model=st.session_state.model_name, | |
| messages=[ | |
| {"role": "system", "content": query}, | |
| {"role": "user", "content": input_text}, | |
| ], | |
| temperature = 0.7, | |
| max_tokens=max_tkns | |
| ) | |
| st.session_state.output = response["choices"][0]["message"]["content"].replace("\\n", "\n") | |
| # chain = load_qa_chain(ChatOpenAI(model = st.session_state.model_name,max_tokens=max_tkns,temperature=0), chain_type="stuff") | |
| # st.session_state.output = chain.run(input_documents=docs, question=query) | |
| #st.session_state["output"] = output | |
| #col3.text_area('Result', value=output, key="output_data", height=450) | |
| st.experimental_rerun() | |
| st.success("Ready!") | |
| with col2: | |
| st.session_state.model_name = st.selectbox("GPT model: ",model_names, key="gpt_model") | |
| template_return_value(st.selectbox("Template: ",st.session_state.prompt_list, key="prompt_template",)) | |
| new_name = st.text_input("Template name:",value=st.session_state.name, key="template_name") | |
| input_query = st.text_area("Prompt:",value=st.session_state.prompt, key="template_text", height=200) | |
| col4, col5 = st.columns([1,1]) | |
| if col4.button("Save"): | |
| insert_prompt_template(user_id, new_name, input_query) | |
| st.session_state.prompt_list = get_prompt_list(user_id) | |
| st.success("Prompt saved!") | |
| st.experimental_rerun() | |
| if col5.button("Delete"): | |
| delete_prompt_template(user_id, new_name) | |
| st.session_state.prompt_list = get_prompt_list(user_id) | |
| st.success("Prompt deleted!") | |
| st.experimental_rerun() | |
| with col3: | |
| txt_result = st.text_area('Result', value=st.session_state.output, key="output_data", height=450) | |
| if __name__ == "__main__": | |
| main() |