File size: 3,343 Bytes
a9249a5
a0b6257
1e0561f
 
 
 
 
 
355bd35
1e0561f
e0a657b
 
 
 
 
 
 
 
d62961b
e0a657b
d62961b
3c2d8de
d62961b
 
 
 
57f2bfd
8bcebd6
fcb1273
 
 
d10e90d
584b917
72cd4ab
584b917
d10e90d
584b917
d10e90d
 
6be832d
14a449c
52a0c67
3c2d8de
57f2bfd
 
102d868
bd8b19b
acc3cf5
bbe119e
acc3cf5
bbe119e
 
 
 
acc3cf5
584b917
cf08011
d56deb0
 
c486970
f8f7574
071de29
d62961b
1e0561f
355bd35
 
 
 
 
 
 
1e0561f
355bd35
1e0561f
355bd35
 
 
 
 
1e0561f
355bd35
1e0561f
355bd35
1e0561f
355bd35
 
 
 
 
 
 
 
bd6e005
 
 
1e0561f
355bd35
 
 
1e0561f
d62961b
 
355bd35
 
 
 
 
 
 
 
 
 
 
 
1e0561f
 
355bd35
1e0561f
355bd35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import json
import gradio as gr
from huggingface_hub import InferenceClient

"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1",token=os.getenv('HUGGINGFACE_TOKEN').strip())

def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate_response(
    prompt,
    history: list[tuple[str, str]],
    system_prompt: list[tuple[str,str]],
    max_tokens,
    temperature,
    top_p,
):
    print('=====================')
    print(type(history))
    print(history)
    print(type(system_prompt))
    print('=====================')
    listObject = ""
    try:
        listObject = json.loads(system_prompt)
    except ValueError:
        print("system_prompt not a list")

    if isinstance(listObject,list):
        history = listObject
        print("system_prompt as history")
    else:
        print(type(system_prompt))
        print(system_prompt)
    print('=====================')

    #system_prompt = "i'm a friendly robot"   
    sys_message = ""
    print('=====================')
    print(prompt)
    print(history)
    print(system_prompt)
    print(max_tokens)
    print(temperature)
    print(top_p)
    print('=====================')
    formatted_prompt = format_prompt(f"{sys_message}, {prompt}", history)
    stream = client.text_generation(formatted_prompt,stream=True, max_new_tokens=256, return_full_text=False)
    output = ""
    for response in stream:
        output += response
        yield response
    #return output


def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    print("============= make chat_completion  =============")
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token
        yield response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    #respond,
    generate_response,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch(share=True)