BiliSakura commited on
Commit
bbc2355
·
verified ·
1 Parent(s): bb6fd5f

Update all files for DiffusionSat-SR-Texas-256

Browse files
Files changed (1) hide show
  1. pipeline_diffusionsat.py +17 -3
pipeline_diffusionsat.py CHANGED
@@ -5,11 +5,16 @@ from the checkpoint folder without importing the project package.
5
 
6
  from __future__ import annotations
7
 
 
8
  from typing import Any, Callable, Dict, List, Optional, Union
9
 
10
  import torch
11
  from packaging import version
12
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
 
 
 
 
13
 
14
  from diffusers.configuration_utils import FrozenDict
15
  from diffusers.models import AutoencoderKL
@@ -17,10 +22,10 @@ from diffusers.schedulers import KarrasDiffusionSchedulers
17
  from diffusers.utils import (
18
  deprecate,
19
  logging,
20
- randn_tensor,
21
  replace_example_docstring,
22
  is_accelerate_available,
23
  )
 
24
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
25
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
26
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -52,6 +57,16 @@ class DiffusionSatPipeline(DiffusionPipeline):
52
 
53
  _optional_components = ["safety_checker", "feature_extractor"]
54
 
 
 
 
 
 
 
 
 
 
 
55
  def __init__(
56
  self,
57
  vae: AutoencoderKL,
@@ -176,7 +191,6 @@ class DiffusionSatPipeline(DiffusionPipeline):
176
  return md
177
 
178
  @torch.no_grad()
179
- @replace_example_docstring(EXAMPLE_DOC_STRING)
180
  def __call__(
181
  self,
182
  prompt: Union[str, List[str]] = None,
 
5
 
6
  from __future__ import annotations
7
 
8
+ import inspect
9
  from typing import Any, Callable, Dict, List, Optional, Union
10
 
11
  import torch
12
  from packaging import version
13
+ try:
14
+ from transformers import CLIPFeatureExtractor
15
+ except ImportError: # transformers>=5.0 drops CLIPFeatureExtractor
16
+ from transformers import CLIPImageProcessor as CLIPFeatureExtractor
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
 
19
  from diffusers.configuration_utils import FrozenDict
20
  from diffusers.models import AutoencoderKL
 
22
  from diffusers.utils import (
23
  deprecate,
24
  logging,
 
25
  replace_example_docstring,
26
  is_accelerate_available,
27
  )
28
+ from diffusers.utils.torch_utils import randn_tensor
29
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
30
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
31
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
 
57
 
58
  _optional_components = ["safety_checker", "feature_extractor"]
59
 
60
+ @classmethod
61
+ def _get_signature_types(cls):
62
+ """
63
+ Override to skip strict type resolution when loading via diffusers 0.36,
64
+ which cannot resolve the forward references in these custom modules.
65
+ """
66
+ required, optional = DiffusionPipeline._get_signature_keys(cls)
67
+ keys = list(required) + list(optional)
68
+ return {key: (inspect.Signature.empty,) for key in keys}
69
+
70
  def __init__(
71
  self,
72
  vae: AutoencoderKL,
 
191
  return md
192
 
193
  @torch.no_grad()
 
194
  def __call__(
195
  self,
196
  prompt: Union[str, List[str]] = None,