|
|
import time |
|
|
import os |
|
|
import gradio as gr |
|
|
from typing import List, Optional |
|
|
|
|
|
import langchain_core.callbacks |
|
|
import markdown_it.cli.parse |
|
|
from langchain_huggingface import HuggingFaceEndpoint |
|
|
|
|
|
from langchain.schema import BaseMessage |
|
|
from langchain_core.chat_history import BaseChatMessageHistory |
|
|
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
|
from langchain_core.runnables import ( |
|
|
ConfigurableFieldSpec, |
|
|
) |
|
|
from langchain_core.runnables.history import RunnableWithMessageHistory |
|
|
|
|
|
from pydantic import BaseModel, Field |
|
|
|
|
|
|
|
|
class InMemoryHistory(BaseChatMessageHistory, BaseModel): |
|
|
"""In memory implementation of chat message history.""" |
|
|
|
|
|
messages: List[BaseMessage] = Field(default_factory=list) |
|
|
|
|
|
def add_messages(self, messages: List[BaseMessage]) -> None: |
|
|
"""Add a list of messages to the store""" |
|
|
self.messages.extend(messages) |
|
|
|
|
|
def clear(self) -> None: |
|
|
self.messages = [] |
|
|
|
|
|
|
|
|
store = {} |
|
|
bot_llm:Optional[RunnableWithMessageHistory] = None |
|
|
|
|
|
def get_session_history( |
|
|
user_id: str, conversation_id: str |
|
|
) -> BaseChatMessageHistory: |
|
|
if (user_id, conversation_id) not in store: |
|
|
store[(user_id, conversation_id)] = InMemoryHistory() |
|
|
return store[(user_id, conversation_id)] |
|
|
|
|
|
|
|
|
def init_llm(k, p, t): |
|
|
global bot_llm |
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
|
("system", "[INST] You're an assistant who's good at everything"), |
|
|
MessagesPlaceholder(variable_name="history"), |
|
|
("human", "{question} [/INST]"), |
|
|
]) |
|
|
|
|
|
model_id="mistralai/Mistral-7B-Instruct-v0.3" |
|
|
callbacks = [langchain_core.callbacks.StreamingStdOutCallbackHandler()] |
|
|
|
|
|
llm = HuggingFaceEndpoint( |
|
|
repo_id=model_id, |
|
|
max_new_tokens=4096, |
|
|
temperature=t, |
|
|
top_p=p, |
|
|
top_k=k, |
|
|
repetition_penalty=1.03, |
|
|
callbacks=callbacks, |
|
|
streaming=True, |
|
|
huggingfacehub_api_token=os.getenv('HF_TOKEN'), |
|
|
) |
|
|
|
|
|
chain = prompt | llm |
|
|
with_message_history = RunnableWithMessageHistory( |
|
|
chain, |
|
|
get_session_history=get_session_history, |
|
|
input_messages_key="question", |
|
|
history_messages_key="history", |
|
|
history_factory_config=[ |
|
|
ConfigurableFieldSpec( |
|
|
id="user_id", |
|
|
annotation=str, |
|
|
name="User ID", |
|
|
description="Unique identifier for the user.", |
|
|
default="", |
|
|
is_shared=True, |
|
|
), |
|
|
ConfigurableFieldSpec( |
|
|
id="conversation_id", |
|
|
annotation=str, |
|
|
name="Conversation ID", |
|
|
description="Unique identifier for the conversation.", |
|
|
default="", |
|
|
is_shared=True, |
|
|
), |
|
|
], |
|
|
) |
|
|
bot_llm = with_message_history |
|
|
return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(open=False) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.HTML("<center><h1>Chat with a Smart Assistant</h1></center>") |
|
|
chatbot = gr.Chatbot(type="messages") |
|
|
msg = gr.Textbox(placeholder="Enter text and press enter", interactive=False) |
|
|
stop = gr.Button("Stop", interactive=False) |
|
|
clear = gr.Button("Clear",interactive=False) |
|
|
|
|
|
def user(user_message, history: list): |
|
|
return "", history + [{"role": "user", "content": user_message}] |
|
|
|
|
|
def bot(history: list): |
|
|
question = history[-1]['content'] |
|
|
answer = bot_llm.stream( |
|
|
{"ability": "everything", "question": question}, |
|
|
config={"configurable": {"user_id": "123", "conversation_id": "1"}} |
|
|
) |
|
|
history.append({"role": "assistant", "content": ""}) |
|
|
for character in answer: |
|
|
history[-1]['content'] += character |
|
|
time.sleep(0.05) |
|
|
yield history |
|
|
|
|
|
with gr.Sidebar() as s: |
|
|
gr.HTML("<h1>Model Configuration<h1>") |
|
|
k = gr.Slider(0.0, 100.0, label="top_k", value=50, interactive=True, |
|
|
info="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)") |
|
|
p = gr.Slider(0.0, 1.0, label="top_p", value=0.9, interactive=True, |
|
|
info=" Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)") |
|
|
t = gr.Slider(0.0, 1.0, label="temperature", value=0.4, interactive=True, |
|
|
info="The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)") |
|
|
|
|
|
bnt1 = gr.Button("Confirm") |
|
|
bnt1.click(init_llm, inputs=[k, p, t], outputs=[msg, stop, clear, s]) |
|
|
|
|
|
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then( |
|
|
bot, chatbot, chatbot |
|
|
) |
|
|
|
|
|
stop.click(None, None, None, cancels=[submit_event], queue=False) |
|
|
clear.click(lambda: None, None, chatbot, queue=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |