Spaces:
Sleeping
Sleeping
Upload inference.py
Browse files- inference.py +10 -6
inference.py
CHANGED
|
@@ -11,7 +11,7 @@ from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, Autoe
|
|
| 11 |
import numpy as np
|
| 12 |
import torch.nn as nn
|
| 13 |
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
|
| 14 |
-
|
| 15 |
|
| 16 |
class UNetNoCondWrapper(nn.Module):
|
| 17 |
def __init__(self, base_unet: UNet2DModel):
|
|
@@ -66,13 +66,17 @@ def inference(pipe, img1, img2, num_steps):
|
|
| 66 |
all_images = []
|
| 67 |
|
| 68 |
def cb_fn(step, timestep, latents):
|
| 69 |
-
# 1) Décoder
|
| 70 |
with torch.no_grad():
|
| 71 |
decoded_output = pipe.vae.decode(latents / pipe.vae.config.scaling_factor)
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
#
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
all_images.append(img)
|
| 77 |
|
| 78 |
num_inference_steps = num_steps
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import torch.nn as nn
|
| 13 |
from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
|
| 14 |
+
from PIL import Image
|
| 15 |
|
| 16 |
class UNetNoCondWrapper(nn.Module):
|
| 17 |
def __init__(self, base_unet: UNet2DModel):
|
|
|
|
| 66 |
all_images = []
|
| 67 |
|
| 68 |
def cb_fn(step, timestep, latents):
|
| 69 |
+
# 1) Décoder
|
| 70 |
with torch.no_grad():
|
| 71 |
decoded_output = pipe.vae.decode(latents / pipe.vae.config.scaling_factor)
|
| 72 |
+
decoded_tensor = decoded_output.sample # (B, C, H, W)
|
| 73 |
+
|
| 74 |
+
# 2) Transformer en NumPy (channels last) et en uint8 [0–255]
|
| 75 |
+
t = decoded_tensor.cpu().clamp(0, 1)[0] # (C, H, W)
|
| 76 |
+
arr = (t.permute(1, 2, 0).numpy() * 255).astype(np.uint8) # (H, W, C)
|
| 77 |
+
|
| 78 |
+
# 3) Créer la PIL.Image
|
| 79 |
+
img = Image.fromarray(arr)
|
| 80 |
all_images.append(img)
|
| 81 |
|
| 82 |
num_inference_steps = num_steps
|