Bobby commited on
Commit
4d4920b
·
1 Parent(s): 1667c1c

testing out caching

Browse files
Files changed (1) hide show
  1. app.py +64 -9
app.py CHANGED
@@ -19,12 +19,14 @@ from diffusers import (
19
  ControlNetModel,
20
  DPMSolverMultistepScheduler,
21
  StableDiffusionControlNetPipeline,
22
- AutoencoderKL,
23
  )
 
24
  from controlnet_aux_local import NormalBaeDetector
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
27
  API_KEY = os.environ.get("API_KEY", None)
 
28
 
29
  print("CUDA version:", torch.version.cuda)
30
  print("loading everything")
@@ -37,34 +39,75 @@ class Preprocessor:
37
  self.model = None
38
  self.name = ""
39
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def load(self, name: str) -> None:
41
  if name == self.name:
42
  return
43
  elif name == "NormalBae":
44
  print("Loading NormalBae")
45
- self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
 
46
  torch.cuda.empty_cache()
47
  self.name = name
48
  else:
49
  raise ValueError
50
  return
51
-
52
  def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
53
  return self.model(image, **kwargs)
54
 
55
  torch.cuda.max_memory_allocated(device="cuda")
 
 
 
 
 
 
 
 
 
 
56
  # Controlnet Normal
57
  model_id = "lllyasviel/control_v11p_sd15_normalbae"
58
  print("initializing controlnet")
 
59
  controlnet = ControlNetModel.from_pretrained(
60
- model_id,
61
  torch_dtype=torch.float16,
62
  attn_implementation="flash_attention_2",
63
  ).to("cuda")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # Scheduler
 
 
66
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
67
- "runwayml/stable-diffusion-v1-5",
68
  solver_order=2,
69
  subfolder="scheduler",
70
  use_karras_sigmas=True,
@@ -86,18 +129,30 @@ vae_url = "https://huggingface.co/stabilityai/sd-vae-ft-mse-original/blob/main/v
86
  # vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
87
  # vae.to(memory_format=torch.channels_last)
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  print('loading pipe')
90
  pipe = StableDiffusionControlNetPipeline.from_single_file(
91
- base_model_url,
92
  safety_checker=None,
93
- # load_safety_checker=True,
94
  controlnet=controlnet,
95
  scheduler=scheduler,
96
- # vae=vae,
97
  torch_dtype=torch.float16,
98
  )
99
 
100
-
101
  print("loading preprocessor")
102
  preprocessor = Preprocessor()
103
  preprocessor.load("NormalBae")
 
19
  ControlNetModel,
20
  DPMSolverMultistepScheduler,
21
  StableDiffusionControlNetPipeline,
22
+ # AutoencoderKL,
23
  )
24
+ from huggingface_hub import cached_download, hf_hub_url
25
  from controlnet_aux_local import NormalBaeDetector
26
 
27
  MAX_SEED = np.iinfo(np.int32).max
28
  API_KEY = os.environ.get("API_KEY", None)
29
+ os.environ['HF_HOME'] = '/data/.huggingface'
30
 
31
  print("CUDA version:", torch.version.cuda)
32
  print("loading everything")
 
39
  self.model = None
40
  self.name = ""
41
 
42
+ # def load(self, name: str) -> None:
43
+ # if name == self.name:
44
+ # return
45
+ # elif name == "NormalBae":
46
+ # print("Loading NormalBae")
47
+ # self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID).to("cuda")
48
+ # torch.cuda.empty_cache()
49
+ # self.name = name
50
+ # else:
51
+ # raise ValueError
52
+ # return
53
+
54
  def load(self, name: str) -> None:
55
  if name == self.name:
56
  return
57
  elif name == "NormalBae":
58
  print("Loading NormalBae")
59
+ model_file = cached_download(hf_hub_url(self.MODEL_ID, filename="NormalBaeDetector.pth"))
60
+ self.model = NormalBaeDetector.from_pretrained(model_file).to("cuda")
61
  torch.cuda.empty_cache()
62
  self.name = name
63
  else:
64
  raise ValueError
65
  return
66
+
67
  def __call__(self, image: Image.Image, **kwargs) -> Image.Image:
68
  return self.model(image, **kwargs)
69
 
70
  torch.cuda.max_memory_allocated(device="cuda")
71
+
72
+ # # Controlnet Normal
73
+ # model_id = "lllyasviel/control_v11p_sd15_normalbae"
74
+ # print("initializing controlnet")
75
+ # controlnet = ControlNetModel.from_pretrained(
76
+ # model_id,
77
+ # torch_dtype=torch.float16,
78
+ # attn_implementation="flash_attention_2",
79
+ # ).to("cuda")
80
+
81
  # Controlnet Normal
82
  model_id = "lllyasviel/control_v11p_sd15_normalbae"
83
  print("initializing controlnet")
84
+ controlnet_file = cached_download(hf_hub_url(model_id, filename="diffusion_pytorch_model.safetensors"))
85
  controlnet = ControlNetModel.from_pretrained(
86
+ controlnet_file,
87
  torch_dtype=torch.float16,
88
  attn_implementation="flash_attention_2",
89
  ).to("cuda")
90
 
91
+ # # Scheduler
92
+ # scheduler = DPMSolverMultistepScheduler.from_pretrained(
93
+ # "runwayml/stable-diffusion-v1-5",
94
+ # solver_order=2,
95
+ # subfolder="scheduler",
96
+ # use_karras_sigmas=True,
97
+ # final_sigmas_type="sigma_min",
98
+ # algorithm_type="sde-dpmsolver++",
99
+ # prediction_type="epsilon",
100
+ # thresholding=False,
101
+ # denoise_final=True,
102
+ # device_map="cuda",
103
+ # torch_dtype=torch.float16,
104
+ # )
105
+
106
  # Scheduler
107
+ scheduler_repo = "runwayml/stable-diffusion-v1-5"
108
+ scheduler_file = cached_download(hf_hub_url(scheduler_repo, filename="scheduler/scheduler_config.json", subfolder="scheduler"))
109
  scheduler = DPMSolverMultistepScheduler.from_pretrained(
110
+ scheduler_file,
111
  solver_order=2,
112
  subfolder="scheduler",
113
  use_karras_sigmas=True,
 
129
  # vae = AutoencoderKL.from_single_file(vae_url, torch_dtype=torch.float16).to("cuda")
130
  # vae.to(memory_format=torch.channels_last)
131
 
132
+ # print('loading pipe')
133
+ # pipe = StableDiffusionControlNetPipeline.from_single_file(
134
+ # base_model_url,
135
+ # safety_checker=None,
136
+ # # load_safety_checker=True,
137
+ # controlnet=controlnet,
138
+ # scheduler=scheduler,
139
+ # # vae=vae,
140
+ # torch_dtype=torch.float16,
141
+ # )
142
+
143
+ # Stable Diffusion Pipeline
144
+ base_model_repo = "Lykon/AbsoluteReality"
145
+ base_model_file = cached_download(hf_hub_url(base_model_repo, filename="AbsoluteReality_1.8.1_pruned.safetensors"))
146
+
147
  print('loading pipe')
148
  pipe = StableDiffusionControlNetPipeline.from_single_file(
149
+ base_model_file,
150
  safety_checker=None,
 
151
  controlnet=controlnet,
152
  scheduler=scheduler,
 
153
  torch_dtype=torch.float16,
154
  )
155
 
 
156
  print("loading preprocessor")
157
  preprocessor = Preprocessor()
158
  preprocessor.load("NormalBae")