AbstractPhil commited on
Commit
dfebd94
·
verified ·
1 Parent(s): 4fd71cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -26
app.py CHANGED
@@ -318,32 +318,9 @@ print("✓ Text encoders loaded")
318
 
319
  # VAE (local weights - Apache 2.0 from Flux)
320
  print("Loading VAE...")
321
- from diffusers.models.autoencoders.autoencoder_kl import AutoencoderKL
322
-
323
- # VAE config from Flux
324
- vae_config = {
325
- "in_channels": 3,
326
- "out_channels": 3,
327
- "latent_channels": 16,
328
- "block_out_channels": [128, 256, 512, 512],
329
- "down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D", "DownEncoderBlock2D"],
330
- "up_block_types": ["UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D", "UpDecoderBlock2D"],
331
- "layers_per_block": 2,
332
- "norm_num_groups": 32,
333
- "act_fn": "silu",
334
- "sample_size": 1024,
335
- "scaling_factor": 0.3611,
336
- "shift_factor": 0.1159,
337
- "use_quant_conv": False,
338
- "use_post_quant_conv": False,
339
- "mid_block_add_attention": True,
340
- }
341
-
342
- vae = AutoencoderKL(**vae_config)
343
- vae_weights = load_file("ae.safetensors")
344
- vae.load_state_dict(vae_weights)
345
- vae.to(DTYPE).eval()
346
- VAE_SCALE = vae_config["scaling_factor"]
347
  print("✓ VAE loaded")
348
 
349
 
 
318
 
319
  # VAE (local weights - Apache 2.0 from Flux)
320
  print("Loading VAE...")
321
+ vae = AutoencoderKL.from_single_file("ae.safetensors", torch_dtype=DTYPE)
322
+ vae.eval()
323
+ VAE_SCALE = 0.3611 # Flux VAE scaling factor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  print("✓ VAE loaded")
325
 
326