Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from openai import AsyncInferenceClient | |
| # Assuming client is a global variable | |
| client = AsyncInferenceClient("meta-llama/Llama-2-70b-chat-hf") | |
| def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0): | |
| input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n " if system_prompt else "<s>[INST] " | |
| temperature = max(1e-2, float(temperature)) | |
| top_p = float(top_p) | |
| for interaction in chatbot: | |
| input_prompt += f"{interaction[0]} [/INST] {interaction[1]} </s><s>[INST] " | |
| input_prompt += f"{message} [/INST] " | |
| partial_message = "" | |
| for token in client.text_generation( | |
| prompt=input_prompt, | |
| max_new_tokens=max_new_tokens, | |
| stream=True, | |
| best_of=1, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True, | |
| repetition_penalty=repetition_penalty, | |
| ): | |
| partial_message += token | |
| yield partial_message | |
| # Create a Gradio interface | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.Textbox("text", label="Message"), | |
| gr.Textbox("text", label="Chatbot"), | |
| gr.Textbox("text", label="System Prompt"), | |
| gr.Number("slider", minimum=0.1, maximum=2, default=0.9, label="Temperature"), | |
| gr.Number("slider", minimum=1, maximum=1000, default=256, label="Max New Tokens"), | |
| gr.Number("slider", minimum=0.1, maximum=1, default=0.6, label="Top P"), | |
| gr.Number("slider", minimum=0.1, maximum=2, default=1.0, label="Repetition Penalty"), | |
| ], | |
| outputs=gr.Textbox(), | |
| ) | |
| iface.launch() | |