benjamin-paine commited on
Commit
26cc2dd
·
verified ·
1 Parent(s): 1ae5648

Create flux2_tiny_autoencoder.py

Browse files
Files changed (1) hide show
  1. flux2_tiny_autoencoder.py +104 -0
flux2_tiny_autoencoder.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Apache License
2
+ #
3
+ # Copyright 2025 fal - features and labels, inc.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from diffusers.models import AutoencoderTiny
21
+ from diffusers.models.modeling_utils import ModelMixin
22
+ from diffusers.models.autoencoders.vae import EncoderOutput, DecoderOutput
23
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
24
+
25
+ class Flux2TinyAutoEncoder(ModelMixin, ConfigMixin):
26
+ @register_to_config
27
+ def __init__(
28
+ self,
29
+ in_channels: int = 3,
30
+ out_channels: int = 3,
31
+ latent_channels: int = 128,
32
+ encoder_block_out_channels: list[int] = [64, 64, 64, 64],
33
+ decoder_block_out_channels: list[int] = [64, 64, 64, 64],
34
+ act_fn: str = "silu",
35
+ upsampling_scaling_factor: int = 2,
36
+ num_encoder_blocks: list[int] = [1, 3, 3, 3],
37
+ num_decoder_blocks: list[int] = [3, 3, 3, 1],
38
+ latent_magnitude: float = 3.0,
39
+ latent_shift: float = 0.5,
40
+ force_upcast: bool = False,
41
+ scaling_factor: float = 0.13025,
42
+ ) -> None:
43
+ super().__init__()
44
+ self.tiny_vae = AutoencoderTiny(
45
+ in_channels=in_channels,
46
+ out_channels=out_channels,
47
+ encoder_block_out_channels=encoder_block_out_channels,
48
+ decoder_block_out_channels=decoder_block_out_channels,
49
+ act_fn=act_fn,
50
+ latent_channels=latent_channels // 4,
51
+ upsampling_scaling_factor=upsampling_scaling_factor,
52
+ num_encoder_blocks=num_encoder_blocks,
53
+ num_decoder_blocks=num_decoder_blocks,
54
+ latent_magnitude=latent_magnitude,
55
+ latent_shift=latent_shift,
56
+ force_upcast=force_upcast,
57
+ scaling_factor=scaling_factor,
58
+ )
59
+ self.extra_encoder = nn.Conv2d(
60
+ latent_channels // 4, latent_channels,
61
+ kernel_size=4, stride=2, padding=1
62
+ )
63
+ self.extra_decoder = nn.ConvTranspose2d(
64
+ latent_channels, latent_channels // 4,
65
+ kernel_size=4, stride=2, padding=1
66
+ )
67
+ self.residual_encoder = nn.Sequential(
68
+ nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1),
69
+ nn.GroupNorm(8, latent_channels),
70
+ nn.SiLU(),
71
+ nn.Conv2d(latent_channels, latent_channels, kernel_size=3, padding=1),
72
+ )
73
+ self.residual_decoder = nn.Sequential(
74
+ nn.Conv2d(latent_channels // 4, latent_channels // 4, kernel_size=3, padding=1),
75
+ nn.GroupNorm(8, latent_channels // 4),
76
+ nn.SiLU(),
77
+ nn.Conv2d(latent_channels // 4, latent_channels // 4, kernel_size=3, padding=1),
78
+ )
79
+
80
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> EncoderOutput:
81
+ encoded = self.tiny_vae.encode(x, return_dict=False)[0]
82
+ compressed = self.extra_encoder(encoded)
83
+ enhanced = self.residual_encoder(compressed) + compressed
84
+
85
+ if return_dict:
86
+ return EncoderOutput(latent=enhanced)
87
+ return enhanced
88
+
89
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput:
90
+ decompressed = self.extra_decoder(z)
91
+ enhanced = self.residual_decoder(decompressed) + decompressed
92
+ decoded = self.tiny_vae.decode(enhanced, return_dict=False)[0]
93
+
94
+ if return_dict:
95
+ return DecoderOutput(sample=decoded)
96
+ return decoded
97
+
98
+ def forward(self, sample: torch.Tensor, return_dict: bool = True) -> DecoderOutput:
99
+ encoded = self.encode(sample, return_dict=False)[0]
100
+ decoded = self.decode(encoded, return_dict=False)[0]
101
+
102
+ if return_dict:
103
+ return DecoderOutput(sample=decoded)
104
+ return decoded