chatbot / gradio /block_function.py
siddharth24m's picture
Upload 3001 files
69ba728 verified
from __future__ import annotations
import inspect
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Literal
from . import utils
try:
import spaces # type: ignore
except Exception:
spaces = None
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from gradio.components.base import Component
from gradio.renderable import Renderable
from .blocks import BlockContext
class BlockFunction:
def __init__(
self,
fn: Callable | None,
inputs: Sequence[Component | BlockContext],
outputs: Sequence[Component | BlockContext],
preprocess: bool,
postprocess: bool,
inputs_as_dict: bool,
targets: list[tuple[int | None, str]],
_id: int,
batch: bool = False,
max_batch_size: int = 4,
concurrency_limit: int | None | Literal["default"] = "default",
concurrency_id: str | None = None,
tracks_progress: bool = False,
api_name: str | None = None,
api_description: str | None | Literal[False] = None,
js: str | Literal[True] | None = None,
show_progress: Literal["full", "minimal", "hidden"] = "full",
show_progress_on: Sequence[Component] | None = None,
cancels: list[int] | None = None,
collects_event_data: bool = False,
trigger_after: int | None = None,
trigger_only_on_success: bool = False,
trigger_only_on_failure: bool = False,
trigger_mode: Literal["always_last", "once", "multiple"] = "once",
queue: bool = True,
scroll_to_output: bool = False,
api_visibility: Literal["public", "private", "undocumented"] = "public",
renderable: Renderable | None = None,
rendered_in: Renderable | None = None,
render_iteration: int | None = None,
is_cancel_function: bool = False,
connection: Literal["stream", "sse"] = "sse",
time_limit: float | None = None,
stream_every: float = 0.5,
event_specific_args: list[str] | None = None,
component_prop_inputs: list[int] | None = None,
page: str = "",
js_implementation: str | None = None,
key: str | int | tuple[int | str, ...] | None = None,
validator: Callable | None = None,
):
self.fn = fn
self._id = _id
self.inputs = inputs
self.outputs = outputs
self.preprocess = preprocess
self.postprocess = postprocess
self.tracks_progress = tracks_progress
self.concurrency_limit: int | None | Literal["default"] = concurrency_limit
self.concurrency_id = concurrency_id or str(id(fn))
self.batch = batch
self.max_batch_size = max_batch_size
self.total_runtime = 0
self.total_runs = 0
self.inputs_as_dict = inputs_as_dict
self.targets = targets
self.name = getattr(fn, "__name__", "fn") if fn is not None else None
self.api_name = api_name
self.api_description = api_description
self.js = js
self.show_progress = show_progress
self.show_progress_on = show_progress_on
self.cancels = cancels or []
self.collects_event_data = collects_event_data
self.trigger_after = trigger_after
self.trigger_only_on_success = trigger_only_on_success
self.trigger_only_on_failure = trigger_only_on_failure
self.trigger_mode = trigger_mode
self.queue = False if fn is None else queue
self.scroll_to_output = False if utils.get_space() else scroll_to_output
self.api_visibility = api_visibility
self.types_generator = inspect.isgeneratorfunction(
self.fn
) or inspect.isasyncgenfunction(self.fn)
self.renderable = renderable
self.rendered_in = rendered_in
self.render_iteration = render_iteration
self.page = page
self.validator = validator
if js_implementation:
self.fn.__js_implementation__ = js_implementation # type: ignore
# We need to keep track of which events are cancel events
# so that the client can call the /cancel route directly
self.is_cancel_function = is_cancel_function
self.time_limit = time_limit
self.stream_every = stream_every
self.connection = connection
self.event_specific_args = event_specific_args
self.component_prop_inputs = component_prop_inputs or []
self.key = key
self.spaces_auto_wrap()
def spaces_auto_wrap(self):
if spaces is None:
return
if utils.get_space() is None:
return
self.fn = spaces.gradio_auto_wrap(self.fn)
def __str__(self):
return str(
{
"fn": self.name,
"preprocess": self.preprocess,
"postprocess": self.postprocess,
}
)
def __repr__(self):
return str(self)
def get_config(self):
return {
"id": self._id,
"targets": self.targets,
"inputs": [block._id for block in self.inputs],
"outputs": [block._id for block in self.outputs],
"backend_fn": self.fn is not None,
"js": self.js,
"queue": self.queue,
"api_name": self.api_name,
"api_description": self.api_description,
"scroll_to_output": self.scroll_to_output,
"show_progress": self.show_progress,
"show_progress_on": None
if self.show_progress_on is None
else [block._id for block in self.show_progress_on],
"batch": self.batch,
"max_batch_size": self.max_batch_size,
"cancels": self.cancels,
"types": {
"generator": self.types_generator,
"cancel": self.is_cancel_function,
},
"collects_event_data": self.collects_event_data,
"trigger_after": self.trigger_after,
"trigger_only_on_success": self.trigger_only_on_success,
"trigger_only_on_failure": self.trigger_only_on_failure,
"trigger_mode": self.trigger_mode,
"api_visibility": self.api_visibility,
"rendered_in": self.rendered_in._id if self.rendered_in else None,
"render_id": self.renderable._id if self.renderable else None,
"connection": self.connection,
"time_limit": self.time_limit,
"stream_every": self.stream_every,
"event_specific_args": self.event_specific_args,
"component_prop_inputs": self.component_prop_inputs,
"js_implementation": getattr(self.fn, "__js_implementation__", None),
}