Spaces:
Runtime error
Runtime error
| """ | |
| @Desc: 全局配置文件读取 | |
| """ | |
| import argparse | |
| import yaml | |
| from typing import Dict, List | |
| import os | |
| import shutil | |
| import sys | |
| class Webui_config: | |
| """webui 配置""" | |
| def __init__( | |
| self, | |
| model: str, | |
| config_path: str, | |
| port: int = 7860, | |
| share: bool = False, | |
| debug: bool = False, | |
| ): | |
| self.model: str = model # 端口号 | |
| self.config_path: str = config_path # 是否公开部署,对外网开放 | |
| self.port: int = port # 是否开启debug模式 | |
| self.share: bool = share # 模型路径 | |
| self.debug: bool = debug # 配置文件路径 | |
| def from_dict(cls, dataset_path: str, data: Dict[str, any]): | |
| data["config_path"] = os.path.join(dataset_path, data["config_path"]) | |
| data["model"] = os.path.join(dataset_path, data["model"]) | |
| return cls(**data) | |
| class Server_config: | |
| def __init__( | |
| self, models: List[Dict[str, any]], port: int = 5000, device: str = "cuda" | |
| ): | |
| self.models: List[Dict[str, any]] = models # 需要加载的所有模型的配置 | |
| self.port: int = port # 端口号 | |
| def from_dict(cls, data: Dict[str, any]): | |
| return cls(**data) | |
| class Config: | |
| def __init__(self, config_path: str): | |
| with open(file=config_path, mode="r", encoding="utf-8") as file: | |
| yaml_config: Dict[str, any] = yaml.safe_load(file.read()) | |
| dataset_path: str = yaml_config["dataset_path"] | |
| self.dataset_path: str = dataset_path | |
| self.webui_config: Webui_config = Webui_config.from_dict( | |
| dataset_path, yaml_config["webui"] | |
| ) | |
| self.server_config: Server_config = Server_config.from_dict( | |
| yaml_config["server"] | |
| ) | |
| self.device: str = yaml_config["device"] | |
| if self.device == "auto": | |
| import torch | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| parser = argparse.ArgumentParser() | |
| # 为避免与以前的config.json起冲突,将其更名如下 | |
| parser.add_argument("-y", "--yml_config", type=str, default="config.yml") | |
| args, _ = parser.parse_known_args() | |
| config = Config(args.yml_config) | |