Spaces:
Runtime error
Runtime error
| import argparse | |
| from collections import defaultdict | |
| import datetime | |
| import json | |
| import os | |
| import random | |
| import time | |
| import uuid | |
| import websocket | |
| from websocket import WebSocketConnectionClosedException | |
| import gradio as gr | |
| import requests | |
| import logging | |
| import re | |
| from fastchat.conversation import SeparatorStyle | |
| from fastchat.constants import ( | |
| LOGDIR, | |
| WORKER_API_TIMEOUT, | |
| ErrorCode, | |
| MODERATION_MSG, | |
| CONVERSATION_LIMIT_MSG, | |
| SERVER_ERROR_MSG, | |
| INACTIVE_MSG, | |
| INPUT_CHAR_LEN_LIMIT, | |
| CONVERSATION_TURN_LIMIT, | |
| SESSION_EXPIRATION_TIME, | |
| ) | |
| from fastchat.model.model_adapter import get_conversation_template | |
| from fastchat.model.model_registry import model_info | |
| from fastchat.serve.api_provider import ( | |
| anthropic_api_stream_iter, | |
| openai_api_stream_iter, | |
| palm_api_stream_iter, | |
| init_palm_chat, | |
| ) | |
| from fastchat.utils import ( | |
| build_logger, | |
| violates_moderation, | |
| get_window_url_params_js, | |
| parse_gradio_auth_creds, | |
| ) | |
| logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
| no_change_dropdown = gr.Dropdown.update() | |
| no_change_slider = gr.Slider.update() | |
| no_change_textbox = gr.Textbox.update() | |
| no_change_btn = gr.Button.update() | |
| enable_btn = gr.Button.update(interactive=True) | |
| disable_btn = gr.Button.update(interactive=False) | |
| def get_internet_ip(): | |
| r = requests.get("http://txt.go.sohu.com/ip/soip") | |
| ip = re.findall(r'\d+.\d+.\d+.\d+', r.text) | |
| if ip is not None and len(ip) > 0: | |
| return ip[0] | |
| return None | |
| enable_moderation = True if os.environ.get('enable_moderation', default='False')=="True" else False | |
| concurrency_count = int(os.environ.get('concurrency_count', default='10')) | |
| model_list_mode = os.environ.get('model_list_mode', default='reload') | |
| midware_url = os.environ.get('midware_url', default='') | |
| preset_token = os.environ.get('preset_token', default='') | |
| worker_addr = os.environ.get('worker_addr', default='') | |
| allow_running = int(os.environ.get('allow_running', default='1')) | |
| ft_list_job_url = os.environ.get('ft_list_job_url', default='') | |
| ft_submit_job_url = os.environ.get('ft_submit_job_url', default='') | |
| ft_remove_job_url = os.environ.get('ft_remove_job_url', default='') | |
| ft_console_log_url = os.environ.get('ft_console_log_url', default='') | |
| dataset_sample = { | |
| "english": { | |
| "train": ["abcdef"], | |
| "valid": ["zxcvbn"] | |
| }, | |
| } | |
| dataset_to_midware_name = { | |
| "english": "english", | |
| "cat": "cat", | |
| "dog": "dog", | |
| "bird": "bird" | |
| } | |
| hps_keys = ["epochs", "train_batch_size", "eval_batch_size", "gradient_accumulation_steps", "learning_rate", "weight_decay", "model_max_length"] | |
| headers = {"User-Agent": "FastChat Client", "PRIVATE-TOKEN": preset_token} | |
| learn_more_md = """ | |
| ### License | |
| The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/LICENSE) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. | |
| """ | |
| ip_expiration_dict = defaultdict(lambda: 0) | |
| def is_legal_char(c): | |
| if c.isalnum(): | |
| return True | |
| if '\u4e00' <= c <= '\u9fff': | |
| return True | |
| if c in "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏.": | |
| return True | |
| if c in '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~': | |
| return True | |
| return False | |
| def str_filter(s): | |
| for _ in range(2): | |
| if len(s) > 0 and (not is_legal_char(s[-1])): | |
| s = s[:-1] | |
| return s | |
| def str_not_int(s): | |
| try: | |
| int(s) | |
| return False | |
| except ValueError: | |
| return True | |
| def str_not_float(s): | |
| try: | |
| float(s) | |
| return False | |
| except ValueError: | |
| return True | |
| class State: | |
| def __init__(self, model_name): | |
| self.conv = get_conversation_template(model_name) | |
| self.conv_id = uuid.uuid4().hex | |
| self.skip_next = False | |
| self.model_name = model_name | |
| if model_name == "palm-2": | |
| # According to release note, "chat-bison@001" is PaLM 2 for chat. | |
| # https://cloud.google.com/vertex-ai/docs/release-notes#May_10_2023 | |
| self.palm_chat = init_palm_chat("chat-bison@001") | |
| def to_gradio_chatbot(self): | |
| return self.conv.to_gradio_chatbot() | |
| def dict(self): | |
| base = self.conv.dict() | |
| base.update( | |
| { | |
| "conv_id": self.conv_id, | |
| "model_name": self.model_name, | |
| } | |
| ) | |
| return base | |
| def get_conv_log_filename(): | |
| t = datetime.datetime.now() | |
| name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
| return name | |
| def get_model_list(midware_url): | |
| setted_model_order = { | |
| "vicuna-7b-v1.5-16k": 10, | |
| "vicuna-13b-v1.5": 90, | |
| } | |
| try: | |
| ret = requests.get(midware_url, headers={"PRIVATE-TOKEN": preset_token}, timeout=5) | |
| if "code" in ret.json() and "invalid" in ret.json()["code"]: | |
| gr.Warning("Invalid preset token.") | |
| models = ["CANNOT GET MODEL"] | |
| else: | |
| models = ret.json()["data"] | |
| except requests.exceptions.RequestException: | |
| models = ["CANNOT GET MODEL"] | |
| models = sorted(models, key=lambda x: setted_model_order.get(x, 100)) | |
| logger.info(f"Models: {models}") | |
| return models | |
| def load_demo_single(models, url_params): | |
| selected_model = models[0] if len(models) > 0 else "" | |
| if "model" in url_params: | |
| model = url_params["model"] | |
| if model in models: | |
| selected_model = model | |
| dropdown_update = gr.Dropdown.update( | |
| choices=models, value=selected_model, visible=True | |
| ) | |
| state = None | |
| return ( | |
| state, | |
| dropdown_update, | |
| gr.Chatbot.update(visible=True), | |
| gr.Textbox.update(visible=True), | |
| gr.Button.update(visible=True), | |
| gr.Row.update(visible=True), | |
| gr.Accordion.update(visible=True), | |
| ) | |
| def load_demo(url_params, request: gr.Request): | |
| global models | |
| ip = request.client.host | |
| logger.info(f"load_demo. ip: {ip}. params: {url_params}") | |
| ip_expiration_dict[ip] = time.time() + SESSION_EXPIRATION_TIME | |
| if model_list_mode == "reload": | |
| models = get_model_list(midware_url) | |
| return load_demo_single(models, url_params) | |
| def regenerate(state, request: gr.Request): | |
| logger.info(f"regenerate. ip: {request.client.host}") | |
| state.conv.update_last_message(None) | |
| return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 2 | |
| def clear_history(request: gr.Request): | |
| logger.info(f"clear_history. ip: {request.client.host}") | |
| state = None | |
| return (state, [], "") + (disable_btn,) * 2 | |
| def add_text(state, model_selector, text, request: gr.Request): | |
| ip = request.client.host | |
| logger.info(f"add_text. ip: {ip}. len: {len(text)}") | |
| if state is None: | |
| state = State(model_selector) | |
| if len(text) <= 0: | |
| state.skip_next = True | |
| return (state, state.to_gradio_chatbot(), "") + (no_change_btn,) * 2 | |
| if ip_expiration_dict[ip] < time.time(): | |
| logger.info(f"inactive. ip: {request.client.host}. text: {text}") | |
| state.skip_next = True | |
| return (state, state.to_gradio_chatbot(), INACTIVE_MSG) + (no_change_btn,) * 2 | |
| if enable_moderation: | |
| flagged = violates_moderation(text) | |
| if flagged: | |
| logger.info(f"violate moderation. ip: {request.client.host}. text: {text}") | |
| state.skip_next = True | |
| return (state, state.to_gradio_chatbot(), MODERATION_MSG) + ( | |
| no_change_btn, | |
| ) * 2 | |
| conv = state.conv | |
| if (len(conv.messages) - conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: | |
| logger.info(f"conversation turn limit. ip: {request.client.host}. text: {text}") | |
| state.skip_next = True | |
| return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG) + ( | |
| no_change_btn, | |
| ) * 2 | |
| text = text[:INPUT_CHAR_LEN_LIMIT] # Hard cut-off | |
| conv.append_message(conv.roles[0], text) | |
| conv.append_message(conv.roles[1], None) | |
| return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 2 | |
| def post_process_code(code): | |
| sep = "\n```" | |
| if sep in code: | |
| blocks = code.split(sep) | |
| if len(blocks) % 2 == 1: | |
| for i in range(1, len(blocks), 2): | |
| blocks[i] = blocks[i].replace("\\_", "_") | |
| code = sep.join(blocks) | |
| return code | |
| def model_worker_stream_iter( | |
| conv, | |
| model_name, | |
| worker_addr, | |
| prompt, | |
| temperature, | |
| repetition_penalty, | |
| top_p, | |
| max_new_tokens, | |
| ): | |
| # Make requests | |
| gen_params = { | |
| "model_name": model_name, | |
| "question": prompt, | |
| "temperature": 1e-6, | |
| "repetition_penalty": repetition_penalty, | |
| "top_p": top_p, | |
| "max_new_tokens": max_new_tokens, | |
| "stop": conv.stop_str, | |
| "stop_token_ids": conv.stop_token_ids, | |
| "echo": False, | |
| } | |
| logger.info(f"==== request ====\n{gen_params}") | |
| # Stream output | |
| response = requests.post( | |
| worker_addr, | |
| headers=headers, | |
| json=gen_params, | |
| stream=True, | |
| timeout=WORKER_API_TIMEOUT, | |
| ) | |
| for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): | |
| if chunk: | |
| data = json.loads(chunk.decode()) | |
| yield data | |
| def bot_response(state, temperature, top_p, max_new_tokens, request: gr.Request): | |
| logger.info(f"bot_response. ip: {request.client.host}") | |
| start_tstamp = time.time() | |
| temperature = float(temperature) | |
| top_p = float(top_p) | |
| max_new_tokens = int(max_new_tokens) | |
| if state.skip_next: | |
| # This generate call is skipped due to invalid inputs | |
| state.skip_next = False | |
| yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 2 | |
| return | |
| conv, model_name = state.conv, state.model_name | |
| if model_name == "gpt-3.5-turbo" or model_name == "gpt-4": | |
| prompt = conv.to_openai_api_messages() | |
| stream_iter = openai_api_stream_iter( | |
| model_name, prompt, temperature, top_p, max_new_tokens | |
| ) | |
| elif model_name == "claude-2" or model_name == "claude-instant-1": | |
| prompt = conv.get_prompt() | |
| stream_iter = anthropic_api_stream_iter( | |
| model_name, prompt, temperature, top_p, max_new_tokens | |
| ) | |
| elif model_name == "palm-2": | |
| stream_iter = palm_api_stream_iter( | |
| state.palm_chat, conv.messages[-2][1], temperature, top_p, max_new_tokens | |
| ) | |
| else: | |
| # Get worker address | |
| logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}") | |
| # No available worker | |
| if worker_addr == "": | |
| conv.update_last_message(SERVER_ERROR_MSG) | |
| yield ( | |
| state, | |
| state.to_gradio_chatbot(), | |
| enable_btn, | |
| enable_btn, | |
| ) | |
| return | |
| # Construct prompt. | |
| # We need to call it here, so it will not be affected by "▌". | |
| prompt = conv.get_prompt() | |
| # Set repetition_penalty | |
| if "t5" in model_name: | |
| repetition_penalty = 1.2 | |
| else: | |
| repetition_penalty = 1.0 | |
| stream_iter = model_worker_stream_iter( | |
| conv, | |
| model_name, | |
| worker_addr, | |
| prompt, | |
| temperature, | |
| repetition_penalty, | |
| top_p, | |
| max_new_tokens, | |
| ) | |
| conv.update_last_message("▌") | |
| yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 | |
| try: | |
| for data in stream_iter: | |
| if data["error_code"] == 0: | |
| finish_reason = data.get("finish_reason", None) | |
| if finish_reason is not None and finish_reason == "length": | |
| gr.Warning("Answer interrupted because the setting of [Max output tokens], try set a larger value.") | |
| output = data["text"].strip() | |
| if "vicuna" in model_name: | |
| output = post_process_code(output) | |
| output = str_filter(output) | |
| conv.update_last_message(output + "▌") | |
| yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 2 | |
| else: | |
| output = data["text"] + f"\n\n(error_code: {data['error_code']})" | |
| conv.update_last_message(output) | |
| yield (state, state.to_gradio_chatbot()) + ( | |
| enable_btn, | |
| enable_btn, | |
| ) | |
| return | |
| time.sleep(0.015) | |
| except requests.exceptions.RequestException as e: | |
| conv.update_last_message( | |
| f"{SERVER_ERROR_MSG}\n\n" | |
| f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" | |
| ) | |
| yield (state, state.to_gradio_chatbot()) + ( | |
| enable_btn, | |
| enable_btn, | |
| ) | |
| return | |
| except Exception as e: | |
| conv.update_last_message( | |
| f"{SERVER_ERROR_MSG}\n\n" | |
| f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" | |
| ) | |
| yield (state, state.to_gradio_chatbot()) + ( | |
| enable_btn, | |
| enable_btn, | |
| ) | |
| return | |
| # Delete "▌" | |
| conv.update_last_message(conv.messages[-1][-1][:-1]) | |
| yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 2 | |
| finish_tstamp = time.time() | |
| logger.info(f"{output}") | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name, | |
| "gen_params": { | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "max_new_tokens": max_new_tokens, | |
| }, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state.dict(), | |
| "ip": request.client.host, | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| block_css = """ | |
| #dialog_notice_markdown { | |
| font-size: 104% | |
| } | |
| #dialog_notice_markdown th { | |
| display: none; | |
| } | |
| #dialog_notice_markdown td { | |
| padding-top: 6px; | |
| padding-bottom: 6px; | |
| } | |
| #leaderboard_markdown { | |
| font-size: 104% | |
| } | |
| #leaderboard_markdown td { | |
| padding-top: 6px; | |
| padding-bottom: 6px; | |
| } | |
| #leaderboard_dataframe td { | |
| line-height: 0.1em; | |
| } | |
| """ | |
| def get_model_description_md(models): | |
| model_description_md = """ | |
| | | | | | |
| | ---- | ---- | ---- | | |
| """ | |
| ct = 0 | |
| visited = set() | |
| for i, name in enumerate(models): | |
| if name in model_info: | |
| minfo = model_info[name] | |
| if minfo.simple_name in visited: | |
| continue | |
| visited.add(minfo.simple_name) | |
| one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" | |
| else: | |
| visited.add(name) | |
| one_model_md = ( | |
| f"[{name}](): Add the description at fastchat/model/model_registry.py" | |
| ) | |
| if ct % 3 == 0: | |
| model_description_md += "|" | |
| model_description_md += f" {one_model_md} |" | |
| if ct % 3 == 2: | |
| model_description_md += "\n" | |
| ct += 1 | |
| return model_description_md | |
| def build_single_model_ui(models, add_promotion_links=False): | |
| with gr.Column(): | |
| with gr.Tab("🧠 模型对话 Dialog"): | |
| state = gr.State() | |
| with gr.Row(elem_id="model_selector_row"): | |
| model_selector = gr.Dropdown( | |
| choices=models, | |
| value=models[0] if len(models) > 0 else "", | |
| interactive=True, | |
| show_label=False, | |
| container=False, | |
| ) | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| label="Scroll down and start chatting", | |
| visible=False, | |
| height=550, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=20): | |
| textbox = gr.Textbox( | |
| show_label=False, | |
| placeholder="Enter text and press ENTER", | |
| visible=False, | |
| container=False, | |
| ) | |
| with gr.Column(scale=1, min_width=50): | |
| send_btn = gr.Button(value="Send", visible=False) | |
| with gr.Row(visible=False) as button_row: | |
| regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) | |
| clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) | |
| gr.Examples( | |
| examples=["如何变得富有?", "你能用Python写一段快速排序吗?", "How to be rich?", "Can you write a quicksort code in Python?"], | |
| inputs=textbox, | |
| ) | |
| with gr.Accordion("Parameters", open=False, visible=False) as parameter_row: | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.7, | |
| step=0.1, | |
| interactive=True, | |
| label="Temperature", | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=1.0, | |
| step=0.1, | |
| interactive=True, | |
| label="Top P", | |
| ) | |
| max_output_tokens = gr.Slider( | |
| minimum=16, | |
| maximum=1024, | |
| value=512, | |
| step=64, | |
| interactive=True, | |
| label="Max output tokens", | |
| ) | |
| gr.Markdown(learn_more_md) | |
| # Register listeners | |
| btn_list = [regenerate_btn, clear_btn] | |
| regenerate_btn.click(regenerate, state, [state, chatbot, textbox] + btn_list).then( | |
| bot_response, | |
| [state, temperature, top_p, max_output_tokens], | |
| [state, chatbot] + btn_list, | |
| ) | |
| clear_btn.click(clear_history, None, [state, chatbot, textbox] + btn_list) | |
| model_selector.change(clear_history, None, [state, chatbot, textbox] + btn_list) | |
| textbox.submit( | |
| add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list | |
| ).then( | |
| bot_response, | |
| [state, temperature, top_p, max_output_tokens], | |
| [state, chatbot] + btn_list, | |
| ) | |
| send_btn.click( | |
| add_text, [state, model_selector, textbox], [state, chatbot, textbox] + btn_list | |
| ).then( | |
| bot_response, | |
| [state, temperature, top_p, max_output_tokens], | |
| [state, chatbot] + btn_list, | |
| ) | |
| return state, model_selector, chatbot, textbox, send_btn, button_row, parameter_row | |
| def ft_get_job_data(): | |
| running = 0 | |
| res_lst = [] | |
| try: | |
| r = requests.get(ft_list_job_url, headers={"PRIVATE-TOKEN": preset_token}, timeout=8) | |
| if "code" in r.json() and "invalid" in r.json()["code"]: | |
| gr.Warning("Invalid preset token.") | |
| return res_lst, running | |
| for d in r.json(): | |
| if isinstance(d['status'], str) and d['status'].lower() == "running": | |
| running += 1 | |
| hps = dict() | |
| for key in hps_keys: | |
| if key in d['parameter']: | |
| hps[key] = d['parameter'][key] | |
| res_lst.append([d['jobName'], d['username'], d['created_at'], d['model'], d['dataset'], d['status'], json.dumps(hps)]) | |
| res_lst = sorted(res_lst,key=(lambda x:x[2]), reverse=True) | |
| res_lst = sorted(res_lst,key=(lambda x:x[5]), reverse=True) | |
| except requests.exceptions.RequestException: | |
| logger.info(f"Get job list fail") | |
| return res_lst, running | |
| def ft_refresh_click(): | |
| return ft_get_job_data() | |
| def ft_cease_click(ft_console): | |
| output = ft_console + "\n" + "** Streaming output ceased by user **" | |
| return output | |
| def console_generator(addr, sleep_time): | |
| total_str = "" | |
| ws = websocket.WebSocket() | |
| ws.connect(addr, header={"PRIVATE-TOKEN": preset_token}) | |
| while True: | |
| try: | |
| new_str = ws.recv() | |
| total_str = total_str + new_str | |
| time.sleep(sleep_time) | |
| yield total_str | |
| except WebSocketConnectionClosedException: | |
| ws.close() | |
| break | |
| ws.close() | |
| def ft_submit_click(ft_latest_running_cnt, ft_user_name, ft_model, ft_dataset_name, ft_token, ft_epochs, ft_train_batch_size, ft_eval_batch_size, ft_gradient_accumulation_steps, ft_learning_rate, ft_weight_decay, ft_model_max_length): | |
| if ft_user_name == "": | |
| gr.Warning(f"Submit fail, empty username.") | |
| res_lst, running = ft_get_job_data() | |
| return res_lst, running, no_change_textbox | |
| if str_not_int(ft_train_batch_size) or str_not_int(ft_eval_batch_size) or str_not_int(ft_gradient_accumulation_steps) or str_not_float(ft_learning_rate) or str_not_float(ft_weight_decay) or str_not_int(ft_model_max_length): | |
| gr.Warning(f"Submit fail, check the types. [learning rate] and [weight decay] should be float, others HPs should be int.") | |
| res_lst, running = ft_get_job_data() | |
| return res_lst, running, no_change_textbox | |
| if ft_latest_running_cnt < int(allow_running): | |
| midware_header = {"FINETUNE-SECRET": ft_token, "PRIVATE-TOKEN": preset_token} | |
| hps_json = { | |
| "epochs": str(ft_epochs), | |
| "train_batch_size": str(ft_train_batch_size), | |
| "eval_batch_size": str(ft_eval_batch_size), | |
| "gradient_accumulation_steps": str(ft_gradient_accumulation_steps), | |
| "learning_rate": str(ft_learning_rate), | |
| "weight_decay": str(ft_weight_decay), | |
| "model_max_length": str(ft_model_max_length) | |
| } | |
| json_data = { | |
| "dataset": dataset_to_midware_name[ft_dataset_name], | |
| "model": ft_model, | |
| "parameter": hps_json, | |
| "username": ft_user_name | |
| } | |
| try: | |
| r = requests.post(ft_submit_job_url, json=json_data, headers=midware_header, timeout=120) | |
| job_name = r.json()["jobName"] | |
| gr.Info(f"Job {job_name} submit success.") | |
| res_lst, running = ft_get_job_data() | |
| total_str = "" | |
| for s in console_generator(ft_console_log_url + job_name, 1): | |
| total_str = s | |
| yield res_lst, running, s | |
| res_lst, running = ft_get_job_data() | |
| yield res_lst, running, total_str | |
| except requests.exceptions.RequestException: | |
| gr.Warning(f"Connection Failure.") | |
| res_lst, running = ft_get_job_data() | |
| return res_lst, running, "" | |
| else: | |
| gr.Warning(f"Only allow {str(allow_running)} job(s) running simultaneously, please wait.") | |
| res_lst, running = ft_get_job_data() | |
| return res_lst, running, no_change_textbox | |
| def ft_show_click(ft_selected_row_data): | |
| for s in console_generator(ft_console_log_url + ft_selected_row_data[0], 0.2): | |
| yield s | |
| def ft_remove_click(ft_selected_row_data, ft_token): | |
| status = ft_selected_row_data[5] | |
| if isinstance(status, str) and status.lower() == "running": | |
| r = requests.delete(ft_remove_job_url + ft_selected_row_data[0], headers={'FINETUNE-SECRET': ft_token, "PRIVATE-TOKEN": preset_token}) | |
| if r.status_code == 200: | |
| gr.Info("Remove success.") | |
| else: | |
| gr.Warning(f"Remove fail. {r.status_code} {r.reason}.") | |
| else: | |
| gr.Warning("Remove fail. Can only remove a running job.") | |
| return ft_get_job_data() | |
| def ft_jobs_info_select(ft_jobs_info, evt: gr.SelectData): | |
| selected_row = ft_jobs_info[evt.index[0]] | |
| if evt.index[1] in (3, 4, 6): | |
| try: | |
| Hps = json.loads(selected_row[6]) | |
| except json.decoder.JSONDecodeError: | |
| Hps = dict() | |
| return [selected_row, selected_row[3], selected_row[4], Hps.get('epochs', ''), Hps.get('train_batch_size', ''), Hps.get('eval_batch_size', ''), | |
| Hps.get('gradient_accumulation_steps', ''), Hps.get('learning_rate', ''), Hps.get('weight_decay', ''), Hps.get('model_max_length', '')] | |
| else: | |
| return [selected_row, no_change_dropdown, no_change_dropdown, no_change_slider, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox, no_change_textbox] | |
| def ft_dataset_preview_click(ft_dataset_name): | |
| value = dataset_sample.get(ft_dataset_name, {}) | |
| return gr.JSON.update(value=value, visible=True) | |
| def ft_hide_dataset_click(): | |
| return gr.JSON.update(visible=False) | |
| def build_demo(models): | |
| with gr.Blocks( | |
| title="Vicuna Test", | |
| theme=gr.themes.Base(), | |
| css = block_css | |
| ) as demo: | |
| url_params = gr.JSON(visible=False) | |
| ( | |
| state, | |
| model_selector, | |
| chatbot, | |
| textbox, | |
| send_btn, | |
| button_row, | |
| parameter_row, | |
| ) = build_single_model_ui(models) | |
| if model_list_mode not in ["once", "reload"]: | |
| raise ValueError(f"Unknown model list mode: {model_list_mode}") | |
| demo.load( | |
| load_demo, | |
| [url_params], | |
| [ | |
| state, | |
| model_selector, | |
| chatbot, | |
| textbox, | |
| send_btn, | |
| button_row, | |
| parameter_row, | |
| ], | |
| _js=get_window_url_params_js, | |
| ) | |
| return demo | |
| try: | |
| print("Internet IP:", get_internet_ip()) | |
| except Exception as e: | |
| print(f"Get Internet IP error: {e}") | |
| models = get_model_list(midware_url) | |
| # Launch the demo | |
| demo = build_demo(models) | |
| demo.queue( | |
| concurrency_count=concurrency_count, status_update_rate=10, api_open=False | |
| ).launch( | |
| max_threads=200, | |
| ) | |