Spaces:
Sleeping
Sleeping
File size: 7,419 Bytes
67270bc 4d8fe4d 8cb7dea 988418a 67270bc 7e69c1f 8cb7dea 7e69c1f 9143731 7e69c1f 988418a 4d8fe4d 9143731 8cb7dea 9143731 7e69c1f 67270bc 7e69c1f 988418a 7e69c1f 9143731 7e69c1f 8cb7dea 9143731 8cb7dea 8f2a5dc 988418a 7e69c1f 988418a 7e69c1f 67270bc 7e69c1f 67270bc 4d8fe4d 67270bc 33d8051 7e69c1f 9143731 7e69c1f 9143731 7e69c1f 9143731 7e69c1f 8cb7dea 33d8051 4d8fe4d 7e69c1f 67270bc 7e69c1f 9143731 7e69c1f 67270bc 7e69c1f 9143731 7e69c1f 4d8fe4d 988418a 4d8fe4d 67270bc 4d8fe4d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import gradio as gr
import os
import json
import requests
from langchain import FAISS
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
from langchain import VectorDBQA
from langchain.chat_models import ChatOpenAI
from prompts import MyTemplate
from langchain.prompts.chat import (
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
# Streaming endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
faiss_store = './indexer/{}'
docsearch = None
def gen_conversation(conversations):
messages = []
for data in conversations:
temp1 = {}
temp1["role"] = "user"
temp1["content"] = data[0]
temp2 = {}
temp2["role"] = "assistant"
temp2["content"] = data[1]
messages.append(temp1)
messages.append(temp2)
return messages
def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic,
chat_counter, chatbot=[], history=[]):
global docsearch
topic = topic[0]
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {openai_api_key}"
}
print(f"chat_counter - {chat_counter}")
print(f'Histroy - {history}') # History: Original Input and Output in flatten list
print(f'chatbot - {chatbot}') # Chat Bot: 上一轮回复的[[user, AI]]
history.append(inputs)
# Debugging
if enable_index:
# Faiss 检索最近的embedding
store = faiss_store.format(topic)
if docsearch is None:
print('Loading FAISS')
docsearch = FAISS.load_local(store, OpenAIEmbeddings(openai_api_key=openai_api_key))
else:
print('Faiss already loaded')
# 构建模板
llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
messages_combine = [
SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
messages_reduce = [
SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
HumanMessagePromptTemplate.from_template("{question}")
]
p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
k=4,
chain_type_kwargs={"question_prompt": p_chat_reduce,
"combine_prompt": p_chat_combine}
)
result = chain({"query": inputs})
print(result)
result = result['result']
# 生成返回值
history.append(result)
chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
chat_counter += 1
yield chat, history, chat_counter
else:
if chat_counter == 0:
messages = [{"role": "user", "content": f"{inputs}"}]
else:
# 如果有历史对话,把对话拼接进入上下文
messages = gen_conversation(chatbot)
messages.append({'role': 'user', 'content': inputs})
# messages
payload = {
"model": "gpt-3.5-turbo",
"messages": messages, # [{"role": "user", "content": f"{inputs}"}],
"temperature": temperature, # 1.0,
"top_p": top_p, # 1.0,
"n": 1,
"stream": True,
"presence_penalty": 0,
"frequency_penalty": 0,
}
print(f"payload is - {payload}")
chat_counter += 1
# 请求OpenAI
response = requests.post(API_URL, headers=headers, json=payload, stream=True)
token_counter = 0
partial_words = ""
# 逐字返回
counter = 0
for chunk in response.iter_lines():
if counter == 0:
counter += 1
continue
counter += 1
# check whether each line is non-empty
if chunk:
# decode each line as response data is in bytes
delta = json.loads(chunk.decode()[6:])['choices'][0]["delta"]
if len(delta) == 0:
break
partial_words += delta["content"]
# Keep Updating history
if token_counter == 0:
history.append(" " + partial_words)
else:
history[-1] = partial_words
chat = [(history[i], history[i + 1]) for i in
range(0, len(history) - 1, 2)] # convert to tuples of list
token_counter += 1
yield chat, history, chat_counter
def reset_textbox():
return gr.update(value='')
with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
#chatbot {height: 520px; overflow: auto;}""") as demo:
gr.HTML("""<h1 align="center">🚀Finance ChatBot🚀</h1>""")
with gr.Column(elem_id="col_container"):
openai_api_key = gr.Textbox(type='password', label="输入OPEN API Key")
# inputs, top_p, temperature, top_k, repetition_penalty
with gr.Accordion("Parameters", open=True):
with gr.Row():
top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.9, step=0.05, interactive=True,
label="Top-p (nucleus sampling)", )
temperature = gr.Slider(minimum=-0, maximum=5.0, value=0.8, step=0.1, interactive=True,
label="Temperature", )
max_tokens = gr.Slider(minimum=100, maximum=1000, value=200, step=100, interactive=True,
label="Max Tokens", )
chat_counter = gr.Number(value=0, precision=0, label='对话轮数')
with gr.Row():
enable_index = gr.Checkbox(label='是', info='开启文档问答模式/聊天模式')
enable_search = gr.Checkbox(label='是', info='是否使用搜索')
topic = gr.CheckboxGroup(["两会", "数字经济", "硅谷银行"], label='使用文档索引')
chatbot = gr.Chatbot(elem_id='chatbot')
inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
state = gr.State([])
with gr.Row():
clear = gr.Button("Clear Conversation")
run = gr.Button("Run")
inputs.submit(predict,
[inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic, chat_counter, chatbot,
state],
[chatbot, state, chat_counter], )
run.click(predict,
[inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic, chat_counter, chatbot,
state],
[chatbot, state, chat_counter], )
# 每次对话结束都重置对话
clear.click(reset_textbox, [], [inputs], queue=False)
inputs.submit(reset_textbox, [], [inputs])
demo.queue().launch(debug=True)
|