File size: 1,073 Bytes
65266c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from dataclasses import dataclass

from ..utils import BaseOutput


@dataclass
class AutoencoderKLOutput(BaseOutput):
    """

    Output of AutoencoderKL encoding method.



    Args:

        latent_dist (`DiagonalGaussianDistribution`):

            Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.

            `DiagonalGaussianDistribution` allows for sampling latents from the distribution.

    """

    latent_dist: "DiagonalGaussianDistribution"  # noqa: F821


@dataclass
class Transformer2DModelOutput(BaseOutput):
    """

    The output of [`Transformer2DModel`].



    Args:

        sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):

            The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability

            distributions for the unnoised latent pixels.

    """

    sample: "torch.Tensor"  # noqa: F821