BiliSakura commited on
Commit
113a3ed
·
verified ·
1 Parent(s): 62f39a6

Update all files for DiffusionSat-Single-512

Browse files
Files changed (1) hide show
  1. pipeline_diffusionsat.py +10 -1
pipeline_diffusionsat.py CHANGED
@@ -21,10 +21,13 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
21
  from diffusers.utils import (
22
  deprecate,
23
  logging,
24
- randn_tensor,
25
  replace_example_docstring,
26
  is_accelerate_available,
27
  )
 
 
 
 
28
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
29
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
30
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -202,6 +205,12 @@ class DiffusionSatPipeline(DiffusionPipeline):
202
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
203
  metadata: Optional[List[float]] = None,
204
  ):
 
 
 
 
 
 
205
  # 0. Default height and width to unet
206
  height = height or self.unet.config.sample_size * self.vae_scale_factor
207
  width = width or self.unet.config.sample_size * self.vae_scale_factor
 
21
  from diffusers.utils import (
22
  deprecate,
23
  logging,
 
24
  replace_example_docstring,
25
  is_accelerate_available,
26
  )
27
+ try:
28
+ from diffusers.utils import randn_tensor
29
+ except ImportError:
30
+ from diffusers.utils.torch_utils import randn_tensor
31
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
205
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
206
  metadata: Optional[List[float]] = None,
207
  ):
208
+ """
209
+ Run inference (text-to-image with optional metadata).
210
+
211
+ Examples:
212
+
213
+ """
214
  # 0. Default height and width to unet
215
  height = height or self.unet.config.sample_size * self.vae_scale_factor
216
  width = width or self.unet.config.sample_size * self.vae_scale_factor