Spaces:
Running on Zero
Running on Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -41,6 +41,9 @@ def load_dit_model(dit_size):
|
|
| 41 |
# Load checkpoint
|
| 42 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 43 |
model.load_state_dict(checkpoint["model_state_dict"])
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
return model
|
| 46 |
|
|
@@ -98,7 +101,7 @@ class DiffusionSampler:
|
|
| 98 |
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
|
| 99 |
|
| 100 |
# Start with random latents
|
| 101 |
-
latents = torch.randn((num_samples, 4, 32, 32), device=self.device)
|
| 102 |
|
| 103 |
# Use classifier-free guidance for better quality
|
| 104 |
cfg_scale = 2.5
|
|
@@ -135,7 +138,9 @@ class DiffusionSampler:
|
|
| 135 |
|
| 136 |
# Decode latents to images
|
| 137 |
self.load_vae()
|
| 138 |
-
|
|
|
|
|
|
|
| 139 |
latents = latents.to(self.device)
|
| 140 |
|
| 141 |
progress(0.95, desc="Decoding images...")
|
|
@@ -167,9 +172,9 @@ def generate_random_seed():
|
|
| 167 |
return random.randint(0, 2**32 - 1)
|
| 168 |
|
| 169 |
MODEL_SAMPLE_LIMITS = {
|
| 170 |
-
"S": {"min":1, "max":
|
| 171 |
-
"B": {"min":1, "max":
|
| 172 |
-
"L": {"min":1, "max":
|
| 173 |
}
|
| 174 |
|
| 175 |
def update_sample_slider(dit_size):
|
|
|
|
| 41 |
# Load checkpoint
|
| 42 |
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
| 43 |
model.load_state_dict(checkpoint["model_state_dict"])
|
| 44 |
+
|
| 45 |
+
# Use half precision to speed up sampling
|
| 46 |
+
model = model.half()
|
| 47 |
|
| 48 |
return model
|
| 49 |
|
|
|
|
| 101 |
s_null = torch.tensor([model.num_styles] * num_samples, device=self.device, dtype=torch.long)
|
| 102 |
|
| 103 |
# Start with random latents
|
| 104 |
+
latents = torch.randn((num_samples, 4, 32, 32), device=self.device, dtype=torch.float16)
|
| 105 |
|
| 106 |
# Use classifier-free guidance for better quality
|
| 107 |
cfg_scale = 2.5
|
|
|
|
| 138 |
|
| 139 |
# Decode latents to images
|
| 140 |
self.load_vae()
|
| 141 |
+
|
| 142 |
+
# Convert latents back to float32 for vae decoding
|
| 143 |
+
latents = latents.to(dtype=torch.float16) / self.vae.config.scaling_factor
|
| 144 |
latents = latents.to(self.device)
|
| 145 |
|
| 146 |
progress(0.95, desc="Decoding images...")
|
|
|
|
| 172 |
return random.randint(0, 2**32 - 1)
|
| 173 |
|
| 174 |
MODEL_SAMPLE_LIMITS = {
|
| 175 |
+
"S": {"min":1, "max": 16, "default": 4},
|
| 176 |
+
"B": {"min":1, "max": 12, "default": 4},
|
| 177 |
+
"L": {"min":1, "max": 4, "default": 1}
|
| 178 |
}
|
| 179 |
|
| 180 |
def update_sample_slider(dit_size):
|