| import os |
| import time |
| import requests |
| import gradio as gr |
| from huggingface_hub import get_inference_endpoint |
|
|
| endpoint_name = os.getenv('ENDPOINT_NAME') |
| endpoint_url = os.getenv('ENDPOINT_URL') |
| personal_secret_token = os.getenv('PERSONAL_HF_TOKEN') |
|
|
| turn_breaker = os.getenv('TURN_BREAKER') |
| system_symbol = os.getenv('SYSTEM_SYMBOL') |
| user_symbol = os.getenv('USER_SYMBOL') |
| assistant_symbol = os.getenv('ASSISTANT_SYMBOL') |
|
|
| headers = { |
| "Accept" : "application/json", |
| "Authorization": f"Bearer {personal_secret_token}", |
| "Content-Type": "application/json" |
| } |
|
|
| def query(payload): |
| response = requests.post(endpoint_url, headers=headers, json=payload) |
| return response.json() |
|
|
| def get_status(): |
| endpoint = get_inference_endpoint(endpoint_name, token=personal_secret_token) |
| return endpoint.status |
|
|
| def respond( |
| message, |
| history: list[tuple[str, str]], |
| system_message, |
| max_new_tokens, |
| temperature, |
| top_p, |
| progress=gr.Progress() |
| ): |
| progress(0, desc="Starting") |
|
|
| if get_status() != "running": |
| query({"inputs": "wake up!"}) |
| progress(0.25, desc="Waking up model") |
| |
| while get_status() != "running": |
| time.sleep(1) |
|
|
| progress(0.5, desc="Generating") |
| |
| all_messages = [system_message] |
|
|
| for val in history: |
| if val[0]: |
| all_messages.append(user_symbol+val[0]) |
| if val[1]: |
| all_messages.append(assistant_symbol+val[1]) |
|
|
| all_messages.append(user_symbol+message) |
|
|
| generation_kwargs = dict( |
| max_new_tokens=max_new_tokens, |
| do_sample=temperature > 0, |
| top_p=top_p, |
| temperature=temperature |
| ) |
|
|
| response = query({ |
| "inputs": turn_breaker.join(all_messages), |
| "parameters": generation_kwargs |
| }) |
|
|
| progress(1, desc="Generating") |
|
|
| return response |
|
|
|
|
| """ |
| For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface |
| """ |
| demo = gr.ChatInterface( |
| respond, |
| additional_inputs=[ |
| gr.Textbox(value="请你扮演一个开心,积极的角色,名叫贺英旭,今年26岁,工作是程序员。你需要以这个身份和朋友们进行对话。", label="System message"), |
| gr.Slider(minimum=1, maximum=2048, value=64, step=1, label="Max new tokens"), |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
| gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| value=0.7, |
| step=0.05, |
| label="Top-p (nucleus sampling)", |
| ), |
| ], |
| show_progress="full" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |