BiliSakura commited on
Commit
d4cb4c6
·
verified ·
1 Parent(s): de5f0e6

Update all files for DiffusionSat-SR-Texas-256

Browse files
Files changed (1) hide show
  1. pipeline_diffusionsat_controlnet.py +18 -6
pipeline_diffusionsat_controlnet.py CHANGED
@@ -8,6 +8,13 @@ from __future__ import annotations
8
  import os
9
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
 
 
 
 
 
 
 
 
11
  import einops
12
  import numpy as np
13
  import PIL.Image
@@ -22,11 +29,11 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
22
  from diffusers.utils import (
23
  PIL_INTERPOLATION,
24
  logging,
25
- randn_tensor,
26
  replace_example_docstring,
27
  is_accelerate_available,
28
  is_accelerate_version,
29
  )
 
30
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
31
  from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
32
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -83,11 +90,11 @@ class DiffusionSatControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMi
83
  vae: AutoencoderKL,
84
  text_encoder: CLIPTextModel,
85
  tokenizer: CLIPTokenizer,
86
- unet: Any,
87
- controlnet: Any,
88
  scheduler: KarrasDiffusionSchedulers,
89
- safety_checker: StableDiffusionSafetyChecker,
90
- feature_extractor: CLIPImageProcessor,
91
  requires_safety_checker: bool = True,
92
  ):
93
  super().__init__()
@@ -229,7 +236,12 @@ class DiffusionSatControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMi
229
  cond_metadata: Optional[List[float]] = None,
230
  is_temporal: bool = False,
231
  conditioning_downsample: bool = True,
232
- ):
 
 
 
 
 
233
  # 0. Default height and width to unet
234
  height, width = self._default_height_width(height, width, image)
235
  cond_height, cond_width = height, width
 
8
  import os
9
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
10
 
11
+ # Import types directly (not just for TYPE_CHECKING) so diffusers can introspect them
12
+ from diffusers.models import AutoencoderKL
13
+ from diffusers.models.controlnets import ControlNetModel
14
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
15
+ from diffusers.schedulers import KarrasDiffusionSchedulers
16
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
17
+
18
  import einops
19
  import numpy as np
20
  import PIL.Image
 
29
  from diffusers.utils import (
30
  PIL_INTERPOLATION,
31
  logging,
 
32
  replace_example_docstring,
33
  is_accelerate_available,
34
  is_accelerate_version,
35
  )
36
+ from diffusers.utils.torch_utils import randn_tensor
37
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
38
  from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
39
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
90
  vae: AutoencoderKL,
91
  text_encoder: CLIPTextModel,
92
  tokenizer: CLIPTokenizer,
93
+ unet: UNet2DConditionModel,
94
+ controlnet: ControlNetModel,
95
  scheduler: KarrasDiffusionSchedulers,
96
+ safety_checker: Optional[StableDiffusionSafetyChecker] = None,
97
+ feature_extractor: Optional[CLIPImageProcessor] = None,
98
  requires_safety_checker: bool = True,
99
  ):
100
  super().__init__()
 
236
  cond_metadata: Optional[List[float]] = None,
237
  is_temporal: bool = False,
238
  conditioning_downsample: bool = True,
239
+ ) -> Union[StableDiffusionPipelineOutput, Tuple]:
240
+ """
241
+ Function invoked when calling the pipeline for generation.
242
+
243
+ Examples:
244
+ """
245
  # 0. Default height and width to unet
246
  height, width = self._default_height_width(height, width, image)
247
  cond_height, cond_width = height, width