|
|
|
|
|
from contextlib import nullcontext |
|
|
from typing import List, Union |
|
|
|
|
|
import gradio |
|
|
from packaging import version |
|
|
|
|
|
from swift.utils import get_logger |
|
|
from ..argument import AppArguments |
|
|
from ..base import SwiftPipeline |
|
|
from ..infer import run_deploy |
|
|
from .build_ui import build_ui |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class SwiftApp(SwiftPipeline): |
|
|
args_class = AppArguments |
|
|
args: args_class |
|
|
|
|
|
def run(self): |
|
|
args = self.args |
|
|
deploy_context = nullcontext() if args.base_url else run_deploy(args, return_url=True) |
|
|
with deploy_context as base_url: |
|
|
base_url = base_url or args.base_url |
|
|
demo = build_ui( |
|
|
base_url, |
|
|
args.model_suffix, |
|
|
request_config=args.get_request_config(), |
|
|
is_multimodal=args.is_multimodal, |
|
|
studio_title=args.studio_title, |
|
|
lang=args.lang, |
|
|
default_system=args.system) |
|
|
concurrency_count = 1 if args.infer_backend == 'pt' else 16 |
|
|
if version.parse(gradio.__version__) < version.parse('4'): |
|
|
queue_kwargs = {'concurrency_count': concurrency_count} |
|
|
else: |
|
|
queue_kwargs = {'default_concurrency_limit': concurrency_count} |
|
|
demo.queue(**queue_kwargs).launch( |
|
|
server_name=args.server_name, server_port=args.server_port, share=args.share) |
|
|
|
|
|
|
|
|
def app_main(args: Union[List[str], AppArguments, None] = None): |
|
|
return SwiftApp(args).main() |
|
|
|