Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """ | |
| Usage example: | |
| TODO | |
| """ | |
| import ast | |
| import importlib.util | |
| import os | |
| from argparse import ArgumentParser, Namespace | |
| from pathlib import Path | |
| from ..utils import logging | |
| from . import BaseDiffusersCLICommand | |
| EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"] | |
| CONFIG = "config.json" | |
| def conversion_command_factory(args: Namespace): | |
| return CustomBlocksCommand(args.block_module_name, args.block_class_name) | |
| class CustomBlocksCommand(BaseDiffusersCLICommand): | |
| def register_subcommand(parser: ArgumentParser): | |
| conversion_parser = parser.add_parser("custom_blocks") | |
| conversion_parser.add_argument( | |
| "--block_module_name", | |
| type=str, | |
| default="block.py", | |
| help="Module filename in which the custom block will be implemented.", | |
| ) | |
| conversion_parser.add_argument( | |
| "--block_class_name", | |
| type=str, | |
| default=None, | |
| help="Name of the custom block. If provided None, we will try to infer it.", | |
| ) | |
| conversion_parser.set_defaults(func=conversion_command_factory) | |
| def __init__(self, block_module_name: str = "block.py", block_class_name: str = None): | |
| self.logger = logging.get_logger("diffusers-cli/custom_blocks") | |
| self.block_module_name = Path(block_module_name) | |
| self.block_class_name = block_class_name | |
| def run(self): | |
| # determine the block to be saved. | |
| out = self._get_class_names(self.block_module_name) | |
| classes_found = list({cls for cls, _ in out}) | |
| if self.block_class_name is not None: | |
| child_class, parent_class = self._choose_block(out, self.block_class_name) | |
| if child_class is None and parent_class is None: | |
| raise ValueError( | |
| "`block_class_name` could not be retrieved. Available classes from " | |
| f"{self.block_module_name}:\n{classes_found}" | |
| ) | |
| else: | |
| self.logger.info( | |
| f"Found classes: {classes_found} will be using {classes_found[0]}. " | |
| "If this needs to be changed, re-run the command specifying `block_class_name`." | |
| ) | |
| child_class, parent_class = out[0][0], out[0][1] | |
| # dynamically get the custom block and initialize it to call `save_pretrained` in the current directory. | |
| # the user is responsible for running it, so I guess that is safe? | |
| module_name = f"__dynamic__{self.block_module_name.stem}" | |
| spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name)) | |
| module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(module) | |
| getattr(module, child_class)().save_pretrained(os.getcwd()) | |
| # or, we could create it manually. | |
| # automap = self._create_automap(parent_class=parent_class, child_class=child_class) | |
| # with open(CONFIG, "w") as f: | |
| # json.dump(automap, f) | |
| with open("requirements.txt", "w") as f: | |
| f.write("") | |
| def _choose_block(self, candidates, chosen=None): | |
| for cls, base in candidates: | |
| if cls == chosen: | |
| return cls, base | |
| return None, None | |
| def _get_class_names(self, file_path): | |
| source = file_path.read_text(encoding="utf-8") | |
| try: | |
| tree = ast.parse(source, filename=file_path) | |
| except SyntaxError as e: | |
| raise ValueError(f"Could not parse {file_path!r}: {e}") from e | |
| results: list[tuple[str, str]] = [] | |
| for node in tree.body: | |
| if not isinstance(node, ast.ClassDef): | |
| continue | |
| # extract all base names for this class | |
| base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None] | |
| # for each allowed base that appears in the class's bases, emit a tuple | |
| for allowed in EXPECTED_PARENT_CLASSES: | |
| if allowed in base_names: | |
| results.append((node.name, allowed)) | |
| return results | |
| def _get_base_name(self, node: ast.expr): | |
| if isinstance(node, ast.Name): | |
| return node.id | |
| elif isinstance(node, ast.Attribute): | |
| val = self._get_base_name(node.value) | |
| return f"{val}.{node.attr}" if val else node.attr | |
| return None | |
| def _create_automap(self, parent_class, child_class): | |
| module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] | |
| auto_map = {f"{parent_class}": f"{module}.{child_class}"} | |
| return {"auto_map": auto_map} | |