binzhango commited on
Commit
f0afc51
·
1 Parent(s): 1cf303c
Files changed (1) hide show
  1. chat_app.py +70 -55
chat_app.py CHANGED
@@ -1,9 +1,10 @@
1
  import time
2
  import os
3
  import gradio as gr
4
- from typing import List
5
 
6
  import langchain_core.callbacks
 
7
  from langchain_huggingface import HuggingFaceEndpoint
8
 
9
  from langchain.schema import BaseMessage
@@ -31,6 +32,7 @@ class InMemoryHistory(BaseChatMessageHistory, BaseModel):
31
 
32
  # In-memory storage for session history
33
  store = {}
 
34
 
35
  def get_session_history(
36
  user_id: str, conversation_id: str
@@ -39,65 +41,71 @@ def get_session_history(
39
  store[(user_id, conversation_id)] = InMemoryHistory()
40
  return store[(user_id, conversation_id)]
41
 
42
- prompt = ChatPromptTemplate.from_messages([
43
- ("system", "[INST] You're an assistant who's good at everything"),
44
- MessagesPlaceholder(variable_name="history"),
45
- ("human", "{question} [/INST]"),
46
- ])
47
-
48
- model_id="mistralai/Mistral-7B-Instruct-v0.3"
49
- callbacks = [langchain_core.callbacks.StreamingStdOutCallbackHandler()]
50
- llm = HuggingFaceEndpoint(
51
- repo_id=model_id,
52
- max_new_tokens=512,
53
- temperature=0.1,
54
- repetition_penalty=1.03,
55
- callbacks=callbacks,
56
- streaming=True,
57
- huggingfacehub_api_token=os.getenv('HF_TOKEN'),
58
- )
59
 
60
- chain = prompt | llm
61
-
62
- with_message_history = RunnableWithMessageHistory(
63
- chain,
64
- get_session_history=get_session_history,
65
- input_messages_key="question",
66
- history_messages_key="history",
67
- history_factory_config=[
68
- ConfigurableFieldSpec(
69
- id="user_id",
70
- annotation=str,
71
- name="User ID",
72
- description="Unique identifier for the user.",
73
- default="",
74
- is_shared=True,
75
- ),
76
- ConfigurableFieldSpec(
77
- id="conversation_id",
78
- annotation=str,
79
- name="Conversation ID",
80
- description="Unique identifier for the conversation.",
81
- default="",
82
- is_shared=True,
83
- ),
84
- ],
85
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
  with gr.Blocks() as demo:
 
88
  chatbot = gr.Chatbot(type="messages")
89
- msg = gr.Textbox()
90
- stop = gr.Button("Stop")
91
- clear = gr.Button("Clear")
92
 
93
  def user(user_message, history: list):
94
  return "", history + [{"role": "user", "content": user_message}]
95
 
96
  def bot(history: list):
97
-
98
  question = history[-1]['content']
99
-
100
- answer = with_message_history.stream(
101
  {"ability": "everything", "question": question},
102
  config={"configurable": {"user_id": "123", "conversation_id": "1"}}
103
  )
@@ -106,11 +114,18 @@ with gr.Blocks() as demo:
106
  history[-1]['content'] += character
107
  time.sleep(0.05)
108
  yield history
109
- # for item in answer:
110
- # for character in item.content:
111
- # history[-1]['content'] += character
112
- # time.sleep(0.05)
113
- # yield history
 
 
 
 
 
 
 
114
 
115
  submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(
116
  bot, chatbot, chatbot
 
1
  import time
2
  import os
3
  import gradio as gr
4
+ from typing import List, Optional
5
 
6
  import langchain_core.callbacks
7
+ import markdown_it.cli.parse
8
  from langchain_huggingface import HuggingFaceEndpoint
9
 
10
  from langchain.schema import BaseMessage
 
32
 
33
  # In-memory storage for session history
34
  store = {}
35
+ bot_llm:Optional[RunnableWithMessageHistory] = None
36
 
37
  def get_session_history(
38
  user_id: str, conversation_id: str
 
41
  store[(user_id, conversation_id)] = InMemoryHistory()
42
  return store[(user_id, conversation_id)]
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
+ def init_llm(k, p, t):
46
+ global bot_llm
47
+ prompt = ChatPromptTemplate.from_messages([
48
+ ("system", "[INST] You're an assistant who's good at everything"),
49
+ MessagesPlaceholder(variable_name="history"),
50
+ ("human", "{question} [/INST]"),
51
+ ])
52
+
53
+ model_id="mistralai/Mistral-7B-Instruct-v0.3"
54
+ callbacks = [langchain_core.callbacks.StreamingStdOutCallbackHandler()]
55
+
56
+ llm = HuggingFaceEndpoint(
57
+ repo_id=model_id,
58
+ max_new_tokens=4096,
59
+ temperature=t,
60
+ top_p=p,
61
+ top_k=k,
62
+ repetition_penalty=1.03,
63
+ callbacks=callbacks,
64
+ streaming=True,
65
+ huggingfacehub_api_token=os.getenv('HF_TOKEN'),
66
+ )
67
+
68
+ chain = prompt | llm
69
+ with_message_history = RunnableWithMessageHistory(
70
+ chain,
71
+ get_session_history=get_session_history,
72
+ input_messages_key="question",
73
+ history_messages_key="history",
74
+ history_factory_config=[
75
+ ConfigurableFieldSpec(
76
+ id="user_id",
77
+ annotation=str,
78
+ name="User ID",
79
+ description="Unique identifier for the user.",
80
+ default="",
81
+ is_shared=True,
82
+ ),
83
+ ConfigurableFieldSpec(
84
+ id="conversation_id",
85
+ annotation=str,
86
+ name="Conversation ID",
87
+ description="Unique identifier for the conversation.",
88
+ default="",
89
+ is_shared=True,
90
+ ),
91
+ ],
92
+ )
93
+ bot_llm = with_message_history
94
+ return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(open=False)
95
 
96
  with gr.Blocks() as demo:
97
+ gr.HTML("<center><h1>Chat with a Smart Assistant</h1></center>")
98
  chatbot = gr.Chatbot(type="messages")
99
+ msg = gr.Textbox(placeholder="Enter text and press enter", interactive=False)
100
+ stop = gr.Button("Stop", interactive=False)
101
+ clear = gr.Button("Clear",interactive=False)
102
 
103
  def user(user_message, history: list):
104
  return "", history + [{"role": "user", "content": user_message}]
105
 
106
  def bot(history: list):
 
107
  question = history[-1]['content']
108
+ answer = bot_llm.stream(
 
109
  {"ability": "everything", "question": question},
110
  config={"configurable": {"user_id": "123", "conversation_id": "1"}}
111
  )
 
114
  history[-1]['content'] += character
115
  time.sleep(0.05)
116
  yield history
117
+
118
+ with gr.Sidebar() as s:
119
+ gr.HTML("<h1>Model Configuration<h1>")
120
+ k = gr.Slider(0.0, 100.0, label="top_k", value=50, interactive=True,
121
+ info="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)")
122
+ p = gr.Slider(0.0, 1.0, label="top_p", value=0.9, interactive=True,
123
+ info=" Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)")
124
+ t = gr.Slider(0.0, 1.0, label="temperature", value=0.4, interactive=True,
125
+ info="The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)")
126
+
127
+ bnt1 = gr.Button("Confirm")
128
+ bnt1.click(init_llm, inputs=[k, p, t], outputs=[msg, stop, clear, s])
129
 
130
  submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(
131
  bot, chatbot, chatbot