|
|
import json |
|
|
import logging |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import numpy as np |
|
|
import PIL |
|
|
import torch |
|
|
|
|
|
from ..configuration_utils import ConfigMixin |
|
|
from ..image_processor import PipelineImageInput |
|
|
from .modular_pipeline import ModularPipelineBlocks, SequentialPipelineBlocks |
|
|
from .modular_pipeline_utils import InputParam |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
SDXL_INPUTS_SCHEMA = { |
|
|
"prompt": InputParam( |
|
|
"prompt", type_hint=Union[str, List[str]], description="The prompt or prompts to guide the image generation" |
|
|
), |
|
|
"prompt_2": InputParam( |
|
|
"prompt_2", |
|
|
type_hint=Union[str, List[str]], |
|
|
description="The prompt or prompts to be sent to the tokenizer_2 and text_encoder_2", |
|
|
), |
|
|
"negative_prompt": InputParam( |
|
|
"negative_prompt", |
|
|
type_hint=Union[str, List[str]], |
|
|
description="The prompt or prompts not to guide the image generation", |
|
|
), |
|
|
"negative_prompt_2": InputParam( |
|
|
"negative_prompt_2", |
|
|
type_hint=Union[str, List[str]], |
|
|
description="The negative prompt or prompts for text_encoder_2", |
|
|
), |
|
|
"cross_attention_kwargs": InputParam( |
|
|
"cross_attention_kwargs", |
|
|
type_hint=Optional[dict], |
|
|
description="Kwargs dictionary passed to the AttentionProcessor", |
|
|
), |
|
|
"clip_skip": InputParam( |
|
|
"clip_skip", type_hint=Optional[int], description="Number of layers to skip in CLIP text encoder" |
|
|
), |
|
|
"image": InputParam( |
|
|
"image", |
|
|
type_hint=PipelineImageInput, |
|
|
required=True, |
|
|
description="The image(s) to modify for img2img or inpainting", |
|
|
), |
|
|
"mask_image": InputParam( |
|
|
"mask_image", |
|
|
type_hint=PipelineImageInput, |
|
|
required=True, |
|
|
description="Mask image for inpainting, white pixels will be repainted", |
|
|
), |
|
|
"generator": InputParam( |
|
|
"generator", |
|
|
type_hint=Optional[Union[torch.Generator, List[torch.Generator]]], |
|
|
description="Generator(s) for deterministic generation", |
|
|
), |
|
|
"height": InputParam("height", type_hint=Optional[int], description="Height in pixels of the generated image"), |
|
|
"width": InputParam("width", type_hint=Optional[int], description="Width in pixels of the generated image"), |
|
|
"num_images_per_prompt": InputParam( |
|
|
"num_images_per_prompt", type_hint=int, default=1, description="Number of images to generate per prompt" |
|
|
), |
|
|
"num_inference_steps": InputParam( |
|
|
"num_inference_steps", type_hint=int, default=50, description="Number of denoising steps" |
|
|
), |
|
|
"timesteps": InputParam( |
|
|
"timesteps", type_hint=Optional[torch.Tensor], description="Custom timesteps for the denoising process" |
|
|
), |
|
|
"sigmas": InputParam( |
|
|
"sigmas", type_hint=Optional[torch.Tensor], description="Custom sigmas for the denoising process" |
|
|
), |
|
|
"denoising_end": InputParam( |
|
|
"denoising_end", |
|
|
type_hint=Optional[float], |
|
|
description="Fraction of denoising process to complete before termination", |
|
|
), |
|
|
|
|
|
"strength": InputParam( |
|
|
"strength", type_hint=float, default=0.3, description="How much to transform the reference image" |
|
|
), |
|
|
"denoising_start": InputParam( |
|
|
"denoising_start", type_hint=Optional[float], description="Starting point of the denoising process" |
|
|
), |
|
|
"latents": InputParam( |
|
|
"latents", type_hint=Optional[torch.Tensor], description="Pre-generated noisy latents for image generation" |
|
|
), |
|
|
"padding_mask_crop": InputParam( |
|
|
"padding_mask_crop", |
|
|
type_hint=Optional[Tuple[int, int]], |
|
|
description="Size of margin in crop for image and mask", |
|
|
), |
|
|
"original_size": InputParam( |
|
|
"original_size", |
|
|
type_hint=Optional[Tuple[int, int]], |
|
|
description="Original size of the image for SDXL's micro-conditioning", |
|
|
), |
|
|
"target_size": InputParam( |
|
|
"target_size", type_hint=Optional[Tuple[int, int]], description="Target size for SDXL's micro-conditioning" |
|
|
), |
|
|
"negative_original_size": InputParam( |
|
|
"negative_original_size", |
|
|
type_hint=Optional[Tuple[int, int]], |
|
|
description="Negative conditioning based on image resolution", |
|
|
), |
|
|
"negative_target_size": InputParam( |
|
|
"negative_target_size", |
|
|
type_hint=Optional[Tuple[int, int]], |
|
|
description="Negative conditioning based on target resolution", |
|
|
), |
|
|
"crops_coords_top_left": InputParam( |
|
|
"crops_coords_top_left", |
|
|
type_hint=Tuple[int, int], |
|
|
default=(0, 0), |
|
|
description="Top-left coordinates for SDXL's micro-conditioning", |
|
|
), |
|
|
"negative_crops_coords_top_left": InputParam( |
|
|
"negative_crops_coords_top_left", |
|
|
type_hint=Tuple[int, int], |
|
|
default=(0, 0), |
|
|
description="Negative conditioning crop coordinates", |
|
|
), |
|
|
"aesthetic_score": InputParam( |
|
|
"aesthetic_score", type_hint=float, default=6.0, description="Simulates aesthetic score of generated image" |
|
|
), |
|
|
"negative_aesthetic_score": InputParam( |
|
|
"negative_aesthetic_score", type_hint=float, default=2.0, description="Simulates negative aesthetic score" |
|
|
), |
|
|
"eta": InputParam("eta", type_hint=float, default=0.0, description="Parameter η in the DDIM paper"), |
|
|
"output_type": InputParam( |
|
|
"output_type", type_hint=str, default="pil", description="Output format (pil/tensor/np.array)" |
|
|
), |
|
|
"ip_adapter_image": InputParam( |
|
|
"ip_adapter_image", |
|
|
type_hint=PipelineImageInput, |
|
|
required=True, |
|
|
description="Image(s) to be used as IP adapter", |
|
|
), |
|
|
"control_image": InputParam( |
|
|
"control_image", type_hint=PipelineImageInput, required=True, description="ControlNet input condition" |
|
|
), |
|
|
"control_guidance_start": InputParam( |
|
|
"control_guidance_start", |
|
|
type_hint=Union[float, List[float]], |
|
|
default=0.0, |
|
|
description="When ControlNet starts applying", |
|
|
), |
|
|
"control_guidance_end": InputParam( |
|
|
"control_guidance_end", |
|
|
type_hint=Union[float, List[float]], |
|
|
default=1.0, |
|
|
description="When ControlNet stops applying", |
|
|
), |
|
|
"controlnet_conditioning_scale": InputParam( |
|
|
"controlnet_conditioning_scale", |
|
|
type_hint=Union[float, List[float]], |
|
|
default=1.0, |
|
|
description="Scale factor for ControlNet outputs", |
|
|
), |
|
|
"guess_mode": InputParam( |
|
|
"guess_mode", |
|
|
type_hint=bool, |
|
|
default=False, |
|
|
description="Enables ControlNet encoder to recognize input without prompts", |
|
|
), |
|
|
"control_mode": InputParam( |
|
|
"control_mode", type_hint=List[int], required=True, description="Control mode for union controlnet" |
|
|
), |
|
|
} |
|
|
|
|
|
SDXL_INTERMEDIATE_INPUTS_SCHEMA = { |
|
|
"prompt_embeds": InputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
required=True, |
|
|
description="Text embeddings used to guide image generation", |
|
|
), |
|
|
"negative_prompt_embeds": InputParam( |
|
|
"negative_prompt_embeds", type_hint=torch.Tensor, description="Negative text embeddings" |
|
|
), |
|
|
"pooled_prompt_embeds": InputParam( |
|
|
"pooled_prompt_embeds", type_hint=torch.Tensor, required=True, description="Pooled text embeddings" |
|
|
), |
|
|
"negative_pooled_prompt_embeds": InputParam( |
|
|
"negative_pooled_prompt_embeds", type_hint=torch.Tensor, description="Negative pooled text embeddings" |
|
|
), |
|
|
"batch_size": InputParam("batch_size", type_hint=int, required=True, description="Number of prompts"), |
|
|
"dtype": InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), |
|
|
"preprocess_kwargs": InputParam( |
|
|
"preprocess_kwargs", type_hint=Optional[dict], description="Kwargs for ImageProcessor" |
|
|
), |
|
|
"latents": InputParam( |
|
|
"latents", type_hint=torch.Tensor, required=True, description="Initial latents for denoising process" |
|
|
), |
|
|
"timesteps": InputParam("timesteps", type_hint=torch.Tensor, required=True, description="Timesteps for inference"), |
|
|
"num_inference_steps": InputParam( |
|
|
"num_inference_steps", type_hint=int, required=True, description="Number of denoising steps" |
|
|
), |
|
|
"latent_timestep": InputParam( |
|
|
"latent_timestep", type_hint=torch.Tensor, required=True, description="Initial noise level timestep" |
|
|
), |
|
|
"image_latents": InputParam( |
|
|
"image_latents", type_hint=torch.Tensor, required=True, description="Latents representing reference image" |
|
|
), |
|
|
"mask": InputParam("mask", type_hint=torch.Tensor, required=True, description="Mask for inpainting"), |
|
|
"masked_image_latents": InputParam( |
|
|
"masked_image_latents", type_hint=torch.Tensor, description="Masked image latents for inpainting" |
|
|
), |
|
|
"add_time_ids": InputParam( |
|
|
"add_time_ids", type_hint=torch.Tensor, required=True, description="Time ids for conditioning" |
|
|
), |
|
|
"negative_add_time_ids": InputParam( |
|
|
"negative_add_time_ids", type_hint=torch.Tensor, description="Negative time ids" |
|
|
), |
|
|
"timestep_cond": InputParam("timestep_cond", type_hint=torch.Tensor, description="Timestep conditioning for LCM"), |
|
|
"noise": InputParam("noise", type_hint=torch.Tensor, description="Noise added to image latents"), |
|
|
"crops_coords": InputParam("crops_coords", type_hint=Optional[Tuple[int]], description="Crop coordinates"), |
|
|
"ip_adapter_embeds": InputParam( |
|
|
"ip_adapter_embeds", type_hint=List[torch.Tensor], description="Image embeddings for IP-Adapter" |
|
|
), |
|
|
"negative_ip_adapter_embeds": InputParam( |
|
|
"negative_ip_adapter_embeds", |
|
|
type_hint=List[torch.Tensor], |
|
|
description="Negative image embeddings for IP-Adapter", |
|
|
), |
|
|
"images": InputParam( |
|
|
"images", |
|
|
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]], |
|
|
required=True, |
|
|
description="Generated images", |
|
|
), |
|
|
} |
|
|
|
|
|
SDXL_PARAM_SCHEMA = {**SDXL_INPUTS_SCHEMA, **SDXL_INTERMEDIATE_INPUTS_SCHEMA} |
|
|
|
|
|
|
|
|
DEFAULT_PARAM_MAPS = { |
|
|
"prompt": { |
|
|
"label": "Prompt", |
|
|
"type": "string", |
|
|
"default": "a bear sitting in a chair drinking a milkshake", |
|
|
"display": "textarea", |
|
|
}, |
|
|
"negative_prompt": { |
|
|
"label": "Negative Prompt", |
|
|
"type": "string", |
|
|
"default": "deformed, ugly, wrong proportion, low res, bad anatomy, worst quality, low quality", |
|
|
"display": "textarea", |
|
|
}, |
|
|
"num_inference_steps": { |
|
|
"label": "Steps", |
|
|
"type": "int", |
|
|
"default": 25, |
|
|
"min": 1, |
|
|
"max": 1000, |
|
|
}, |
|
|
"seed": { |
|
|
"label": "Seed", |
|
|
"type": "int", |
|
|
"default": 0, |
|
|
"min": 0, |
|
|
"display": "random", |
|
|
}, |
|
|
"width": { |
|
|
"label": "Width", |
|
|
"type": "int", |
|
|
"display": "text", |
|
|
"default": 1024, |
|
|
"min": 8, |
|
|
"max": 8192, |
|
|
"step": 8, |
|
|
"group": "dimensions", |
|
|
}, |
|
|
"height": { |
|
|
"label": "Height", |
|
|
"type": "int", |
|
|
"display": "text", |
|
|
"default": 1024, |
|
|
"min": 8, |
|
|
"max": 8192, |
|
|
"step": 8, |
|
|
"group": "dimensions", |
|
|
}, |
|
|
"images": { |
|
|
"label": "Images", |
|
|
"type": "image", |
|
|
"display": "output", |
|
|
}, |
|
|
"image": { |
|
|
"label": "Image", |
|
|
"type": "image", |
|
|
"display": "input", |
|
|
}, |
|
|
} |
|
|
|
|
|
DEFAULT_TYPE_MAPS = { |
|
|
"int": { |
|
|
"type": "int", |
|
|
"default": 0, |
|
|
"min": 0, |
|
|
}, |
|
|
"float": { |
|
|
"type": "float", |
|
|
"default": 0.0, |
|
|
"min": 0.0, |
|
|
}, |
|
|
"str": { |
|
|
"type": "string", |
|
|
"default": "", |
|
|
}, |
|
|
"bool": { |
|
|
"type": "boolean", |
|
|
"default": False, |
|
|
}, |
|
|
"image": { |
|
|
"type": "image", |
|
|
}, |
|
|
} |
|
|
|
|
|
DEFAULT_MODEL_KEYS = ["unet", "vae", "text_encoder", "tokenizer", "controlnet", "transformer", "image_encoder"] |
|
|
DEFAULT_CATEGORY = "Modular Diffusers" |
|
|
DEFAULT_EXCLUDE_MODEL_KEYS = ["processor", "feature_extractor", "safety_checker"] |
|
|
DEFAULT_PARAMS_GROUPS_KEYS = { |
|
|
"text_encoders": ["text_encoder", "tokenizer"], |
|
|
"ip_adapter_embeds": ["ip_adapter_embeds"], |
|
|
"prompt_embeddings": ["prompt_embeds"], |
|
|
} |
|
|
|
|
|
|
|
|
def get_group_name(name, group_params_keys=DEFAULT_PARAMS_GROUPS_KEYS): |
|
|
""" |
|
|
Get the group name for a given parameter name, if not part of a group, return None e.g. "prompt_embeds" -> |
|
|
"text_embeds", "text_encoder" -> "text_encoders", "prompt" -> None |
|
|
""" |
|
|
if name is None: |
|
|
return None |
|
|
for group_name, group_keys in group_params_keys.items(): |
|
|
for group_key in group_keys: |
|
|
if group_key in name: |
|
|
return group_name |
|
|
return None |
|
|
|
|
|
|
|
|
class ModularNode(ConfigMixin): |
|
|
""" |
|
|
A ModularNode is a base class to build UI nodes using diffusers. Currently only supports Mellon. It is a wrapper |
|
|
around a ModularPipelineBlocks object. |
|
|
|
|
|
<Tip warning={true}> |
|
|
|
|
|
This is an experimental feature and is likely to change in the future. |
|
|
|
|
|
</Tip> |
|
|
""" |
|
|
|
|
|
config_name = "node_config.json" |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained( |
|
|
cls, |
|
|
pretrained_model_name_or_path: str, |
|
|
trust_remote_code: Optional[bool] = None, |
|
|
**kwargs, |
|
|
): |
|
|
blocks = ModularPipelineBlocks.from_pretrained( |
|
|
pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs |
|
|
) |
|
|
return cls(blocks, **kwargs) |
|
|
|
|
|
def __init__(self, blocks, category=DEFAULT_CATEGORY, label=None, **kwargs): |
|
|
self.blocks = blocks |
|
|
|
|
|
if label is None: |
|
|
label = self.blocks.__class__.__name__ |
|
|
|
|
|
self.name_mapping = {} |
|
|
|
|
|
input_params = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inputs = self.blocks.inputs + self.blocks.intermediate_inputs |
|
|
for inp in inputs: |
|
|
param = kwargs.pop(inp.name, None) |
|
|
if param: |
|
|
|
|
|
input_params[inp.name] = param |
|
|
mellon_name = param.pop("name", inp.name) |
|
|
if mellon_name != inp.name: |
|
|
self.name_mapping[inp.name] = mellon_name |
|
|
continue |
|
|
|
|
|
if inp.name not in DEFAULT_PARAM_MAPS and not inp.required and not get_group_name(inp.name): |
|
|
continue |
|
|
|
|
|
if inp.name in DEFAULT_PARAM_MAPS: |
|
|
|
|
|
param = DEFAULT_PARAM_MAPS[inp.name].copy() |
|
|
elif get_group_name(inp.name): |
|
|
param = get_group_name(inp.name) |
|
|
if inp.name not in self.name_mapping: |
|
|
self.name_mapping[inp.name] = param |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
if inp.type_hint is not None: |
|
|
type_str = str(inp.type_hint).lower() |
|
|
else: |
|
|
inp_spec = SDXL_PARAM_SCHEMA.get(inp.name, None) |
|
|
type_str = str(inp_spec.type_hint).lower() if inp_spec else "" |
|
|
for type_key, type_param in DEFAULT_TYPE_MAPS.items(): |
|
|
if type_key in type_str: |
|
|
param = type_param.copy() |
|
|
param["label"] = inp.name |
|
|
param["display"] = "input" |
|
|
break |
|
|
else: |
|
|
param = inp.name |
|
|
|
|
|
input_params[inp.name] = param |
|
|
|
|
|
component_params = {} |
|
|
for comp in self.blocks.expected_components: |
|
|
param = kwargs.pop(comp.name, None) |
|
|
if param: |
|
|
component_params[comp.name] = param |
|
|
mellon_name = param.pop("name", comp.name) |
|
|
if mellon_name != comp.name: |
|
|
self.name_mapping[comp.name] = mellon_name |
|
|
continue |
|
|
|
|
|
to_exclude = False |
|
|
for exclude_key in DEFAULT_EXCLUDE_MODEL_KEYS: |
|
|
if exclude_key in comp.name: |
|
|
to_exclude = True |
|
|
break |
|
|
if to_exclude: |
|
|
continue |
|
|
|
|
|
if get_group_name(comp.name): |
|
|
param = get_group_name(comp.name) |
|
|
if comp.name not in self.name_mapping: |
|
|
self.name_mapping[comp.name] = param |
|
|
elif comp.name in DEFAULT_MODEL_KEYS: |
|
|
param = {"label": comp.name, "type": "diffusers_auto_model", "display": "input"} |
|
|
else: |
|
|
param = comp.name |
|
|
|
|
|
component_params[comp.name] = param |
|
|
|
|
|
output_params = {} |
|
|
if isinstance(self.blocks, SequentialPipelineBlocks): |
|
|
last_block_name = list(self.blocks.sub_blocks.keys())[-1] |
|
|
outputs = self.blocks.sub_blocks[last_block_name].intermediate_outputs |
|
|
else: |
|
|
outputs = self.blocks.intermediate_outputs |
|
|
|
|
|
for out in outputs: |
|
|
param = kwargs.pop(out.name, None) |
|
|
if param: |
|
|
output_params[out.name] = param |
|
|
mellon_name = param.pop("name", out.name) |
|
|
if mellon_name != out.name: |
|
|
self.name_mapping[out.name] = mellon_name |
|
|
continue |
|
|
|
|
|
if out.name in DEFAULT_PARAM_MAPS: |
|
|
param = DEFAULT_PARAM_MAPS[out.name].copy() |
|
|
param["display"] = "output" |
|
|
else: |
|
|
group_name = get_group_name(out.name) |
|
|
if group_name: |
|
|
param = group_name |
|
|
if out.name not in self.name_mapping: |
|
|
self.name_mapping[out.name] = param |
|
|
else: |
|
|
param = out.name |
|
|
|
|
|
output_params[out.name] = param |
|
|
|
|
|
if len(kwargs) > 0: |
|
|
logger.warning(f"Unused kwargs: {kwargs}") |
|
|
|
|
|
register_dict = { |
|
|
"category": category, |
|
|
"label": label, |
|
|
"input_params": input_params, |
|
|
"component_params": component_params, |
|
|
"output_params": output_params, |
|
|
"name_mapping": self.name_mapping, |
|
|
} |
|
|
self.register_to_config(**register_dict) |
|
|
|
|
|
def setup(self, components_manager, collection=None): |
|
|
self.pipeline = self.blocks.init_pipeline(components_manager=components_manager, collection=collection) |
|
|
self._components_manager = components_manager |
|
|
|
|
|
@property |
|
|
def mellon_config(self): |
|
|
return self._convert_to_mellon_config() |
|
|
|
|
|
def _convert_to_mellon_config(self): |
|
|
node = {} |
|
|
node["label"] = self.config.label |
|
|
node["category"] = self.config.category |
|
|
|
|
|
node_param = {} |
|
|
for inp_name, inp_param in self.config.input_params.items(): |
|
|
if inp_name in self.name_mapping: |
|
|
mellon_name = self.name_mapping[inp_name] |
|
|
else: |
|
|
mellon_name = inp_name |
|
|
if isinstance(inp_param, str): |
|
|
param = { |
|
|
"label": inp_param, |
|
|
"type": inp_param, |
|
|
"display": "input", |
|
|
} |
|
|
else: |
|
|
param = inp_param |
|
|
|
|
|
if mellon_name not in node_param: |
|
|
node_param[mellon_name] = param |
|
|
else: |
|
|
logger.debug(f"Input param {mellon_name} already exists in node_param, skipping {inp_name}") |
|
|
|
|
|
for comp_name, comp_param in self.config.component_params.items(): |
|
|
if comp_name in self.name_mapping: |
|
|
mellon_name = self.name_mapping[comp_name] |
|
|
else: |
|
|
mellon_name = comp_name |
|
|
if isinstance(comp_param, str): |
|
|
param = { |
|
|
"label": comp_param, |
|
|
"type": comp_param, |
|
|
"display": "input", |
|
|
} |
|
|
else: |
|
|
param = comp_param |
|
|
|
|
|
if mellon_name not in node_param: |
|
|
node_param[mellon_name] = param |
|
|
else: |
|
|
logger.debug(f"Component param {comp_param} already exists in node_param, skipping {comp_name}") |
|
|
|
|
|
for out_name, out_param in self.config.output_params.items(): |
|
|
if out_name in self.name_mapping: |
|
|
mellon_name = self.name_mapping[out_name] |
|
|
else: |
|
|
mellon_name = out_name |
|
|
if isinstance(out_param, str): |
|
|
param = { |
|
|
"label": out_param, |
|
|
"type": out_param, |
|
|
"display": "output", |
|
|
} |
|
|
else: |
|
|
param = out_param |
|
|
|
|
|
if mellon_name not in node_param: |
|
|
node_param[mellon_name] = param |
|
|
else: |
|
|
logger.debug(f"Output param {out_param} already exists in node_param, skipping {out_name}") |
|
|
node["params"] = node_param |
|
|
return node |
|
|
|
|
|
def save_mellon_config(self, file_path): |
|
|
""" |
|
|
Save the Mellon configuration to a JSON file. |
|
|
|
|
|
Args: |
|
|
file_path (str or Path): Path where the JSON file will be saved |
|
|
|
|
|
Returns: |
|
|
Path: Path to the saved config file |
|
|
""" |
|
|
file_path = Path(file_path) |
|
|
|
|
|
|
|
|
os.makedirs(file_path.parent, exist_ok=True) |
|
|
|
|
|
|
|
|
config = {"module": self.mellon_config, "name_mapping": self.name_mapping} |
|
|
|
|
|
|
|
|
with open(file_path, "w", encoding="utf-8") as f: |
|
|
json.dump(config, f, indent=2) |
|
|
|
|
|
logger.info(f"Mellon config and name mapping saved to {file_path}") |
|
|
|
|
|
return file_path |
|
|
|
|
|
@classmethod |
|
|
def load_mellon_config(cls, file_path): |
|
|
""" |
|
|
Load a Mellon configuration from a JSON file. |
|
|
|
|
|
Args: |
|
|
file_path (str or Path): Path to the JSON file containing Mellon config |
|
|
|
|
|
Returns: |
|
|
dict: The loaded combined configuration containing 'module' and 'name_mapping' |
|
|
""" |
|
|
file_path = Path(file_path) |
|
|
|
|
|
if not file_path.exists(): |
|
|
raise FileNotFoundError(f"Config file not found: {file_path}") |
|
|
|
|
|
with open(file_path, "r", encoding="utf-8") as f: |
|
|
config = json.load(f) |
|
|
|
|
|
logger.info(f"Mellon config loaded from {file_path}") |
|
|
|
|
|
return config |
|
|
|
|
|
def process_inputs(self, **kwargs): |
|
|
params_components = {} |
|
|
for comp_name, comp_param in self.config.component_params.items(): |
|
|
logger.debug(f"component: {comp_name}") |
|
|
mellon_comp_name = self.name_mapping.get(comp_name, comp_name) |
|
|
if mellon_comp_name in kwargs: |
|
|
if isinstance(kwargs[mellon_comp_name], dict) and comp_name in kwargs[mellon_comp_name]: |
|
|
comp = kwargs[mellon_comp_name].pop(comp_name) |
|
|
else: |
|
|
comp = kwargs.pop(mellon_comp_name) |
|
|
if comp: |
|
|
params_components[comp_name] = self._components_manager.get_one(comp["model_id"]) |
|
|
|
|
|
params_run = {} |
|
|
for inp_name, inp_param in self.config.input_params.items(): |
|
|
logger.debug(f"input: {inp_name}") |
|
|
mellon_inp_name = self.name_mapping.get(inp_name, inp_name) |
|
|
if mellon_inp_name in kwargs: |
|
|
if isinstance(kwargs[mellon_inp_name], dict) and inp_name in kwargs[mellon_inp_name]: |
|
|
inp = kwargs[mellon_inp_name].pop(inp_name) |
|
|
else: |
|
|
inp = kwargs.pop(mellon_inp_name) |
|
|
if inp is not None: |
|
|
params_run[inp_name] = inp |
|
|
|
|
|
return_output_names = list(self.config.output_params.keys()) |
|
|
|
|
|
return params_components, params_run, return_output_names |
|
|
|
|
|
def execute(self, **kwargs): |
|
|
params_components, params_run, return_output_names = self.process_inputs(**kwargs) |
|
|
|
|
|
self.pipeline.update_components(**params_components) |
|
|
output = self.pipeline(**params_run, output=return_output_names) |
|
|
return output |
|
|
|