Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from text_generation import Client, InferenceAPIClient | |
| cntxt = ( | |
| "\nHuman: Hi!\nAssistant: I'm Jarvis StarCoder, a 15.5B parameter Programming and Web Development model checkpoint trained on over 80 programming languages " | |
| "by BigCode! I was created to be an excellent expert assistant capable of carefully, logically, truthfully, methodically fulfilling any Human request." | |
| "I'm capable of acting as an expert AI Writing model, acting as an expert AI Programming model, acting as an expert AI Web Development model and much more... " | |
| "I'm programmed to be helpful, polite, honest, and friendly.\n" | |
| ) | |
| def get_client(model: str): | |
| return InferenceAPIClient(model, token=os.getenv("HF_TOKEN", None)) | |
| def get_usernames(model: str): | |
| """ | |
| Returns: | |
| (str, str, str, str): pre-prompt, username, bot name, separator | |
| """ | |
| return cntxt, "Human: ", "Assistant: ", "\n" | |
| def predict(model: str,inputs: str,typical_p: float,top_p: float,temperature: float,top_k: int,repetition_penalty: float,watermark: bool,chatbot,history,): | |
| client = get_client(model) | |
| preprompt, user_name, assistant_name, sep = get_usernames(model) | |
| history.append(inputs) | |
| past = [] | |
| for data in chatbot: | |
| user_data, model_data = data | |
| if not user_data.startswith(user_name): | |
| user_data = user_name + user_data | |
| if not model_data.startswith(sep + assistant_name): | |
| model_data = sep + assistant_name + model_data | |
| past.append(user_data + model_data.rstrip() + sep) | |
| if not inputs.startswith(user_name): | |
| inputs = user_name + inputs | |
| total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() | |
| partial_words = "" | |
| if model in ("bigcode/starcoder", "bigcode/starcoder"): | |
| iterator = client.generate_stream(total_inputs,typical_p=typical_p,truncate=500,watermark=False,max_new_tokens=500,) | |
| for i, response in enumerate(iterator): | |
| if response.token.special: | |
| continue | |
| partial_words = partial_words + response.token.text | |
| if partial_words.endswith(user_name.rstrip()): | |
| partial_words = partial_words.rstrip(user_name.rstrip()) | |
| if partial_words.endswith(assistant_name.rstrip()): | |
| partial_words = partial_words.rstrip(assistant_name.rstrip()) | |
| if i == 0: | |
| history.append(" " + partial_words) | |
| elif response.token.text not in user_name: | |
| history[-1] = partial_words | |
| chat = [ | |
| (history[i].strip(), history[i + 1].strip()) | |
| for i in range(0, len(history) - 1, 2) | |
| ] | |
| yield chat, history | |
| def reset_textbox(): | |
| return gr.update(value="") | |
| def radio_on_change(value: str,typical_p,top_p,top_k,temperature,repetition_penalty,watermark,): | |
| if model in ("bigcode/starcoder", "bigcode/starcoder"): | |
| typical_p = typical_p.update(value=0.2, visible=True) | |
| top_p = top_p.update(visible=False) | |
| top_k = top_k.update(visible=False) | |
| temperature = temperature.update(visible=False) | |
| repetition_penalty = repetition_penalty.update(visible=False) | |
| watermark = watermark.update(False) | |
| return (typical_p,top_p,top_k,temperature,repetition_penalty,watermark,) | |
| with gr.Blocks( | |
| css="""#col_container {margin-left: auto; margin-right: auto;} | |
| #chatbot {height: 520px; overflow: auto;}""" | |
| ) as demo: | |
| with gr.Column(elem_id="col_container"): | |
| model = gr.Radio(value="bigcode/starcoder",choices=["bigcode/starcoder",],label="Model",visible=False,) | |
| chatbot = gr.Chatbot(elem_id="chatbot") | |
| inputs = gr.Textbox(placeholder="Hi there!", label="Type an input and press Enter") | |
| state = gr.State([]) | |
| b1 = gr.Button() | |
| with gr.Accordion("Parameters", open=False, visible=False): | |
| typical_p = gr.Slider(minimum=-0,maximum=1.0,value=0.2,step=0.05,interactive=True,label="Typical P mass",) | |
| top_p = gr.Slider(minimum=-0,maximum=1.0,value=0.25,step=0.05,interactive=True,label="Top-p (nucleus sampling)",visible=False,) | |
| temperature = gr.Slider(minimum=-0,maximum=5.0,value=0.6,step=0.1,interactive=True,label="Temperature",visible=False,) | |
| top_k = gr.Slider(minimum=1,maximum=50,value=50,step=1,interactive=True,label="Top-k",visible=False,) | |
| repetition_penalty = gr.Slider(minimum=0.1,maximum=3.0,value=1.03,step=0.01,interactive=True,label="Repetition Penalty",visible=False,) | |
| watermark = gr.Checkbox(value=False, label="Text watermarking") | |
| model.change(lambda value: radio_on_change(value,typical_p,top_p,top_k,temperature,repetition_penalty,watermark,),inputs=model,outputs=[typical_p,top_p,top_k,temperature,repetition_penalty,watermark,],) | |
| inputs.submit(predict,[model,inputs,typical_p,top_p,temperature,top_k,repetition_penalty,watermark,chatbot,state,], [chatbot, state],) | |
| b1.click(predict,[model,inputs,typical_p,top_p,temperature,top_k,repetition_penalty,watermark,chatbot,state,], [chatbot, state],) | |
| b1.click(reset_textbox, [], [inputs]) | |
| inputs.submit(reset_textbox, [], [inputs]) | |
| demo.queue(max_size=1,api_open=False).launch(max_threads=1,) |