File size: 5,138 Bytes
04680d9
 
 
f0afc51
04680d9
ed73151
f0afc51
04680d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0afc51
04680d9
 
 
 
 
 
 
 
 
f0afc51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
04680d9
 
f0afc51
04680d9
f0afc51
 
 
04680d9
 
 
 
 
 
f0afc51
04680d9
 
 
 
1cf303c
 
 
 
f0afc51
 
 
 
 
 
 
 
 
 
 
 
04680d9
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import time
import os
import gradio as gr
from typing import List, Optional

import langchain_core.callbacks
import markdown_it.cli.parse
from langchain_huggingface import HuggingFaceEndpoint

from langchain.schema import BaseMessage
from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.runnables import (
    ConfigurableFieldSpec,
)
from langchain_core.runnables.history import RunnableWithMessageHistory

from pydantic import BaseModel, Field


class InMemoryHistory(BaseChatMessageHistory, BaseModel):
    """In memory implementation of chat message history."""

    messages: List[BaseMessage] = Field(default_factory=list)

    def add_messages(self, messages: List[BaseMessage]) -> None:
        """Add a list of messages to the store"""
        self.messages.extend(messages)

    def clear(self) -> None:
        self.messages = []

# In-memory storage for session history
store = {}
bot_llm:Optional[RunnableWithMessageHistory] = None

def get_session_history(
    user_id: str, conversation_id: str
) -> BaseChatMessageHistory:
    if (user_id, conversation_id) not in store:
        store[(user_id, conversation_id)] = InMemoryHistory()
    return store[(user_id, conversation_id)]


def init_llm(k, p, t):
    global bot_llm
    prompt = ChatPromptTemplate.from_messages([
        ("system", "[INST] You're an assistant who's good at everything"),
        MessagesPlaceholder(variable_name="history"),
        ("human", "{question} [/INST]"),
    ])

    model_id="mistralai/Mistral-7B-Instruct-v0.3"
    callbacks = [langchain_core.callbacks.StreamingStdOutCallbackHandler()]

    llm = HuggingFaceEndpoint(
        repo_id=model_id,
        max_new_tokens=4096,
        temperature=t,
        top_p=p,
        top_k=k,
        repetition_penalty=1.03,
        callbacks=callbacks,
        streaming=True,
        huggingfacehub_api_token=os.getenv('HF_TOKEN'),
    )

    chain = prompt | llm
    with_message_history = RunnableWithMessageHistory(
        chain,
        get_session_history=get_session_history,
        input_messages_key="question",
        history_messages_key="history",
        history_factory_config=[
            ConfigurableFieldSpec(
                id="user_id",
                annotation=str,
                name="User ID",
                description="Unique identifier for the user.",
                default="",
                is_shared=True,
            ),
            ConfigurableFieldSpec(
                id="conversation_id",
                annotation=str,
                name="Conversation ID",
                description="Unique identifier for the conversation.",
                default="",
                is_shared=True,
            ),
        ],
    )
    bot_llm = with_message_history
    return gr.update(interactive=True), gr.update(interactive=True), gr.update(interactive=True), gr.update(open=False)

with gr.Blocks() as demo:
    gr.HTML("<center><h1>Chat with a Smart Assistant</h1></center>")
    chatbot = gr.Chatbot(type="messages")
    msg = gr.Textbox(placeholder="Enter text and press enter", interactive=False)
    stop = gr.Button("Stop", interactive=False)
    clear = gr.Button("Clear",interactive=False)

    def user(user_message, history: list):
        return "", history + [{"role": "user", "content": user_message}]

    def bot(history: list):
        question = history[-1]['content']
        answer = bot_llm.stream(
            {"ability": "everything", "question": question},
            config={"configurable": {"user_id": "123", "conversation_id": "1"}}
        )
        history.append({"role": "assistant", "content": ""})
        for character in answer:
            history[-1]['content'] += character
            time.sleep(0.05)
            yield history

    with gr.Sidebar() as s:
        gr.HTML("<h1>Model Configuration<h1>")
        k = gr.Slider(0.0, 100.0, label="top_k", value=50, interactive=True,
                  info="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)")
        p = gr.Slider(0.0, 1.0, label="top_p", value=0.9, interactive=True,
                  info=" Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)")
        t = gr.Slider(0.0, 1.0, label="temperature", value=0.4, interactive=True,
                  info="The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8)")

        bnt1 = gr.Button("Confirm")
        bnt1.click(init_llm, inputs=[k, p, t], outputs=[msg, stop, clear, s])

    submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=True).then(
        bot, chatbot, chatbot
    )

    stop.click(None, None, None, cancels=[submit_event], queue=False)
    clear.click(lambda: None, None, chatbot, queue=True)

if __name__ == "__main__":
    demo.launch()