Spaces:
Running
Running
| import hashlib | |
| from io import BytesIO | |
| from typing import Optional | |
| import safetensors.torch | |
| import torch | |
| def model_hash(filename): | |
| """Old model hash used by stable-diffusion-webui""" | |
| try: | |
| with open(filename, "rb") as file: | |
| m = hashlib.sha256() | |
| file.seek(0x100000) | |
| m.update(file.read(0x10000)) | |
| return m.hexdigest()[0:8] | |
| except FileNotFoundError: | |
| return "NOFILE" | |
| except IsADirectoryError: # Linux? | |
| return "IsADirectory" | |
| except PermissionError: # Windows | |
| return "IsADirectory" | |
| def calculate_sha256(filename): | |
| """New model hash used by stable-diffusion-webui""" | |
| try: | |
| hash_sha256 = hashlib.sha256() | |
| blksize = 1024 * 1024 | |
| with open(filename, "rb") as f: | |
| for chunk in iter(lambda: f.read(blksize), b""): | |
| hash_sha256.update(chunk) | |
| return hash_sha256.hexdigest() | |
| except FileNotFoundError: | |
| return "NOFILE" | |
| except IsADirectoryError: # Linux? | |
| return "IsADirectory" | |
| except PermissionError: # Windows | |
| return "IsADirectory" | |
| def addnet_hash_legacy(b): | |
| """Old model hash used by sd-webui-additional-networks for .safetensors format files""" | |
| m = hashlib.sha256() | |
| b.seek(0x100000) | |
| m.update(b.read(0x10000)) | |
| return m.hexdigest()[0:8] | |
| def addnet_hash_safetensors(b): | |
| """New model hash used by sd-webui-additional-networks for .safetensors format files""" | |
| hash_sha256 = hashlib.sha256() | |
| blksize = 1024 * 1024 | |
| b.seek(0) | |
| header = b.read(8) | |
| n = int.from_bytes(header, "little") | |
| offset = n + 8 | |
| b.seek(offset) | |
| for chunk in iter(lambda: b.read(blksize), b""): | |
| hash_sha256.update(chunk) | |
| return hash_sha256.hexdigest() | |
| def precalculate_safetensors_hashes(tensors, metadata): | |
| """Precalculate the model hashes needed by sd-webui-additional-networks to | |
| save time on indexing the model later.""" | |
| # Because writing user metadata to the file can change the result of | |
| # sd_models.model_hash(), only retain the training metadata for purposes of | |
| # calculating the hash, as they are meant to be immutable | |
| metadata = {k: v for k, v in metadata.items() if k.startswith("ss_")} | |
| bytes = safetensors.torch.save(tensors, metadata) | |
| b = BytesIO(bytes) | |
| model_hash = addnet_hash_safetensors(b) | |
| legacy_hash = addnet_hash_legacy(b) | |
| return model_hash, legacy_hash | |
| def dtype_to_str(dtype: torch.dtype) -> str: | |
| # get name of the dtype | |
| dtype_name = str(dtype).split(".")[-1] | |
| return dtype_name | |
| def str_to_dtype(s: Optional[str], default_dtype: Optional[torch.dtype] = None) -> torch.dtype: | |
| """ | |
| Convert a string to a torch.dtype | |
| Args: | |
| s: string representation of the dtype | |
| default_dtype: default dtype to return if s is None | |
| Returns: | |
| torch.dtype: the corresponding torch.dtype | |
| Raises: | |
| ValueError: if the dtype is not supported | |
| Examples: | |
| >>> str_to_dtype("float32") | |
| torch.float32 | |
| >>> str_to_dtype("fp32") | |
| torch.float32 | |
| >>> str_to_dtype("float16") | |
| torch.float16 | |
| >>> str_to_dtype("fp16") | |
| torch.float16 | |
| >>> str_to_dtype("bfloat16") | |
| torch.bfloat16 | |
| >>> str_to_dtype("bf16") | |
| torch.bfloat16 | |
| >>> str_to_dtype("fp8") | |
| torch.float8_e4m3fn | |
| >>> str_to_dtype("fp8_e4m3fn") | |
| torch.float8_e4m3fn | |
| >>> str_to_dtype("fp8_e4m3fnuz") | |
| torch.float8_e4m3fnuz | |
| >>> str_to_dtype("fp8_e5m2") | |
| torch.float8_e5m2 | |
| >>> str_to_dtype("fp8_e5m2fnuz") | |
| torch.float8_e5m2fnuz | |
| """ | |
| if s is None: | |
| return default_dtype | |
| if s in ["bf16", "bfloat16"]: | |
| return torch.bfloat16 | |
| elif s in ["fp16", "float16"]: | |
| return torch.float16 | |
| elif s in ["fp32", "float32", "float"]: | |
| return torch.float32 | |
| elif s in ["fp8_e4m3fn", "e4m3fn", "float8_e4m3fn"]: | |
| return torch.float8_e4m3fn | |
| elif s in ["fp8_e4m3fnuz", "e4m3fnuz", "float8_e4m3fnuz"]: | |
| return torch.float8_e4m3fnuz | |
| elif s in ["fp8_e5m2", "e5m2", "float8_e5m2"]: | |
| return torch.float8_e5m2 | |
| elif s in ["fp8_e5m2fnuz", "e5m2fnuz", "float8_e5m2fnuz"]: | |
| return torch.float8_e5m2fnuz | |
| elif s in ["fp8", "float8"]: | |
| return torch.float8_e4m3fn # default fp8 | |
| else: | |
| raise ValueError(f"Unsupported dtype: {s}") | |