| import datetime |
| import os |
| import random |
| import re |
| from io import StringIO |
|
|
| import gradio as gr |
| import pandas as pd |
| from huggingface_hub import upload_file |
| from text_generation import Client |
|
|
| from dialogues import DialogueTemplate |
| from share_btn import (community_icon_html, loading_icon_html, share_btn_css, |
| share_js) |
|
|
| HF_TOKEN = os.environ.get("HF_TOKEN", None) |
| API_TOKEN = os.environ.get("API_TOKEN", None) |
| DIALOGUES_DATASET = "ehristoforu/dialogues" |
|
|
| model2endpoint = { |
| "starchat-beta": "https://api-inference.huggingface.co/models/HuggingFaceH4/starchat-beta", |
| } |
| model_names = list(model2endpoint.keys()) |
|
|
|
|
| def randomize_seed_generator(): |
| seed = random.randint(0, 1000000) |
| return seed |
|
|
|
|
| def save_inputs_and_outputs(now, inputs, outputs, generate_kwargs, model): |
| buffer = StringIO() |
| timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S.%f") |
| file_name = f"prompts_{timestamp}.jsonl" |
| data = {"model": model, "inputs": inputs, "outputs": outputs, "generate_kwargs": generate_kwargs} |
| pd.DataFrame([data]).to_json(buffer, orient="records", lines=True) |
|
|
| |
| upload_file( |
| path_in_repo=f"{now.date()}/{now.hour}/{file_name}", |
| path_or_fileobj=buffer.getvalue().encode(), |
| repo_id=DIALOGUES_DATASET, |
| token=HF_TOKEN, |
| repo_type="dataset", |
| ) |
|
|
| |
| buffer.close() |
|
|
|
|
| def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep): |
| past = [] |
| for data in chatbot: |
| user_data, model_data = data |
|
|
| if not user_data.startswith(user_name): |
| user_data = user_name + user_data |
| if not model_data.startswith(sep + assistant_name): |
| model_data = sep + assistant_name + model_data |
|
|
| past.append(user_data + model_data.rstrip() + sep) |
|
|
| if not inputs.startswith(user_name): |
| inputs = user_name + inputs |
|
|
| total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip() |
|
|
| return total_inputs |
|
|
|
|
| def wrap_html_code(text): |
| pattern = r"<.*?>" |
| matches = re.findall(pattern, text) |
| if len(matches) > 0: |
| return f"```{text}```" |
| else: |
| return text |
|
|
|
|
| def has_no_history(chatbot, history): |
| return not chatbot and not history |
|
|
|
|
| def generate( |
| RETRY_FLAG, |
| model_name, |
| system_message, |
| user_message, |
| chatbot, |
| history, |
| temperature, |
| top_k, |
| top_p, |
| max_new_tokens, |
| repetition_penalty, |
| do_save=True, |
| ): |
| client = Client( |
| model2endpoint[model_name], |
| headers={"Authorization": f"Bearer {API_TOKEN}"}, |
| timeout=60, |
| ) |
| |
| if not user_message: |
| print("Empty input") |
|
|
| if not RETRY_FLAG: |
| history.append(user_message) |
| seed = 42 |
| else: |
| seed = randomize_seed_generator() |
|
|
| past_messages = [] |
| for data in chatbot: |
| user_data, model_data = data |
|
|
| past_messages.extend( |
| [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}] |
| ) |
|
|
| if len(past_messages) < 1: |
| dialogue_template = DialogueTemplate( |
| system=system_message, messages=[{"role": "user", "content": user_message}] |
| ) |
| prompt = dialogue_template.get_inference_prompt() |
| else: |
| dialogue_template = DialogueTemplate( |
| system=system_message, messages=past_messages + [{"role": "user", "content": user_message}] |
| ) |
| prompt = dialogue_template.get_inference_prompt() |
|
|
| generate_kwargs = { |
| "temperature": temperature, |
| "top_k": top_k, |
| "top_p": top_p, |
| "max_new_tokens": max_new_tokens, |
| } |
|
|
| temperature = float(temperature) |
| if temperature < 1e-2: |
| temperature = 1e-2 |
| top_p = float(top_p) |
|
|
| generate_kwargs = dict( |
| temperature=temperature, |
| max_new_tokens=max_new_tokens, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| do_sample=True, |
| truncate=4096, |
| seed=seed, |
| stop_sequences=["<|end|>"], |
| ) |
|
|
| stream = client.generate_stream( |
| prompt, |
| **generate_kwargs, |
| ) |
|
|
| output = "" |
| for idx, response in enumerate(stream): |
| if response.token.special: |
| continue |
| output += response.token.text |
| if idx == 0: |
| history.append(" " + output) |
| else: |
| history[-1] = output |
|
|
| chat = [ |
| (wrap_html_code(history[i].strip()), wrap_html_code(history[i + 1].strip())) |
| for i in range(0, len(history) - 1, 2) |
| ] |
|
|
| |
|
|
| yield chat, history, user_message, "" |
|
|
| if HF_TOKEN and do_save: |
| try: |
| now = datetime.datetime.now() |
| current_time = now.strftime("%Y-%m-%d %H:%M:%S") |
| print(f"[{current_time}] Pushing prompt and completion to the Hub") |
| save_inputs_and_outputs(now, prompt, output, generate_kwargs, model_name) |
| except Exception as e: |
| print(e) |
|
|
| return chat, history, user_message, "" |
|
|
|
|
| examples = [ |
| "How can I write a Python function to generate the nth Fibonacci number?", |
| "How do I get the current date using shell commands? Explain how it works.", |
| "What's the meaning of life?", |
| "Write a function in Javascript to reverse words in a given string.", |
| "Give the following data {'Name':['Tom', 'Brad', 'Kyle', 'Jerry'], 'Age':[20, 21, 19, 18], 'Height' : [6.1, 5.9, 6.0, 6.1]}. Can you plot one graph with two subplots as columns. The first is a bar graph showing the height of each person. The second is a bargraph showing the age of each person? Draw the graph in seaborn talk mode.", |
| "Create a regex to extract dates from logs", |
| "How to decode JSON into a typescript object", |
| "Write a list into a jsonlines file and save locally", |
| ] |
|
|
|
|
| def clear_chat(): |
| return [], [] |
|
|
|
|
| def delete_last_turn(chat, history): |
| if chat and history: |
| chat.pop(-1) |
| history.pop(-1) |
| history.pop(-1) |
| return chat, history |
|
|
|
|
| def process_example(args): |
| for [x, y] in generate(args): |
| pass |
| return [x, y] |
|
|
|
|
| |
| def retry_last_answer( |
| selected_model, |
| system_message, |
| user_message, |
| chat, |
| history, |
| temperature, |
| top_k, |
| top_p, |
| max_new_tokens, |
| repetition_penalty, |
| do_save, |
| ): |
| if chat and history: |
| |
| chat.pop(-1) |
| |
| history.pop(-1) |
| |
| RETRY_FLAG = True |
| |
| user_message = history[-1] |
|
|
| yield from generate( |
| RETRY_FLAG, |
| selected_model, |
| system_message, |
| user_message, |
| chat, |
| history, |
| temperature, |
| top_k, |
| top_p, |
| max_new_tokens, |
| repetition_penalty, |
| do_save, |
| ) |
|
|
| with gr.Blocks(analytics_enabled=False, css="style.css") as demo: |
|
|
| with gr.Row(): |
| with gr.Column(): |
| gr.Image("StarChat_logo.png", elem_id="banner-image", show_label=False, show_share_button=False, show_download_button=False) |
| with gr.Row(): |
| with gr.Column(): |
| gr.DuplicateButton(value='Duplicate Space for private use', |
| elem_id='duplicate-button') |
| with gr.Row(): |
| selected_model = gr.Radio(choices=model_names, value=model_names[0], label="Current Model", interactive=False) |
| |
| with gr.Row(): |
| with gr.Column(): |
| output = gr.Markdown() |
| chatbot = gr.Chatbot(elem_id="chat-message", label="Playground") |
|
|
| with gr.Row(): |
| with gr.Column(scale=3): |
| user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input", lines=2) |
| with gr.Row(): |
| send_button = gr.Button("▶️ Send", elem_id="send-btn", visible=True) |
|
|
| regenerate_button = gr.Button("🔄 Regenerate", elem_id="retry-btn", visible=True) |
|
|
| delete_turn_button = gr.Button("↩️ Delete last turn", elem_id="delete-btn", visible=True) |
|
|
| clear_chat_button = gr.Button("🗑 Clear chat", elem_id="clear-btn", visible=True) |
|
|
| with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"): |
| system_message = gr.Textbox( |
| elem_id="system-message", |
| placeholder="Below is a conversation between a human user and a helpful AI coding assistant.", |
| label="System Prompt", |
| lines=2, |
| ) |
| temperature = gr.Slider( |
| label="Temperature", |
| value=0.2, |
| minimum=0.0, |
| maximum=1.0, |
| step=0.1, |
| interactive=True, |
| info="Higher values produce more diverse outputs", |
| ) |
| top_k = gr.Slider( |
| label="Top-k", |
| value=50, |
| minimum=0.0, |
| maximum=100, |
| step=1, |
| interactive=True, |
| info="Sample from a shortlist of top-k tokens", |
| ) |
| top_p = gr.Slider( |
| label="Top-p (nucleus sampling)", |
| value=0.95, |
| minimum=0.0, |
| maximum=1, |
| step=0.05, |
| interactive=True, |
| info="Higher values sample more low-probability tokens", |
| ) |
| max_new_tokens = gr.Slider( |
| label="Max new tokens", |
| value=512, |
| minimum=0, |
| maximum=1024, |
| step=4, |
| interactive=True, |
| info="The maximum numbers of new tokens", |
| ) |
| repetition_penalty = gr.Slider( |
| label="Repetition Penalty", |
| value=1.2, |
| minimum=0.0, |
| maximum=10, |
| step=0.1, |
| interactive=True, |
| info="The parameter for repetition penalty. 1.0 means no penalty.", |
| ) |
| do_save = gr.Checkbox( |
| value=True, |
| label="Store data", |
| info="You agree to the storage of your prompt and generated text for research and development purposes:", |
| ) |
| |
| |
| |
| |
| with gr.Row(): |
| gr.Examples( |
| examples=examples, |
| inputs=[user_message], |
| cache_examples=False, |
| fn=process_example, |
| outputs=[output], |
| ) |
|
|
| history = gr.State([]) |
| RETRY_FLAG = gr.Checkbox(value=False, visible=False) |
|
|
| |
| last_user_message = gr.State("") |
|
|
| user_message.submit( |
| generate, |
| inputs=[ |
| RETRY_FLAG, |
| selected_model, |
| system_message, |
| user_message, |
| chatbot, |
| history, |
| temperature, |
| top_k, |
| top_p, |
| max_new_tokens, |
| repetition_penalty, |
| do_save, |
| ], |
| outputs=[chatbot, history, last_user_message, user_message], |
| ) |
|
|
| send_button.click( |
| generate, |
| inputs=[ |
| RETRY_FLAG, |
| selected_model, |
| system_message, |
| user_message, |
| chatbot, |
| history, |
| temperature, |
| top_k, |
| top_p, |
| max_new_tokens, |
| repetition_penalty, |
| do_save, |
| ], |
| outputs=[chatbot, history, last_user_message, user_message], |
| ) |
|
|
| regenerate_button.click( |
| retry_last_answer, |
| inputs=[ |
| selected_model, |
| system_message, |
| user_message, |
| chatbot, |
| history, |
| temperature, |
| top_k, |
| top_p, |
| max_new_tokens, |
| repetition_penalty, |
| do_save, |
| ], |
| outputs=[chatbot, history, last_user_message, user_message], |
| ) |
|
|
| delete_turn_button.click(delete_last_turn, [chatbot, history], [chatbot, history]) |
| clear_chat_button.click(clear_chat, outputs=[chatbot, history]) |
| selected_model.change(clear_chat, outputs=[chatbot, history]) |
| |
|
|
| demo.launch(show_api=False) |
|
|