BiliSakura commited on
Commit
b88ec75
·
verified ·
1 Parent(s): 3b91ebd

Fix generator determinism: forward generator through scheduler steps and seeded noise

Browse files
edm2-img512-l-dino/pipeline.py CHANGED
@@ -350,6 +350,9 @@ class EDM2Pipeline(DiffusionPipeline):
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
 
 
 
353
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
 
@@ -375,13 +378,14 @@ class EDM2Pipeline(DiffusionPipeline):
375
  latents = self._sample_edm2_heun(
376
  denoise_fn=denoise_fn,
377
  noise=noise,
378
- sigmas=self.scheduler.sigmas.to(device),
379
  generator=generator,
380
  progress_bar=self.progress_bar,
381
  dtype=torch.float32,
382
  )
383
 
384
  image = self.decode_latents(latents, output_type=output_type)
 
385
  if not return_dict:
386
  return (image, latents)
387
  return ImagePipelineOutput(images=image)
 
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
353
+ self.unet.eval()
354
+ if getattr(self, "gnet", None) is not None:
355
+ self.gnet.eval()
356
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
357
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
358
 
 
378
  latents = self._sample_edm2_heun(
379
  denoise_fn=denoise_fn,
380
  noise=noise,
381
+ sigmas=self.scheduler.sigmas.to(device).clone(),
382
  generator=generator,
383
  progress_bar=self.progress_bar,
384
  dtype=torch.float32,
385
  )
386
 
387
  image = self.decode_latents(latents, output_type=output_type)
388
+ self.maybe_free_model_hooks()
389
  if not return_dict:
390
  return (image, latents)
391
  return ImagePipelineOutput(images=image)
edm2-img512-l-fid/pipeline.py CHANGED
@@ -350,6 +350,9 @@ class EDM2Pipeline(DiffusionPipeline):
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
 
 
 
353
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
 
@@ -375,13 +378,14 @@ class EDM2Pipeline(DiffusionPipeline):
375
  latents = self._sample_edm2_heun(
376
  denoise_fn=denoise_fn,
377
  noise=noise,
378
- sigmas=self.scheduler.sigmas.to(device),
379
  generator=generator,
380
  progress_bar=self.progress_bar,
381
  dtype=torch.float32,
382
  )
383
 
384
  image = self.decode_latents(latents, output_type=output_type)
 
385
  if not return_dict:
386
  return (image, latents)
387
  return ImagePipelineOutput(images=image)
 
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
353
+ self.unet.eval()
354
+ if getattr(self, "gnet", None) is not None:
355
+ self.gnet.eval()
356
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
357
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
358
 
 
378
  latents = self._sample_edm2_heun(
379
  denoise_fn=denoise_fn,
380
  noise=noise,
381
+ sigmas=self.scheduler.sigmas.to(device).clone(),
382
  generator=generator,
383
  progress_bar=self.progress_bar,
384
  dtype=torch.float32,
385
  )
386
 
387
  image = self.decode_latents(latents, output_type=output_type)
388
+ self.maybe_free_model_hooks()
389
  if not return_dict:
390
  return (image, latents)
391
  return ImagePipelineOutput(images=image)
edm2-img512-m-fid/pipeline.py CHANGED
@@ -350,6 +350,9 @@ class EDM2Pipeline(DiffusionPipeline):
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
 
 
 
353
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
 
@@ -375,13 +378,14 @@ class EDM2Pipeline(DiffusionPipeline):
375
  latents = self._sample_edm2_heun(
376
  denoise_fn=denoise_fn,
377
  noise=noise,
378
- sigmas=self.scheduler.sigmas.to(device),
379
  generator=generator,
380
  progress_bar=self.progress_bar,
381
  dtype=torch.float32,
382
  )
383
 
384
  image = self.decode_latents(latents, output_type=output_type)
 
385
  if not return_dict:
386
  return (image, latents)
387
  return ImagePipelineOutput(images=image)
 
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
353
+ self.unet.eval()
354
+ if getattr(self, "gnet", None) is not None:
355
+ self.gnet.eval()
356
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
357
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
358
 
 
378
  latents = self._sample_edm2_heun(
379
  denoise_fn=denoise_fn,
380
  noise=noise,
381
+ sigmas=self.scheduler.sigmas.to(device).clone(),
382
  generator=generator,
383
  progress_bar=self.progress_bar,
384
  dtype=torch.float32,
385
  )
386
 
387
  image = self.decode_latents(latents, output_type=output_type)
388
+ self.maybe_free_model_hooks()
389
  if not return_dict:
390
  return (image, latents)
391
  return ImagePipelineOutput(images=image)
edm2-img512-s-fid/pipeline.py CHANGED
@@ -350,6 +350,9 @@ class EDM2Pipeline(DiffusionPipeline):
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
 
 
 
353
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
 
@@ -375,13 +378,14 @@ class EDM2Pipeline(DiffusionPipeline):
375
  latents = self._sample_edm2_heun(
376
  denoise_fn=denoise_fn,
377
  noise=noise,
378
- sigmas=self.scheduler.sigmas.to(device),
379
  generator=generator,
380
  progress_bar=self.progress_bar,
381
  dtype=torch.float32,
382
  )
383
 
384
  image = self.decode_latents(latents, output_type=output_type)
 
385
  if not return_dict:
386
  return (image, latents)
