KublaiKhan1 commited on
Commit
11d6975
·
verified ·
1 Parent(s): 1f7ec51

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. normal_mean_scale_std/stable_vae.py +72 -0
normal_mean_scale_std/stable_vae.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from functools import partial, cached_property
3
+
4
+ import jax
5
+ from diffusers import FlaxAutoencoderKL
6
+ from einops import rearrange
7
+ from flax import struct
8
+
9
+ from jaxtyping import Array, PyTree, Key, Float, Shaped, Int, UInt8, jaxtyped
10
+ from typeguard import typechecked
11
+ from functools import partial
12
+ typecheck = partial(jaxtyped, typechecker=typechecked)
13
+
14
+ import jax.numpy as jnp
15
+
16
+ @struct.dataclass
17
+ class StableVAE:
18
+ params: PyTree[Float[Array, "..."]]
19
+ module: FlaxAutoencoderKL = struct.field(pytree_node=False)
20
+
21
+ @classmethod
22
+ def create(cls) -> "VAE":
23
+ # module, params = FlaxAutoencoderKL.from_pretrained(
24
+ # "stabilityai/stable-diffusion-xl-base-1.0", subfolder="vae"
25
+ # )
26
+ module, params = FlaxAutoencoderKL.from_pretrained(
27
+ "pcuenq/sd-vae-ft-mse-flax"
28
+ )
29
+ params = jax.device_get(params)
30
+ return cls(
31
+ params=params,
32
+ module=module,
33
+ )
34
+
35
+ @partial(jax.jit, static_argnames="scale")
36
+ def encode(
37
+ self, key: Key[Array, ""], images: Float[Array, "b h w 3"], scale: bool = True
38
+ ) -> Float[Array, "b lh lw 4"]:
39
+ images = rearrange(images, "b h w c -> b c h w")
40
+ latents = self.module.apply(
41
+ {"params": self.params}, images, method=self.module.encode
42
+ ).latent_dist.sample(key)
43
+ if scale:
44
+ # latents *= self.module.config.scaling_factor
45
+ mean = jnp.array([1.1743683, -0.4075004, 0.4488433, 0.6760574])
46
+ std = jnp.array([4.9045634, 5.4250283, 3.9848266, 4.010667])
47
+ #latents = latents - mean
48
+ latents = latents * 1.0/std#Testing mean shift + global scale
49
+ return latents
50
+
51
+ @partial(jax.jit, static_argnames="scale")
52
+ def decode(
53
+ self, latents: Float[Array, "b lh lw 4"], scale: bool = True
54
+ ) -> Float[Array, "b h w 3"]:
55
+ if scale:
56
+ #latents /= self.module.config.scaling_factor
57
+ mean = jnp.array([1.1743683, -0.4075004, 0.4488433, 0.6760574])
58
+ std = jnp.array([4.9045634, 5.4250283, 3.9848266, 4.010667])
59
+ latents = latents * std# + mean
60
+
61
+ #Now we go back to per channel mean/std
62
+
63
+ images = self.module.apply(
64
+ {"params": self.params}, latents, method=self.module.decode
65
+ ).sample
66
+ # convert to channels-last
67
+ images = rearrange(images, "b c h w -> b h w c")
68
+ return images
69
+
70
+ @cached_property
71
+ def downscale_factor(self) -> int:
72
+ return 2 ** (len(self.module.block_out_channels) - 1)