Update app.py
Browse files
app.py
CHANGED
|
@@ -47,10 +47,8 @@ if torch.cuda.is_available():
|
|
| 47 |
previewer = Previewer()
|
| 48 |
previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
|
| 49 |
previewer.load_state_dict(previewer_state_dict)
|
| 50 |
-
def callback_prior(
|
| 51 |
-
latents = kwargs["latents"]
|
| 52 |
output = previewer(latents)
|
| 53 |
-
print(output)
|
| 54 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
|
| 55 |
return output
|
| 56 |
callback_steps = 1
|
|
@@ -100,7 +98,8 @@ def generate(
|
|
| 100 |
guidance_scale=prior_guidance_scale,
|
| 101 |
num_images_per_prompt=num_images_per_prompt,
|
| 102 |
generator=generator,
|
| 103 |
-
|
|
|
|
| 104 |
)
|
| 105 |
|
| 106 |
if PREVIEW_IMAGES:
|
|
|
|
| 47 |
previewer = Previewer()
|
| 48 |
previewer_state_dict = torch.load("previewer/previewer_v1_100k.pt", map_location=torch.device('cpu'))["state_dict"]
|
| 49 |
previewer.load_state_dict(previewer_state_dict)
|
| 50 |
+
def callback_prior(i, t, latents):
|
|
|
|
| 51 |
output = previewer(latents)
|
|
|
|
| 52 |
output = numpy_to_pil(output.clamp(0, 1).permute(0, 2, 3, 1).float().cpu().numpy())
|
| 53 |
return output
|
| 54 |
callback_steps = 1
|
|
|
|
| 98 |
guidance_scale=prior_guidance_scale,
|
| 99 |
num_images_per_prompt=num_images_per_prompt,
|
| 100 |
generator=generator,
|
| 101 |
+
callback=callback_prior,
|
| 102 |
+
callback_steps=callback_steps
|
| 103 |
)
|
| 104 |
|
| 105 |
if PREVIEW_IMAGES:
|