xl2533 commited on
Commit
8cb7dea
·
1 Parent(s): 8f2a5dc

add model selection

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -3,7 +3,7 @@ import os
3
  import json
4
  import requests
5
  from langchain import FAISS
6
- from langchain.embeddings import CohereEmbeddings
7
  from langchain import VectorDBQA
8
  from langchain.chat_models import ChatOpenAI
9
  from prompts import MyTemplate
@@ -15,7 +15,7 @@ from langchain.prompts.chat import (
15
 
16
  # Streaming endpoint
17
  API_URL = "https://api.openai.com/v1/chat/completions" # os.getenv("API_URL") + "/generate_stream"
18
- embeddings_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
19
  faiss_store = './indexer'
20
 
21
  def gen_conversation(conversations):
@@ -32,8 +32,9 @@ def gen_conversation(conversations):
32
  return messages
33
 
34
 
35
- def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens,
36
- chat_counter, chatbot=[], history=[]): # repetition_penalty, top_k
 
37
  headers = {
38
  "Content-Type": "application/json",
39
  "Authorization": f"Bearer {openai_api_key}"
@@ -43,7 +44,11 @@ def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens
43
  #Debugging
44
  if enable_index:
45
  # Faiss 检索最近的embedding
46
- docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=embeddings_key))
 
 
 
 
47
  llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
48
  messages_combine = [
49
  SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
@@ -133,7 +138,7 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
133
  gr.HTML("""<h1 align="center">🚀Finance ChatBot🚀</h1>""")
134
  with gr.Column(elem_id="col_container"):
135
  openai_api_key = gr.Textbox(type='password', label="输入OPEN API Key")
136
- chatbot = gr.Chatbot(elem_id='chatbot') # c
137
  inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
138
  state = gr.State([])
139
 
@@ -148,16 +153,15 @@ with gr.Blocks(css="""#col_container {width: 1000px; margin-left: auto; margin-r
148
  label="Max Tokens", )
149
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True,
150
  label="Temperature", )
151
- # top_k = gr.Slider( minimum=1, maximum=50, value=4, step=1, interactive=True, label="Top-k",)
152
- # repetition_penalty = gr.Slider( minimum=0.1, maximum=3.0, value=1.03, step=0.01, interactive=True, label="Repetition Penalty", )
153
  chat_counter = gr.Number(value=0, precision=0)
154
  enable_index = gr.Checkbox(label='是', info='是否使用研报等金融数据')
155
  # 后续考虑加入搜索结果
156
  enable_search = gr.Checkbox(label='是', info='是否使用搜索结果')
157
 
158
- inputs.submit(predict, [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, chat_counter, chatbot, state],
159
  [chatbot, state, chat_counter], )
160
- run.click(predict, [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, chat_counter, chatbot, state],
161
  [chatbot, state, chat_counter], )
162
 
163
  # 每次对话结束都重置对话
 
3
  import json
4
  import requests
5
  from langchain import FAISS
6
+ from langchain.embeddings import CohereEmbeddings, OpenAIEmbeddings
7
  from langchain import VectorDBQA
8
  from langchain.chat_models import ChatOpenAI
9
  from prompts import MyTemplate
 
15
 
16
  # Streaming endpoint
17
  API_URL = "https://api.openai.com/v1/chat/completions" # os.getenv("API_URL") + "/generate_stream"
18
+ cohere_key = '5IRbILAbjTI0VcqTsktBfKsr13Lych9iBAFbLpkj'
19
  faiss_store = './indexer'
20
 
21
  def gen_conversation(conversations):
 
32
  return messages
33
 
34
 
35
+ def predict(inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model,
36
+ chat_counter, chatbot=[], history=[]):
37
+ model = model[0]
38
  headers = {
39
  "Content-Type": "application/json",
40
  "Authorization": f"Bearer {openai_api_key}"
 
44
  #Debugging
45
  if enable_index:
46
  # Faiss 检索最近的embedding
47
+ if model =='openai':
48
+ docsearch = FAISS.load_local(faiss_store, OpenAIEmbeddings(openai_api_key=openai_api_key))
49
+ else:
50
+ docsearch = FAISS.load_local(faiss_store, CohereEmbeddings(cohere_api_key=cohere_key ))
51
+ # 构建模板
52
  llm = ChatOpenAI(openai_api_key=openai_api_key, max_tokens=max_tokens)
53
  messages_combine = [
54
  SystemMessagePromptTemplate.from_template(MyTemplate['chat_combine_template']),
 
138
  gr.HTML("""<h1 align="center">🚀Finance ChatBot🚀</h1>""")
139
  with gr.Column(elem_id="col_container"):
140
  openai_api_key = gr.Textbox(type='password', label="输入OPEN API Key")
141
+ chatbot = gr.Chatbot(elem_id='chatbot')
142
  inputs = gr.Textbox(placeholder="您有什么问题可以问我", label="输入数字经济,两会,硅谷银行相关的提问")
143
  state = gr.State([])
144
 
 
153
  label="Max Tokens", )
154
  temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True,
155
  label="Temperature", )
156
+ model = gr.CheckboxGroup(["cohere", "openai", "mpnet"])
 
157
  chat_counter = gr.Number(value=0, precision=0)
158
  enable_index = gr.Checkbox(label='是', info='是否使用研报等金融数据')
159
  # 后续考虑加入搜索结果
160
  enable_search = gr.Checkbox(label='是', info='是否使用搜索结果')
161
 
162
+ inputs.submit(predict, [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, chat_counter, chatbot, state],
163
  [chatbot, state, chat_counter], )
164
+ run.click(predict, [inputs, top_p, temperature, openai_api_key, enable_index, max_tokens, model, chat_counter, chatbot, state],
165
  [chatbot, state, chat_counter], )
166
 
167
  # 每次对话结束都重置对话