387
  return ImagePipelineOutput(images=image)
 
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
353
+ self.unet.eval()
354
+ if getattr(self, "gnet", None) is not None:
355
+ self.gnet.eval()
356
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
357
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
358
 
 
378
  latents = self._sample_edm2_heun(
379
  denoise_fn=denoise_fn,
380
  noise=noise,
381
+ sigmas=self.scheduler.sigmas.to(device).clone(),
382
  generator=generator,
383
  progress_bar=self.progress_bar,
384
  dtype=torch.float32,
385
  )
386
 
387
  image = self.decode_latents(latents, output_type=output_type)
388
+ self.maybe_free_model_hooks()
389
  if not return_dict:
390
  return (image, latents)
391
  return ImagePipelineOutput(images=image)
edm2-img512-xl-fid/pipeline.py CHANGED
@@ -350,6 +350,9 @@ class EDM2Pipeline(DiffusionPipeline):
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
 
 
 
353
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
 
@@ -375,13 +378,14 @@ class EDM2Pipeline(DiffusionPipeline):
375
  latents = self._sample_edm2_heun(
376
  denoise_fn=denoise_fn,
377
  noise=noise,
378
- sigmas=self.scheduler.sigmas.to(device),
379
  generator=generator,
380
  progress_bar=self.progress_bar,
381
  dtype=torch.float32,
382
  )
383
 
384
  image = self.decode_latents(latents, output_type=output_type)
 
385
  if not return_dict:
386
  return (image, latents)
387
  return ImagePipelineOutput(images=image)
 
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
353
+ self.unet.eval()
354
+ if getattr(self, "gnet", None) is not None:
355
+ self.gnet.eval()
356
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
357
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
358
 
 
378
  latents = self._sample_edm2_heun(
379
  denoise_fn=denoise_fn,
380
  noise=noise,
381
+ sigmas=self.scheduler.sigmas.to(device).clone(),
382
  generator=generator,
383
  progress_bar=self.progress_bar,
384
  dtype=torch.float32,
385
  )
386
 
387
  image = self.decode_latents(latents, output_type=output_type)
388
+ self.maybe_free_model_hooks()
389
  if not return_dict:
390
  return (image, latents)
391
  return ImagePipelineOutput(images=image)
edm2-img512-xs-fid/pipeline.py CHANGED
@@ -350,6 +350,9 @@ class EDM2Pipeline(DiffusionPipeline):
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
 
 
 
353
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
 
@@ -375,13 +378,14 @@ class EDM2Pipeline(DiffusionPipeline):
375
  latents = self._sample_edm2_heun(
376
  denoise_fn=denoise_fn,
377
  noise=noise,
378
- sigmas=self.scheduler.sigmas.to(device),
379
  generator=generator,
380
  progress_bar=self.progress_bar,
381
  dtype=torch.float32,
382
  )
383
 
384
  image = self.decode_latents(latents, output_type=output_type)
 
385
  if not return_dict:
386
  return (image, latents)
387
  return ImagePipelineOutput(images=image)
 
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
353
+ self.unet.eval()
354
+ if getattr(self, "gnet", None) is not None:
355
+ self.gnet.eval()
356
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
357
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
358
 
 
378
  latents = self._sample_edm2_heun(
379
  denoise_fn=denoise_fn,
380
  noise=noise,
381
+ sigmas=self.scheduler.sigmas.to(device).clone(),
382
  generator=generator,
383
  progress_bar=self.progress_bar,
384
  dtype=torch.float32,
385
  )
386
 
387
  image = self.decode_latents(latents, output_type=output_type)
388
+ self.maybe_free_model_hooks()
389
  if not return_dict:
390
  return (image, latents)
391
  return ImagePipelineOutput(images=image)
edm2-img512-xxl-fid/pipeline.py CHANGED
@@ -350,6 +350,9 @@ class EDM2Pipeline(DiffusionPipeline):
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
 
 
 
353
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
354
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
355
 
@@ -375,13 +378,14 @@ class EDM2Pipeline(DiffusionPipeline):
375
  latents = self._sample_edm2_heun(
376
  denoise_fn=denoise_fn,
377
  noise=noise,
378
- sigmas=self.scheduler.sigmas.to(device),
379
  generator=generator,
380
  progress_bar=self.progress_bar,
381
  dtype=torch.float32,
382
  )
383
 
384
  image = self.decode_latents(latents, output_type=output_type)
 
385
  if not return_dict:
386
  return (image, latents)
387
  return ImagePipelineOutput(images=image)
 
350
 
351
  device = self._execution_device
352
  dtype = self.unet.dtype
353
+ self.unet.eval()
354
+ if getattr(self, "gnet", None) is not None:
355
+ self.gnet.eval()
356
  labels = self._normalize_class_labels(class_labels, batch_size=batch_size, device=device)
357
  noise = self.prepare_latents(batch_size, height, width, dtype, device, generator)
358
 
 
378
  latents = self._sample_edm2_heun(
379
  denoise_fn=denoise_fn,
380
  noise=noise,
381
+ sigmas=self.scheduler.sigmas.to(device).clone(),
382
  generator=generator,
383
  progress_bar=self.progress_bar,
384
  dtype=torch.float32,
385
  )
386
 
387
  image = self.decode_latents(latents, output_type=output_type)
388
+ self.maybe_free_model_hooks()
389
  if not return_dict:
390
  return (image, latents)
391
  return ImagePipelineOutput(images=image)