Spaces:
Build error
Build error
| """ | |
| Defines helper methods useful for setting up ports, launching servers, and | |
| creating tunnels. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import socket | |
| import threading | |
| import time | |
| import warnings | |
| from functools import partial | |
| from typing import TYPE_CHECKING | |
| import requests | |
| import uvicorn | |
| from uvicorn.config import Config | |
| from gradio.exceptions import ServerFailedToStartError | |
| from gradio.routes import App | |
| from gradio.tunneling import Tunnel | |
| from gradio.utils import SourceFileReloader, watchfn | |
| if TYPE_CHECKING: # Only import for type checking (to avoid circular imports). | |
| from gradio.blocks import Blocks | |
| # By default, the local server will try to open on localhost, port 7860. | |
| # If that is not available, then it will try 7861, 7862, ... 7959. | |
| INITIAL_PORT_VALUE = int(os.getenv("GRADIO_SERVER_PORT", "7860")) | |
| TRY_NUM_PORTS = int(os.getenv("GRADIO_NUM_PORTS", "100")) | |
| LOCALHOST_NAME = os.getenv("GRADIO_SERVER_NAME", "127.0.0.1") | |
| GRADIO_API_SERVER = "https://api.gradio.app/v2/tunnel-request" | |
| should_watch = bool(os.getenv("GRADIO_WATCH_DIRS", False)) | |
| GRADIO_WATCH_DIRS = ( | |
| os.getenv("GRADIO_WATCH_DIRS", "").split(",") if should_watch else [] | |
| ) | |
| GRADIO_WATCH_FILE = os.getenv("GRADIO_WATCH_FILE", "app") | |
| GRADIO_WATCH_DEMO_NAME = os.getenv("GRADIO_WATCH_DEMO_NAME", "demo") | |
| class Server(uvicorn.Server): | |
| def __init__( | |
| self, config: Config, reloader: SourceFileReloader | None = None | |
| ) -> None: | |
| self.running_app = config.app | |
| super().__init__(config) | |
| self.reloader = reloader | |
| if self.reloader: | |
| self.event = threading.Event() | |
| self.watch = partial(watchfn, self.reloader) | |
| def install_signal_handlers(self): | |
| pass | |
| def run_in_thread(self): | |
| self.thread = threading.Thread(target=self.run, daemon=True) | |
| if self.reloader: | |
| self.watch_thread = threading.Thread(target=self.watch, daemon=True) | |
| self.watch_thread.start() | |
| self.thread.start() | |
| start = time.time() | |
| while not self.started: | |
| time.sleep(1e-3) | |
| if time.time() - start > 5: | |
| raise ServerFailedToStartError( | |
| "Server failed to start. Please check that the port is available." | |
| ) | |
| def close(self): | |
| self.should_exit = True | |
| if self.reloader: | |
| self.reloader.stop() | |
| self.watch_thread.join() | |
| self.thread.join() | |
| def get_first_available_port(initial: int, final: int) -> int: | |
| """ | |
| Gets the first open port in a specified range of port numbers | |
| Parameters: | |
| initial: the initial value in the range of port numbers | |
| final: final (exclusive) value in the range of port numbers, should be greater than `initial` | |
| Returns: | |
| port: the first open port in the range | |
| """ | |
| for port in range(initial, final): | |
| try: | |
| s = socket.socket() # create a socket object | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| s.bind((LOCALHOST_NAME, port)) # Bind to the port | |
| s.close() | |
| return port | |
| except OSError: | |
| pass | |
| raise OSError( | |
| f"All ports from {initial} to {final - 1} are in use. Please close a port." | |
| ) | |
| def configure_app(app: App, blocks: Blocks) -> App: | |
| auth = blocks.auth | |
| if auth is not None: | |
| if not callable(auth): | |
| app.auth = {account[0]: account[1] for account in auth} | |
| else: | |
| app.auth = auth | |
| else: | |
| app.auth = None | |
| app.blocks = blocks | |
| app.cwd = os.getcwd() | |
| app.favicon_path = blocks.favicon_path | |
| app.tokens = {} | |
| return app | |
| def start_server( | |
| blocks: Blocks, | |
| server_name: str | None = None, | |
| server_port: int | None = None, | |
| ssl_keyfile: str | None = None, | |
| ssl_certfile: str | None = None, | |
| ssl_keyfile_password: str | None = None, | |
| app_kwargs: dict | None = None, | |
| ) -> tuple[str, int, str, App, Server]: | |
| """Launches a local server running the provided Interface | |
| Parameters: | |
| blocks: The Blocks object to run on the server | |
| server_name: to make app accessible on local network, set this to "0.0.0.0". Can be set by environment variable GRADIO_SERVER_NAME. | |
| server_port: will start gradio app on this port (if available). Can be set by environment variable GRADIO_SERVER_PORT. | |
| auth: If provided, username and password (or list of username-password tuples) required to access the Blocks. Can also provide function that takes username and password and returns True if valid login. | |
| ssl_keyfile: If a path to a file is provided, will use this as the private key file to create a local server running on https. | |
| ssl_certfile: If a path to a file is provided, will use this as the signed certificate for https. Needs to be provided if ssl_keyfile is provided. | |
| ssl_keyfile_password: If a password is provided, will use this with the ssl certificate for https. | |
| app_kwargs: Additional keyword arguments to pass to the gradio.routes.App constructor. | |
| Returns: | |
| port: the port number the server is running on | |
| path_to_local_server: the complete address that the local server can be accessed at | |
| app: the FastAPI app object | |
| server: the server object that is a subclass of uvicorn.Server (used to close the server) | |
| """ | |
| if ssl_keyfile is not None and ssl_certfile is None: | |
| raise ValueError("ssl_certfile must be provided if ssl_keyfile is provided.") | |
| server_name = server_name or LOCALHOST_NAME | |
| url_host_name = "localhost" if server_name == "0.0.0.0" else server_name | |
| # Strip IPv6 brackets from the address if they exist. | |
| # This is needed as http://[::1]:port/ is a valid browser address, | |
| # but not a valid IPv6 address, so asyncio will throw an exception. | |
| if server_name.startswith("[") and server_name.endswith("]"): | |
| host = server_name[1:-1] | |
| else: | |
| host = server_name | |
| app = App.create_app(blocks, app_kwargs=app_kwargs) | |
| server_ports = ( | |
| [server_port] | |
| if server_port is not None | |
| else range(INITIAL_PORT_VALUE, INITIAL_PORT_VALUE + TRY_NUM_PORTS) | |
| ) | |
| for port in server_ports: | |
| try: | |
| # The fastest way to check if a port is available is to try to bind to it with socket. | |
| # If the port is not available, socket will throw an OSError. | |
| s = socket.socket() | |
| s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| # Really, we should be checking if (server_name, server_port) is available, but | |
| # socket.bind() doesn't seem to throw an OSError with ipv6 addresses, based on my testing. | |
| # Instead, we just check if the port is available on localhost. | |
| s.bind((LOCALHOST_NAME, port)) | |
| s.close() | |
| # To avoid race conditions, so we also check if the port by trying to start the uvicorn server. | |
| # If the port is not available, this will throw a ServerFailedToStartError. | |
| config = uvicorn.Config( | |
| app=app, | |
| port=port, | |
| host=host, | |
| log_level="warning", | |
| ssl_keyfile=ssl_keyfile, | |
| ssl_certfile=ssl_certfile, | |
| ssl_keyfile_password=ssl_keyfile_password, | |
| ws_max_size=1024 * 1024 * 1024, # Setting max websocket size to be 1 GB | |
| ) | |
| reloader = None | |
| if GRADIO_WATCH_DIRS: | |
| change_event = threading.Event() | |
| app.change_event = change_event | |
| reloader = SourceFileReloader( | |
| app=app, | |
| watch_dirs=GRADIO_WATCH_DIRS, | |
| watch_file=GRADIO_WATCH_FILE, | |
| demo_name=GRADIO_WATCH_DEMO_NAME, | |
| stop_event=threading.Event(), | |
| change_event=change_event, | |
| ) | |
| server = Server(config=config, reloader=reloader) | |
| server.run_in_thread() | |
| break | |
| except (OSError, ServerFailedToStartError): | |
| pass | |
| else: | |
| raise OSError( | |
| f"Cannot find empty port in range: {min(server_ports)}-{max(server_ports)}. You can specify a different port by setting the GRADIO_SERVER_PORT environment variable or passing the `server_port` parameter to `launch()`." | |
| ) | |
| if ssl_keyfile is not None: | |
| path_to_local_server = f"https://{url_host_name}:{port}/" | |
| else: | |
| path_to_local_server = f"http://{url_host_name}:{port}/" | |
| return server_name, port, path_to_local_server, app, server | |
| def setup_tunnel(local_host: str, local_port: int, share_token: str) -> str: | |
| response = requests.get(GRADIO_API_SERVER) | |
| if response and response.status_code == 200: | |
| try: | |
| payload = response.json()[0] | |
| remote_host, remote_port = payload["host"], int(payload["port"]) | |
| tunnel = Tunnel( | |
| remote_host, remote_port, local_host, local_port, share_token | |
| ) | |
| address = tunnel.start_tunnel() | |
| return address | |
| except Exception as e: | |
| raise RuntimeError(str(e)) from e | |
| raise RuntimeError("Could not get share link from Gradio API Server.") | |
| def url_ok(url: str) -> bool: | |
| try: | |
| for _ in range(5): | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore") | |
| r = requests.head(url, timeout=3, verify=False) | |
| if r.status_code in (200, 401, 302): # 401 or 302 if auth is set | |
| return True | |
| time.sleep(0.500) | |
| except (ConnectionError, requests.exceptions.ConnectionError): | |
| return False | |
| return False | |