File size: 5,227 Bytes
eedd6da
 
 
 
 
921463a
 
b6e5a8a
eedd6da
 
 
 
 
 
 
 
 
 
 
921463a
eedd6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e5a8a
e3d567c
eedd6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b6e5a8a
eedd6da
 
 
 
 
 
 
 
 
 
 
 
 
b6e5a8a
eedd6da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410971f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import gradio as gr
from text_generation import Client, InferenceAPIClient

cntxt = (
    "\nHuman: Hi!\nAssistant: I'm Jarvis StarCoder, a 15.5B parameter Programming and Web Development model checkpoint trained on over 80 programming languages "
    "by BigCode! I was created to be an excellent expert assistant capable of carefully, logically, truthfully, methodically fulfilling any Human request."
    "I'm capable of acting as an expert AI Writing model, acting as an expert AI Programming model, acting as an expert AI Web Development model and much more... "
    "I'm programmed to be helpful, polite, honest, and friendly.\n"
)

def get_client(model: str):
    return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None))

def get_usernames(model: str):
    """
    Returns:
        (str, str, str, str): pre-prompt, username, bot name, separator
    """
    return cntxt, "Human: ", "Assistant: ", "\n"

def predict(model: str,inputs: str,typical_p: float,top_p: float,temperature: float,top_k: int,repetition_penalty: float,watermark: bool,chatbot,history,):
    client = get_client(model)
    preprompt, user_name, assistant_name, sep = get_usernames(model)
    history.append(inputs)
    past = []
    for data in chatbot:
        user_data, model_data = data
        if not user_data.startswith(user_name):
            user_data = user_name + user_data
        if not model_data.startswith(sep + assistant_name):
            model_data = sep + assistant_name + model_data
        past.append(user_data + model_data.rstrip() + sep)

    if not inputs.startswith(user_name):
        inputs = user_name + inputs

    total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
    partial_words = ""

    if model in ("bigcode/starcoder", "bigcode/starcoder"):
        iterator = client.generate_stream(total_inputs,typical_p=typical_p,truncate=500,watermark=False,max_new_tokens=500,)

    for i, response in enumerate(iterator):
        if response.token.special:
            continue

        partial_words = partial_words + response.token.text
        if partial_words.endswith(user_name.rstrip()):
            partial_words = partial_words.rstrip(user_name.rstrip())
        if partial_words.endswith(assistant_name.rstrip()):
            partial_words = partial_words.rstrip(assistant_name.rstrip())

        if i == 0:
            history.append(" " + partial_words)
        elif response.token.text not in user_name:
            history[-1] = partial_words

        chat = [
            (history[i].strip(), history[i + 1].strip())
            for i in range(0, len(history) - 1, 2)
        ]
        yield chat, history

def reset_textbox():
    return gr.update(value="")

def radio_on_change(value: str,typical_p,top_p,top_k,temperature,repetition_penalty,watermark,):
    if model in ("bigcode/starcoder", "bigcode/starcoder"):
        typical_p = typical_p.update(value=0.2, visible=True)
        top_p = top_p.update(visible=False)
        top_k = top_k.update(visible=False)
        temperature = temperature.update(visible=False)
        repetition_penalty = repetition_penalty.update(visible=False)
        watermark = watermark.update(False)
    return (typical_p,top_p,top_k,temperature,repetition_penalty,watermark,)

with gr.Blocks(
    css="""#col_container {margin-left: auto; margin-right: auto;}
                #chatbot {height: 520px; overflow: auto;}"""
) as demo:
    with gr.Column(elem_id="col_container"):
        model = gr.Radio(value="bigcode/starcoder",choices=["bigcode/starcoder",],label="Model",visible=False,)
        chatbot = gr.Chatbot(elem_id="chatbot")
        inputs = gr.Textbox(placeholder="Hi there!", label="Type an input and press Enter")
        state = gr.State([])
        b1 = gr.Button()

        with gr.Accordion("Parameters", open=False, visible=False):
            typical_p = gr.Slider(minimum=-0,maximum=1.0,value=0.2,step=0.05,interactive=True,label="Typical P mass",)
            top_p = gr.Slider(minimum=-0,maximum=1.0,value=0.25,step=0.05,interactive=True,label="Top-p (nucleus sampling)",visible=False,)
            temperature = gr.Slider(minimum=-0,maximum=5.0,value=0.6,step=0.1,interactive=True,label="Temperature",visible=False,)
            top_k = gr.Slider(minimum=1,maximum=50,value=50,step=1,interactive=True,label="Top-k",visible=False,)
            repetition_penalty = gr.Slider(minimum=0.1,maximum=3.0,value=1.03,step=0.01,interactive=True,label="Repetition Penalty",visible=False,)
            watermark = gr.Checkbox(value=False, label="Text watermarking")

    model.change(lambda value: radio_on_change(value,typical_p,top_p,top_k,temperature,repetition_penalty,watermark,),inputs=model,outputs=[typical_p,top_p,top_k,temperature,repetition_penalty,watermark,],)
    inputs.submit(predict,[model,inputs,typical_p,top_p,temperature,top_k,repetition_penalty,watermark,chatbot,state,], [chatbot, state],)
    b1.click(predict,[model,inputs,typical_p,top_p,temperature,top_k,repetition_penalty,watermark,chatbot,state,], [chatbot, state],)
    b1.click(reset_textbox, [], [inputs])
    inputs.submit(reset_textbox, [], [inputs])

    demo.queue(max_size=1,api_open=False).launch(max_threads=1,)