| |
| import datetime |
| import hashlib |
| from io import BytesIO |
| import os |
| from typing import List, Optional, Tuple, Union |
| import safetensors |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
| logger.setLevel(logging.INFO) |
|
|
|
|
| r""" |
| # Metadata Example |
| metadata = { |
| # === Must === |
| "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec |
| "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID |
| "modelspec.implementation": "sgm", |
| "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc |
| # === Should === |
| "modelspec.author": "Example Corp", # Your name or company name |
| "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know |
| "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created |
| # === Can === |
| "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc. |
| "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model |
| } |
| """ |
|
|
| BASE_METADATA = { |
| |
| "modelspec.sai_model_spec": "1.0.0", |
| "modelspec.architecture": None, |
| "modelspec.implementation": None, |
| "modelspec.title": None, |
| "modelspec.resolution": None, |
| |
| "modelspec.description": None, |
| "modelspec.author": None, |
| "modelspec.date": None, |
| |
| "modelspec.license": None, |
| "modelspec.tags": None, |
| "modelspec.merged_from": None, |
| "modelspec.prediction_type": None, |
| "modelspec.timestep_range": None, |
| "modelspec.encoder_layer": None, |
| } |
|
|
| |
| MODELSPEC_TITLE = "modelspec.title" |
|
|
| ARCH_HUNYUAN_VIDEO = "hunyuan-video" |
|
|
| ADAPTER_LORA = "lora" |
|
|
| IMPL_HUNYUAN_VIDEO = "https://github.com/Tencent/HunyuanVideo" |
|
|
| PRED_TYPE_EPSILON = "epsilon" |
| |
|
|
|
|
| def load_bytes_in_safetensors(tensors): |
| bytes = safetensors.torch.save(tensors) |
| b = BytesIO(bytes) |
|
|
| b.seek(0) |
| header = b.read(8) |
| n = int.from_bytes(header, "little") |
|
|
| offset = n + 8 |
| b.seek(offset) |
|
|
| return b.read() |
|
|
|
|
| def precalculate_safetensors_hashes(state_dict): |
| |
| hash_sha256 = hashlib.sha256() |
| for tensor in state_dict.values(): |
| single_tensor_sd = {"tensor": tensor} |
| bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) |
| hash_sha256.update(bytes_for_tensor) |
|
|
| return f"0x{hash_sha256.hexdigest()}" |
|
|
|
|
| def update_hash_sha256(metadata: dict, state_dict: dict): |
| raise NotImplementedError |
|
|
|
|
| def build_metadata( |
| state_dict: Optional[dict], |
| timestamp: float, |
| title: Optional[str] = None, |
| reso: Optional[Union[int, Tuple[int, int]]] = None, |
| author: Optional[str] = None, |
| description: Optional[str] = None, |
| license: Optional[str] = None, |
| tags: Optional[str] = None, |
| merged_from: Optional[str] = None, |
| timesteps: Optional[Tuple[int, int]] = None, |
| ): |
| metadata = {} |
| metadata.update(BASE_METADATA) |
|
|
| |
| |
| |
| |
|
|
| arch = ARCH_HUNYUAN_VIDEO |
| arch += f"/{ADAPTER_LORA}" |
| metadata["modelspec.architecture"] = arch |
|
|
| impl = IMPL_HUNYUAN_VIDEO |
| metadata["modelspec.implementation"] = impl |
|
|
| if title is None: |
| title = "LoRA" |
| title += f"@{timestamp}" |
| metadata[MODELSPEC_TITLE] = title |
|
|
| if author is not None: |
| metadata["modelspec.author"] = author |
| else: |
| del metadata["modelspec.author"] |
|
|
| if description is not None: |
| metadata["modelspec.description"] = description |
| else: |
| del metadata["modelspec.description"] |
|
|
| if merged_from is not None: |
| metadata["modelspec.merged_from"] = merged_from |
| else: |
| del metadata["modelspec.merged_from"] |
|
|
| if license is not None: |
| metadata["modelspec.license"] = license |
| else: |
| del metadata["modelspec.license"] |
|
|
| if tags is not None: |
| metadata["modelspec.tags"] = tags |
| else: |
| del metadata["modelspec.tags"] |
|
|
| |
| int_ts = int(timestamp) |
|
|
| |
| date = datetime.datetime.fromtimestamp(int_ts).isoformat() |
| metadata["modelspec.date"] = date |
|
|
| if reso is not None: |
| |
| if isinstance(reso, str): |
| reso = tuple(map(int, reso.split(","))) |
| if len(reso) == 1: |
| reso = (reso[0], reso[0]) |
| else: |
| |
| reso = (1280, 720) |
| if isinstance(reso, int): |
| reso = (reso, reso) |
|
|
| metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" |
|
|
| |
| del metadata["modelspec.prediction_type"] |
|
|
| if timesteps is not None: |
| if isinstance(timesteps, str) or isinstance(timesteps, int): |
| timesteps = (timesteps, timesteps) |
| if len(timesteps) == 1: |
| timesteps = (timesteps[0], timesteps[0]) |
| metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" |
| else: |
| del metadata["modelspec.timestep_range"] |
|
|
| |
| |
| |
| del metadata["modelspec.encoder_layer"] |
|
|
| |
| |
| if not all([v is not None for v in metadata.values()]): |
| logger.error(f"Internal error: some metadata values are None: {metadata}") |
|
|
| return metadata |
|
|
|
|
| |
|
|
|
|
| def get_title(metadata: dict) -> Optional[str]: |
| return metadata.get(MODELSPEC_TITLE, None) |
|
|
|
|
| def load_metadata_from_safetensors(model: str) -> dict: |
| if not model.endswith(".safetensors"): |
| return {} |
|
|
| with safetensors.safe_open(model, framework="pt") as f: |
| metadata = f.metadata() |
| if metadata is None: |
| metadata = {} |
| return metadata |
|
|
|
|
| def build_merged_from(models: List[str]) -> str: |
| def get_title(model: str): |
| metadata = load_metadata_from_safetensors(model) |
| title = metadata.get(MODELSPEC_TITLE, None) |
| if title is None: |
| title = os.path.splitext(os.path.basename(model))[0] |
| return title |
|
|
| titles = [get_title(model) for model in models] |
| return ", ".join(titles) |
|
|
|
|
| |
|
|
|
|
| r""" |
| if __name__ == "__main__": |
| import argparse |
| import torch |
| from safetensors.torch import load_file |
| from library import train_util |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--ckpt", type=str, required=True) |
| args = parser.parse_args() |
| |
| print(f"Loading {args.ckpt}") |
| state_dict = load_file(args.ckpt) |
| |
| print(f"Calculating metadata") |
| metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0) |
| print(metadata) |
| del state_dict |
| |
| # by reference implementation |
| with open(args.ckpt, mode="rb") as file_data: |
| file_hash = hashlib.sha256() |
| head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix |
| header = json.loads(file_data.read(head_len[0])) # header itself, json string |
| content = ( |
| file_data.read() |
| ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl. |
| file_hash.update(content) |
| # ===== Update the hash for modelspec ===== |
| by_ref = f"0x{file_hash.hexdigest()}" |
| print(by_ref) |
| print("is same?", by_ref == metadata["modelspec.hash_sha256"]) |
| |
| """ |
|
|