# Copyright (c) Alibaba, Inc. and its affiliates. from contextlib import nullcontext from typing import List, Union import gradio from packaging import version from swift.utils import get_logger from ..argument import AppArguments from ..base import SwiftPipeline from ..infer import run_deploy from .build_ui import build_ui logger = get_logger() class SwiftApp(SwiftPipeline): args_class = AppArguments args: args_class def run(self): args = self.args deploy_context = nullcontext() if args.base_url else run_deploy(args, return_url=True) with deploy_context as base_url: base_url = base_url or args.base_url demo = build_ui( base_url, args.model_suffix, request_config=args.get_request_config(), is_multimodal=args.is_multimodal, studio_title=args.studio_title, lang=args.lang, default_system=args.system) concurrency_count = 1 if args.infer_backend == 'pt' else 16 if version.parse(gradio.__version__) < version.parse('4'): queue_kwargs = {'concurrency_count': concurrency_count} else: queue_kwargs = {'default_concurrency_limit': concurrency_count} demo.queue(**queue_kwargs).launch( server_name=args.server_name, server_port=args.server_port, share=args.share) def app_main(args: Union[List[str], AppArguments, None] = None): return SwiftApp(args).main()