Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import os | |
| from langchain.agents import load_tools | |
| from langchain.agents import initialize_agent | |
| from langchain import PromptTemplate, HuggingFaceHub, LLMChain, ConversationChain | |
| from langchain.llms import OpenAI | |
| from langchain.chains.conversation.memory import ConversationBufferMemory | |
| from threading import Lock | |
| import openai | |
| from openai.error import AuthenticationError, InvalidRequestError, RateLimitError | |
| from typing import Optional, Tuple | |
| TOOLS_DEFAULT_LIST = ['serpapi', 'news-api', 'pal-math'] | |
| MAX_TOKENS = 512 | |
| PROMPT_TEMPLATE = PromptTemplate( | |
| input_variables=["original_words"], | |
| template="Restate the following: \n{original_words}\n", | |
| ) | |
| BUG_FOUND_MSG = "Congratulations, you've found a bug in this application!" | |
| AUTH_ERR_MSG = "Please paste your OpenAI key." | |
| news_api_key = os.environ["NEWS_API_KEY"] | |
| def run_chain(chain, inp, capture_hidden_text): | |
| output = "" | |
| hidden_text = None | |
| try: | |
| output = chain.run(input=inp) | |
| except AuthenticationError as ae: | |
| output = AUTH_ERR_MSG | |
| except RateLimitError as rle: | |
| output = "\n\nRateLimitError: " + str(rle) | |
| except ValueError as ve: | |
| output = "\n\nValueError: " + str(ve) | |
| except InvalidRequestError as ire: | |
| output = "\n\nInvalidRequestError: " + str(ire) | |
| except Exception as e: | |
| output = "\n\n" + BUG_FOUND_MSG + ":\n\n" + str(e) | |
| return output, hidden_text | |
| def transform_text(desc, express_chain): | |
| formatted_prompt = PROMPT_TEMPLATE.format( | |
| original_words=desc | |
| ) | |
| generated_text = desc | |
| # replace all newlines with <br> in generated_text | |
| generated_text = generated_text.replace("\n", "\n\n") | |
| return generated_text | |
| class ChatWrapper: | |
| def __init__(self): | |
| self.lock = Lock() | |
| def __call__( | |
| self, api_key: str, inp: str, history: Optional[Tuple[str, str]], chain: Optional[ConversationChain], express_chain: Optional[LLMChain]): | |
| """Execute the chat functionality.""" | |
| self.lock.acquire() | |
| try: | |
| history = history or [] | |
| # If chain is None, that is because no API key was provided. | |
| output = "Please paste your OpenAI key to use this application." | |
| hidden_text = output | |
| if chain and chain != "": | |
| # Set OpenAI key | |
| openai.api_key = api_key | |
| output, hidden_text = run_chain(chain, inp, capture_hidden_text=False) | |
| print('output1', output) | |
| output = transform_text(output, express_chain) | |
| print('output2', output) | |
| text_to_display = output | |
| history.append((inp, text_to_display)) | |
| except Exception as e: | |
| raise e | |
| finally: | |
| self.lock.release() | |
| # return history, history, html_video, temp_file, "" | |
| return history, history | |
| chat = ChatWrapper() | |
| def load_chain(tools_list, llm): | |
| chain = None | |
| express_chain = None | |
| print("\ntools_list", tools_list) | |
| tool_names = tools_list | |
| tools = load_tools(tool_names, llm=llm, news_api_key=news_api_key) | |
| memory = ConversationBufferMemory(memory_key="chat_history") | |
| chain = initialize_agent(tools, llm, agent="zero-shot-react-description", verbose=True, memory=memory) | |
| express_chain = LLMChain(llm=llm, prompt=PROMPT_TEMPLATE, verbose=True) | |
| return chain, express_chain | |
| def set_openai_api_key(api_key): | |
| """Set the api key and return chain. | |
| If no api_key, then None is returned. | |
| """ | |
| os.environ["OPENAI_API_KEY"] = api_key | |
| llm = OpenAI(temperature=0, max_tokens=MAX_TOKENS) | |
| chain, express_chain = load_chain(TOOLS_DEFAULT_LIST, llm) | |
| os.environ["OPENAI_API_KEY"] = "" | |
| return chain, express_chain, llm | |
| with gr.Blocks() as app: | |
| llm_state = gr.State() | |
| history_state = gr.State() | |
| chain_state = gr.State() | |
| express_chain_state = gr.State() | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML( | |
| """<b><center>GPT + Google</center></b>""") | |
| openai_api_key_textbox = gr.Textbox(placeholder="Paste your OpenAI API key (sk-...)", | |
| show_label=False, lines=1, type='password') | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot() | |
| with gr.Row(): | |
| message = gr.Textbox(label="What's on your mind??", | |
| placeholder="What's the answer to life, the universe, and everything?", | |
| lines=1) | |
| submit = gr.Button(value="Send", variant="secondary").style(full_width=False) | |
| gr.Examples( | |
| examples=["How many people live in Canada?", | |
| "What is 2 to the 30th power?", | |
| "If x+y=10 and x-y=4, what are x and y?", | |
| "How much did it rain in SF today?", | |
| "Get me information about the movie 'Avatar'", | |
| "What are the top tech headlines in the US?", | |
| "On the desk, you see two blue booklets, two purple booklets, and two yellow pairs of sunglasses - " | |
| "if I remove all the pairs of sunglasses from the desk, how many purple items remain on it?"], | |
| inputs=message | |
| ) | |
| message.submit(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state, | |
| express_chain_state], outputs=[chatbot, history_state]) | |
| submit.click(chat, inputs=[openai_api_key_textbox, message, history_state, chain_state, | |
| express_chain_state], outputs=[chatbot, history_state]) | |
| openai_api_key_textbox.change(set_openai_api_key, | |
| inputs=[openai_api_key_textbox], | |
| outputs=[chain_state, express_chain_state, llm_state]) | |
| app.launch(debug=True) |