developerskyebrowse commited on
Commit
2be296c
·
1 Parent(s): a9de8a3

preprocessor class

Browse files
Files changed (1) hide show
  1. app.py +24 -3
app.py CHANGED
@@ -10,6 +10,7 @@ import random
10
  import time
11
  import gradio as gr
12
  import numpy as np
 
13
  # import imageio
14
  import torch
15
  from PIL import Image
@@ -19,7 +20,8 @@ from diffusers import (
19
  StableDiffusionControlNetPipeline,
20
  AutoencoderKL,
21
  )
22
- # from diffusers.models.attention_processor import AttnProcessor2_0
 
23
  MAX_SEED = np.iinfo(np.int32).max
24
  API_KEY = os.environ.get("API_KEY", None)
25
 
@@ -27,7 +29,27 @@ print("CUDA version:", torch.version.cuda)
27
  print("loading everything")
28
  compiled = False
29
 
30
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  torch.cuda.max_memory_allocated(device="cuda")
33
  # Controlnet Normal
@@ -87,7 +109,6 @@ pipe.load_textual_inversion("broyang/hentaidigitalart_v20", weight_name="HDA_Shi
87
  pipe.to("cuda")
88
 
89
  print("loading preprocessor")
90
- from preprocess import Preprocessor
91
  preprocessor = Preprocessor()
92
  # preprocessor.load("NormalBae")
93
 
 
10
  import time
11
  import gradio as gr
12
  import numpy as np
13
+ import spaces
14
  # import imageio
15
  import torch
16
  from PIL import Image
 
20
  StableDiffusionControlNetPipeline,
21
  AutoencoderKL,
22
  )
23
+ from controlnet_aux_local import NormalBaeDetector
24
+
25
  MAX_SEED = np.iinfo(np.int32).max
26
  API_KEY = os.environ.get("API_KEY", None)
27
 
 
29
  print("loading everything")
30
  compiled = False
31
 
32
+ class Preprocessor:
33
+ MODEL_ID = "lllyasviel/Annotators"
34
+
35
+ def __init__(self):
36
+ self.model = None
37
+ self.name = ""
38
+
39
+ def load(self, name: str) -> None:
40
+ if name == self.name:
41
+ return
42
+ elif name == "NormalBae":
43
+ print("Loading NormalBae")
44
+ self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
45
+ torch.cuda.empty_cache()
46
+ self.name = name
47
+ else:
48
+ raise ValueError
49
+ return
50
+
51
+ def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
52
+ return self.model(image, **kwargs)
53
 
54
  torch.cuda.max_memory_allocated(device="cuda")
55
  # Controlnet Normal
 
109
  pipe.to("cuda")
110
 
111
  print("loading preprocessor")
 
112
  preprocessor = Preprocessor()
113
  # preprocessor.load("NormalBae")
114