Spaces:
Sleeping
Sleeping
| import json | |
| import gradio as gr | |
| import os | |
| import requests | |
| # We get the token and the models API url | |
| hf_token = os.getenv("HF_TOKEN") | |
| llama_7b = os.getenv("API_URL_LLAMA_7") | |
| llama_13b = os.getenv("API_URL_LLAMA_13") | |
| zephyr_7b = os.getenv("API_URL_ZEPHYR_7") | |
| headers = { | |
| 'Content-Type': 'application/json', | |
| } | |
| """ | |
| Chat Function | |
| """ | |
| def chat(message, | |
| chatbot, | |
| model= llama_13b, | |
| system_prompt = "", | |
| temperature = 0.9, | |
| max_new_tokens = 256, | |
| top_p = 0.6, | |
| repetition_penalty = 1.0 | |
| ): | |
| # Write the system prompt | |
| if system_prompt != "": | |
| input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " | |
| else: | |
| input_prompt = f"<s>[INST] " | |
| temperature = float(temperature) | |
| # We check that temperature is not less than 1e-2 | |
| if temperature < 1e-2: | |
| temperature = 1e-2 | |
| top_p = float(top_p) | |
| for interaction in chatbot: | |
| input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] " | |
| input_prompt = input_prompt + str(message) + " [/INST] " | |
| data = { | |
| "inputs": input_prompt, | |
| "parameters": { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| "do_sample": True, | |
| }, | |
| } | |
| print("MODEL" + model) | |
| if model == "zephyr_7b": | |
| model = zephyr_7b | |
| elif model == "llama_7b": | |
| model = llama_7b | |
| elif model == "llama_13b": | |
| model = llama_13b | |
| response = requests.post(model, headers=headers, data=json.dumps(data), auth=("hf", hf_token), stream=True) | |
| partial_message = "" | |
| for line in response.iter_lines(): | |
| if line: # filter out keep-alive new lines | |
| # Decode from bytes to string | |
| decoded_line = line.decode('utf-8') | |
| # Remove 'data:' prefix | |
| if decoded_line.startswith('data:'): | |
| json_line = decoded_line[5:] # Exclude the first 5 characters ('data:') | |
| else: | |
| gr.Warning(f"This line does not start with 'data:': {decoded_line}") | |
| continue | |
| # Load as JSON | |
| try: | |
| json_obj = json.loads(json_line) | |
| if 'token' in json_obj: | |
| partial_message = partial_message + json_obj['token']['text'] | |
| return partial_message #yield | |
| elif 'error' in json_obj: | |
| return json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.' | |
| # yield | |
| else: | |
| gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}") | |
| except json.JSONDecodeError: | |
| gr.Warning(f"This line is not valid JSON: {json_line}") | |
| continue | |
| except KeyError as e: | |
| gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}") | |
| continue | |
| additional_inputs=[ | |
| gr.Dropdown(choices=["llama_7b", "llama_13b", "zephyr_7b"], label="Model", info="Which model do you want to use?"), | |
| gr.Textbox("", label="Optional system prompt"), | |
| gr.Slider( | |
| label="Temperature", | |
| value=0.9, | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values produce more diverse outputs", | |
| ), | |
| gr.Slider( | |
| label="Max new tokens", | |
| value=256, | |
| minimum=0, | |
| maximum=4096, | |
| step=64, | |
| interactive=True, | |
| info="The maximum numbers of new tokens", | |
| ), | |
| gr.Slider( | |
| label="Top-p (nucleus sampling)", | |
| value=0.6, | |
| minimum=0.0, | |
| maximum=1, | |
| step=0.05, | |
| interactive=True, | |
| info="Higher values sample more low-probability tokens", | |
| ), | |
| gr.Slider( | |
| label="Repetition penalty", | |
| value=1.2, | |
| minimum=1.0, | |
| maximum=2.0, | |
| step=0.05, | |
| interactive=True, | |
| info="Penalize repeated tokens", | |
| ) | |
| ] | |
| title = "Find the password 🔒" | |
| description = "In this game prototype, your goal is to discuss with the intercom to find the correct password" | |
| chatbot = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False) | |
| chat_interface = gr.ChatInterface(chat, | |
| title=title, | |
| description=description, | |
| textbox=gr.Textbox(), | |
| chatbot=chatbot, | |
| additional_inputs=additional_inputs) | |
| # Gradio Demo | |
| with gr.Blocks() as demo: | |
| chat_interface.render() | |
| demo.launch(debug=True) |