import os import gradio as gr from text_generation import Client, InferenceAPIClient cntxt = ( "\nHuman: Hi!\nAssistant: I'm Jarvis StarCoder, a 15.5B parameter Programming and Web Development model checkpoint trained on over 80 programming languages " "by BigCode! I was created to be an excellent expert assistant capable of carefully, logically, truthfully, methodically fulfilling any Human request." "I'm capable of acting as an expert AI Writing model, acting as an expert AI Programming model, acting as an expert AI Web Development model and much more... " "I'm programmed to be helpful, polite, honest, and friendly.\n" ) def get_client(model: str): return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None)) def get_usernames(model: str): """ Returns: (str, str, str, str): pre-prompt, username, bot name, separator """ return cntxt, "Human: ", "Assistant: ", "\n" def predict(model: str,inputs: str,typical_p: float,top_p: float,temperature: float,top_k: int,repetition_penalty: float,watermark: bool,chatbot,history,): client = get_client(model) preprompt, user_name, assistant_name, sep = get_usernames(model) history.append(inputs) past = [] for data in chatbot: user_data, model_data = data if not user_data.startswith(user_name): user_data = user_name + user_data if not model_data.startswith(sep + assistant_name): model_data = sep + assistant_name + model_data past.append(user_data + model_data.rstrip() + sep) if not inputs.startswith(user_name): inputs = user_name + inputs total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() partial_words = "" if model in ("bigcode/starcoder", "bigcode/starcoder"): iterator = client.generate_stream(total_inputs,typical_p=typical_p,truncate=500,watermark=False,max_new_tokens=500,) for i, response in enumerate(iterator): if response.token.special: continue partial_words = partial_words + response.token.text if partial_words.endswith(user_name.rstrip()): partial_words = partial_words.rstrip(user_name.rstrip()) if partial_words.endswith(assistant_name.rstrip()): partial_words = partial_words.rstrip(assistant_name.rstrip()) if i == 0: history.append(" " + partial_words) elif response.token.text not in user_name: history[-1] = partial_words chat = [ (history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2) ] yield chat, history def reset_textbox(): return gr.update(value="") def radio_on_change(value: str,typical_p,top_p,top_k,temperature,repetition_penalty,watermark,): if model in ("bigcode/starcoder", "bigcode/starcoder"): typical_p = typical_p.update(value=0.2, visible=True) top_p = top_p.update(visible=False) top_k = top_k.update(visible=False) temperature = temperature.update(visible=False) repetition_penalty = repetition_penalty.update(visible=False) watermark = watermark.update(False) return (typical_p,top_p,top_k,temperature,repetition_penalty,watermark,) with gr.Blocks( css="""#col_container {margin-left: auto; margin-right: auto;} #chatbot {height: 520px; overflow: auto;}""" ) as demo: with gr.Column(elem_id="col_container"): model = gr.Radio(value="bigcode/starcoder",choices=["bigcode/starcoder",],label="Model",visible=False,) chatbot = gr.Chatbot(elem_id="chatbot") inputs = gr.Textbox(placeholder="Hi there!", label="Type an input and press Enter") state = gr.State([]) b1 = gr.Button() with gr.Accordion("Parameters", open=False, visible=False): typical_p = gr.Slider(minimum=-0,maximum=1.0,value=0.2,step=0.05,interactive=True,label="Typical P mass",) top_p = gr.Slider(minimum=-0,maximum=1.0,value=0.25,step=0.05,interactive=True,label="Top-p (nucleus sampling)",visible=False,) temperature = gr.Slider(minimum=-0,maximum=5.0,value=0.6,step=0.1,interactive=True,label="Temperature",visible=False,) top_k = gr.Slider(minimum=1,maximum=50,value=50,step=1,interactive=True,label="Top-k",visible=False,) repetition_penalty = gr.Slider(minimum=0.1,maximum=3.0,value=1.03,step=0.01,interactive=True,label="Repetition Penalty",visible=False,) watermark = gr.Checkbox(value=False, label="Text watermarking") model.change(lambda value: radio_on_change(value,typical_p,top_p,top_k,temperature,repetition_penalty,watermark,),inputs=model,outputs=[typical_p,top_p,top_k,temperature,repetition_penalty,watermark,],) inputs.submit(predict,[model,inputs,typical_p,top_p,temperature,top_k,repetition_penalty,watermark,chatbot,state,], [chatbot, state],) b1.click(predict,[model,inputs,typical_p,top_p,temperature,top_k,repetition_penalty,watermark,chatbot,state,], [chatbot, state],) b1.click(reset_textbox, [], [inputs]) inputs.submit(reset_textbox, [], [inputs]) demo.queue(max_size=1,api_open=False).launch(max_threads=1,)