| | """ |
| | This file defines a useful high-level abstraction to build Gradio chatbots: ChatInterface. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import inspect |
| | from typing import AsyncGenerator, Callable, Literal, Union, cast |
| |
|
| | import anyio |
| | from gradio_client.documentation import document |
| |
|
| | from gradio.blocks import Blocks |
| | from gradio.components import ( |
| | Button, |
| | Chatbot, |
| | Component, |
| | Markdown, |
| | MultimodalTextbox, |
| | State, |
| | Textbox, |
| | get_component_instance, |
| | Dataset, |
| | ) |
| | from gradio.events import Dependency, on |
| | from gradio.helpers import special_args |
| | from gradio.layouts import Accordion, Group, Row |
| | from gradio.routes import Request |
| | from gradio.themes import ThemeClass as Theme |
| | from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda |
| |
|
| |
|
| | @document() |
| | class ChatInterface(Blocks): |
| | """ |
| | ChatInterface is Gradio's high-level abstraction for creating chatbot UIs, and allows you to create |
| | a web-based demo around a chatbot model in a few lines of code. Only one parameter is required: fn, which |
| | takes a function that governs the response of the chatbot based on the user input and chat history. Additional |
| | parameters can be used to control the appearance and behavior of the demo. |
| | |
| | Example: |
| | import gradio as gr |
| | |
| | def echo(message, history): |
| | return message |
| | |
| | demo = gr.ChatInterface(fn=echo, examples=["hello", "hola", "merhaba"], title="Echo Bot") |
| | demo.launch() |
| | Demos: chatinterface_multimodal, chatinterface_random_response, chatinterface_streaming_echo |
| | Guides: creating-a-chatbot-fast, sharing-your-app |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | fn: Callable, |
| | post_fn: Callable, |
| | pre_fn: Callable, |
| | chatbot: Chatbot, |
| | *, |
| | show_stop_button=True, |
| | post_fn_kwargs: dict = None, |
| | pre_fn_kwargs: dict = None, |
| | multimodal: bool = False, |
| | textbox: Textbox | MultimodalTextbox | None = None, |
| | additional_inputs: str | Component | list[str | Component] | None = None, |
| | additional_inputs_accordion_name: str | None = None, |
| | additional_inputs_accordion: str | Accordion | None = None, |
| | examples: Dataset = None, |
| | title: str | None = None, |
| | description: str | None = None, |
| | theme: Theme | str | None = None, |
| | css: str | None = None, |
| | js: str | None = None, |
| | head: str | None = None, |
| | analytics_enabled: bool | None = None, |
| | submit_btn: str | None | Button = "Submit", |
| | stop_btn: str | None | Button = "Stop", |
| | retry_btn: str | None | Button = "🔄 Retry", |
| | undo_btn: str | None | Button = "↩️ Undo", |
| | clear_btn: str | None | Button = "🗑️ Clear", |
| | autofocus: bool = True, |
| | concurrency_limit: int | None | Literal["default"] = "default", |
| | fill_height: bool = True, |
| | delete_cache: tuple[int, int] | None = None, |
| | ): |
| | super().__init__( |
| | analytics_enabled=analytics_enabled, |
| | mode="chat_interface", |
| | css=css, |
| | title=title or "Gradio", |
| | theme=theme, |
| | js=js, |
| | head=head, |
| | fill_height=fill_height, |
| | delete_cache=delete_cache, |
| | ) |
| |
|
| | if post_fn_kwargs is None: |
| | post_fn_kwargs = [] |
| |
|
| | self.post_fn = post_fn |
| | self.post_fn_kwargs = post_fn_kwargs |
| |
|
| | self.pre_fn = pre_fn |
| | self.pre_fn_kwargs = pre_fn_kwargs |
| |
|
| | self.show_stop_button = show_stop_button |
| |
|
| | self.interrupter = State(None) |
| |
|
| | self.multimodal = multimodal |
| | self.concurrency_limit = concurrency_limit |
| | self.fn = fn |
| | self.is_async = inspect.iscoroutinefunction( |
| | self.fn |
| | ) or inspect.isasyncgenfunction(self.fn) |
| | self.is_generator = inspect.isgeneratorfunction( |
| | self.fn |
| | ) or inspect.isasyncgenfunction(self.fn) |
| |
|
| | if additional_inputs: |
| | if not isinstance(additional_inputs, list): |
| | additional_inputs = [additional_inputs] |
| | self.additional_inputs = [ |
| | get_component_instance(i) |
| | for i in additional_inputs |
| | ] |
| | else: |
| | self.additional_inputs = [] |
| | if additional_inputs_accordion_name is not None: |
| | print( |
| | "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead." |
| | ) |
| | self.additional_inputs_accordion_params = { |
| | "label": additional_inputs_accordion_name |
| | } |
| | if additional_inputs_accordion is None: |
| | self.additional_inputs_accordion_params = { |
| | "label": "Additional Inputs", |
| | "open": False, |
| | } |
| | elif isinstance(additional_inputs_accordion, str): |
| | self.additional_inputs_accordion_params = { |
| | "label": additional_inputs_accordion |
| | } |
| | elif isinstance(additional_inputs_accordion, Accordion): |
| | self.additional_inputs_accordion_params = ( |
| | additional_inputs_accordion.recover_kwargs( |
| | additional_inputs_accordion.get_config() |
| | ) |
| | ) |
| | else: |
| | raise ValueError( |
| | f"The `additional_inputs_accordion` parameter must be a string or gr.Accordion, not {type(additional_inputs_accordion)}" |
| | ) |
| |
|
| | with self: |
| | if title: |
| | Markdown( |
| | f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>" |
| | ) |
| | if description: |
| | Markdown(description) |
| |
|
| | self.chatbot = chatbot.render() |
| |
|
| | self.buttons = [retry_btn, undo_btn, clear_btn] |
| |
|
| | with Group(): |
| | with Row(): |
| | if textbox: |
| | if self.multimodal: |
| | submit_btn = None |
| | else: |
| | textbox.container = False |
| | textbox.show_label = False |
| | textbox_ = textbox.render() |
| | if not isinstance(textbox_, (Textbox, MultimodalTextbox)): |
| | raise TypeError( |
| | f"Expected a gr.Textbox or gr.MultimodalTextbox component, but got {type(textbox_)}" |
| | ) |
| | self.textbox = textbox_ |
| | elif self.multimodal: |
| | submit_btn = None |
| | self.textbox = MultimodalTextbox( |
| | show_label=False, |
| | label="Message", |
| | placeholder="Type a message...", |
| | scale=7, |
| | autofocus=autofocus, |
| | ) |
| | else: |
| | self.textbox = Textbox( |
| | container=False, |
| | show_label=False, |
| | label="Message", |
| | placeholder="Type a message...", |
| | scale=7, |
| | autofocus=autofocus, |
| | ) |
| | if submit_btn is not None and not multimodal: |
| | if isinstance(submit_btn, Button): |
| | submit_btn.render() |
| | elif isinstance(submit_btn, str): |
| | submit_btn = Button( |
| | submit_btn, |
| | variant="primary", |
| | scale=1, |
| | min_width=150, |
| | ) |
| | else: |
| | raise ValueError( |
| | f"The submit_btn parameter must be a gr.Button, string, or None, not {type(submit_btn)}" |
| | ) |
| | if stop_btn is not None: |
| | if isinstance(stop_btn, Button): |
| | stop_btn.visible = False |
| | stop_btn.render() |
| | elif isinstance(stop_btn, str): |
| | stop_btn = Button( |
| | stop_btn, |
| | variant="stop", |
| | visible=False, |
| | scale=1, |
| | min_width=150, |
| | ) |
| | else: |
| | raise ValueError( |
| | f"The stop_btn parameter must be a gr.Button, string, or None, not {type(stop_btn)}" |
| | ) |
| | self.buttons.extend([submit_btn, stop_btn]) |
| |
|
| | self.fake_api_btn = Button("Fake API", visible=False) |
| | self.fake_response_textbox = Textbox(label="Response", visible=False) |
| | ( |
| | self.retry_btn, |
| | self.undo_btn, |
| | self.clear_btn, |
| | self.submit_btn, |
| | self.stop_btn, |
| | ) = self.buttons |
| |
|
| | any_unrendered_inputs = any( |
| | not inp.is_rendered for inp in self.additional_inputs |
| | ) |
| | if self.additional_inputs and any_unrendered_inputs: |
| | with Accordion(**self.additional_inputs_accordion_params): |
| | for input_component in self.additional_inputs: |
| | if not input_component.is_rendered: |
| | input_component.render() |
| |
|
| | self.saved_input = State() |
| | self.chatbot_state = ( |
| | State(self.chatbot.value) if self.chatbot.value else State([]) |
| | ) |
| |
|
| | self._setup_events() |
| | self._setup_api() |
| |
|
| | if examples: |
| | examples.click(lambda x: x[0], inputs=[examples], outputs=self.textbox, show_progress=False, queue=False) |
| |
|
| | def _setup_events(self) -> None: |
| | submit_fn = self._stream_fn if self.is_generator else self._submit_fn |
| | submit_triggers = ( |
| | [self.textbox.submit, self.submit_btn.click] |
| | if self.submit_btn |
| | else [self.textbox.submit] |
| | ) |
| | submit_event = ( |
| | on( |
| | submit_triggers, |
| | self._clear_and_save_textbox, |
| | [self.textbox], |
| | [self.textbox, self.saved_input], |
| | show_api=False, |
| | queue=False, |
| | ) |
| | .then( |
| | self.pre_fn, |
| | **self.pre_fn_kwargs, |
| | show_api=False, |
| | queue=False, |
| | ) |
| | .then( |
| | self._display_input, |
| | [self.saved_input, self.chatbot_state], |
| | [self.chatbot, self.chatbot_state], |
| | show_api=False, |
| | queue=False, |
| | ) |
| | .then( |
| | submit_fn, |
| | [self.saved_input, self.chatbot_state] + self.additional_inputs, |
| | [self.chatbot, self.chatbot_state, self.interrupter], |
| | show_api=False, |
| | concurrency_limit=cast( |
| | Union[int, Literal["default"], None], self.concurrency_limit |
| | ), |
| | ).then( |
| | self.post_fn, |
| | **self.post_fn_kwargs, |
| | show_api=False, |
| | concurrency_limit=cast( |
| | Union[int, Literal["default"], None], self.concurrency_limit |
| | ), |
| | ) |
| | ) |
| | self._setup_stop_events(submit_triggers, submit_event) |
| |
|
| | if self.retry_btn: |
| | retry_event = ( |
| | self.retry_btn.click( |
| | self._delete_prev_fn, |
| | [self.saved_input, self.chatbot_state], |
| | [self.chatbot, self.saved_input, self.chatbot_state], |
| | show_api=False, |
| | queue=False, |
| | ) |
| | .then( |
| | self.pre_fn, |
| | **self.pre_fn_kwargs, |
| | show_api=False, |
| | queue=False, |
| | ) |
| | .then( |
| | self._display_input, |
| | [self.saved_input, self.chatbot_state], |
| | [self.chatbot, self.chatbot_state], |
| | show_api=False, |
| | queue=False, |
| | ) |
| | .then( |
| | submit_fn, |
| | [self.saved_input, self.chatbot_state] + self.additional_inputs, |
| | [self.chatbot, self.chatbot_state], |
| | show_api=False, |
| | concurrency_limit=cast( |
| | Union[int, Literal["default"], None], self.concurrency_limit |
| | ), |
| | ).then( |
| | self.post_fn, |
| | **self.post_fn_kwargs, |
| | show_api=False, |
| | concurrency_limit=cast( |
| | Union[int, Literal["default"], None], self.concurrency_limit |
| | ), |
| | ) |
| | ) |
| | self._setup_stop_events([self.retry_btn.click], retry_event) |
| |
|
| | if self.undo_btn: |
| | self.undo_btn.click( |
| | self._delete_prev_fn, |
| | [self.saved_input, self.chatbot_state], |
| | [self.chatbot, self.saved_input, self.chatbot_state], |
| | show_api=False, |
| | queue=False, |
| | ).then( |
| | self.pre_fn, |
| | **self.pre_fn_kwargs, |
| | show_api=False, |
| | queue=False, |
| | ).then( |
| | async_lambda(lambda x: x), |
| | [self.saved_input], |
| | [self.textbox], |
| | show_api=False, |
| | queue=False, |
| | ).then( |
| | self.post_fn, |
| | **self.post_fn_kwargs, |
| | show_api=False, |
| | concurrency_limit=cast( |
| | Union[int, Literal["default"], None], self.concurrency_limit |
| | ), |
| | ) |
| |
|
| | if self.clear_btn: |
| | self.clear_btn.click( |
| | async_lambda(lambda: ([], [], None)), |
| | None, |
| | [self.chatbot, self.chatbot_state, self.saved_input], |
| | queue=False, |
| | show_api=False, |
| | ).then( |
| | self.pre_fn, |
| | **self.pre_fn_kwargs, |
| | show_api=False, |
| | queue=False, |
| | ).then( |
| | self.post_fn, |
| | **self.post_fn_kwargs, |
| | show_api=False, |
| | concurrency_limit=cast( |
| | Union[int, Literal["default"], None], self.concurrency_limit |
| | ), |
| | ) |
| |
|
| | def _setup_stop_events( |
| | self, event_triggers: list[Callable], event_to_cancel: Dependency |
| | ) -> None: |
| | def perform_interrupt(ipc): |
| | if ipc is not None: |
| | ipc() |
| | return |
| |
|
| | if self.stop_btn and self.is_generator: |
| | if self.submit_btn: |
| | for event_trigger in event_triggers: |
| | event_trigger( |
| | async_lambda( |
| | lambda: ( |
| | Button(visible=False), |
| | Button(visible=self.show_stop_button), |
| | ) |
| | ), |
| | None, |
| | [self.submit_btn, self.stop_btn], |
| | show_api=False, |
| | queue=False, |
| | ) |
| | event_to_cancel.then( |
| | async_lambda(lambda: (Button(visible=True), Button(visible=False))), |
| | None, |
| | [self.submit_btn, self.stop_btn], |
| | show_api=False, |
| | queue=False, |
| | ) |
| | else: |
| | for event_trigger in event_triggers: |
| | event_trigger( |
| | async_lambda(lambda: Button(visible=self.show_stop_button)), |
| | None, |
| | [self.stop_btn], |
| | show_api=False, |
| | queue=False, |
| | ) |
| | event_to_cancel.then( |
| | async_lambda(lambda: Button(visible=False)), |
| | None, |
| | [self.stop_btn], |
| | show_api=False, |
| | queue=False, |
| | ) |
| | self.stop_btn.click( |
| | fn=perform_interrupt, |
| | inputs=[self.interrupter], |
| | cancels=event_to_cancel, |
| | show_api=False, |
| | ) |
| |
|
| | def _setup_api(self) -> None: |
| | api_fn = self._api_stream_fn if self.is_generator else self._api_submit_fn |
| |
|
| | self.fake_api_btn.click( |
| | api_fn, |
| | [self.textbox, self.chatbot_state] + self.additional_inputs, |
| | [self.textbox, self.chatbot_state], |
| | api_name="chat", |
| | concurrency_limit=cast( |
| | Union[int, Literal["default"], None], self.concurrency_limit |
| | ), |
| | ) |
| |
|
| | def _clear_and_save_textbox(self, message: str) -> tuple[str | dict, str]: |
| | if self.multimodal: |
| | return {"text": "", "files": []}, message |
| | else: |
| | return "", message |
| |
|
| | def _append_multimodal_history( |
| | self, |
| | message: dict[str, list], |
| | response: str | None, |
| | history: list[list[str | tuple | None]], |
| | ): |
| | for x in message["files"]: |
| | history.append([(x,), None]) |
| | if message["text"] is None or not isinstance(message["text"], str): |
| | return |
| | elif message["text"] == "" and message["files"] != []: |
| | history.append([None, response]) |
| | else: |
| | history.append([message["text"], response]) |
| |
|
| | async def _display_input( |
| | self, message: str | dict[str, list], history: list[list[str | tuple | None]] |
| | ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]: |
| | if self.multimodal and isinstance(message, dict): |
| | self._append_multimodal_history(message, None, history) |
| | elif isinstance(message, str): |
| | history.append([message, None]) |
| | return history, history |
| |
|
| | async def _submit_fn( |
| | self, |
| | message: str | dict[str, list], |
| | history_with_input: list[list[str | tuple | None]], |
| | request: Request, |
| | *args, |
| | ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]: |
| | if self.multimodal and isinstance(message, dict): |
| | remove_input = ( |
| | len(message["files"]) + 1 |
| | if message["text"] is not None |
| | else len(message["files"]) |
| | ) |
| | history = history_with_input[:-remove_input] |
| | else: |
| | history = history_with_input[:-1] |
| | inputs, _, _ = special_args( |
| | self.fn, inputs=[message, history, *args], request=request |
| | ) |
| |
|
| | if self.is_async: |
| | response = await self.fn(*inputs) |
| | else: |
| | response = await anyio.to_thread.run_sync( |
| | self.fn, *inputs, limiter=self.limiter |
| | ) |
| |
|
| | if self.multimodal and isinstance(message, dict): |
| | self._append_multimodal_history(message, response, history) |
| | elif isinstance(message, str): |
| | history.append([message, response]) |
| | return history, history |
| |
|
| | async def _stream_fn( |
| | self, |
| | message: str | dict[str, list], |
| | history_with_input: list[list[str | tuple | None]], |
| | request: Request, |
| | *args, |
| | ) -> AsyncGenerator: |
| | if self.multimodal and isinstance(message, dict): |
| | remove_input = ( |
| | len(message["files"]) + 1 |
| | if message["text"] is not None |
| | else len(message["files"]) |
| | ) |
| | history = history_with_input[:-remove_input] |
| | else: |
| | history = history_with_input[:-1] |
| | inputs, _, _ = special_args( |
| | self.fn, inputs=[message, history, *args], request=request |
| | ) |
| |
|
| | if self.is_async: |
| | generator = self.fn(*inputs) |
| | else: |
| | generator = await anyio.to_thread.run_sync( |
| | self.fn, *inputs, limiter=self.limiter |
| | ) |
| | generator = SyncToAsyncIterator(generator, self.limiter) |
| | try: |
| | first_response, first_interrupter = await async_iteration(generator) |
| | if self.multimodal and isinstance(message, dict): |
| | for x in message["files"]: |
| | history.append([(x,), None]) |
| | update = history + [[message["text"], first_response]] |
| | yield update, update |
| | else: |
| | update = history + [[message, first_response]] |
| | yield update, update, first_interrupter |
| | except StopIteration: |
| | if self.multimodal and isinstance(message, dict): |
| | self._append_multimodal_history(message, None, history) |
| | yield history, history |
| | else: |
| | update = history + [[message, None]] |
| | yield update, update, first_interrupter |
| | async for response, interrupter in generator: |
| | if self.multimodal and isinstance(message, dict): |
| | update = history + [[message["text"], response]] |
| | yield update, update |
| | else: |
| | update = history + [[message, response]] |
| | yield update, update, interrupter |
| |
|
| | async def _api_submit_fn( |
| | self, message: str, history: list[list[str | None]], request: Request, *args |
| | ) -> tuple[str, list[list[str | None]]]: |
| | inputs, _, _ = special_args( |
| | self.fn, inputs=[message, history, *args], request=request |
| | ) |
| |
|
| | if self.is_async: |
| | response = await self.fn(*inputs) |
| | else: |
| | response = await anyio.to_thread.run_sync( |
| | self.fn, *inputs, limiter=self.limiter |
| | ) |
| | history.append([message, response]) |
| | return response, history |
| |
|
| | async def _api_stream_fn( |
| | self, message: str, history: list[list[str | None]], request: Request, *args |
| | ) -> AsyncGenerator: |
| | inputs, _, _ = special_args( |
| | self.fn, inputs=[message, history, *args], request=request |
| | ) |
| |
|
| | if self.is_async: |
| | generator = self.fn(*inputs) |
| | else: |
| | generator = await anyio.to_thread.run_sync( |
| | self.fn, *inputs, limiter=self.limiter |
| | ) |
| | generator = SyncToAsyncIterator(generator, self.limiter) |
| | try: |
| | first_response = await async_iteration(generator) |
| | yield first_response, history + [[message, first_response]] |
| | except StopIteration: |
| | yield None, history + [[message, None]] |
| | async for response in generator: |
| | yield response, history + [[message, response]] |
| |
|
| | async def _delete_prev_fn( |
| | self, |
| | message: str | dict[str, list], |
| | history: list[list[str | tuple | None]], |
| | ) -> tuple[ |
| | list[list[str | tuple | None]], |
| | str | dict[str, list], |
| | list[list[str | tuple | None]], |
| | ]: |
| | if self.multimodal and isinstance(message, dict): |
| | remove_input = ( |
| | len(message["files"]) + 1 |
| | if message["text"] is not None |
| | else len(message["files"]) |
| | ) |
| | history = history[:-remove_input] |
| | else: |
| | while history: |
| | deleted_a, deleted_b = history[-1] |
| | history = history[:-1] |
| | if isinstance(deleted_a, str) and isinstance(deleted_b, str): |
| | break |
| | return history, message or "", history |
| |
|