""" @Desc: 全局配置文件读取 """ import argparse import yaml from typing import Dict, List import os import shutil import sys class Webui_config: """webui 配置""" def __init__( self, model: str, config_path: str, port: int = 7860, share: bool = False, debug: bool = False, ): self.model: str = model # 端口号 self.config_path: str = config_path # 是否公开部署,对外网开放 self.port: int = port # 是否开启debug模式 self.share: bool = share # 模型路径 self.debug: bool = debug # 配置文件路径 @classmethod def from_dict(cls, dataset_path: str, data: Dict[str, any]): data["config_path"] = os.path.join(dataset_path, data["config_path"]) data["model"] = os.path.join(dataset_path, data["model"]) return cls(**data) class Server_config: def __init__( self, models: List[Dict[str, any]], port: int = 5000, device: str = "cuda" ): self.models: List[Dict[str, any]] = models # 需要加载的所有模型的配置 self.port: int = port # 端口号 @classmethod def from_dict(cls, data: Dict[str, any]): return cls(**data) class Config: def __init__(self, config_path: str): with open(file=config_path, mode="r", encoding="utf-8") as file: yaml_config: Dict[str, any] = yaml.safe_load(file.read()) dataset_path: str = yaml_config["dataset_path"] self.dataset_path: str = dataset_path self.webui_config: Webui_config = Webui_config.from_dict( dataset_path, yaml_config["webui"] ) self.server_config: Server_config = Server_config.from_dict( yaml_config["server"] ) self.device: str = yaml_config["device"] if self.device == "auto": import torch self.device = "cuda" if torch.cuda.is_available() else "cpu" parser = argparse.ArgumentParser() # 为避免与以前的config.json起冲突,将其更名如下 parser.add_argument("-y", "--yml_config", type=str, default="config.yml") args, _ = parser.parse_known_args() config = Config(args.yml_config)