File size: 8,381 Bytes
c009d4f
 
 
 
b7d4bc8
 
 
c009d4f
 
b7d4bc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c009d4f
b7d4bc8
c009d4f
 
b7d4bc8
c009d4f
 
 
 
 
 
 
 
 
 
 
 
b7d4bc8
c009d4f
 
 
 
 
b7d4bc8
c009d4f
b7d4bc8
c009d4f
 
 
 
b7d4bc8
 
 
 
 
 
 
 
 
c009d4f
b7d4bc8
c009d4f
 
 
b7d4bc8
c009d4f
 
b7d4bc8
 
 
 
 
 
 
 
c009d4f
 
b7d4bc8
c009d4f
b7d4bc8
c009d4f
 
 
 
b7d4bc8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c009d4f
 
b7d4bc8
 
c009d4f
b7d4bc8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def inject(assembler, chain_definition, chain_items):
    if not chain_items:
        return

    guider_node_name = chain_definition.get('guider_node')
    guider_target_inputs = chain_definition.get('guider_target_inputs', [])
    start_connections_map = chain_definition.get('start_connections', {})
    vae_node_name = chain_definition.get('vae_node', 'vae_loader')

    if guider_node_name and guider_node_name in assembler.node_map and guider_target_inputs:
        guider_id = assembler.node_map[guider_node_name]
        if vae_node_name not in assembler.node_map:
            print(f"Warning: VAE node '{vae_node_name}' not found for Guider chain. Skipping.")
            return
        vae_node_id = assembler.node_map[vae_node_name]
        
        print(f"ReferenceLatent injector targeting DualCFGGuider node '{guider_node_name}'.")

        current_connections = {}
        for target_input in guider_target_inputs:
            conn_str = start_connections_map.get(target_input)
            if not conn_str:
                print(f"Warning: No start connection defined for '{target_input}' in Guider chain. Skipping this input.")
                continue
            try:
                node_name, idx_str = conn_str.split(':')
                node_id = assembler.node_map[node_name]
                current_connections[target_input] = [node_id, int(idx_str)]
            except (ValueError, KeyError):
                print(f"Warning: Invalid start connection '{conn_str}' for '{target_input}'. Skipping.")

        encoded_latents = []
        for i, img_filename in enumerate(chain_items):
            load_id = assembler._get_unique_id()
            load_node = assembler._get_node_template("LoadImage")
            load_node['inputs']['image'] = img_filename
            assembler.workflow[load_id] = load_node
            
            scale_id = assembler._get_unique_id()
            scale_node = assembler._get_node_template("ImageScaleToTotalPixels")
            scale_node['inputs']['megapixels'] = 1.0
            scale_node['inputs']['upscale_method'] = "lanczos"
            scale_node['inputs']['image'] = [load_id, 0]
            assembler.workflow[scale_id] = scale_node
            
            vae_encode_id = assembler._get_unique_id()
            vae_encode_node = assembler._get_node_template("VAEEncode")
            vae_encode_node['inputs']['pixels'] = [scale_id, 0]
            vae_encode_node['inputs']['vae'] = [vae_node_id, 0]
            assembler.workflow[vae_encode_id] = vae_encode_node
            encoded_latents.append([vae_encode_id, 0])
        
        for target_input_name, start_connection in current_connections.items():
            current_chain_head = start_connection
            for i, latent_conn in enumerate(encoded_latents):
                ref_latent_id = assembler._get_unique_id()
                ref_latent_node = assembler._get_node_template("ReferenceLatent")
                ref_latent_node['inputs']['conditioning'] = current_chain_head
                ref_latent_node['inputs']['latent'] = latent_conn
                ref_latent_node['_meta']['title'] = f"{target_input_name} RefLatent {i+1}"
                assembler.workflow[ref_latent_id] = ref_latent_node
                current_chain_head = [ref_latent_id, 0]
            
            assembler.workflow[guider_id]['inputs'][target_input_name] = current_chain_head
            print(f"  - Input '{target_input_name}' of node '{guider_node_name}' re-routed through {len(chain_items)} reference images.")

        return

    flux_guidance_name = chain_definition.get('flux_guidance_node')
    ksampler_name = chain_definition.get('ksampler_node', 'ksampler')

    if ksampler_name not in assembler.node_map:
        print(f"Warning: KSampler node '{ksampler_name}' not found for ReferenceLatent chain. Skipping.")
        return
    if vae_node_name not in assembler.node_map:
        print(f"Warning: VAE loader node '{vae_node_name}' not found for ReferenceLatent chain. Skipping.")
        return

    ksampler_id = assembler.node_map[ksampler_name]
    vae_node_id = assembler.node_map[vae_node_name]

    pos_target_node_id = None
    pos_target_input_name = None
    if flux_guidance_name and flux_guidance_name in assembler.node_map:
        flux_guidance_id = assembler.node_map[flux_guidance_name]
        if 'conditioning' in assembler.workflow[flux_guidance_id]['inputs']:
            pos_target_node_id = flux_guidance_id
            pos_target_input_name = 'conditioning'
            print(f"ReferenceLatent injector targeting FluxGuidance node '{flux_guidance_name}' for positive chain.")
    
    if not pos_target_node_id:
        if 'positive' in assembler.workflow[ksampler_id]['inputs']:
            pos_target_node_id = ksampler_id
            pos_target_input_name = 'positive'
            print(f"ReferenceLatent injector targeting KSampler node '{ksampler_name}' for positive chain.")
        else:
            print(f"Warning: Could not find a valid positive injection point for ReferenceLatent chain. Skipping.")
            return

    current_pos_conditioning = assembler.workflow[pos_target_node_id]['inputs'][pos_target_input_name]

    neg_target_node_id = ksampler_id
    neg_target_input_name = 'negative'
    if 'negative' not in assembler.workflow[neg_target_node_id]['inputs']:
        print(f"Warning: KSampler node '{ksampler_name}' has no 'negative' input. Skipping negative ReferenceLatent chain.")
        neg_target_node_id = None
    
    current_neg_conditioning = None
    if neg_target_node_id:
        current_neg_conditioning = assembler.workflow[neg_target_node_id]['inputs'][neg_target_input_name]

    for i, img_filename in enumerate(chain_items):
        load_id = assembler._get_unique_id()
        load_node = assembler._get_node_template("LoadImage")
        load_node['inputs']['image'] = img_filename
        load_node['_meta']['title'] = f"Load Reference Image {i+1}"
        assembler.workflow[load_id] = load_node
        
        scale_id = assembler._get_unique_id()
        scale_node = assembler._get_node_template("ImageScaleToTotalPixels")
        scale_node['inputs']['megapixels'] = 1.0
        scale_node['inputs']['upscale_method'] = "lanczos"
        scale_node['inputs']['image'] = [load_id, 0]
        scale_node['_meta']['title'] = f"Scale Reference {i+1}"
        assembler.workflow[scale_id] = scale_node
        
        vae_encode_id = assembler._get_unique_id()
        vae_encode_node = assembler._get_node_template("VAEEncode")
        vae_encode_node['inputs']['pixels'] = [scale_id, 0]
        vae_encode_node['inputs']['vae'] = [vae_node_id, 0]
        vae_encode_node['_meta']['title'] = f"VAE Encode Reference {i+1}"
        assembler.workflow[vae_encode_id] = vae_encode_node
        
        latent_conn = [vae_encode_id, 0]

        pos_ref_latent_id = assembler._get_unique_id()
        pos_ref_latent_node = assembler._get_node_template("ReferenceLatent")
        pos_ref_latent_node['inputs']['conditioning'] = current_pos_conditioning
        pos_ref_latent_node['inputs']['latent'] = latent_conn
        pos_ref_latent_node['_meta']['title'] = f"Positive ReferenceLatent {i+1}"
        assembler.workflow[pos_ref_latent_id] = pos_ref_latent_node
        current_pos_conditioning = [pos_ref_latent_id, 0]

        if neg_target_node_id:
            neg_ref_latent_id = assembler._get_unique_id()
            neg_ref_latent_node = assembler._get_node_template("ReferenceLatent")
            neg_ref_latent_node['inputs']['conditioning'] = current_neg_conditioning
            neg_ref_latent_node['inputs']['latent'] = latent_conn
            neg_ref_latent_node['_meta']['title'] = f"Negative ReferenceLatent {i+1}"
            assembler.workflow[neg_ref_latent_id] = neg_ref_latent_node
            current_neg_conditioning = [neg_ref_latent_id, 0]

    assembler.workflow[pos_target_node_id]['inputs'][pos_target_input_name] = current_pos_conditioning
    if neg_target_node_id:
        assembler.workflow[neg_target_node_id]['inputs'][neg_target_input_name] = current_neg_conditioning
    
    print(f"ReferenceLatent injector applied. Re-routed inputs through {len(chain_items)} reference images.")