|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Usage example: |
|
|
TODO |
|
|
""" |
|
|
|
|
|
import ast |
|
|
import importlib.util |
|
|
import os |
|
|
from argparse import ArgumentParser, Namespace |
|
|
from pathlib import Path |
|
|
|
|
|
from ..utils import logging |
|
|
from . import BaseDiffusersCLICommand |
|
|
|
|
|
|
|
|
EXPECTED_PARENT_CLASSES = ["ModularPipelineBlocks"] |
|
|
CONFIG = "config.json" |
|
|
|
|
|
|
|
|
def conversion_command_factory(args: Namespace): |
|
|
return CustomBlocksCommand(args.block_module_name, args.block_class_name) |
|
|
|
|
|
|
|
|
class CustomBlocksCommand(BaseDiffusersCLICommand): |
|
|
@staticmethod |
|
|
def register_subcommand(parser: ArgumentParser): |
|
|
conversion_parser = parser.add_parser("custom_blocks") |
|
|
conversion_parser.add_argument( |
|
|
"--block_module_name", |
|
|
type=str, |
|
|
default="block.py", |
|
|
help="Module filename in which the custom block will be implemented.", |
|
|
) |
|
|
conversion_parser.add_argument( |
|
|
"--block_class_name", |
|
|
type=str, |
|
|
default=None, |
|
|
help="Name of the custom block. If provided None, we will try to infer it.", |
|
|
) |
|
|
conversion_parser.set_defaults(func=conversion_command_factory) |
|
|
|
|
|
def __init__(self, block_module_name: str = "block.py", block_class_name: str = None): |
|
|
self.logger = logging.get_logger("diffusers-cli/custom_blocks") |
|
|
self.block_module_name = Path(block_module_name) |
|
|
self.block_class_name = block_class_name |
|
|
|
|
|
def run(self): |
|
|
|
|
|
out = self._get_class_names(self.block_module_name) |
|
|
classes_found = list({cls for cls, _ in out}) |
|
|
|
|
|
if self.block_class_name is not None: |
|
|
child_class, parent_class = self._choose_block(out, self.block_class_name) |
|
|
if child_class is None and parent_class is None: |
|
|
raise ValueError( |
|
|
"`block_class_name` could not be retrieved. Available classes from " |
|
|
f"{self.block_module_name}:\n{classes_found}" |
|
|
) |
|
|
else: |
|
|
self.logger.info( |
|
|
f"Found classes: {classes_found} will be using {classes_found[0]}. " |
|
|
"If this needs to be changed, re-run the command specifying `block_class_name`." |
|
|
) |
|
|
child_class, parent_class = out[0][0], out[0][1] |
|
|
|
|
|
|
|
|
|
|
|
module_name = f"__dynamic__{self.block_module_name.stem}" |
|
|
spec = importlib.util.spec_from_file_location(module_name, str(self.block_module_name)) |
|
|
module = importlib.util.module_from_spec(spec) |
|
|
spec.loader.exec_module(module) |
|
|
getattr(module, child_class)().save_pretrained(os.getcwd()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open("requirements.txt", "w") as f: |
|
|
f.write("") |
|
|
|
|
|
def _choose_block(self, candidates, chosen=None): |
|
|
for cls, base in candidates: |
|
|
if cls == chosen: |
|
|
return cls, base |
|
|
return None, None |
|
|
|
|
|
def _get_class_names(self, file_path): |
|
|
source = file_path.read_text(encoding="utf-8") |
|
|
try: |
|
|
tree = ast.parse(source, filename=file_path) |
|
|
except SyntaxError as e: |
|
|
raise ValueError(f"Could not parse {file_path!r}: {e}") from e |
|
|
|
|
|
results: list[tuple[str, str]] = [] |
|
|
for node in tree.body: |
|
|
if not isinstance(node, ast.ClassDef): |
|
|
continue |
|
|
|
|
|
|
|
|
base_names = [bname for b in node.bases if (bname := self._get_base_name(b)) is not None] |
|
|
|
|
|
|
|
|
for allowed in EXPECTED_PARENT_CLASSES: |
|
|
if allowed in base_names: |
|
|
results.append((node.name, allowed)) |
|
|
|
|
|
return results |
|
|
|
|
|
def _get_base_name(self, node: ast.expr): |
|
|
if isinstance(node, ast.Name): |
|
|
return node.id |
|
|
elif isinstance(node, ast.Attribute): |
|
|
val = self._get_base_name(node.value) |
|
|
return f"{val}.{node.attr}" if val else node.attr |
|
|
return None |
|
|
|
|
|
def _create_automap(self, parent_class, child_class): |
|
|
module = str(self.block_module_name).replace(".py", "").rsplit(".", 1)[-1] |
|
|
auto_map = {f"{parent_class}": f"{module}.{child_class}"} |
|
|
return {"auto_map": auto_map} |
|
|
|