Spaces:
Paused
Paused
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Utility functions for Blissful Tuner extension | |
| License: Apache 2.0 | |
| Created on Sat Apr 12 14:09:37 2025 | |
| @author: blyss | |
| """ | |
| import argparse | |
| import hashlib | |
| import torch | |
| import safetensors | |
| from typing import List, Union, Dict, Tuple, Optional | |
| import logging | |
| from rich.logging import RichHandler | |
| # Adapted from ComfyUI | |
| def load_torch_file( | |
| ckpt: str, | |
| safe_load: Optional[bool] = True, | |
| device: Optional[Union[str, torch.device]] = None, | |
| return_metadata: Optional[bool] = False | |
| ) -> Union[ | |
| Dict[str, torch.Tensor], | |
| Tuple[Dict[str, torch.Tensor], Optional[Dict[str, str]]] | |
| ]: | |
| if device is None: | |
| device = torch.device("cpu") | |
| metadata = None | |
| if ckpt.lower().endswith(".safetensors") or ckpt.lower().endswith(".sft"): | |
| try: | |
| with safetensors.safe_open(ckpt, framework="pt", device=device.type) as f: | |
| sd = {} | |
| for k in f.keys(): | |
| sd[k] = f.get_tensor(k) | |
| if return_metadata: | |
| metadata = f.metadata() | |
| except Exception as e: | |
| if len(e.args) > 0: | |
| message = e.args[0] | |
| if "HeaderTooLarge" in message: | |
| raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt or invalid. Make sure this is actually a safetensors file and not a ckpt or pt or other filetype.".format(message, ckpt)) | |
| if "MetadataIncompleteBuffer" in message: | |
| raise ValueError("{}\n\nFile path: {}\n\nThe safetensors file is corrupt/incomplete. Check the file size and make sure you have copied/downloaded it correctly.".format(message, ckpt)) | |
| raise e | |
| else: | |
| pl_sd = torch.load(ckpt, map_location=device, weights_only=safe_load) | |
| if "state_dict" in pl_sd: | |
| sd = pl_sd["state_dict"] | |
| else: | |
| if len(pl_sd) == 1: | |
| key = list(pl_sd.keys())[0] | |
| sd = pl_sd[key] | |
| if not isinstance(sd, dict): | |
| sd = pl_sd | |
| else: | |
| sd = pl_sd | |
| return (sd, metadata) if return_metadata else sd | |
| def add_noise_to_reference_video( | |
| image: torch.Tensor, | |
| ratio: Optional[float] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Add Gaussian noise (scaled by `ratio`) to an image or batch of images. | |
| Supports: | |
| • Single image: (C, H, W) | |
| • Batch of images: (B, C, H, W) | |
| Any pixel exactly == –1 will have zero noise (mask value). | |
| """ | |
| if ratio is None or ratio == 0.0: | |
| return image | |
| dims = image.ndim | |
| if dims == 3: | |
| # Single image -> make it a batch of 1 | |
| image = image.unsqueeze(0) # -> (1, C, H, W) | |
| squeeze_back = True | |
| elif dims == 4: | |
| squeeze_back = False | |
| else: | |
| raise ValueError( | |
| f"add_noise_to_reference_video() expected 3D or 4D tensor, got {dims}D" | |
| ) | |
| # image is now (B, C, H, W) | |
| B, C, H, W = image.shape | |
| # make a (B,) sigma array, all = ratio | |
| sigma = image.new_ones((B,)) * ratio | |
| # sample noise and scale by sigma | |
| noise = torch.randn_like(image) * sigma.view(B, 1, 1, 1) | |
| # zero out noise wherever the original was -1 | |
| noise = torch.where(image == -1, torch.zeros_like(image), noise) | |
| out = image + noise | |
| return out.squeeze(0) if squeeze_back else out | |
| # Below here, Blyss wrote it! | |
| class BlissfulLogger: | |
| def __init__(self, logging_source: str, log_color: str, do_announce: Optional[bool] = False): | |
| logging_source = f"{logging_source}" | |
| self.logging_source = "{:<8}".format(logging_source) | |
| self.log_color = log_color | |
| self.logger = logging.getLogger(self.logging_source) | |
| self.logger.setLevel(logging.DEBUG) | |
| self.handler = RichHandler( | |
| show_time=False, | |
| show_level=True, | |
| show_path=True, | |
| rich_tracebacks=True, | |
| markup=True | |
| ) | |
| formatter = logging.Formatter( | |
| f"[{self.log_color} bold]%(name)s[/] | %(message)s [dim](%(funcName)s)[/]" | |
| ) | |
| self.handler.setFormatter(formatter) | |
| self.logger.addHandler(self.handler) | |
| if do_announce: | |
| self.logger.info("Set up logging!") | |
| def set_color(self, new_color): | |
| self.log_color = new_color | |
| formatter = logging.Formatter( | |
| f"[{self.log_color} bold]%(name)s[/] | %(message)s [dim](%(funcName)s)[/]" | |
| ) | |
| self.handler.setFormatter(formatter) | |
| def set_name(self, new_name): | |
| self.logging_source = "{:<8}".format(new_name) | |
| self.logger = logging.getLogger(self.logging_source) | |
| self.logger.setLevel(logging.DEBUG) | |
| # Remove any existing handlers (just in case) | |
| if not self.logger.hasHandlers(): | |
| self.logger.addHandler(self.handler) | |
| else: | |
| self.logger.handlers.clear() | |
| self.logger.addHandler(self.handler) | |
| def info(self, msg): | |
| self.logger.info(msg, stacklevel=2) | |
| def debug(self, msg): | |
| self.logger.debug(msg, stacklevel=2) | |
| def warning(self, msg, levelmod=0): | |
| self.logger.warning(msg, stacklevel=2 + levelmod) | |
| def warn(self, msg): | |
| self.logger.warning(msg, stacklevel=2) | |
| def error(self, msg): | |
| self.logger.error(msg, stacklevel=2) | |
| def critical(self, msg): | |
| self.logger.critical(msg, stacklevel=2) | |
| def setLevel(self, level): | |
| self.logger.set_level(level) | |
| def parse_scheduled_cfg(schedule: str, infer_steps: int, guidance_scale: int) -> List[int]: | |
| """ | |
| Parse a schedule string like "1-10,20,!5,e~3" into a sorted list of steps. | |
| - "start-end" includes all steps in [start, end] | |
| - "e~n" includes every nth step (n, 2n, ...) up to infer_steps | |
| - "x" includes the single step x | |
| - Prefix "!" on any token to exclude those steps instead of including them. | |
| - Postfix ":float" e.g. ":6.0" to any step or range to specify a guidance_scale override for that step | |
| Raises argparse.ArgumentTypeError on malformed tokens or out-of-range steps. | |
| """ | |
| excluded = set() | |
| guidance_scale_dict = {} | |
| for raw in schedule.split(","): | |
| token = raw.strip() | |
| if not token: | |
| continue # skip empty tokens | |
| # exclusion if it starts with "!" | |
| if token.startswith("!"): | |
| target = "exclude" | |
| token = token[1:] | |
| else: | |
| target = "include" | |
| weight = guidance_scale | |
| if ":" in token: | |
| token, float_part = token.rsplit(":", 1) | |
| weight = float(float_part) | |
| # modulus syntax: e.g. "e~3" | |
| if token.startswith("e~"): | |
| num_str = token[2:] | |
| try: | |
| n = int(num_str) | |
| except ValueError: | |
| raise argparse.ArgumentTypeError(f"Invalid modulus in '{raw}'") | |
| if n < 1: | |
| raise argparse.ArgumentTypeError(f"Modulus must be ≥ 1 in '{raw}'") | |
| steps = range(n, infer_steps + 1, n) | |
| # range syntax: e.g. "5-10" | |
| elif "-" in token: | |
| parts = token.split("-") | |
| if len(parts) != 2: | |
| raise argparse.ArgumentTypeError(f"Malformed range '{raw}'") | |
| start_str, end_str = parts | |
| try: | |
| start = int(start_str) | |
| end = int(end_str) | |
| except ValueError: | |
| raise argparse.ArgumentTypeError(f"Non‑integer in range '{raw}'") | |
| if start < 1 or end < 1: | |
| raise argparse.ArgumentTypeError(f"Steps must be ≥ 1 in '{raw}'") | |
| if start > end: | |
| raise argparse.ArgumentTypeError(f"Start > end in '{raw}'") | |
| if end > infer_steps: | |
| raise argparse.ArgumentTypeError(f"End > infer_steps ({infer_steps}) in '{raw}'") | |
| steps = range(start, end + 1) | |
| # single‑step syntax: e.g. "7" | |
| else: | |
| try: | |
| step = int(token) | |
| except ValueError: | |
| raise argparse.ArgumentTypeError(f"Invalid token '{raw}'") | |
| if step < 1 or step > infer_steps: | |
| raise argparse.ArgumentTypeError(f"Step {step} out of range 1–{infer_steps} in '{raw}'") | |
| steps = [step] | |
| # apply include/exclude | |
| if target == "include": | |
| for step in steps: | |
| guidance_scale_dict[step] = weight | |
| else: | |
| excluded.update(steps) | |
| for step in excluded: | |
| guidance_scale_dict.pop(step, None) | |
| return guidance_scale_dict | |
| def setup_compute_context(device: Optional[Union[torch.device, str]] = None, dtype: Optional[Union[torch.dtype, str]] = None) -> Tuple[torch.device, torch.dtype]: | |
| dtype_mapping = { | |
| "fp16": torch.float16, | |
| "float16": torch.float16, | |
| "bf16": torch.bfloat16, | |
| "bfloat16": torch.bfloat16, | |
| "fp32": torch.float32, | |
| "float32": torch.float32, | |
| "fp8": torch.float8_e4m3fn, | |
| "float8": torch.float8_e4m3fn | |
| } | |
| if device is None: | |
| device = torch.device("cpu") | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.mps.is_available(): | |
| device = torch.device("mps") | |
| elif isinstance(device, str): | |
| device = torch.device(device) | |
| if dtype is None: | |
| dtype = torch.float32 | |
| elif isinstance(dtype, str): | |
| if dtype not in dtype_mapping: | |
| raise ValueError(f"Unknown dtype string '{dtype}'") | |
| dtype = dtype_mapping[dtype] | |
| torch.set_float32_matmul_precision('high') | |
| if dtype == torch.float16 or dtype == torch.bfloat16: | |
| if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): | |
| torch.backends.cuda.matmul.allow_fp16_accumulation = True | |
| print("FP16 accumulation enabled.") | |
| return device, dtype | |
| def string_to_seed(s: str, bits: int = 63) -> int: | |
| """ | |
| Turn any string into a reproducible integer in [0, 2**bits) with a hash and some other logic. | |
| Args: | |
| s: Input string | |
| bits: Number of bits for the final seed (PyTorch accepts up to 63 safely, numpy likes 32) | |
| Returns: | |
| A non-negative int < 2**bits | |
| """ | |
| digest = hashlib.sha256(s.encode("utf-8")).digest() | |
| crypto = int.from_bytes(digest, byteorder="big") | |
| mask = (1 << bits) - 1 | |
| algo = 0 | |
| for i, char in enumerate(s): | |
| char_val = ord(char) | |
| if i % 2 == 0: | |
| algo *= char_val | |
| elif i % 3 == 0: | |
| algo -= char_val | |
| elif i % 5 == 0: | |
| algo /= char_val | |
| else: | |
| algo += char_val | |
| seed = (abs(crypto - int(algo))) & mask | |
| return seed | |
| def error_out(error, message): | |
| logger = BlissfulLogger(__name__, "#8e00ed") | |
| logger.warning(message, levelmod=1) | |
| raise error(message) | |