import json import os import sys from typing import Any from huggingface_hub import hf_hub_download from mmgp import offload from omegaconf import OmegaConf from shared.utils import files_locator as fl from .model_signature import MATANYONE_V1, MATANYONE_V2, detect_matanyone_model_version MATANYONE_SETTINGS_KEY = "matanyone_version" MATANYONE_DEFAULT_VERSION = MATANYONE_V1 MATANYONE_REPO_ID = "DeepBeepMeep/Wan2.1" MATANYONE_FOLDER = "mask" MATANYONE_CONFIG_NAME = "config.json" MATANYONE_SAM_NAME = "sam_vit_h_4b8939_fp16.safetensors" MATANYONE_LEGACY_NAME = "model.safetensors" MATANYONE_WEIGHT_FILES = { MATANYONE_V1: "matanyone.safetensors", MATANYONE_V2: "matanyone2.safetensors", } MATANYONE_VERSION_LABELS = { MATANYONE_V1: "MatAnyone v1 (original)", MATANYONE_V2: "MatAnyone v2", } def _mask_relpath(filename: str) -> str: return os.path.join(MATANYONE_FOLDER, filename) def normalize_matanyone_version(value: Any) -> str: if value is None: return MATANYONE_DEFAULT_VERSION text = str(value).strip().lower() if text in {"2", MATANYONE_V2, "matanyone2"}: return MATANYONE_V2 return MATANYONE_V1 def _get_runtime_server_config(server_config=None): if isinstance(server_config, dict): return server_config main_module = sys.modules.get("__main__") runtime_config = getattr(main_module, "server_config", None) return runtime_config if isinstance(runtime_config, dict) else {} def _get_runtime_server_config_filename() -> str | None: main_module = sys.modules.get("__main__") filename = getattr(main_module, "server_config_filename", None) return filename if isinstance(filename, str) and len(filename) > 0 else None def _save_runtime_server_config(server_config) -> bool: filename = _get_runtime_server_config_filename() if not isinstance(server_config, dict) or filename is None: return False with open(filename, "w", encoding="utf-8") as writer: writer.write(json.dumps(server_config, indent=4)) return True def get_selected_matanyone_version(server_config=None) -> str: runtime_config = _get_runtime_server_config(server_config) return normalize_matanyone_version(runtime_config.get(MATANYONE_SETTINGS_KEY, MATANYONE_DEFAULT_VERSION)) def get_selected_matanyone_label(server_config=None) -> str: return MATANYONE_VERSION_LABELS[get_selected_matanyone_version(server_config)] def get_matanyone_title_html(server_config=None) -> str: return f"Mask Edition is provided by {get_selected_matanyone_label(server_config)}, VRAM optimizations & Extended Masks by DeepBeepMeep" def get_selected_matanyone_weight_name(server_config=None) -> str: return MATANYONE_WEIGHT_FILES[get_selected_matanyone_version(server_config)] def get_selected_matanyone_weights_path(server_config=None): selected_name = get_selected_matanyone_weight_name(server_config) selected_path = fl.locate_file(_mask_relpath(selected_name), error_if_none=False) if selected_path is not None: return selected_path legacy_path = fl.locate_file(_mask_relpath(MATANYONE_LEGACY_NAME), error_if_none=False) if legacy_path is None: return None if detect_matanyone_model_version(legacy_path) == get_selected_matanyone_version(server_config): return legacy_path return None def _download_mask_asset(filename: str) -> str: return hf_hub_download(repo_id=MATANYONE_REPO_ID, filename=filename, local_dir=fl.get_download_location(), subfolder=MATANYONE_FOLDER) def query_matanyone_download_def(server_config=None): runtime_config = _get_runtime_server_config(server_config) return { "repoId": MATANYONE_REPO_ID, "sourceFolderList": [MATANYONE_FOLDER], "fileList": [[MATANYONE_SAM_NAME, get_selected_matanyone_weight_name(runtime_config), MATANYONE_CONFIG_NAME]], } def migrate_matanyone_install(server_config=None): runtime_config = _get_runtime_server_config(server_config) runtime_config.setdefault(MATANYONE_SETTINGS_KEY, MATANYONE_DEFAULT_VERSION) legacy_path = fl.locate_file(_mask_relpath(MATANYONE_LEGACY_NAME), error_if_none=False) if legacy_path is None: return None legacy_version = detect_matanyone_model_version(legacy_path) if legacy_version is None: return None target_name = MATANYONE_WEIGHT_FILES[legacy_version] target_path = fl.get_download_location(_mask_relpath(target_name)) os.makedirs(os.path.dirname(target_path), exist_ok=True) if os.path.normcase(os.path.abspath(legacy_path)) != os.path.normcase(os.path.abspath(target_path)): if os.path.isfile(target_path): os.remove(legacy_path) else: os.replace(legacy_path, target_path) config_changed = False if legacy_version == MATANYONE_V2 and runtime_config.get(MATANYONE_SETTINGS_KEY) != MATANYONE_V2: runtime_config[MATANYONE_SETTINGS_KEY] = MATANYONE_V2 config_changed = True elif MATANYONE_SETTINGS_KEY not in runtime_config: runtime_config[MATANYONE_SETTINGS_KEY] = legacy_version config_changed = True if config_changed: _save_runtime_server_config(runtime_config) if legacy_version == MATANYONE_V2: return "Migrated legacy MatAnyone v2 weights to 'mask/matanyone2.safetensors' and selected MatAnyone v2." return "Migrated legacy MatAnyone v1 weights to 'mask/matanyone.safetensors'." def ensure_selected_matanyone_assets(server_config=None): runtime_config = _get_runtime_server_config(server_config) migrate_matanyone_install(runtime_config) for filename in query_matanyone_download_def(runtime_config)["fileList"][0]: if filename == get_selected_matanyone_weight_name(runtime_config): continue if fl.locate_file(_mask_relpath(filename), error_if_none=False) is None: _download_mask_asset(filename) weights_path = get_selected_matanyone_weights_path(runtime_config) if weights_path is None: weights_path = _download_mask_asset(get_selected_matanyone_weight_name(runtime_config)) config_path = fl.locate_file(_mask_relpath(MATANYONE_CONFIG_NAME)) sam_path = fl.locate_file(_mask_relpath(MATANYONE_SAM_NAME)) return config_path, weights_path, sam_path def load_selected_matanyone_model(server_config=None): config_path, weights_path, _ = ensure_selected_matanyone_assets(server_config) with open(config_path, "r", encoding="utf-8-sig") as reader: config_data = json.load(reader) from ..matanyone.model.matanyone import MatAnyone model = MatAnyone(OmegaConf.create(config_data["cfg"]), single_object=config_data.get("single_object", True)).eval() offload.load_model_data(model, weights_path, writable_tensors=False) return model, get_selected_matanyone_version(server_config), weights_path