CarolineM5 commited on
Commit
a9653d6
·
verified ·
1 Parent(s): 170b294

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +10 -6
inference.py CHANGED
@@ -11,7 +11,7 @@ from diffusers import StableDiffusionInstructPix2PixPipeline, UNet2DModel, Autoe
11
  import numpy as np
12
  import torch.nn as nn
13
  from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
-
15
 
16
  class UNetNoCondWrapper(nn.Module):
17
  def __init__(self, base_unet: UNet2DModel):
@@ -66,13 +66,17 @@ def inference(pipe, img1, img2, num_steps):
66
  all_images = []
67
 
68
  def cb_fn(step, timestep, latents):
69
- # 1) Décoder les latents -> DecoderOutput
70
  with torch.no_grad():
71
  decoded_output = pipe.vae.decode(latents / pipe.vae.config.scaling_factor)
72
- # 2) Extraire le tenseur : .sample contient le batch de sorties
73
- decoded_tensor = decoded_output.sample # type: torch.Tensor
74
- # 3) Passer sur CPU, clampler et convertir en PIL
75
- img = pipe.numpy_to_pil(decoded_tensor.cpu().clamp(0, 1))[0]
 
 
 
 
76
  all_images.append(img)
77
 
78
  num_inference_steps = num_steps
 
11
  import numpy as np
12
  import torch.nn as nn
13
  from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
+ from PIL import Image
15
 
16
  class UNetNoCondWrapper(nn.Module):
17
  def __init__(self, base_unet: UNet2DModel):
 
66
  all_images = []
67
 
68
  def cb_fn(step, timestep, latents):
69
+ # 1) Décoder
70
  with torch.no_grad():
71
  decoded_output = pipe.vae.decode(latents / pipe.vae.config.scaling_factor)
72
+ decoded_tensor = decoded_output.sample # (B, C, H, W)
73
+
74
+ # 2) Transformer en NumPy (channels last) et en uint8 [0–255]
75
+ t = decoded_tensor.cpu().clamp(0, 1)[0] # (C, H, W)
76
+ arr = (t.permute(1, 2, 0).numpy() * 255).astype(np.uint8) # (H, W, C)
77
+
78
+ # 3) Créer la PIL.Image
79
+ img = Image.fromarray(arr)
80
  all_images.append(img)
81
 
82
  num_inference_steps = num_steps