Wan2GP / preprocessing /matanyone /utils /model_assets.py
Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
import json
import os
import sys
from typing import Any
from huggingface_hub import hf_hub_download
from mmgp import offload
from omegaconf import OmegaConf
from shared.utils import files_locator as fl
from .model_signature import MATANYONE_V1, MATANYONE_V2, detect_matanyone_model_version
MATANYONE_SETTINGS_KEY = "matanyone_version"
MATANYONE_DEFAULT_VERSION = MATANYONE_V1
MATANYONE_REPO_ID = "DeepBeepMeep/Wan2.1"
MATANYONE_FOLDER = "mask"
MATANYONE_CONFIG_NAME = "config.json"
MATANYONE_SAM_NAME = "sam_vit_h_4b8939_fp16.safetensors"
MATANYONE_LEGACY_NAME = "model.safetensors"
MATANYONE_WEIGHT_FILES = {
MATANYONE_V1: "matanyone.safetensors",
MATANYONE_V2: "matanyone2.safetensors",
}
MATANYONE_VERSION_LABELS = {
MATANYONE_V1: "MatAnyone v1 (original)",
MATANYONE_V2: "MatAnyone v2",
}
def _mask_relpath(filename: str) -> str:
return os.path.join(MATANYONE_FOLDER, filename)
def normalize_matanyone_version(value: Any) -> str:
if value is None:
return MATANYONE_DEFAULT_VERSION
text = str(value).strip().lower()
if text in {"2", MATANYONE_V2, "matanyone2"}:
return MATANYONE_V2
return MATANYONE_V1
def _get_runtime_server_config(server_config=None):
if isinstance(server_config, dict):
return server_config
main_module = sys.modules.get("__main__")
runtime_config = getattr(main_module, "server_config", None)
return runtime_config if isinstance(runtime_config, dict) else {}
def _get_runtime_server_config_filename() -> str | None:
main_module = sys.modules.get("__main__")
filename = getattr(main_module, "server_config_filename", None)
return filename if isinstance(filename, str) and len(filename) > 0 else None
def _save_runtime_server_config(server_config) -> bool:
filename = _get_runtime_server_config_filename()
if not isinstance(server_config, dict) or filename is None:
return False
with open(filename, "w", encoding="utf-8") as writer:
writer.write(json.dumps(server_config, indent=4))
return True
def get_selected_matanyone_version(server_config=None) -> str:
runtime_config = _get_runtime_server_config(server_config)
return normalize_matanyone_version(runtime_config.get(MATANYONE_SETTINGS_KEY, MATANYONE_DEFAULT_VERSION))
def get_selected_matanyone_label(server_config=None) -> str:
return MATANYONE_VERSION_LABELS[get_selected_matanyone_version(server_config)]
def get_matanyone_title_html(server_config=None) -> str:
return f"<B>Mask Edition is provided by {get_selected_matanyone_label(server_config)}, VRAM optimizations & Extended Masks by DeepBeepMeep</B>"
def get_selected_matanyone_weight_name(server_config=None) -> str:
return MATANYONE_WEIGHT_FILES[get_selected_matanyone_version(server_config)]
def get_selected_matanyone_weights_path(server_config=None):
selected_name = get_selected_matanyone_weight_name(server_config)
selected_path = fl.locate_file(_mask_relpath(selected_name), error_if_none=False)
if selected_path is not None:
return selected_path
legacy_path = fl.locate_file(_mask_relpath(MATANYONE_LEGACY_NAME), error_if_none=False)
if legacy_path is None:
return None
if detect_matanyone_model_version(legacy_path) == get_selected_matanyone_version(server_config):
return legacy_path
return None
def _download_mask_asset(filename: str) -> str:
return hf_hub_download(repo_id=MATANYONE_REPO_ID, filename=filename, local_dir=fl.get_download_location(), subfolder=MATANYONE_FOLDER)
def query_matanyone_download_def(server_config=None):
runtime_config = _get_runtime_server_config(server_config)
return {
"repoId": MATANYONE_REPO_ID,
"sourceFolderList": [MATANYONE_FOLDER],
"fileList": [[MATANYONE_SAM_NAME, get_selected_matanyone_weight_name(runtime_config), MATANYONE_CONFIG_NAME]],
}
def migrate_matanyone_install(server_config=None):
runtime_config = _get_runtime_server_config(server_config)
runtime_config.setdefault(MATANYONE_SETTINGS_KEY, MATANYONE_DEFAULT_VERSION)
legacy_path = fl.locate_file(_mask_relpath(MATANYONE_LEGACY_NAME), error_if_none=False)
if legacy_path is None:
return None
legacy_version = detect_matanyone_model_version(legacy_path)
if legacy_version is None:
return None
target_name = MATANYONE_WEIGHT_FILES[legacy_version]
target_path = fl.get_download_location(_mask_relpath(target_name))
os.makedirs(os.path.dirname(target_path), exist_ok=True)
if os.path.normcase(os.path.abspath(legacy_path)) != os.path.normcase(os.path.abspath(target_path)):
if os.path.isfile(target_path):
os.remove(legacy_path)
else:
os.replace(legacy_path, target_path)
config_changed = False
if legacy_version == MATANYONE_V2 and runtime_config.get(MATANYONE_SETTINGS_KEY) != MATANYONE_V2:
runtime_config[MATANYONE_SETTINGS_KEY] = MATANYONE_V2
config_changed = True
elif MATANYONE_SETTINGS_KEY not in runtime_config:
runtime_config[MATANYONE_SETTINGS_KEY] = legacy_version
config_changed = True
if config_changed:
_save_runtime_server_config(runtime_config)
if legacy_version == MATANYONE_V2:
return "Migrated legacy MatAnyone v2 weights to 'mask/matanyone2.safetensors' and selected MatAnyone v2."
return "Migrated legacy MatAnyone v1 weights to 'mask/matanyone.safetensors'."
def ensure_selected_matanyone_assets(server_config=None):
runtime_config = _get_runtime_server_config(server_config)
migrate_matanyone_install(runtime_config)
for filename in query_matanyone_download_def(runtime_config)["fileList"][0]:
if filename == get_selected_matanyone_weight_name(runtime_config):
continue
if fl.locate_file(_mask_relpath(filename), error_if_none=False) is None:
_download_mask_asset(filename)
weights_path = get_selected_matanyone_weights_path(runtime_config)
if weights_path is None:
weights_path = _download_mask_asset(get_selected_matanyone_weight_name(runtime_config))
config_path = fl.locate_file(_mask_relpath(MATANYONE_CONFIG_NAME))
sam_path = fl.locate_file(_mask_relpath(MATANYONE_SAM_NAME))
return config_path, weights_path, sam_path
def load_selected_matanyone_model(server_config=None):
config_path, weights_path, _ = ensure_selected_matanyone_assets(server_config)
with open(config_path, "r", encoding="utf-8-sig") as reader:
config_data = json.load(reader)
from ..matanyone.model.matanyone import MatAnyone
model = MatAnyone(OmegaConf.create(config_data["cfg"]), single_object=config_data.get("single_object", True)).eval()
offload.load_model_data(model, weights_path, writable_tensors=False)
return model, get_selected_matanyone_version(server_config), weights_path