| | |
| | import datetime |
| | import hashlib |
| | from io import BytesIO |
| | import os |
| | from typing import List, Optional, Tuple, Union |
| | import safetensors |
| |
|
| | r""" |
| | # Metadata Example |
| | metadata = { |
| | # === Must === |
| | "modelspec.sai_model_spec": "1.0.0", # Required version ID for the spec |
| | "modelspec.architecture": "stable-diffusion-xl-v1-base", # Architecture, reference the ID of the original model of the arch to match the ID |
| | "modelspec.implementation": "sgm", |
| | "modelspec.title": "Example Model Version 1.0", # Clean, human-readable title. May use your own phrasing/language/etc |
| | # === Should === |
| | "modelspec.author": "Example Corp", # Your name or company name |
| | "modelspec.description": "This is my example model to show you how to do it!", # Describe the model in your own words/language/etc. Focus on what users need to know |
| | "modelspec.date": "2023-07-20", # ISO-8601 compliant date of when the model was created |
| | # === Can === |
| | "modelspec.license": "ExampleLicense-1.0", # eg CreativeML Open RAIL, etc. |
| | "modelspec.usage_hint": "Use keyword 'example'" # In your own language, very short hints about how the user should use the model |
| | } |
| | """ |
| |
|
| | BASE_METADATA = { |
| | |
| | "modelspec.sai_model_spec": "1.0.0", |
| | "modelspec.architecture": None, |
| | "modelspec.implementation": None, |
| | "modelspec.title": None, |
| | "modelspec.resolution": None, |
| | |
| | "modelspec.description": None, |
| | "modelspec.author": None, |
| | "modelspec.date": None, |
| | |
| | "modelspec.license": None, |
| | "modelspec.tags": None, |
| | "modelspec.merged_from": None, |
| | "modelspec.prediction_type": None, |
| | "modelspec.timestep_range": None, |
| | "modelspec.encoder_layer": None, |
| | } |
| |
|
| | |
| | MODELSPEC_TITLE = "modelspec.title" |
| |
|
| | ARCH_SD_V1 = "stable-diffusion-v1" |
| | ARCH_SD_V2_512 = "stable-diffusion-v2-512" |
| | ARCH_SD_V2_768_V = "stable-diffusion-v2-768-v" |
| | ARCH_SD_XL_V1_BASE = "stable-diffusion-xl-v1-base" |
| |
|
| | ADAPTER_LORA = "lora" |
| | ADAPTER_TEXTUAL_INVERSION = "textual-inversion" |
| |
|
| | IMPL_STABILITY_AI = "https://github.com/Stability-AI/generative-models" |
| | IMPL_DIFFUSERS = "diffusers" |
| |
|
| | PRED_TYPE_EPSILON = "epsilon" |
| | PRED_TYPE_V = "v" |
| |
|
| |
|
| | def load_bytes_in_safetensors(tensors): |
| | bytes = safetensors.torch.save(tensors) |
| | b = BytesIO(bytes) |
| |
|
| | b.seek(0) |
| | header = b.read(8) |
| | n = int.from_bytes(header, "little") |
| |
|
| | offset = n + 8 |
| | b.seek(offset) |
| |
|
| | return b.read() |
| |
|
| |
|
| | def precalculate_safetensors_hashes(state_dict): |
| | |
| | hash_sha256 = hashlib.sha256() |
| | for tensor in state_dict.values(): |
| | single_tensor_sd = {"tensor": tensor} |
| | bytes_for_tensor = load_bytes_in_safetensors(single_tensor_sd) |
| | hash_sha256.update(bytes_for_tensor) |
| |
|
| | return f"0x{hash_sha256.hexdigest()}" |
| |
|
| |
|
| | def update_hash_sha256(metadata: dict, state_dict: dict): |
| | raise NotImplementedError |
| |
|
| |
|
| | def build_metadata( |
| | state_dict: Optional[dict], |
| | v2: bool, |
| | v_parameterization: bool, |
| | sdxl: bool, |
| | lora: bool, |
| | textual_inversion: bool, |
| | timestamp: float, |
| | title: Optional[str] = None, |
| | reso: Optional[Union[int, Tuple[int, int]]] = None, |
| | is_stable_diffusion_ckpt: Optional[bool] = None, |
| | author: Optional[str] = None, |
| | description: Optional[str] = None, |
| | license: Optional[str] = None, |
| | tags: Optional[str] = None, |
| | merged_from: Optional[str] = None, |
| | timesteps: Optional[Tuple[int, int]] = None, |
| | clip_skip: Optional[int] = None, |
| | ): |
| | |
| |
|
| | metadata = {} |
| | metadata.update(BASE_METADATA) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | if sdxl: |
| | arch = ARCH_SD_XL_V1_BASE |
| | elif v2: |
| | if v_parameterization: |
| | arch = ARCH_SD_V2_768_V |
| | else: |
| | arch = ARCH_SD_V2_512 |
| | else: |
| | arch = ARCH_SD_V1 |
| |
|
| | if lora: |
| | arch += f"/{ADAPTER_LORA}" |
| | elif textual_inversion: |
| | arch += f"/{ADAPTER_TEXTUAL_INVERSION}" |
| |
|
| | metadata["modelspec.architecture"] = arch |
| |
|
| | if not lora and not textual_inversion and is_stable_diffusion_ckpt is None: |
| | is_stable_diffusion_ckpt = True |
| |
|
| | if (lora and sdxl) or textual_inversion or is_stable_diffusion_ckpt: |
| | |
| | impl = IMPL_STABILITY_AI |
| | else: |
| | |
| | impl = IMPL_DIFFUSERS |
| | metadata["modelspec.implementation"] = impl |
| |
|
| | if title is None: |
| | if lora: |
| | title = "LoRA" |
| | elif textual_inversion: |
| | title = "TextualInversion" |
| | else: |
| | title = "Checkpoint" |
| | title += f"@{timestamp}" |
| | metadata[MODELSPEC_TITLE] = title |
| |
|
| | if author is not None: |
| | metadata["modelspec.author"] = author |
| | else: |
| | del metadata["modelspec.author"] |
| |
|
| | if description is not None: |
| | metadata["modelspec.description"] = description |
| | else: |
| | del metadata["modelspec.description"] |
| |
|
| | if merged_from is not None: |
| | metadata["modelspec.merged_from"] = merged_from |
| | else: |
| | del metadata["modelspec.merged_from"] |
| |
|
| | if license is not None: |
| | metadata["modelspec.license"] = license |
| | else: |
| | del metadata["modelspec.license"] |
| |
|
| | if tags is not None: |
| | metadata["modelspec.tags"] = tags |
| | else: |
| | del metadata["modelspec.tags"] |
| |
|
| | |
| | int_ts = int(timestamp) |
| |
|
| | |
| | date = datetime.datetime.fromtimestamp(int_ts).isoformat() |
| | metadata["modelspec.date"] = date |
| |
|
| | if reso is not None: |
| | |
| | if isinstance(reso, str): |
| | reso = tuple(map(int, reso.split(","))) |
| | if len(reso) == 1: |
| | reso = (reso[0], reso[0]) |
| | else: |
| | |
| | if sdxl: |
| | reso = 1024 |
| | elif v2 and v_parameterization: |
| | reso = 768 |
| | else: |
| | reso = 512 |
| | if isinstance(reso, int): |
| | reso = (reso, reso) |
| |
|
| | metadata["modelspec.resolution"] = f"{reso[0]}x{reso[1]}" |
| |
|
| | if v_parameterization: |
| | metadata["modelspec.prediction_type"] = PRED_TYPE_V |
| | else: |
| | metadata["modelspec.prediction_type"] = PRED_TYPE_EPSILON |
| |
|
| | if timesteps is not None: |
| | if isinstance(timesteps, str) or isinstance(timesteps, int): |
| | timesteps = (timesteps, timesteps) |
| | if len(timesteps) == 1: |
| | timesteps = (timesteps[0], timesteps[0]) |
| | metadata["modelspec.timestep_range"] = f"{timesteps[0]},{timesteps[1]}" |
| | else: |
| | del metadata["modelspec.timestep_range"] |
| |
|
| | if clip_skip is not None: |
| | metadata["modelspec.encoder_layer"] = f"{clip_skip}" |
| | else: |
| | del metadata["modelspec.encoder_layer"] |
| |
|
| | |
| | |
| | if not all([v is not None for v in metadata.values()]): |
| | print(f"Internal error: some metadata values are None: {metadata}") |
| | |
| | return metadata |
| |
|
| |
|
| | |
| |
|
| |
|
| | def get_title(metadata: dict) -> Optional[str]: |
| | return metadata.get(MODELSPEC_TITLE, None) |
| |
|
| |
|
| | def load_metadata_from_safetensors(model: str) -> dict: |
| | if not model.endswith(".safetensors"): |
| | return {} |
| | |
| | with safetensors.safe_open(model, framework="pt") as f: |
| | metadata = f.metadata() |
| | if metadata is None: |
| | metadata = {} |
| | return metadata |
| |
|
| |
|
| | def build_merged_from(models: List[str]) -> str: |
| | def get_title(model: str): |
| | metadata = load_metadata_from_safetensors(model) |
| | title = metadata.get(MODELSPEC_TITLE, None) |
| | if title is None: |
| | title = os.path.splitext(os.path.basename(model))[0] |
| | return title |
| |
|
| | titles = [get_title(model) for model in models] |
| | return ", ".join(titles) |
| |
|
| |
|
| | |
| |
|
| |
|
| | r""" |
| | if __name__ == "__main__": |
| | import argparse |
| | import torch |
| | from safetensors.torch import load_file |
| | from library import train_util |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--ckpt", type=str, required=True) |
| | args = parser.parse_args() |
| | |
| | print(f"Loading {args.ckpt}") |
| | state_dict = load_file(args.ckpt) |
| | |
| | print(f"Calculating metadata") |
| | metadata = get(state_dict, False, False, False, False, "sgm", False, False, "title", "date", 256, 1000, 0) |
| | print(metadata) |
| | del state_dict |
| | |
| | # by reference implementation |
| | with open(args.ckpt, mode="rb") as file_data: |
| | file_hash = hashlib.sha256() |
| | head_len = struct.unpack("Q", file_data.read(8)) # int64 header length prefix |
| | header = json.loads(file_data.read(head_len[0])) # header itself, json string |
| | content = ( |
| | file_data.read() |
| | ) # All other content is tightly packed tensors. Copy to RAM for simplicity, but you can avoid this read with a more careful FS-dependent impl. |
| | file_hash.update(content) |
| | # ===== Update the hash for modelspec ===== |
| | by_ref = f"0x{file_hash.hexdigest()}" |
| | print(by_ref) |
| | print("is same?", by_ref == metadata["modelspec.hash_sha256"]) |
| | |
| | """ |
| |
|