BiliSakura commited on
Commit
f11fd20
·
verified ·
1 Parent(s): 42070c7

Update all files for DiffusionSat-Single-256

Browse files
Files changed (1) hide show
  1. pipeline_diffusionsat.py +16 -3
pipeline_diffusionsat.py CHANGED
@@ -9,7 +9,11 @@ from typing import Any, Callable, Dict, List, Optional, Union
9
 
10
  import torch
11
  from packaging import version
12
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
 
 
 
 
13
 
14
  from diffusers.configuration_utils import FrozenDict
15
  from diffusers.models import AutoencoderKL
@@ -17,10 +21,13 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
17
  from diffusers.utils import (
18
  deprecate,
19
  logging,
20
- randn_tensor,
21
  replace_example_docstring,
22
  is_accelerate_available,
23
  )
 
 
 
 
24
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
26
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -60,7 +67,7 @@ class DiffusionSatPipeline(DiffusionPipeline):
60
  unet: Any,
61
  scheduler: KarrasDiffusionSchedulers,
62
  safety_checker: StableDiffusionSafetyChecker,
63
- feature_extractor: CLIPFeatureExtractor,
64
  requires_safety_checker: bool = True,
65
  ):
66
  super().__init__()
@@ -198,6 +205,12 @@ class DiffusionSatPipeline(DiffusionPipeline):
198
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
199
  metadata: Optional[List[float]] = None,
200
  ):
 
 
 
 
 
 
201
  # 0. Default height and width to unet
202
  height = height or self.unet.config.sample_size * self.vae_scale_factor
203
  width = width or self.unet.config.sample_size * self.vae_scale_factor
 
9
 
10
  import torch
11
  from packaging import version
12
+ from transformers import CLIPTextModel, CLIPTokenizer
13
+ try:
14
+ from transformers import CLIPImageProcessor
15
+ except ImportError:
16
+ from transformers import CLIPFeatureExtractor as CLIPImageProcessor
17
 
18
  from diffusers.configuration_utils import FrozenDict
19
  from diffusers.models import AutoencoderKL
 
21
  from diffusers.utils import (
22
  deprecate,
23
  logging,
 
24
  replace_example_docstring,
25
  is_accelerate_available,
26
  )
27
+ try:
28
+ from diffusers.utils import randn_tensor
29
+ except ImportError:
30
+ from diffusers.utils.torch_utils import randn_tensor
31
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
32
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
33
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
67
  unet: Any,
68
  scheduler: KarrasDiffusionSchedulers,
69
  safety_checker: StableDiffusionSafetyChecker,
70
+ feature_extractor: CLIPImageProcessor,
71
  requires_safety_checker: bool = True,
72
  ):
73
  super().__init__()
 
205
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
206
  metadata: Optional[List[float]] = None,
207
  ):
208
+ """
209
+ Run inference (text-to-image with optional metadata).
210
+
211
+ Examples:
212
+
213
+ """
214
  # 0. Default height and width to unet
215
  height = height or self.unet.config.sample_size * self.vae_scale_factor
216
  width = width or self.unet.config.sample_size * self.vae_scale_factor