ImageGen-FLUX.2 / chain_injectors /reference_latent_injector.py
RioShiina's picture
Upload folder using huggingface_hub
c009d4f verified
def inject(assembler, chain_definition, chain_items):
if not chain_items:
return
ksampler_name = chain_definition.get('ksampler_node', 'ksampler')
flux_guidance_name = chain_definition.get('flux_guidance_node')
vae_node_name = chain_definition.get('vae_node', 'vae_loader')
if ksampler_name not in assembler.node_map:
print(f"Warning: [ReferenceLatent] KSampler node '{ksampler_name}' not found. Skipping.")
return
if vae_node_name not in assembler.node_map:
print(f"Warning: [ReferenceLatent] VAE loader node '{vae_node_name}' not found. Skipping.")
return
ksampler_id = assembler.node_map[ksampler_name]
vae_node_id = assembler.node_map[vae_node_name]
pos_target_node_id = None
pos_target_input_name = None
if flux_guidance_name and flux_guidance_name in assembler.node_map:
flux_guidance_id = assembler.node_map[flux_guidance_name]
if 'conditioning' in assembler.workflow[flux_guidance_id]['inputs']:
pos_target_node_id = flux_guidance_id
pos_target_input_name = 'conditioning'
print(f"ReferenceLatent injector targeting FluxGuidance node '{flux_guidance_name}'.")
if not pos_target_node_id:
if 'positive' in assembler.workflow[ksampler_id]['inputs']:
pos_target_node_id = ksampler_id
pos_target_input_name = 'positive'
print(f"ReferenceLatent injector targeting KSampler node '{ksampler_name}'.")
else:
print(f"Warning: [ReferenceLatent] Could not find a valid positive injection point. Skipping.")
return
current_pos_conditioning = assembler.workflow[pos_target_node_id]['inputs'][pos_target_input_name]
for i, img_filename in enumerate(chain_items):
if not img_filename or not isinstance(img_filename, str):
continue
load_id = assembler._get_unique_id()
load_node = assembler._get_node_template("LoadImage")
load_node['inputs']['image'] = img_filename
assembler.workflow[load_id] = load_node
vae_encode_id = assembler._get_unique_id()
vae_encode_node = assembler._get_node_template("VAEEncode")
vae_encode_node['inputs']['pixels'] = [load_id, 0]
vae_encode_node['inputs']['vae'] = [vae_node_id, 0]
assembler.workflow[vae_encode_id] = vae_encode_node
latent_conn = [vae_encode_id, 0]
ref_latent_id = assembler._get_unique_id()
ref_latent_node = assembler._get_node_template("ReferenceLatent")
ref_latent_node['inputs']['conditioning'] = current_pos_conditioning
ref_latent_node['inputs']['latent'] = latent_conn
assembler.workflow[ref_latent_id] = ref_latent_node
current_pos_conditioning = [ref_latent_id, 0]
assembler.workflow[pos_target_node_id]['inputs'][pos_target_input_name] = current_pos_conditioning
print(f"ReferenceLatent injector applied. Re-routed inputs through {len(chain_items)} reference image(s).")