|
|
import ast |
|
|
import importlib |
|
|
import inspect |
|
|
import textwrap |
|
|
|
|
|
|
|
|
class ReturnNameVisitor(ast.NodeVisitor): |
|
|
"""Thanks to ChatGPT for pairing.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.return_names = [] |
|
|
|
|
|
def visit_Return(self, node): |
|
|
|
|
|
if isinstance(node.value, ast.Tuple): |
|
|
for elt in node.value.elts: |
|
|
if isinstance(elt, ast.Name): |
|
|
self.return_names.append(elt.id) |
|
|
else: |
|
|
try: |
|
|
self.return_names.append(ast.unparse(elt)) |
|
|
except Exception: |
|
|
self.return_names.append(str(elt)) |
|
|
else: |
|
|
if isinstance(node.value, ast.Name): |
|
|
self.return_names.append(node.value.id) |
|
|
else: |
|
|
try: |
|
|
self.return_names.append(ast.unparse(node.value)) |
|
|
except Exception: |
|
|
self.return_names.append(str(node.value)) |
|
|
self.generic_visit(node) |
|
|
|
|
|
def _determine_parent_module(self, cls): |
|
|
from diffusers import DiffusionPipeline |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
|
|
|
if issubclass(cls, DiffusionPipeline): |
|
|
return "pipelines" |
|
|
elif issubclass(cls, ModelMixin): |
|
|
return "models" |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
def get_ast_tree(self, cls, attribute_name="encode_prompt"): |
|
|
parent_module_name = self._determine_parent_module(cls) |
|
|
main_module = importlib.import_module(f"diffusers.{parent_module_name}") |
|
|
current_cls_module = getattr(main_module, cls.__name__) |
|
|
source_code = inspect.getsource(getattr(current_cls_module, attribute_name)) |
|
|
source_code = textwrap.dedent(source_code) |
|
|
tree = ast.parse(source_code) |
|
|
return tree |
|
|
|