| |
| from dataclasses import dataclass |
| from typing import Literal, Optional |
|
|
| from swift.utils import find_free_port, get_logger |
| from ..model import get_matched_model_meta |
| from ..template import get_template_meta |
| from .deploy_args import DeployArguments |
| from .webui_args import WebUIArguments |
|
|
| logger = get_logger() |
|
|
|
|
| @dataclass |
| class AppArguments(WebUIArguments, DeployArguments): |
| base_url: Optional[str] = None |
| studio_title: Optional[str] = None |
| is_multimodal: Optional[bool] = None |
|
|
| lang: Literal['en', 'zh'] = 'en' |
| verbose: bool = False |
|
|
| def _init_torch_dtype(self) -> None: |
| if self.base_url: |
| self.model_meta = get_matched_model_meta(self.model) |
| return |
| super()._init_torch_dtype() |
|
|
| def __post_init__(self): |
| super().__post_init__() |
| self.server_port = find_free_port(self.server_port) |
| if self.model_meta: |
| if self.system is None: |
| self.system = get_template_meta(self.model_meta.template).default_system |
| if self.is_multimodal is None: |
| self.is_multimodal = self.model_meta.is_multimodal |
| if self.is_multimodal is None: |
| self.is_multimodal = False |
|
|