File size: 7,419 Bytes
67270bc
 
 
 
4d8fe4d
8cb7dea
988418a
 
 
 
 
 
 
 
67270bc
 
7e69c1f
8cb7dea
7e69c1f
9143731
7e69c1f
988418a
 
 
 
 
 
 
 
 
 
 
 
 
4d8fe4d
 
9143731
8cb7dea
9143731
 
7e69c1f
67270bc
 
 
 
 
 
7e69c1f
 
 
 
 
988418a
 
7e69c1f
9143731
 
7e69c1f
8cb7dea
9143731
8cb7dea
8f2a5dc
988418a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e69c1f
 
 
 
 
 
988418a
7e69c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67270bc
7e69c1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67270bc
 
 
 
 
 
 
 
4d8fe4d
67270bc
33d8051
7e69c1f
 
 
 
 
 
 
 
9143731
7e69c1f
 
9143731
7e69c1f
9143731
 
 
7e69c1f
8cb7dea
33d8051
4d8fe4d
 
7e69c1f
 
 
67270bc
7e69c1f
9143731
7e69c1f
67270bc
7e69c1f
9143731
7e69c1f
4d8fe4d
 
988418a
4d8fe4d
67270bc
 
4d8fe4d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
import os
import json
import requests
from langchain import FAISS
from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
from langchain import VectorDBQA
from langchain.chat_models import ChatOpenAI
from prompts import MyTemplate
from langchain.prompts.chat import (
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)

# Streaming endpoint
API_URL = "https://api.openai.com/v1/chat/completions"
cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
faiss_store = './indexer/{}'
docsearch = None


def gen_conversation(conversations):
    messages = []
    for data in conversations:
        temp1 = {}
        temp1["role"] = "user"
        temp1["content"] = data[0]
        temp2 = {}
        temp2["role"] = "assistant"
        temp2["content"] = data[1]
        messages.append(temp1)
        messages.append(temp2)
    return messages


def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic,
            chat_counter, chatbot=[], history=[]):
    global docsearch

    topic = topic[0]
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {openai_api_key}"
    }

    print(f"chat_counter - {chat_counter}")
    print(f'Histroy - {history}')  # History: Original Input and Output in flatten list
    print(f'chatbot - {chatbot}')  # Chat Bot: 上一轮回复的[[user, AI]]

    history.append(inputs)
    # Debugging
    if enable_index:
        # Faiss 检索最近的embedding
        store = faiss_store.format(topic)
        if docsearch is None:
            print('Loading FAISS')
            docsearch = FAISS.load_local(store, OpenAIEmbeddings(openai_api_key=openai_api_key))
        else:
            print('Faiss already loaded')
        # 构建模板
        llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
        messages_combine = [
            SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
            HumanMessagePromptTemplate.from_template("{question}")
        ]
        p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
        messages_reduce = [
            SystemMessagePromptTemplate.from_template(MyTemplate['chat_reduce_template']),
            HumanMessagePromptTemplate.from_template("{question}")
        ]
        p_chat_reduce = ChatPromptTemplate.from_messages(messages_reduce)
        chain = VectorDBQA.from_chain_type(llm=llm, chain_type="map_reduce", vectorstore=docsearch,
                                           k=4,
                                           chain_type_kwargs={"question_prompt": p_chat_reduce,
                                                              "combine_prompt": p_chat_combine}
                                           )
        result = chain({"query": inputs})
        print(result)
        result = result['result']
        # 生成返回值
        history.append(result)
        chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
        chat_counter += 1
        yield chat, history, chat_counter
    else:
        if chat_counter == 0:
            messages = [{"role": "user", "content": f"{inputs}"}]
        else:
            # 如果有历史对话,把对话拼接进入上下文
            messages = gen_conversation(chatbot)
            messages.append({'role': 'user', 'content': inputs})
        # messages
        payload = {
            "model": "gpt-3.5-turbo",
            "messages": messages,  # [{"role": "user", "content": f"{inputs}"}],
            "temperature": temperature,  # 1.0,
            "top_p": top_p,  # 1.0,
            "n": 1,
            "stream": True,
            "presence_penalty": 0,
            "frequency_penalty": 0,
        }
        print(f"payload is - {payload}")

        chat_counter += 1

        # 请求OpenAI
        response = requests.post(API_URL, headers=headers, json=payload, stream=True)
        token_counter = 0
        partial_words = ""

        # 逐字返回
        counter = 0
        for chunk in response.iter_lines():
            if counter == 0:
                counter += 1
                continue
            counter += 1
            # check whether each line is non-empty
            if chunk:
                # decode each line as response data is in bytes
                delta = json.loads(chunk.decode()[6:])['choices'][0]["delta"]
                if len(delta) == 0:
                    break
                partial_words += delta["content"]
                # Keep Updating history
                if token_counter == 0:
                    history.append(" " + partial_words)
                else:
                    history[-1] = partial_words
                chat = [(history[i], history[i + 1]) for i in
                        range(0, len(history) - 1, 2)]  # convert to tuples of list
                token_counter += 1
                yield chat, history, chat_counter


def reset_textbox():
    return gr.update(value='')


with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
                #chatbot {height: 520px; overflow: auto;}""") as demo:
    gr.HTML("""<h1 align="center">🚀Finance ChatBot🚀</h1>""")
    with gr.Column(elem_id="col_container"):
        openai_api_key = gr.Textbox(type='password', label="输入OPEN API Key")

        # inputs, top_p, temperature, top_k, repetition_penalty
        with gr.Accordion("Parameters", open=True):
            with gr.Row():
                top_p = gr.Slider(minimum=-0, maximum=1.0, value=0.9, step=0.05, interactive=True,
                                  label="Top-p (nucleus sampling)", )
                temperature = gr.Slider(minimum=-0, maximum=5.0, value=0.8, step=0.1, interactive=True,
                                        label="Temperature", )
                max_tokens = gr.Slider(minimum=100, maximum=1000, value=200, step=100, interactive=True,
                                       label="Max Tokens", )
                chat_counter = gr.Number(value=0, precision=0, label='对话轮数')

            with gr.Row():
                enable_index = gr.Checkbox(label='是', info='开启文档问答模式/聊天模式')
                enable_search = gr.Checkbox(label='是', info='是否使用搜索')
                topic = gr.CheckboxGroup(["两会", "数字经济", "硅谷银行"], label='使用文档索引')

        chatbot = gr.Chatbot(elem_id='chatbot')
        inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
        state = gr.State([])

        with gr.Row():
            clear = gr.Button("Clear Conversation")
            run = gr.Button("Run")

    inputs.submit(predict,
                  [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic, chat_counter, chatbot,
                   state],
                  [chatbot, state, chat_counter], )
    run.click(predict,
              [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, topic, chat_counter, chatbot,
               state],
              [chatbot, state, chat_counter], )

    # 每次对话结束都重置对话
    clear.click(reset_textbox, [], [inputs], queue=False)
    inputs.submit(reset_textbox, [], [inputs])

    demo.queue().launch(debug=True)