File size: 8,914 Bytes
b8c861f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
# Copyright 2025 Baidu ERNIE-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch

from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import ErnieImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ..modular_pipeline import (
    BlockState,
    LoopSequentialPipelineBlocks,
    ModularPipelineBlocks,
    PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import ErnieImageModularPipeline


logger = logging.get_logger(__name__)  # pylint: disable=invalid-name


class ErnieImageLoopBeforeDenoiser(ModularPipelineBlocks):
    model_name = "ernie-image"

    @property
    def description(self) -> str:
        return (
            "Step within the denoising loop that prepares the latent model input and timestep tensor. "
            "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
            "object (e.g. `ErnieImageDenoiseLoopWrapper`)."
        )

    @property
    def expected_components(self) -> list[ComponentSpec]:
        return [ComponentSpec("transformer", ErnieImageTransformer2DModel)]

    @property
    def inputs(self) -> list[InputParam]:
        return [
            InputParam(
                "latents",
                required=True,
                type_hint=torch.Tensor,
                description="The latents to denoise.",
            ),
        ]

    @torch.no_grad()
    def __call__(self, components: ErnieImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
        latents = block_state.latents
        block_state.latent_model_input = latents.to(components.transformer.dtype)
        block_state.timestep = t.expand(latents.shape[0]).to(components.transformer.dtype)
        return components, block_state


class ErnieImageLoopDenoiser(ModularPipelineBlocks):
    model_name = "ernie-image"

    @property
    def expected_components(self) -> list[ComponentSpec]:
        return [
            ComponentSpec("transformer", ErnieImageTransformer2DModel),
            ComponentSpec(
                "guider",
                ClassifierFreeGuidance,
                config=FrozenDict({"guidance_scale": 4.0}),
                default_creation_method="from_config",
            ),
        ]

    @property
    def description(self) -> str:
        return (
            "Step within the denoising loop that runs the ErnieImage transformer with classifier-free guidance via "
            "the configured guider."
        )

    @property
    def inputs(self) -> list[InputParam]:
        return [
            InputParam(
                "text_bth",
                required=True,
                type_hint=torch.Tensor,
                description="Padded text hidden states fed into the transformer.",
            ),
            InputParam(
                "text_lens",
                required=True,
                type_hint=torch.Tensor,
                description="Per-prompt text lengths used by the transformer attention mask.",
            ),
            InputParam(
                "negative_text_bth",
                type_hint=torch.Tensor,
                description="Padded negative text hidden states for classifier-free guidance.",
            ),
            InputParam(
                "negative_text_lens",
                type_hint=torch.Tensor,
                description="Per-prompt negative text lengths for classifier-free guidance.",
            ),
            InputParam(
                "num_inference_steps",
                required=True,
                type_hint=int,
                description="Total number of denoising steps. Used by the guider for step-aware scheduling.",
            ),
        ]

    @torch.no_grad()
    def __call__(self, components: ErnieImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
        guider_inputs = {
            "text_bth": (block_state.text_bth, block_state.negative_text_bth),
            "text_lens": (block_state.text_lens, block_state.negative_text_lens),
        }

        components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
        guider_state = components.guider.prepare_inputs(guider_inputs)

        for guider_state_batch in guider_state:
            components.guider.prepare_models(components.transformer)
            cond_kwargs = {name: getattr(guider_state_batch, name) for name in guider_inputs.keys()}
            noise_pred = components.transformer(
                hidden_states=block_state.latent_model_input,
                timestep=block_state.timestep,
                return_dict=False,
                **cond_kwargs,
            )[0]
            guider_state_batch.noise_pred = noise_pred
            components.guider.cleanup_models(components.transformer)

        block_state.noise_pred = components.guider(guider_state)[0]
        return components, block_state


class ErnieImageLoopAfterDenoiser(ModularPipelineBlocks):
    model_name = "ernie-image"

    @property
    def expected_components(self) -> list[ComponentSpec]:
        return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)]

    @property
    def description(self) -> str:
        return "Step within the denoising loop that updates the latents using the scheduler step."

    @torch.no_grad()
    def __call__(self, components: ErnieImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
        latents_dtype = block_state.latents.dtype
        block_state.latents = components.scheduler.step(
            block_state.noise_pred, t, block_state.latents, return_dict=False
        )[0]
        if block_state.latents.dtype != latents_dtype and torch.backends.mps.is_available():
            block_state.latents = block_state.latents.to(latents_dtype)
        return components, block_state


class ErnieImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
    model_name = "ernie-image"

    @property
    def description(self) -> str:
        return (
            "Pipeline block that iteratively denoises the latents over `timesteps`. "
            "The specific steps within each iteration can be customized with `sub_blocks` attribute."
        )

    @property
    def loop_expected_components(self) -> list[ComponentSpec]:
        return [
            ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
            ComponentSpec("transformer", ErnieImageTransformer2DModel),
        ]

    @property
    def loop_inputs(self) -> list[InputParam]:
        return [
            InputParam(
                "timesteps",
                required=True,
                type_hint=torch.Tensor,
                description="The timesteps to use for inference.",
            ),
            InputParam(
                "num_inference_steps",
                required=True,
                type_hint=int,
                description="The number of denoising steps.",
            ),
        ]

    @property
    def intermediate_outputs(self) -> list[OutputParam]:
        return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents.")]

    @torch.no_grad()
    def __call__(self, components: ErnieImageModularPipeline, state: PipelineState) -> PipelineState:
        block_state = self.get_block_state(state)
        with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
            for i, t in enumerate(block_state.timesteps):
                components, block_state = self.loop_step(components, block_state, i=i, t=t)
                progress_bar.update()
        self.set_block_state(state, block_state)
        return components, state


class ErnieImageDenoiseStep(ErnieImageDenoiseLoopWrapper):
    block_classes = [
        ErnieImageLoopBeforeDenoiser,
        ErnieImageLoopDenoiser,
        ErnieImageLoopAfterDenoiser,
    ]
    block_names = ["before_denoiser", "denoiser", "after_denoiser"]

    @property
    def description(self) -> str:
        return (
            "Denoise step that iteratively denoises the latents. At each iteration it runs:\n"
            " - `ErnieImageLoopBeforeDenoiser`\n"
            " - `ErnieImageLoopDenoiser`\n"
            " - `ErnieImageLoopAfterDenoiser`"
        )