| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Usage example: |
| TODO |
| """ |
|
|
| import ast |
| import importlib.util |
| import os |
| from argparse import ArgumentParser, Namespace |
| from pathlib import Path |
|
|
| from ..utils import logging |
| from . import BaseDiffusersCLICommand |
|
|
|
|
| EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"] |
| CONFIG = "config.json" |
|
|
|
|
| def conversion_command_factory(args: Namespace): |
| return CustomBlocksCommand(args.block_module_name, args.block_class_name) |
|
|
|
|
| class CustomBlocksCommand(BaseDiffusersCLICommand): |
| @staticmethod |
| def register_subcommand(parser: ArgumentParser): |
| conversion_parser = parser.add_parser("custom_blocks") |
| conversion_parser.add_argument( |
| "--block_module_name", |
| type=str, |
| default="block.py", |
| help="Module filename in which the custom block will be implemented.", |
| ) |
| conversion_parser.add_argument( |
| "--block_class_name", |
| type=str, |
| default=None, |
| help="Name of the custom block. If provided None, we will try to infer it.", |
| ) |
| conversion_parser.set_defaults(func=conversion_command_factory) |
|
|
| def __init__(self, block_module_name: str = "block.py", block_class_name: str = None): |
| self.logger = logging.get_logger("diffusers-cli/custom_blocks") |
| self.block_module_name = Path(block_module_name) |
| self.block_class_name = block_class_name |
|
|
| def run(self): |
| |
| out = self._get_class_names(self.block_module_name) |
| classes_found = list({cls for cls, _ in out}) |
|
|
| if self.block_class_name is not None: |
| child_class, parent_class = self._choose_block(out, self.block_class_name) |
| if child_class is None and parent_class is None: |
| raise ValueError( |
| "`block_class_name` could not be retrieved. Available classes from " |
| f"{self.block_module_name}:\n{classes_found}" |
| ) |
| else: |
| self.logger.info( |
| f"Found classes: {classes_found} will be using {classes_found[0]}. " |
| "If this needs to be changed, re-run the command specifying `block_class_name`." |
| ) |
| child_class, parent_class = out[0][0], out[0][1] |
|
|
| |
| |
| module_name = f"__dynamic__{self.block_module_name.stem}" |
| spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name)) |
| module = importlib.util.module_from_spec(spec) |
| spec.loader.exec_module(module) |
| getattr(module, child_class)().save_pretrained(os.getcwd()) |
|
|
| |
| |
| |
| |
|
|
| def _choose_block(self, candidates, chosen=None): |
| for cls, base in candidates: |
| if cls == chosen: |
| return cls, base |
| return None, None |
|
|
| def _get_class_names(self, file_path): |
| source = file_path.read_text(encoding="utf-8") |
| try: |
| tree = ast.parse(source, filename=file_path) |
| except SyntaxError as e: |
| raise ValueError(f"Could not parse {file_path!r}: {e}") from e |
|
|
| results: list[tuple[str, str]] = [] |
| for node in tree.body: |
| if not isinstance(node, ast.ClassDef): |
| continue |
|
|
| |
| base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None] |
|
|
| |
| for allowed in EXPECTED_PARENT_CLASSES: |
| if allowed in base_names: |
| results.append((node.name, allowed)) |
|
|
| return results |
|
|
| def _get_base_name(self, node: ast.expr): |
| if isinstance(node, ast.Name): |
| return node.id |
| elif isinstance(node, ast.Attribute): |
| val = self._get_base_name(node.value) |
| return f"{val}.{node.attr}" if val else node.attr |
| return None |
|
|
| def _create_automap(self, parent_class, child_class): |
| module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] |
| auto_map = {f"{parent_class}": f"{module}.{child_class}"} |
| return {"auto_map": auto_map} |
|
|