CarolineM5 commited on
Commit
170b294
·
verified ·
1 Parent(s): c2ea64b

Upload inference.py

Browse files
Files changed (1) hide show
  1. inference.py +8 -5
inference.py CHANGED
@@ -64,14 +64,17 @@ def inference(pipe, img1, img2, num_steps):
64
  image = PIL.ImageOps.exif_transpose(image)
65
 
66
  all_images = []
 
67
  def cb_fn(step, timestep, latents):
68
- # latents(torch.Tensor) -> image via VAE decode
69
  with torch.no_grad():
70
- decoded = pipe.vae.decode(latents / pipe.vae.config.scaling_factor)
71
- # post‐traitement en PIL
72
- img = pipe.numpy_to_pil(decoded.cpu().clamp(0,1))[0]
 
 
73
  all_images.append(img)
74
-
75
  num_inference_steps = num_steps
76
  image_guidance_scale = 1.9
77
  guidance_scale = 10
 
64
  image = PIL.ImageOps.exif_transpose(image)
65
 
66
  all_images = []
67
+
68
  def cb_fn(step, timestep, latents):
69
+ # 1) Décoder les latents -> DecoderOutput
70
  with torch.no_grad():
71
+ decoded_output = pipe.vae.decode(latents / pipe.vae.config.scaling_factor)
72
+ # 2) Extraire le tenseur : .sample contient le batch de sorties
73
+ decoded_tensor = decoded_output.sample # type: torch.Tensor
74
+ # 3) Passer sur CPU, clampler et convertir en PIL
75
+ img = pipe.numpy_to_pil(decoded_tensor.cpu().clamp(0, 1))[0]
76
  all_images.append(img)
77
+
78
  num_inference_steps = num_steps
79
  image_guidance_scale = 1.9
80
  guidance_scale = 10