Spaces:
Sleeping
Sleeping
| import os | |
| import pandas as pd | |
| import gradio as gr | |
| from langchain.vectorstores import Chroma | |
| from langchain.embeddings import OpenAIEmbeddings | |
| from langchain.llms import OpenAI | |
| from langchain.chat_models import ChatOpenAI | |
| from langchain.chains import RetrievalQA | |
| def Loading(): | |
| return "๋ฐ์ดํฐ ๋ก๋ฉ ์ค..." | |
| def LoadData(openai_key): | |
| if openai_key is not None: | |
| os.environ["OPENAI_API_KEY"] = openai_key | |
| persist_directory = 'realdb_LLM' | |
| embedding = OpenAIEmbeddings() | |
| vectordb = Chroma( | |
| persist_directory=persist_directory, | |
| embedding_function=embedding | |
| ) | |
| global retriever | |
| retriever = vectordb.as_retriever(search_kwargs={"k": 1}) | |
| return "์ค๋น ์๋ฃ" | |
| else: | |
| return "์ฌ์ฉํ์๋ API Key๋ฅผ ์ ๋ ฅํ์ฌ ์ฃผ์๊ธฐ ๋ฐ๋๋๋ค." | |
| # ์ฑ๋ด์ ๋ต๋ณ์ ์ฒ๋ฆฌํ๋ ํจ์ | |
| def respond(message, chat_history, temperature, top_p): | |
| try: | |
| print(temperature) | |
| qa_chain = RetrievalQA.from_chain_type( | |
| llm=OpenAI(temperature=temperature, top_p=top_p), | |
| # llm=OpenAI(temperature=0.4), | |
| # llm=ChatOpenAI(temperature=0), | |
| chain_type="stuff", | |
| retriever=retriever | |
| ) | |
| result = qa_chain(message) | |
| bot_message = result['result'] | |
| # ์ฑํ ๊ธฐ๋ก์ ์ฌ์ฉ์์ ๋ฉ์์ง์ ๋ด์ ์๋ต์ ์ถ๊ฐ. | |
| chat_history.append((message, bot_message)) | |
| return "", chat_history | |
| except: | |
| chat_history.append(("", "API Key ์ ๋ ฅ ์๋ง")) | |
| return " ", chat_history | |
| # ์ฑ๋ด ์ค๋ช | |
| title = """ | |
| <div style="text-align: center; max-width: 500px; margin: 0 auto;"> | |
| <div> | |
| <h1>Pretraining Chatbot V2 Real</h1> | |
| </div> | |
| <p style="margin-bottom: 10px; font-size: 94%"> | |
| OpenAI LLM๋ฅผ ์ด์ฉํ Chatbot (Similarity) | |
| </p> | |
| </div> | |
| """ | |
| # ๊พธ๋ฏธ๊ธฐ | |
| css=""" | |
| #col-container {max-width: 700px; margin-left: auto; margin-right: auto;} | |
| """ | |
| with gr.Blocks(css=css) as UnivChatbot: | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML(title) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| openai_key = gr.Textbox(label="You OpenAI API key", type="password", placeholder="OpenAI Key Type", elem_id="InputKey", show_label=False, container=False) | |
| with gr.Column(scale=1): | |
| langchain_status = gr.Textbox(placeholder="Status", interactive=False, show_label=False, container=False) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0, | |
| maximum=2.0, | |
| step=0.01, | |
| value=0.7, | |
| ) | |
| with gr.Column(scale=4): | |
| top_p = gr.Slider( | |
| label="Top_p", | |
| minimum=0, | |
| maximum=1, | |
| step=0.01, | |
| value=0.5, | |
| ) | |
| with gr.Column(scale=1): | |
| chk_key = gr.Button("ํ์ธ", variant="primary") | |
| chatbot = gr.Chatbot(label="๋ํ ์ฑ๋ด์์คํ (OpenAI LLM)", elem_id="chatbot") # ์๋จ ์ข์ธก | |
| with gr.Row(): | |
| with gr.Column(scale=9): | |
| msg = gr.Textbox(label="์ ๋ ฅ", placeholder="๊ถ๊ธํ์ ๋ด์ญ์ ์ ๋ ฅํ์ฌ ์ฃผ์ธ์.", elem_id="InputQuery", show_label=False, container=False) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| submit = gr.Button("์ ์ก", variant="primary") | |
| with gr.Column(scale=1): | |
| clear = gr.Button("์ด๊ธฐํ", variant="stop") | |
| #chk_key.click(Loading, None, langchain_status, queue=False) | |
| chk_key.click( | |
| fn=LoadData, | |
| inputs=[openai_key], | |
| outputs=[langchain_status], | |
| queue=False | |
| ) | |
| # ์ฌ์ฉ์์ ์ ๋ ฅ์ ์ ์ถ(submit)ํ๋ฉด respond ํจ์๊ฐ ํธ์ถ. | |
| msg.submit( | |
| fn=respond, | |
| inputs=[msg, chatbot, temperature, top_p], | |
| outputs=[msg, chatbot] | |
| ) | |
| submit.click(respond, [msg, chatbot, temperature, top_p], [msg, chatbot]) | |
| # '์ด๊ธฐํ' ๋ฒํผ์ ํด๋ฆญํ๋ฉด ์ฑํ ๊ธฐ๋ก์ ์ด๊ธฐํ. | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| UnivChatbot.launch() |