|
|
import random |
|
|
from collections.abc import Mapping |
|
|
from uuid import uuid4 |
|
|
|
|
|
from openai import OpenAI |
|
|
import gradio as gr |
|
|
import base64 |
|
|
import mimetypes |
|
|
import copy |
|
|
import os |
|
|
|
|
|
from theme import apriel |
|
|
from utils import COMMUNITY_POSTFIX_URL, get_model_config, check_format, models_config, \ |
|
|
logged_event_handler, DEBUG_MODE, DEBUG_MODEL, log_debug, log_info, log_error, log_warning |
|
|
from log_chat import log_chat |
|
|
|
|
|
MODEL_TEMPERATURE = 0.8 |
|
|
BUTTON_WIDTH = 160 |
|
|
DEFAULT_OPT_OUT_VALUE = DEBUG_MODE |
|
|
|
|
|
|
|
|
DEFAULT_MODEL_NAME = "Apriel-1.5-15B-thinker" if not DEBUG_MODEL else "Apriel-1.5-15B-thinker" |
|
|
|
|
|
BUTTON_ENABLED = gr.update(interactive=True) |
|
|
BUTTON_DISABLED = gr.update(interactive=False) |
|
|
INPUT_ENABLED = gr.update(interactive=True) |
|
|
INPUT_DISABLED = gr.update(interactive=False) |
|
|
DROPDOWN_ENABLED = gr.update(interactive=True) |
|
|
DROPDOWN_DISABLED = gr.update(interactive=False) |
|
|
|
|
|
SEND_BUTTON_ENABLED = gr.update(interactive=True, visible=True) |
|
|
SEND_BUTTON_DISABLED = gr.update(interactive=True, visible=False) |
|
|
STOP_BUTTON_ENABLED = gr.update(interactive=True, visible=True) |
|
|
STOP_BUTTON_DISABLED = gr.update(interactive=True, visible=False) |
|
|
|
|
|
chat_start_count = 0 |
|
|
model_config = {} |
|
|
openai_client = None |
|
|
|
|
|
USE_RANDOM_ENDPOINT = False |
|
|
endpoint_rotation_count = 0 |
|
|
|
|
|
|
|
|
MAX_IMAGE_MESSAGES = 5 |
|
|
|
|
|
|
|
|
def app_loaded(state, request: gr.Request): |
|
|
message_html = setup_model(DEFAULT_MODEL_NAME, intial=False) |
|
|
state['session'] = request.session_hash if request else uuid4().hex |
|
|
log_debug(f"app_loaded() --> Session: {state['session']}") |
|
|
return state, message_html |
|
|
|
|
|
|
|
|
def update_model_and_clear_chat(model_name): |
|
|
actual_model_name = model_name.replace("Model: ", "") |
|
|
desc = setup_model(actual_model_name) |
|
|
return desc, [] |
|
|
|
|
|
|
|
|
def setup_model(model_key, intial=False): |
|
|
global model_config, openai_client, endpoint_rotation_count |
|
|
model_config = get_model_config(model_key) |
|
|
log_debug(f"update_model() --> Model config: {model_config}") |
|
|
|
|
|
url_list = (model_config.get('VLLM_API_URL_LIST') or "").split(",") |
|
|
if USE_RANDOM_ENDPOINT: |
|
|
base_url = random.choice(url_list) if len(url_list) > 0 else model_config.get('VLLM_API_URL') |
|
|
else: |
|
|
base_url = url_list[endpoint_rotation_count % len(url_list)] |
|
|
endpoint_rotation_count += 1 |
|
|
|
|
|
openai_client = OpenAI( |
|
|
api_key=model_config.get('AUTH_TOKEN'), |
|
|
base_url=base_url |
|
|
) |
|
|
model_config['base_url'] = base_url |
|
|
log_debug(f"Switched to model {model_key} using endpoint {base_url}") |
|
|
|
|
|
_model_hf_name = model_config.get("MODEL_HF_URL").split('https://huggingface.co/')[1] |
|
|
_link = f"<a href='{model_config.get('MODEL_HF_URL')}{COMMUNITY_POSTFIX_URL}' target='_blank'>{_model_hf_name}</a>" |
|
|
_description = f"We'd love to hear your thoughts on the model. Click here to provide feedback - {_link}" |
|
|
|
|
|
if intial: |
|
|
return |
|
|
else: |
|
|
return _description |
|
|
|
|
|
|
|
|
def chat_started(): |
|
|
|
|
|
return (DROPDOWN_DISABLED, gr.update(value="", interactive=False), |
|
|
SEND_BUTTON_DISABLED, STOP_BUTTON_ENABLED, BUTTON_DISABLED) |
|
|
|
|
|
|
|
|
def chat_finished(): |
|
|
|
|
|
return DROPDOWN_ENABLED, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED |
|
|
|
|
|
|
|
|
def stop_chat(state): |
|
|
state["stop_flag"] = True |
|
|
gr.Info("Chat stopped") |
|
|
return state |
|
|
|
|
|
|
|
|
def toggle_opt_out(state, checkbox): |
|
|
state["opt_out"] = checkbox |
|
|
return state |
|
|
|
|
|
|
|
|
def run_chat_inference(history, message, state): |
|
|
global chat_start_count |
|
|
state["is_streaming"] = True |
|
|
state["stop_flag"] = False |
|
|
error = None |
|
|
model_name = model_config.get('MODEL_NAME') |
|
|
|
|
|
|
|
|
setup_model(model_config.get('MODEL_KEY')) |
|
|
log_info("Using model {model_name} with endpoint {model_config.get('base_url')}") |
|
|
|
|
|
if len(history) == 0: |
|
|
state["chat_id"] = uuid4().hex |
|
|
|
|
|
if openai_client is None: |
|
|
log_info("Client UI is stale, letting user know to refresh the page") |
|
|
gr.Warning("Client UI is stale, please refresh the page") |
|
|
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state |
|
|
|
|
|
|
|
|
files = [] |
|
|
|
|
|
|
|
|
log_debug(f"{'-' * 80}") |
|
|
log_debug(f"chat_fn() --> Message: {message}") |
|
|
log_debug(f"chat_fn() --> History: {history}") |
|
|
|
|
|
|
|
|
if isinstance(message, Mapping): |
|
|
files = message.get("files") or [] |
|
|
message = message.get("text") or "" |
|
|
log_debug(f"chat_fn() --> Message (text only): {message}") |
|
|
log_debug(f"chat_fn() --> Files: {files}") |
|
|
|
|
|
|
|
|
if len(files) > 0: |
|
|
invalid_files = [] |
|
|
for path in files: |
|
|
try: |
|
|
mime, _ = mimetypes.guess_type(path) |
|
|
mime = mime or "" |
|
|
if not mime.startswith("image/"): |
|
|
invalid_files.append((os.path.basename(path), mime or "unknown")) |
|
|
except Exception as e: |
|
|
log_error(f"Failed to inspect file '{path}': {e}") |
|
|
invalid_files.append((os.path.basename(path), "unknown")) |
|
|
|
|
|
if invalid_files: |
|
|
msg = "Only image files are allowed. Invalid uploads: " + \ |
|
|
", ".join([f"{p} (type: {m})" for p, m in invalid_files]) |
|
|
log_warning(msg) |
|
|
gr.Warning(msg) |
|
|
yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state |
|
|
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state |
|
|
|
|
|
|
|
|
if len(files) > MAX_IMAGE_MESSAGES: |
|
|
gr.Warning(f"Too many images provided; keeping only the first {MAX_IMAGE_MESSAGES} file(s).") |
|
|
files = files[:MAX_IMAGE_MESSAGES] |
|
|
|
|
|
try: |
|
|
|
|
|
if not message.strip() and len(files) == 0: |
|
|
gr.Info("Please enter a message before sending") |
|
|
yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state |
|
|
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state |
|
|
|
|
|
chat_start_count = chat_start_count + 1 |
|
|
user_messages_count = sum(1 for item in history if isinstance(item, dict) and item.get("role") == "user") |
|
|
log_info(f"chat_start_count: {chat_start_count}, turns: {user_messages_count}, model: {model_name}") |
|
|
|
|
|
is_reasoning = model_config.get("REASONING") |
|
|
|
|
|
|
|
|
log_debug(f"Initial History: {history}") |
|
|
check_format(history, "messages") |
|
|
|
|
|
|
|
|
if len(files) == 0: |
|
|
history.append({"role": "user", "content": message}) |
|
|
else: |
|
|
if message.strip(): |
|
|
history.append({"role": "user", "content": message}) |
|
|
for path in files: |
|
|
history.append({"role": "user", "content": {"path": path}}) |
|
|
|
|
|
log_debug(f"History with user message: {history}") |
|
|
check_format(history, "messages") |
|
|
|
|
|
|
|
|
try: |
|
|
history_no_thoughts = [item for item in history if |
|
|
not (isinstance(item, dict) and |
|
|
item.get("role") == "assistant" and |
|
|
isinstance(item.get("metadata"), dict) and |
|
|
item.get("metadata", {}).get("title") is not None)] |
|
|
log_debug(f"Updated History: {history_no_thoughts}") |
|
|
check_format(history_no_thoughts, "messages") |
|
|
log_debug(f"history_no_thoughts with user message: {history_no_thoughts}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
api_messages = [] |
|
|
image_parts_buffer = [] |
|
|
|
|
|
def flush_image_buffer(): |
|
|
if len(image_parts_buffer) > 0: |
|
|
api_messages.append({"role": "user", "content": list(image_parts_buffer)}) |
|
|
image_parts_buffer.clear() |
|
|
|
|
|
def to_image_part(path: str): |
|
|
try: |
|
|
mime, _ = mimetypes.guess_type(path) |
|
|
mime = mime or "application/octet-stream" |
|
|
with open(path, "rb") as f: |
|
|
b64 = base64.b64encode(f.read()).decode("utf-8") |
|
|
data_url = f"data:{mime};base64,{b64}" |
|
|
return {"type": "image_url", "image_url": {"url": data_url}} |
|
|
except Exception as e: |
|
|
log_error(f"Failed to load file '{path}': {e}") |
|
|
return None |
|
|
|
|
|
def normalize_msg(msg): |
|
|
|
|
|
if isinstance(msg, dict): |
|
|
return msg.get("role"), msg.get("content"), msg |
|
|
|
|
|
role = getattr(msg, "role", None) |
|
|
content = getattr(msg, "content", None) |
|
|
if role is not None: |
|
|
return role, content, {"role": role, "content": content} |
|
|
return None, None, msg |
|
|
|
|
|
for m in copy.deepcopy(history_no_thoughts): |
|
|
role, content, as_dict = normalize_msg(m) |
|
|
|
|
|
if role is None: |
|
|
flush_image_buffer() |
|
|
api_messages.append(as_dict) |
|
|
continue |
|
|
|
|
|
|
|
|
if role == "assistant": |
|
|
flush_image_buffer() |
|
|
api_messages.append(as_dict) |
|
|
continue |
|
|
|
|
|
|
|
|
if role == "user": |
|
|
|
|
|
if isinstance(content, dict) and isinstance(content.get("path"), str): |
|
|
p = content["path"] |
|
|
part = to_image_part(p) if os.path.isfile(p) else None |
|
|
if part: |
|
|
image_parts_buffer.append(part) |
|
|
else: |
|
|
flush_image_buffer() |
|
|
api_messages.append({"role": "user", "content": str(content)}) |
|
|
continue |
|
|
|
|
|
|
|
|
if isinstance(content, str): |
|
|
if os.path.isfile(content): |
|
|
part = to_image_part(content) |
|
|
if part: |
|
|
image_parts_buffer.append(part) |
|
|
continue |
|
|
|
|
|
flush_image_buffer() |
|
|
api_messages.append({"role": "user", "content": content}) |
|
|
continue |
|
|
if isinstance(content, tuple): |
|
|
|
|
|
tuple_items = list(content) |
|
|
tmp_parts = [] |
|
|
text_accum = [] |
|
|
for item in tuple_items: |
|
|
if isinstance(item, str) and os.path.isfile(item): |
|
|
part = to_image_part(item) |
|
|
if part: |
|
|
tmp_parts.append(part) |
|
|
else: |
|
|
text_accum.append(item) |
|
|
else: |
|
|
text_accum.append(str(item)) |
|
|
if tmp_parts: |
|
|
flush_image_buffer() |
|
|
api_messages.append({"role": "user", "content": tmp_parts}) |
|
|
if not text_accum: |
|
|
continue |
|
|
if text_accum: |
|
|
flush_image_buffer() |
|
|
api_messages.append({"role": "user", "content": "\n".join(text_accum)}) |
|
|
continue |
|
|
|
|
|
|
|
|
if isinstance(content, list): |
|
|
|
|
|
all_dicts = all(isinstance(c, dict) for c in content) |
|
|
if all_dicts: |
|
|
flush_image_buffer() |
|
|
api_messages.append({"role": "user", "content": content}) |
|
|
continue |
|
|
|
|
|
tmp_parts = [] |
|
|
text_accum = [] |
|
|
|
|
|
def flush_text_accum(): |
|
|
if text_accum: |
|
|
api_messages.append({"role": "user", "content": "\n".join(text_accum)}) |
|
|
text_accum.clear() |
|
|
for item in content: |
|
|
if isinstance(item, str) and os.path.isfile(item): |
|
|
part = to_image_part(item) |
|
|
if part: |
|
|
tmp_parts.append(part) |
|
|
else: |
|
|
text_accum.append(item) |
|
|
else: |
|
|
text_accum.append(str(item)) |
|
|
if tmp_parts: |
|
|
flush_image_buffer() |
|
|
api_messages.append({"role": "user", "content": tmp_parts}) |
|
|
if text_accum: |
|
|
flush_text_accum() |
|
|
continue |
|
|
|
|
|
|
|
|
flush_image_buffer() |
|
|
api_messages.append(as_dict) |
|
|
continue |
|
|
|
|
|
|
|
|
flush_image_buffer() |
|
|
api_messages.append(as_dict) |
|
|
|
|
|
|
|
|
flush_image_buffer() |
|
|
|
|
|
log_debug(f"sending api_messages to model {model_name}: {api_messages}") |
|
|
|
|
|
|
|
|
image_msg_indices = [ |
|
|
i for i, msg in enumerate(api_messages) |
|
|
if isinstance(msg, dict) and isinstance(msg.get('content'), list) |
|
|
] |
|
|
image_count = len(image_msg_indices) |
|
|
if image_count > MAX_IMAGE_MESSAGES: |
|
|
|
|
|
to_remove = image_count - MAX_IMAGE_MESSAGES |
|
|
removed = 0 |
|
|
for idx in image_msg_indices: |
|
|
if removed >= to_remove: |
|
|
break |
|
|
|
|
|
api_messages.pop(idx - removed) |
|
|
removed += 1 |
|
|
gr.Warning(f"Too many images provided; keeping the latest {MAX_IMAGE_MESSAGES} and dropped {removed} older image message(s).") |
|
|
|
|
|
stream = openai_client.chat.completions.create( |
|
|
model=model_name, |
|
|
messages=api_messages, |
|
|
temperature=MODEL_TEMPERATURE, |
|
|
stream=True |
|
|
) |
|
|
except Exception as e: |
|
|
log_error(f"Error:\n\t{e}\n\tInference failed for model {model_name} and endpoint {model_config['base_url']}") |
|
|
error = str(e) |
|
|
yield ([{"role": "assistant", |
|
|
"content": "😔 The model is unavailable at the moment. Please try again later."}], |
|
|
INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state) |
|
|
if state["opt_out"] is not True: |
|
|
log_chat(chat_id=state["chat_id"], |
|
|
session_id=state["session"], |
|
|
model_name=model_name, |
|
|
prompt=message, |
|
|
history=history, |
|
|
info={"is_reasoning": model_config.get("REASONING"), "temperature": MODEL_TEMPERATURE, |
|
|
"stopped": True, "error": str(e)}, |
|
|
) |
|
|
else: |
|
|
log_info(f"User opted out of chat history. Not logging chat. model: {model_name}") |
|
|
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state |
|
|
|
|
|
if is_reasoning: |
|
|
history.append(gr.ChatMessage( |
|
|
role="assistant", |
|
|
content="Thinking...", |
|
|
metadata={"title": "🧠 Thought"} |
|
|
)) |
|
|
log_debug(f"History added thinking: {history}") |
|
|
check_format(history, "messages") |
|
|
else: |
|
|
history.append(gr.ChatMessage( |
|
|
role="assistant", |
|
|
content="", |
|
|
)) |
|
|
log_debug(f"History added empty assistant: {history}") |
|
|
check_format(history, "messages") |
|
|
|
|
|
output = "" |
|
|
completion_started = False |
|
|
for chunk in stream: |
|
|
if state["stop_flag"]: |
|
|
log_debug(f"chat_fn() --> Stopping streaming...") |
|
|
break |
|
|
|
|
|
content = getattr(chunk.choices[0].delta, "content", "") or "" |
|
|
reasoning_content = getattr(chunk.choices[0].delta, "reasoning_content", "") or "" |
|
|
output += reasoning_content + content |
|
|
|
|
|
if is_reasoning: |
|
|
parts = output.split("[BEGIN FINAL RESPONSE]") |
|
|
|
|
|
if len(parts) > 1: |
|
|
if parts[1].endswith("[END FINAL RESPONSE]"): |
|
|
parts[1] = parts[1].replace("[END FINAL RESPONSE]", "") |
|
|
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>"): |
|
|
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>", "") |
|
|
if parts[1].endswith("[END FINAL RESPONSE]\n<|end|>\n"): |
|
|
parts[1] = parts[1].replace("[END FINAL RESPONSE]\n<|end|>\n", "") |
|
|
if parts[1].endswith("<|end|>"): |
|
|
parts[1] = parts[1].replace("<|end|>", "") |
|
|
if parts[1].endswith("<|end|>\n"): |
|
|
parts[1] = parts[1].replace("<|end|>\n", "") |
|
|
|
|
|
history[-1 if not completion_started else -2] = gr.ChatMessage( |
|
|
role="assistant", |
|
|
content=parts[0], |
|
|
metadata={"title": "🧠 Thought"} |
|
|
) |
|
|
if completion_started: |
|
|
history[-1] = gr.ChatMessage( |
|
|
role="assistant", |
|
|
content=parts[1] |
|
|
) |
|
|
elif len(parts) > 1 and not completion_started: |
|
|
completion_started = True |
|
|
history.append(gr.ChatMessage( |
|
|
role="assistant", |
|
|
content=parts[1] |
|
|
)) |
|
|
else: |
|
|
if output.endswith("<|end|>"): |
|
|
output = output.replace("<|end|>", "") |
|
|
if output.endswith("<|end|>\n"): |
|
|
output = output.replace("<|end|>\n", "") |
|
|
history[-1] = gr.ChatMessage( |
|
|
role="assistant", |
|
|
content=output |
|
|
) |
|
|
|
|
|
|
|
|
yield history, INPUT_DISABLED, SEND_BUTTON_DISABLED, STOP_BUTTON_ENABLED, BUTTON_DISABLED, state |
|
|
|
|
|
log_debug(f"Final History: {history}") |
|
|
check_format(history, "messages") |
|
|
yield history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state |
|
|
finally: |
|
|
if error is None: |
|
|
log_debug(f"chat_fn() --> Finished streaming. {chat_start_count} chats started.") |
|
|
if state["opt_out"] is not True: |
|
|
log_chat(chat_id=state["chat_id"], |
|
|
session_id=state["session"], |
|
|
model_name=model_name, |
|
|
prompt=message, |
|
|
history=history, |
|
|
info={"is_reasoning": model_config.get("REASONING"), "temperature": MODEL_TEMPERATURE, |
|
|
"stopped": state["stop_flag"]}, |
|
|
) |
|
|
|
|
|
else: |
|
|
log_info(f"User opted out of chat history. Not logging chat. model: {model_name}") |
|
|
state["is_streaming"] = False |
|
|
state["stop_flag"] = False |
|
|
return history, INPUT_ENABLED, SEND_BUTTON_ENABLED, STOP_BUTTON_DISABLED, BUTTON_ENABLED, state |
|
|
|
|
|
|
|
|
log_info(f"Gradio version: {gr.__version__}") |
|
|
|
|
|
title = None |
|
|
description = None |
|
|
theme = apriel |
|
|
|
|
|
with open('styles.css', 'r') as f: |
|
|
custom_css = f.read() |
|
|
|
|
|
with gr.Blocks(theme=theme, css=custom_css) as demo: |
|
|
session_state = gr.State(value={ |
|
|
"is_streaming": False, |
|
|
"stop_flag": False, |
|
|
"chat_id": None, |
|
|
"session": None, |
|
|
"opt_out": DEFAULT_OPT_OUT_VALUE, |
|
|
}) |
|
|
|
|
|
gr.HTML(f""" |
|
|
<style> |
|
|
@media (min-width: 1024px) {{ |
|
|
.send-button-container, .clear-button-container {{ |
|
|
max-width: {BUTTON_WIDTH}px; |
|
|
}} |
|
|
}} |
|
|
</style> |
|
|
""", elem_classes="css-styles") |
|
|
with gr.Row(variant="panel", elem_classes="responsive-row"): |
|
|
with gr.Column(scale=1, min_width=400, elem_classes="model-dropdown-container"): |
|
|
model_dropdown = gr.Dropdown( |
|
|
choices=[f"Model: {model}" for model in models_config.keys()], |
|
|
value=f"Model: {DEFAULT_MODEL_NAME}", |
|
|
label=None, |
|
|
interactive=True, |
|
|
container=False, |
|
|
scale=0, |
|
|
min_width=400 |
|
|
) |
|
|
with gr.Column(scale=4, min_width=0): |
|
|
feedback_message_html = gr.HTML(description, elem_classes="model-message") |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
|
type="messages", |
|
|
height="calc(100dvh - 310px)", |
|
|
elem_classes="chatbot", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=10, min_width=400, elem_classes="user-input-container"): |
|
|
with gr.Row(): |
|
|
user_input = gr.MultimodalTextbox( |
|
|
interactive=True, |
|
|
container=False, |
|
|
file_count="multiple", |
|
|
placeholder="Type your message here and press Enter or upload file...", |
|
|
show_label=False, |
|
|
sources=["upload"], |
|
|
max_plain_text_length=100000 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Column(scale=1, min_width=BUTTON_WIDTH * 2 + 20): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="send-button-container"): |
|
|
send_btn = gr.Button("Send", variant="primary", elem_classes="control-button") |
|
|
stop_btn = gr.Button("Stop", variant="cancel", elem_classes="control-button", visible=False) |
|
|
with gr.Column(scale=1, min_width=BUTTON_WIDTH, elem_classes="clear-button-container"): |
|
|
clear_btn = gr.ClearButton(chatbot, value="New Chat", variant="secondary", elem_classes="control-button") |
|
|
with gr.Row(): |
|
|
with gr.Column(min_width=400, elem_classes="opt-out-container"): |
|
|
with gr.Row(): |
|
|
gr.HTML( |
|
|
"We may use your chats to improve our AI. You may opt out if you don’t want your conversations saved.", |
|
|
elem_classes="opt-out-message") |
|
|
with gr.Row(): |
|
|
opt_out_checkbox = gr.Checkbox( |
|
|
label="Don’t save my chat history for improvements or training", |
|
|
value=DEFAULT_OPT_OUT_VALUE, |
|
|
elem_classes="opt-out-checkbox", |
|
|
interactive=True, |
|
|
container=False |
|
|
) |
|
|
|
|
|
gr.on( |
|
|
triggers=[send_btn.click, user_input.submit], |
|
|
fn=run_chat_inference, |
|
|
inputs=[chatbot, user_input, session_state], |
|
|
outputs=[chatbot, user_input, send_btn, stop_btn, clear_btn, session_state], |
|
|
concurrency_limit=4, |
|
|
api_name=False |
|
|
).then( |
|
|
fn=chat_finished, inputs=None, outputs=[model_dropdown, user_input, send_btn, stop_btn, clear_btn], queue=False) |
|
|
|
|
|
|
|
|
gr.on( |
|
|
triggers=[send_btn.click, user_input.submit], |
|
|
fn=chat_started, |
|
|
inputs=None, |
|
|
outputs=[model_dropdown, user_input, send_btn, stop_btn, clear_btn], |
|
|
queue=False, |
|
|
show_progress='hidden', |
|
|
api_name=False |
|
|
) |
|
|
|
|
|
stop_btn.click( |
|
|
fn=stop_chat, |
|
|
inputs=[session_state], |
|
|
outputs=[session_state], |
|
|
api_name=False |
|
|
) |
|
|
|
|
|
opt_out_checkbox.change(fn=toggle_opt_out, inputs=[session_state, opt_out_checkbox], outputs=[session_state]) |
|
|
|
|
|
|
|
|
demo.load( |
|
|
fn=logged_event_handler( |
|
|
log_msg="Browser session started", |
|
|
event_handler=app_loaded |
|
|
), |
|
|
inputs=[session_state], |
|
|
outputs=[session_state, feedback_message_html], |
|
|
queue=True, |
|
|
api_name=False |
|
|
) |
|
|
|
|
|
model_dropdown.change( |
|
|
fn=update_model_and_clear_chat, |
|
|
inputs=[model_dropdown], |
|
|
outputs=[feedback_message_html, chatbot], |
|
|
api_name=False |
|
|
) |
|
|
|
|
|
demo.queue(default_concurrency_limit=2).launch(ssr_mode=False, show_api=False, max_file_size="10mb") |
|
|
log_info("Gradio app launched") |
|
|
|