File size: 5,425 Bytes
8a37e0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Type, Union

import torch
from diffusers import UNet2DConditionModel
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput

if TYPE_CHECKING:
    from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningMode, TextConditioningData


@dataclass
class UNetKwargs:
    sample: torch.Tensor
    timestep: Union[torch.Tensor, float, int]
    encoder_hidden_states: torch.Tensor

    class_labels: Optional[torch.Tensor] = None
    timestep_cond: Optional[torch.Tensor] = None
    attention_mask: Optional[torch.Tensor] = None
    cross_attention_kwargs: Optional[Dict[str, Any]] = None
    added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None
    down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None
    mid_block_additional_residual: Optional[torch.Tensor] = None
    down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None
    encoder_attention_mask: Optional[torch.Tensor] = None
    # return_dict: bool = True


@dataclass
class DenoiseInputs:
    """Initial variables passed to denoise. Supposed to be unchanged."""

    # The latent-space image to denoise.
    # Shape: [batch, channels, latent_height, latent_width]
    # - If we are inpainting, this is the initial latent image before noise has been added.
    # - If we are generating a new image, this should be initialized to zeros.
    # - In some cases, this may be a partially-noised latent image (e.g. when running the SDXL refiner).
    orig_latents: torch.Tensor

    # kwargs forwarded to the scheduler.step() method.
    scheduler_step_kwargs: dict[str, Any]

    # Text conditionging data.
    conditioning_data: TextConditioningData

    # Noise used for two purposes:
    # 1. Used by the scheduler to noise the initial `latents` before denoising.
    # 2. Used to noise the `masked_latents` when inpainting.
    # `noise` should be None if the `latents` tensor has already been noised.
    # Shape: [1 or batch, channels, latent_height, latent_width]
    noise: Optional[torch.Tensor]

    # The seed used to generate the noise for the denoising process.
    # HACK(ryand): seed is only used in a particular case when `noise` is None, but we need to re-generate the
    # same noise used earlier in the pipeline. This should really be handled in a clearer way.
    seed: int

    # The timestep schedule for the denoising process.
    timesteps: torch.Tensor

    # The first timestep in the schedule. This is used to determine the initial noise level, so
    # should be populated if you want noise applied *even* if timesteps is empty.
    init_timestep: torch.Tensor

    # Class of attention processor that is used.
    attention_processor_cls: Type[Any]


@dataclass
class DenoiseContext:
    """Context with all variables in denoise"""

    # Initial variables passed to denoise. Supposed to be unchanged.
    inputs: DenoiseInputs

    # Scheduler which used to apply noise predictions.
    scheduler: SchedulerMixin

    # UNet model.
    unet: Optional[UNet2DConditionModel] = None

    # Current state of latent-space image in denoising process.
    # None until `PRE_DENOISE_LOOP` callback.
    # Shape: [batch, channels, latent_height, latent_width]
    latents: Optional[torch.Tensor] = None

    # Current denoising step index.
    # None until `PRE_STEP` callback.
    step_index: Optional[int] = None

    # Current denoising step timestep.
    # None until `PRE_STEP` callback.
    timestep: Optional[torch.Tensor] = None

    # Arguments which will be passed to UNet model.
    # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
    unet_kwargs: Optional[UNetKwargs] = None

    # SchedulerOutput class returned from step function(normally, generated by scheduler).
    # Supposed to be used only in `POST_STEP` callback, otherwise can be None.
    step_output: Optional[SchedulerOutput] = None

    # Scaled version of `latents`, which will be passed to unet_kwargs initialization.
    # Available in events inside step(between `PRE_STEP` and `POST_STEP`).
    # Shape: [batch, channels, latent_height, latent_width]
    latent_model_input: Optional[torch.Tensor] = None

    # [TMP] Defines on which conditionings current unet call will be runned.
    # Available in `PRE_UNET`/`POST_UNET` callbacks, otherwise will be None.
    conditioning_mode: Optional[ConditioningMode] = None

    # [TMP] Noise predictions from negative conditioning.
    # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
    # Shape: [batch, channels, latent_height, latent_width]
    negative_noise_pred: Optional[torch.Tensor] = None

    # [TMP] Noise predictions from positive conditioning.
    # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
    # Shape: [batch, channels, latent_height, latent_width]
    positive_noise_pred: Optional[torch.Tensor] = None

    # Combined noise prediction from passed conditionings.
    # Available in `POST_COMBINE_NOISE_PREDS` callback, otherwise will be None.
    # Shape: [batch, channels, latent_height, latent_width]
    noise_pred: Optional[torch.Tensor] = None

    # Dictionary for extensions to pass extra info about denoise process to other extensions.
    extra: dict = field(default_factory=dict)