| | """ |
| | The gradio demo server for chatting with a single model. |
| | """ |
| |
|
| | import argparse |
| | from collections import defaultdict |
| | import datetime |
| | import hashlib |
| | import json |
| | import os |
| | import random |
| | import time |
| | import uuid |
| |
|
| | import gradio as gr |
| | import requests |
| |
|
| | from src.constants import ( |
| | LOGDIR, |
| | WORKER_API_TIMEOUT, |
| | ErrorCode, |
| | MODERATION_MSG, |
| | CONVERSATION_LIMIT_MSG, |
| | RATE_LIMIT_MSG, |
| | SERVER_ERROR_MSG, |
| | INPUT_CHAR_LEN_LIMIT, |
| | CONVERSATION_TURN_LIMIT, |
| | SESSION_EXPIRATION_TIME, |
| | ) |
| | from src.model.model_adapter import ( |
| | get_conversation_template, |
| | ) |
| | from src.model.model_registry import get_model_info, model_info |
| | from src.serve.api_provider import get_api_provider_stream_iter |
| | from src.serve.remote_logger import get_remote_logger |
| | from src.utils import ( |
| | build_logger, |
| | get_window_url_params_js, |
| | get_window_url_params_with_tos_js, |
| | moderation_filter, |
| | parse_gradio_auth_creds, |
| | load_image, |
| | ) |
| |
|
| | logger = build_logger("gradio_web_server", "gradio_web_server.log") |
| |
|
| | headers = {"User-Agent": "FastChat Client"} |
| |
|
| | no_change_btn = gr.Button() |
| | enable_btn = gr.Button(interactive=True, visible=True) |
| | disable_btn = gr.Button(interactive=False) |
| | invisible_btn = gr.Button(interactive=False, visible=False) |
| |
|
| | controller_url = None |
| | enable_moderation = False |
| | use_remote_storage = False |
| |
|
| | acknowledgment_md = """ |
| | ### Terms of Service |
| | |
| | Placeholder |
| | ### Acknowledgment |
| | Placeholder |
| | """ |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | api_endpoint_info = {} |
| |
|
| |
|
| | class State: |
| | def __init__(self, model_name, is_vision=False): |
| | self.conv = get_conversation_template(model_name) |
| | self.conv_id = uuid.uuid4().hex |
| | self.skip_next = False |
| | self.model_name = model_name |
| | self.oai_thread_id = None |
| | self.is_vision = is_vision |
| |
|
| | |
| | self.has_csam_image = False |
| |
|
| | self.regen_support = True |
| | if "browsing" in model_name: |
| | self.regen_support = False |
| | self.init_system_prompt(self.conv) |
| |
|
| | def init_system_prompt(self, conv): |
| | if hasattr(conv, "get_system_message"): |
| | system_prompt = conv.get_system_message() |
| | elif (conv, "system"): |
| | system_prompt = conv.system |
| | return |
| | if len(system_prompt) == 0: |
| | return |
| | current_date = datetime.datetime.now().strftime("%Y-%m-%d") |
| | system_prompt = system_prompt.replace("{{currentDateTime}}", current_date) |
| | conv.set_system_message(system_prompt) |
| |
|
| | def to_gradio_chatbot(self): |
| | return self.conv.to_gradio_chatbot() |
| |
|
| | def dict(self): |
| | base = self.conv.dict() |
| | base.update( |
| | { |
| | "conv_id": self.conv_id, |
| | "model_name": self.model_name, |
| | } |
| | ) |
| |
|
| | if self.is_vision: |
| | base.update({"has_csam_image": self.has_csam_image}) |
| | return base |
| |
|
| |
|
| | def set_global_vars(controller_url_, enable_moderation_, use_remote_storage_): |
| | global controller_url, enable_moderation, use_remote_storage |
| | controller_url = controller_url_ |
| | enable_moderation = enable_moderation_ |
| | use_remote_storage = use_remote_storage_ |
| |
|
| |
|
| | def get_conv_log_filename(is_vision=False, has_csam_image=False): |
| | t = datetime.datetime.now() |
| | conv_log_filename = f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json" |
| | if is_vision and not has_csam_image: |
| | name = os.path.join(LOGDIR, f"vision-tmp-{conv_log_filename}") |
| | elif is_vision and has_csam_image: |
| | name = os.path.join(LOGDIR, f"vision-csam-{conv_log_filename}") |
| | else: |
| | name = os.path.join(LOGDIR, conv_log_filename) |
| |
|
| | return name |
| |
|
| |
|
| | def get_model_list(controller_url, register_api_endpoint_file, vision_arena): |
| | global api_endpoint_info |
| |
|
| | |
| | if controller_url: |
| | ret = requests.post(controller_url + "/refresh_all_workers") |
| | assert ret.status_code == 200 |
| |
|
| | if vision_arena: |
| | ret = requests.post(controller_url + "/list_multimodal_models") |
| | models = ret.json()["models"] |
| | else: |
| | ret = requests.post(controller_url + "/list_language_models") |
| | models = ret.json()["models"] |
| | else: |
| | models = [] |
| |
|
| | |
| | if register_api_endpoint_file: |
| | api_endpoint_info = json.load(open(register_api_endpoint_file)) |
| | for mdl, mdl_dict in api_endpoint_info.items(): |
| | mdl_vision = mdl_dict.get("vision-arena", False) |
| | mdl_text = mdl_dict.get("text-arena", True) |
| | if vision_arena and mdl_vision: |
| | models.append(mdl) |
| | if not vision_arena and mdl_text: |
| | models.append(mdl) |
| |
|
| | |
| | models = list(set(models)) |
| | visible_models = models.copy() |
| | for mdl in models: |
| | if mdl not in api_endpoint_info: |
| | continue |
| | mdl_dict = api_endpoint_info[mdl] |
| | if mdl_dict["anony_only"]: |
| | visible_models.remove(mdl) |
| |
|
| | |
| | priority = {k: f"___{i:03d}" for i, k in enumerate(model_info)} |
| | models.sort(key=lambda x: priority.get(x, x)) |
| | visible_models.sort(key=lambda x: priority.get(x, x)) |
| | logger.info(f"All models: {models}") |
| | logger.info(f"Visible models: {visible_models}") |
| | return visible_models, models |
| |
|
| |
|
| | def load_demo_single(models, url_params): |
| | selected_model = models[0] if len(models) > 0 else "" |
| | if "model" in url_params: |
| | model = url_params["model"] |
| | if model in models: |
| | selected_model = model |
| |
|
| | dropdown_update = gr.Dropdown(choices=models, value=selected_model, visible=True) |
| | state = None |
| | return state, dropdown_update |
| |
|
| |
|
| | def load_demo(url_params, request: gr.Request): |
| | global models |
| |
|
| | ip = get_ip(request) |
| | logger.info(f"load_demo. ip: {ip}. params: {url_params}") |
| |
|
| | if args.model_list_mode == "reload": |
| | models, all_models = get_model_list( |
| | controller_url, args.register_api_endpoint_file, vision_arena=False |
| | ) |
| |
|
| | return load_demo_single(models, url_params) |
| |
|
| |
|
| | def vote_last_response(state, vote_type, model_selector, request: gr.Request): |
| | filename = get_conv_log_filename() |
| | if "llava" in model_selector: |
| | filename = filename.replace("2024", "vision-tmp-2024") |
| |
|
| | with open(filename, "a") as fout: |
| | data = { |
| | "tstamp": round(time.time(), 4), |
| | "type": vote_type, |
| | "model": model_selector, |
| | "state": state.dict(), |
| | "ip": get_ip(request), |
| | } |
| | fout.write(json.dumps(data) + "\n") |
| | get_remote_logger().log(data) |
| |
|
| |
|
| | def upvote_last_response(state, model_selector, request: gr.Request): |
| | ip = get_ip(request) |
| | logger.info(f"upvote. ip: {ip}") |
| | vote_last_response(state, "upvote", model_selector, request) |
| | return ("",) + (disable_btn,) * 3 |
| |
|
| |
|
| | def downvote_last_response(state, model_selector, request: gr.Request): |
| | ip = get_ip(request) |
| | logger.info(f"downvote. ip: {ip}") |
| | vote_last_response(state, "downvote", model_selector, request) |
| | return ("",) + (disable_btn,) * 3 |
| |
|
| |
|
| | def flag_last_response(state, model_selector, request: gr.Request): |
| | ip = get_ip(request) |
| | logger.info(f"flag. ip: {ip}") |
| | vote_last_response(state, "flag", model_selector, request) |
| | return ("",) + (disable_btn,) * 3 |
| |
|
| |
|
| | def regenerate(state, request: gr.Request): |
| | ip = get_ip(request) |
| | logger.info(f"regenerate. ip: {ip}") |
| | if not state.regen_support: |
| | state.skip_next = True |
| | return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 |
| | state.conv.update_last_message(None) |
| | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 |
| |
|
| |
|
| | def clear_history(request: gr.Request): |
| | ip = get_ip(request) |
| | logger.info(f"clear_history. ip: {ip}") |
| | state = None |
| | return (state, [], "", None) + (disable_btn,) * 5 |
| |
|
| |
|
| | def get_ip(request: gr.Request): |
| | if "cf-connecting-ip" in request.headers: |
| | ip = request.headers["cf-connecting-ip"] |
| | elif "x-forwarded-for" in request.headers: |
| | ip = request.headers["x-forwarded-for"] |
| | else: |
| | ip = request.client.host |
| | return ip |
| |
|
| |
|
| | |
| | def report_csam_image(state, image): |
| | pass |
| |
|
| |
|
| | def _prepare_text_with_image(state, text, images, csam_flag): |
| | if images is not None and len(images) > 0: |
| | image = images[0] |
| |
|
| | if len(state.conv.get_images()) > 0: |
| | |
| | state.conv = get_conversation_template(state.model_name) |
| |
|
| | if hasattr(state.conv, "convert_image_to_base64"): |
| | image = state.conv.convert_image_to_base64( |
| | image |
| | ) |
| | else: |
| | from src.conversation import convert_image_to_base64 |
| | image = convert_image_to_base64(image, None) |
| |
|
| | if csam_flag: |
| | state.has_csam_image = True |
| | report_csam_image(state, image) |
| |
|
| | text = text, [image] |
| |
|
| | return text |
| |
|
| |
|
| | def add_text(state, model_selector, text, image, request: gr.Request): |
| | ip = get_ip(request) |
| | logger.info(f"add_text. ip: {ip}. len: {len(text)}; text: {text}") |
| |
|
| | if state is None: |
| | state = State(model_selector) |
| |
|
| | if len(text) <= 0: |
| | state.skip_next = True |
| | return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 |
| |
|
| | all_conv_text = state.conv.get_prompt() |
| | all_conv_text = all_conv_text[-2000:] + "\nuser: " + text |
| | flagged = moderation_filter(all_conv_text, [state.model_name]) |
| | |
| | if flagged: |
| | logger.info(f"violate moderation. ip: {ip}. text: {text}") |
| | |
| | text = MODERATION_MSG |
| |
|
| | if (len(state.conv.messages) - state.conv.offset) // 2 >= CONVERSATION_TURN_LIMIT: |
| | logger.info(f"conversation turn limit. ip: {ip}. text: {text}") |
| | state.skip_next = True |
| | return (state, state.to_gradio_chatbot(), CONVERSATION_LIMIT_MSG, None) + ( |
| | no_change_btn, |
| | ) * 5 |
| |
|
| | text = text[:INPUT_CHAR_LEN_LIMIT] |
| | text = _prepare_text_with_image(state, text, image, csam_flag=False) |
| | state.conv.append_message(state.conv.roles[0], text) |
| | state.conv.append_message(state.conv.roles[1], None) |
| | return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 |
| |
|
| |
|
| | def model_worker_stream_iter( |
| | conv, |
| | model_name, |
| | worker_addr, |
| | prompt, |
| | temperature, |
| | repetition_penalty, |
| | top_p, |
| | max_new_tokens, |
| | images, |
| | ): |
| | |
| | gen_params = { |
| | "model": model_name, |
| | "prompt": prompt, |
| | "temperature": temperature, |
| | "repetition_penalty": repetition_penalty, |
| | "top_p": top_p, |
| | "max_new_tokens": max_new_tokens, |
| | "stop": conv.stop_str, |
| | "stop_token_ids": conv.stop_token_ids, |
| | "echo": False, |
| | } |
| |
|
| | logger.info(f"==== request ====\n{gen_params}") |
| |
|
| | if len(images) > 0: |
| | gen_params["images"] = images |
| |
|
| | |
| | response = requests.post( |
| | worker_addr + "/worker_generate_stream", |
| | headers=headers, |
| | json=gen_params, |
| | stream=True, |
| | timeout=WORKER_API_TIMEOUT, |
| | ) |
| | for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): |
| | if chunk: |
| | data = json.loads(chunk.decode()) |
| | yield data |
| |
|
| |
|
| | def is_limit_reached(model_name, ip): |
| | monitor_url = "http://localhost:9090" |
| | try: |
| | ret = requests.get( |
| | f"{monitor_url}/is_limit_reached?model={model_name}&user_id={ip}", timeout=1 |
| | ) |
| | obj = ret.json() |
| | return obj |
| | except Exception as e: |
| | logger.info(f"monitor error: {e}") |
| | return None |
| |
|
| |
|
| | def bot_response( |
| | state, |
| | temperature, |
| | top_p, |
| | max_new_tokens, |
| | request: gr.Request, |
| | apply_rate_limit=False, |
| | use_recommended_config=False, |
| | ): |
| | ip = get_ip(request) |
| | logger.info(f"bot_response. ip: {ip}") |
| | start_tstamp = time.time() |
| | temperature = float(temperature) |
| | top_p = float(top_p) |
| | max_new_tokens = int(max_new_tokens) |
| |
|
| | if state.skip_next: |
| | |
| | state.skip_next = False |
| | yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 |
| | return |
| |
|
| | if apply_rate_limit: |
| | ret = is_limit_reached(state.model_name, ip) |
| | if ret is not None and ret["is_limit_reached"]: |
| | error_msg = RATE_LIMIT_MSG + "\n\n" + ret["reason"] |
| | logger.info(f"rate limit reached. ip: {ip}. error_msg: {ret['reason']}") |
| | state.conv.update_last_message(error_msg) |
| | yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 |
| | return |
| |
|
| | conv, model_name = state.conv, state.model_name |
| | model_api_dict = ( |
| | api_endpoint_info[model_name] if model_name in api_endpoint_info else None |
| | ) |
| | images = conv.get_images() |
| | logger.info(f"model_name: {model_name}; model_api_dict: {model_api_dict}; msg: {len(conv.messages)}; template: {conv.name}") |
| | if model_api_dict is None: |
| | if model_name == "llava-original": |
| | from src.model.model_llava import inference_by_prompt_and_images |
| | logger.info(f"prompt for llava-original: {conv.get_prompt()}; images: {len(images)}") |
| | output_text = inference_by_prompt_and_images(conv.get_prompt(), images)[0] |
| | else: |
| | from src.model.model_llava import inference_by_prompt_and_images_fire |
| | logger.info(f"prompt for llava-fire: {conv.get_prompt()}; images: {len(images)}") |
| | output_text = inference_by_prompt_and_images_fire(conv.get_prompt(), images)[0] |
| | stream_iter = [{ |
| | "error_code": 0, |
| | "text": output_text |
| | }] |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | else: |
| | if use_recommended_config: |
| | recommended_config = model_api_dict.get("recommended_config", None) |
| | if recommended_config is not None: |
| | temperature = recommended_config.get("temperature", temperature) |
| | top_p = recommended_config.get("top_p", top_p) |
| | max_new_tokens = recommended_config.get( |
| | "max_new_tokens", max_new_tokens |
| | ) |
| |
|
| | stream_iter = [{ |
| | "error_code": 0, |
| | "text": "hello" |
| | }] |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | html_code = ' <span class="cursor"></span> ' |
| |
|
| | |
| | conv.update_last_message(html_code) |
| | yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 |
| |
|
| | try: |
| | data = {"text": ""} |
| | for i, data in enumerate(stream_iter): |
| | if data["error_code"] == 0: |
| | output = data["text"].strip() |
| | |
| | conv.update_last_message(output + html_code) |
| | yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 |
| | else: |
| | output = data["text"] + f"\n\n(error_code: {data['error_code']})" |
| | conv.update_last_message(output) |
| | yield (state, state.to_gradio_chatbot()) + ( |
| | disable_btn, |
| | disable_btn, |
| | disable_btn, |
| | enable_btn, |
| | enable_btn, |
| | ) |
| | return |
| | output = data["text"].strip() |
| | conv.update_last_message(output) |
| | yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 |
| | except requests.exceptions.RequestException as e: |
| | conv.update_last_message( |
| | f"{SERVER_ERROR_MSG}\n\n" |
| | f"(error_code: {ErrorCode.GRADIO_REQUEST_ERROR}, {e})" |
| | ) |
| | yield (state, state.to_gradio_chatbot()) + ( |
| | disable_btn, |
| | disable_btn, |
| | disable_btn, |
| | enable_btn, |
| | enable_btn, |
| | ) |
| | return |
| | except Exception as e: |
| | conv.update_last_message( |
| | f"{SERVER_ERROR_MSG}\n\n" |
| | f"(error_code: {ErrorCode.GRADIO_STREAM_UNKNOWN_ERROR}, {e})" |
| | ) |
| | yield (state, state.to_gradio_chatbot()) + ( |
| | disable_btn, |
| | disable_btn, |
| | disable_btn, |
| | enable_btn, |
| | enable_btn, |
| | ) |
| | return |
| |
|
| | finish_tstamp = time.time() |
| | logger.info(f"{output}") |
| |
|
| | conv.save_new_images( |
| | has_csam_images=state.has_csam_image, use_remote_storage=use_remote_storage |
| | ) |
| |
|
| | filename = get_conv_log_filename( |
| | is_vision=state.is_vision, has_csam_image=state.has_csam_image |
| | ) |
| |
|
| | with open(filename, "a") as fout: |
| | data = { |
| | "tstamp": round(finish_tstamp, 4), |
| | "type": "chat", |
| | "model": model_name, |
| | "gen_params": { |
| | "temperature": temperature, |
| | "top_p": top_p, |
| | "max_new_tokens": max_new_tokens, |
| | }, |
| | "start": round(start_tstamp, 4), |
| | "finish": round(finish_tstamp, 4), |
| | "state": state.dict(), |
| | "ip": get_ip(request), |
| | } |
| | fout.write(json.dumps(data) + "\n") |
| | get_remote_logger().log(data) |
| |
|
| |
|
| | block_css = """ |
| | #notice_markdown .prose { |
| | font-size: 110% !important; |
| | } |
| | #notice_markdown th { |
| | display: none; |
| | } |
| | #notice_markdown td { |
| | padding-top: 6px; |
| | padding-bottom: 6px; |
| | } |
| | #arena_leaderboard_dataframe table { |
| | font-size: 110%; |
| | } |
| | #full_leaderboard_dataframe table { |
| | font-size: 110%; |
| | } |
| | #model_description_markdown { |
| | font-size: 110% !important; |
| | } |
| | #leaderboard_markdown .prose { |
| | font-size: 110% !important; |
| | } |
| | #leaderboard_markdown td { |
| | padding-top: 6px; |
| | padding-bottom: 6px; |
| | } |
| | #leaderboard_dataframe td { |
| | line-height: 0.1em; |
| | } |
| | #about_markdown .prose { |
| | font-size: 110% !important; |
| | } |
| | #ack_markdown .prose { |
| | font-size: 110% !important; |
| | } |
| | #chatbot .prose { |
| | font-size: 105% !important; |
| | } |
| | .sponsor-image-about img { |
| | margin: 0 20px; |
| | margin-top: 20px; |
| | height: 40px; |
| | max-height: 100%; |
| | width: auto; |
| | float: left; |
| | } |
| | |
| | .chatbot h1, h2, h3 { |
| | margin-top: 8px; /* Adjust the value as needed */ |
| | margin-bottom: 0px; /* Adjust the value as needed */ |
| | padding-bottom: 0px; |
| | } |
| | |
| | .chatbot h1 { |
| | font-size: 130%; |
| | } |
| | .chatbot h2 { |
| | font-size: 120%; |
| | } |
| | .chatbot h3 { |
| | font-size: 110%; |
| | } |
| | .chatbot p:not(:first-child) { |
| | margin-top: 8px; |
| | } |
| | |
| | .typing { |
| | display: inline-block; |
| | } |
| | |
| | .cursor { |
| | display: inline-block; |
| | width: 7px; |
| | height: 1em; |
| | background-color: black; |
| | vertical-align: middle; |
| | animation: blink 1s infinite; |
| | } |
| | |
| | .dark .cursor { |
| | display: inline-block; |
| | width: 7px; |
| | height: 1em; |
| | background-color: white; |
| | vertical-align: middle; |
| | animation: blink 1s infinite; |
| | } |
| | |
| | @keyframes blink { |
| | 0%, 50% { opacity: 1; } |
| | 50.1%, 100% { opacity: 0; } |
| | } |
| | |
| | .app { |
| | max-width: 100% !important; |
| | padding: 20px !important; |
| | } |
| | |
| | a { |
| | color: #1976D2; /* Your current link color, a shade of blue */ |
| | text-decoration: none; /* Removes underline from links */ |
| | } |
| | a:hover { |
| | color: #63A4FF; /* This can be any color you choose for hover */ |
| | text-decoration: underline; /* Adds underline on hover */ |
| | } |
| | """ |
| |
|
| |
|
| | def get_model_description_md(models): |
| | model_description_md = """ |
| | | | | | |
| | | ---- | ---- | ---- | |
| | """ |
| | ct = 0 |
| | visited = set() |
| | for i, name in enumerate(models): |
| | minfo = get_model_info(name) |
| | if minfo.simple_name in visited: |
| | continue |
| | visited.add(minfo.simple_name) |
| | one_model_md = f"[{minfo.simple_name}]({minfo.link}): {minfo.description}" |
| |
|
| | if ct % 3 == 0: |
| | model_description_md += "|" |
| | model_description_md += f" {one_model_md} |" |
| | if ct % 3 == 2: |
| | model_description_md += "\n" |
| | ct += 1 |
| | return model_description_md |
| |
|
| |
|
| | def build_about(): |
| | about_markdown = """ |
| | # About Us |
| | Placeholder |
| | ## Arena Core Team |
| | Placeholder |
| | ## Past Members |
| | Placeholder |
| | ## Learn more |
| | Placeholder |
| | |
| | ## Contact Us |
| | Placeholder |
| | |
| | ## Acknowledgment |
| | Placeholder |
| | """ |
| | gr.Markdown(about_markdown, elem_id="about_markdown") |
| |
|
| |
|
| | def build_single_model_ui(models, add_promotion_links=False): |
| | promotion = ( |
| | """ |
| | - | [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) | |
| | - Introducing Llama 2: The Next Generation Open Source Large Language Model. [[Website]](https://ai.meta.com/llama/) |
| | - Vicuna: An Open-Source Chatbot Impressing GPT-4 with 90% ChatGPT Quality. [[Blog]](https://lmsys.org/blog/2023-03-30-vicuna/) |
| | |
| | ## 🤖 Choose any model to chat |
| | """ |
| | if add_promotion_links |
| | else "" |
| | ) |
| |
|
| | notice_markdown = f""" |
| | # 🏔️ Chat with Open Large Language Models |
| | {promotion} |
| | """ |
| |
|
| | state = gr.State() |
| | gr.Markdown(notice_markdown, elem_id="notice_markdown") |
| |
|
| | with gr.Group(elem_id="share-region-named"): |
| | with gr.Row(elem_id="model_selector_row"): |
| | model_selector = gr.Dropdown( |
| | choices=models, |
| | value=models[0] if len(models) > 0 else "", |
| | interactive=True, |
| | show_label=False, |
| | container=False, |
| | ) |
| | with gr.Row(): |
| | with gr.Accordion( |
| | f"🔍 Expand to see the descriptions of {len(models)} models", |
| | open=False, |
| | ): |
| | model_description_md = get_model_description_md(models) |
| | gr.Markdown(model_description_md, elem_id="model_description_markdown") |
| |
|
| | chatbot = gr.Chatbot( |
| | elem_id="chatbot", |
| | label="Scroll down and start chatting", |
| | height=550, |
| | show_copy_button=True, |
| | ) |
| | with gr.Row(): |
| | textbox = gr.Textbox( |
| | show_label=False, |
| | placeholder="👉 Enter your prompt and press ENTER", |
| | elem_id="input_box", |
| | ) |
| | send_btn = gr.Button(value="Send", variant="primary", scale=0) |
| |
|
| | with gr.Row() as button_row: |
| | upvote_btn = gr.Button(value="👍 Upvote", interactive=False) |
| | downvote_btn = gr.Button(value="👎 Downvote", interactive=False) |
| | flag_btn = gr.Button(value="⚠️ Flag", interactive=False) |
| | regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False) |
| | clear_btn = gr.Button(value="🗑️ Clear history", interactive=False) |
| |
|
| | with gr.Accordion("Parameters", open=False) as parameter_row: |
| | temperature = gr.Slider( |
| | minimum=0.0, |
| | maximum=1.0, |
| | value=0.7, |
| | step=0.1, |
| | interactive=True, |
| | label="Temperature", |
| | ) |
| | top_p = gr.Slider( |
| | minimum=0.0, |
| | maximum=1.0, |
| | value=1.0, |
| | step=0.1, |
| | interactive=True, |
| | label="Top P", |
| | ) |
| | max_output_tokens = gr.Slider( |
| | minimum=16, |
| | maximum=2048, |
| | value=1024, |
| | step=64, |
| | interactive=True, |
| | label="Max output tokens", |
| | ) |
| |
|
| | if add_promotion_links: |
| | gr.Markdown(acknowledgment_md, elem_id="ack_markdown") |
| |
|
| | |
| | imagebox = gr.State(None) |
| | btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] |
| | upvote_btn.click( |
| | upvote_last_response, |
| | [state, model_selector], |
| | [textbox, upvote_btn, downvote_btn, flag_btn], |
| | ) |
| | downvote_btn.click( |
| | downvote_last_response, |
| | [state, model_selector], |
| | [textbox, upvote_btn, downvote_btn, flag_btn], |
| | ) |
| | flag_btn.click( |
| | flag_last_response, |
| | [state, model_selector], |
| | [textbox, upvote_btn, downvote_btn, flag_btn], |
| | ) |
| | regenerate_btn.click( |
| | regenerate, state, [state, chatbot, textbox, imagebox] + btn_list |
| | ).then( |
| | bot_response, |
| | [state, temperature, top_p, max_output_tokens], |
| | [state, chatbot] + btn_list, |
| | ) |
| | clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list) |
| |
|
| | model_selector.change( |
| | clear_history, None, [state, chatbot, textbox, imagebox] + btn_list |
| | ) |
| |
|
| | textbox.submit( |
| | add_text, |
| | [state, model_selector, textbox, imagebox], |
| | [state, chatbot, textbox, imagebox] + btn_list, |
| | ).then( |
| | bot_response, |
| | [state, temperature, top_p, max_output_tokens], |
| | [state, chatbot] + btn_list, |
| | ) |
| | send_btn.click( |
| | add_text, |
| | [state, model_selector, textbox, imagebox], |
| | [state, chatbot, textbox, imagebox] + btn_list, |
| | ).then( |
| | bot_response, |
| | [state, temperature, top_p, max_output_tokens], |
| | [state, chatbot] + btn_list, |
| | ) |
| |
|
| | return [state, model_selector] |
| |
|
| |
|
| | def build_demo(models): |
| | with gr.Blocks( |
| | title="Chat with Open Large Language Models", |
| | theme=gr.themes.Default(), |
| | css=block_css, |
| | ) as demo: |
| | url_params = gr.JSON(visible=False) |
| |
|
| | state, model_selector = build_single_model_ui(models) |
| |
|
| | if args.model_list_mode not in ["once", "reload"]: |
| | raise ValueError(f"Unknown model list mode: {args.model_list_mode}") |
| |
|
| | if args.show_terms_of_use: |
| | load_js = get_window_url_params_with_tos_js |
| | else: |
| | load_js = get_window_url_params_js |
| |
|
| | demo.load( |
| | load_demo, |
| | [url_params], |
| | [ |
| | state, |
| | model_selector, |
| | ], |
| | js=load_js, |
| | ) |
| |
|
| | return demo |
| |
|
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--host", type=str, default="0.0.0.0") |
| | parser.add_argument("--port", type=int) |
| | parser.add_argument( |
| | "--share", |
| | action="store_true", |
| | help="Whether to generate a public, shareable link", |
| | ) |
| | parser.add_argument( |
| | "--controller-url", |
| | type=str, |
| | default="http://localhost:21001", |
| | help="The address of the controller", |
| | ) |
| | parser.add_argument( |
| | "--concurrency-count", |
| | type=int, |
| | default=10, |
| | help="The concurrency count of the gradio queue", |
| | ) |
| | parser.add_argument( |
| | "--model-list-mode", |
| | type=str, |
| | default="once", |
| | choices=["once", "reload"], |
| | help="Whether to load the model list once or reload the model list every time", |
| | ) |
| | parser.add_argument( |
| | "--moderate", |
| | action="store_true", |
| | help="Enable content moderation to block unsafe inputs", |
| | ) |
| | parser.add_argument( |
| | "--show-terms-of-use", |
| | action="store_true", |
| | help="Shows term of use before loading the demo", |
| | ) |
| | parser.add_argument( |
| | "--register-api-endpoint-file", |
| | type=str, |
| | help="Register API-based model endpoints from a JSON file", |
| | ) |
| | parser.add_argument( |
| | "--gradio-auth-path", |
| | type=str, |
| | help='Set the gradio authentication file path. The file should contain one or more user:password pairs in this format: "u1:p1,u2:p2,u3:p3"', |
| | ) |
| | parser.add_argument( |
| | "--gradio-root-path", |
| | type=str, |
| | help="Sets the gradio root path, eg /abc/def. Useful when running behind a reverse-proxy or at a custom URL path prefix", |
| | ) |
| | parser.add_argument( |
| | "--use-remote-storage", |
| | action="store_true", |
| | default=False, |
| | help="Uploads image files to google cloud storage if set to true", |
| | ) |
| | args = parser.parse_args() |
| | logger.info(f"args: {args}") |
| |
|
| | |
| | set_global_vars(args.controller_url, args.moderate, args.use_remote_storage) |
| | models, all_models = get_model_list( |
| | args.controller_url, args.register_api_endpoint_file, vision_arena=False |
| | ) |
| |
|
| | |
| | auth = None |
| | if args.gradio_auth_path is not None: |
| | auth = parse_gradio_auth_creds(args.gradio_auth_path) |
| |
|
| | |
| | demo = build_demo(models) |
| | demo.queue( |
| | default_concurrency_limit=args.concurrency_count, |
| | status_update_rate=10, |
| | api_open=False, |
| | ).launch( |
| | server_name=args.host, |
| | server_port=args.port, |
| | share=args.share, |
| | max_threads=200, |
| | auth=auth, |
| | root_path=args.gradio_root_path, |
| | ) |
| |
|