stanley commited on
Commit
ef0bf45
·
1 Parent(s): 47a0dbd

restructure for gpu

Browse files
Files changed (1) hide show
  1. app.py +468 -177
app.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import io
2
  import base64
3
  import os
@@ -86,15 +89,40 @@ USE_GLID = False
86
  # except:
87
  # USE_GLID = False
88
 
89
- try:
90
- import onnxruntime
91
- onnx_available = True
92
- onnx_providers = ["CUDAExecutionProvider", "DmlExecutionProvider", "OpenVINOExecutionProvider", 'CPUExecutionProvider']
93
- available_providers = onnxruntime.get_available_providers()
94
- onnx_providers = [item for item in onnx_providers if item in available_providers]
95
- except:
96
- onnx_available = False
97
- onnx_providers = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
  try:
100
  cuda_available = torch.cuda.is_available()
@@ -108,17 +136,17 @@ finally:
108
  else:
109
  device = "cpu"
110
 
111
- if device != "cuda":
112
- import contextlib
113
 
114
- autocast = contextlib.nullcontext
115
 
116
  with open("config.yaml", "r") as yaml_in:
117
  yaml_object = yaml.safe_load(yaml_in)
118
  config_json = json.dumps(yaml_object)
119
-
120
 
121
 
 
 
122
  def load_html():
123
  body, canvaspy = "", ""
124
  with open("index.html", encoding="utf8") as f:
@@ -315,12 +343,13 @@ class StableDiffusionInpaint:
315
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
316
  # if device == "cuda" and not args.fp32:
317
  # vae.to(torch.float16)
 
318
  if original_checkpoint:
319
  print(f"Converting & Loading {model_path}")
320
  from convert_checkpoint import convert_checkpoint
321
 
322
  pipe = convert_checkpoint(model_path, inpainting=True)
323
- if device == "cuda" and not args.fp32:
324
  pipe.to(torch.float16)
325
  inpaint = StableDiffusionInpaintPipeline(
326
  vae=vae,
@@ -333,7 +362,7 @@ class StableDiffusionInpaint:
333
  )
334
  else:
335
  print(f"Loading {model_name}")
336
- if device == "cuda" and not args.fp32:
337
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
338
  model_name,
339
  revision="fp16",
@@ -476,7 +505,6 @@ class StableDiffusionInpaint:
476
  )["images"]
477
  return images
478
 
479
-
480
  class StableDiffusion:
481
  def __init__(
482
  self,
@@ -488,134 +516,74 @@ class StableDiffusion:
488
  ):
489
  self.token = token
490
  original_checkpoint = False
491
- if device=="cpu" and onnx_available:
492
- from diffusers import OnnxStableDiffusionPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionImg2ImgPipeline
493
- text2img = OnnxStableDiffusionPipeline.from_pretrained(
494
- model_name,
495
- revision="onnx",
496
- provider=onnx_providers[0] if onnx_providers else None
497
- )
498
- inpaint = OnnxStableDiffusionInpaintPipelineLegacy(
499
- vae_encoder=text2img.vae_encoder,
500
- vae_decoder=text2img.vae_decoder,
501
- text_encoder=text2img.text_encoder,
502
- tokenizer=text2img.tokenizer,
503
- unet=text2img.unet,
504
- scheduler=text2img.scheduler,
505
- safety_checker=text2img.safety_checker,
506
- feature_extractor=text2img.feature_extractor,
507
- )
508
- img2img = OnnxStableDiffusionImg2ImgPipeline(
509
- vae_encoder=text2img.vae_encoder,
510
- vae_decoder=text2img.vae_decoder,
511
- text_encoder=text2img.text_encoder,
512
- tokenizer=text2img.tokenizer,
513
- unet=text2img.unet,
514
- scheduler=text2img.scheduler,
515
- safety_checker=text2img.safety_checker,
516
- feature_extractor=text2img.feature_extractor,
517
- )
518
  else:
519
- if model_path and os.path.exists(model_path):
520
- if model_path.endswith(".ckpt"):
521
- original_checkpoint = True
522
- elif model_path.endswith(".json"):
523
- model_name = os.path.dirname(model_path)
524
- else:
525
- model_name = model_path
526
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
527
  if device == "cuda" and not args.fp32:
528
- vae.to(torch.float16)
529
- if original_checkpoint:
530
- print(f"Converting & Loading {model_path}")
531
- from convert_checkpoint import convert_checkpoint
532
-
533
- pipe = convert_checkpoint(model_path)
534
- if device == "cuda" and not args.fp32:
535
- pipe.to(torch.float16)
536
- text2img = StableDiffusionPipeline(
537
- vae=vae,
538
- text_encoder=pipe.text_encoder,
539
- tokenizer=pipe.tokenizer,
540
- unet=pipe.unet,
541
- scheduler=pipe.scheduler,
542
- safety_checker=pipe.safety_checker,
543
- feature_extractor=pipe.feature_extractor,
544
  )
545
  else:
546
- print(f"Loading {model_name}")
547
- if device == "cuda" and not args.fp32:
548
- text2img = StableDiffusionPipeline.from_pretrained(
549
- model_name,
550
- revision="fp16",
551
- torch_dtype=torch.float16,
552
- use_auth_token=token,
553
- vae=vae,
554
- )
555
- else:
556
- text2img = StableDiffusionPipeline.from_pretrained(
557
- model_name, use_auth_token=token, vae=vae
558
- )
559
- if inpainting_model:
560
- # can reduce vRAM by reusing models except unet
561
- text2img_unet = text2img.unet
562
- del text2img.vae
563
- del text2img.text_encoder
564
- del text2img.tokenizer
565
- del text2img.scheduler
566
- del text2img.safety_checker
567
- del text2img.feature_extractor
568
- import gc
569
-
570
- gc.collect()
571
- if device == "cuda" and not args.fp32:
572
- inpaint = StableDiffusionInpaintPipeline.from_pretrained(
573
- "runwayml/stable-diffusion-inpainting",
574
- revision="fp16",
575
- torch_dtype=torch.float16,
576
- use_auth_token=token,
577
- vae=vae,
578
- ).to(device)
579
- else:
580
- inpaint = StableDiffusionInpaintPipeline.from_pretrained(
581
- "runwayml/stable-diffusion-inpainting",
582
- use_auth_token=token,
583
- vae=vae,
584
- ).to(device)
585
- text2img_unet.to(device)
586
- text2img = StableDiffusionPipeline(
587
- vae=inpaint.vae,
588
- text_encoder=inpaint.text_encoder,
589
- tokenizer=inpaint.tokenizer,
590
- unet=text2img_unet,
591
- scheduler=inpaint.scheduler,
592
- safety_checker=inpaint.safety_checker,
593
- feature_extractor=inpaint.feature_extractor,
594
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
595
  else:
596
- inpaint = StableDiffusionInpaintPipelineLegacy(
597
- vae=text2img.vae,
598
- text_encoder=text2img.text_encoder,
599
- tokenizer=text2img.tokenizer,
600
- unet=text2img.unet,
601
- scheduler=text2img.scheduler,
602
- safety_checker=text2img.safety_checker,
603
- feature_extractor=text2img.feature_extractor,
604
  ).to(device)
605
- text_encoder = text2img.text_encoder
606
- tokenizer = text2img.tokenizer
607
- if os.path.exists("./embeddings"):
608
- for item in os.listdir("./embeddings"):
609
- if item.endswith(".bin"):
610
- load_learned_embed_in_clip(
611
- os.path.join("./embeddings", item),
612
- text2img.text_encoder,
613
- text2img.tokenizer,
614
- )
615
- text2img.to(device)
616
- if device == "mps":
617
- _ = text2img("", num_inference_steps=1)
618
- img2img = StableDiffusionImg2ImgPipeline(
619
  vae=text2img.vae,
620
  text_encoder=text2img.text_encoder,
621
  tokenizer=text2img.tokenizer,
@@ -624,6 +592,19 @@ class StableDiffusion:
624
  safety_checker=text2img.safety_checker,
625
  feature_extractor=text2img.feature_extractor,
626
  ).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
627
  scheduler_dict["PLMS"] = text2img.scheduler
628
  scheduler_dict["DDIM"] = prepare_scheduler(
629
  DDIMScheduler(
@@ -639,44 +620,40 @@ class StableDiffusion:
639
  beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
640
  )
641
  )
642
- scheduler_dict["PNDM"] = prepare_scheduler(
643
- PNDMScheduler(
644
- beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
645
- skip_prk_steps=True
646
- )
647
- )
648
  scheduler_dict["DPM"] = prepare_scheduler(
649
  DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
650
  )
651
  self.safety_checker = text2img.safety_checker
 
 
 
 
 
 
 
 
 
652
  save_token(token)
653
  try:
654
  total_memory = torch.cuda.get_device_properties(0).total_memory // (
655
  1024 ** 3
656
  )
657
- if total_memory <= 5 or args.lowvram:
658
  inpaint.enable_attention_slicing()
659
- inpaint.enable_sequential_cpu_offload()
660
- if inpainting_model:
661
- text2img.enable_attention_slicing()
662
- text2img.enable_sequential_cpu_offload()
663
  except:
664
  pass
665
  self.text2img = text2img
666
  self.inpaint = inpaint
667
  self.img2img = img2img
668
- if True:
669
- self.unified = inpaint
670
- else:
671
- self.unified = UnifiedPipeline(
672
- vae=text2img.vae,
673
- text_encoder=text2img.text_encoder,
674
- tokenizer=text2img.tokenizer,
675
- unet=text2img.unet,
676
- scheduler=text2img.scheduler,
677
- safety_checker=text2img.safety_checker,
678
- feature_extractor=text2img.feature_extractor,
679
- ).to(device)
680
  self.inpainting_model = inpainting_model
681
 
682
  def run(
@@ -707,7 +684,7 @@ class StableDiffusion:
707
  selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
708
  for item in [text2img, inpaint, img2img, unified]:
709
  item.scheduler = selected_scheduler
710
- if enable_safety or self.safety_checker is None:
711
  item.safety_checker = self.safety_checker
712
  else:
713
  item.safety_checker = lambda images, **kwargs: (images, False)
@@ -743,7 +720,7 @@ class StableDiffusion:
743
  if True:
744
  images = img2img(
745
  prompt=prompt,
746
- image=init_image.resize(
747
  (process_width, process_height), resample=SAMPLING_MODE
748
  ),
749
  strength=strength,
@@ -753,40 +730,33 @@ class StableDiffusion:
753
  if fill_mode == "g_diffuser" and not self.inpainting_model:
754
  mask = 255 - mask
755
  mask = mask[:, :, np.newaxis].repeat(3, axis=2)
756
- img, mask = functbl[fill_mode](img, mask)
757
  extra_kwargs["strength"] = 1.0
758
- extra_kwargs["out_mask"] = Image.fromarray(mask)
759
  inpaint_func = unified
760
  else:
761
  img, mask = functbl[fill_mode](img, mask)
762
  mask = 255 - mask
763
  mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
764
  mask = mask.repeat(8, axis=0).repeat(8, axis=1)
 
765
  inpaint_func = inpaint
766
  init_image = Image.fromarray(img)
767
  mask_image = Image.fromarray(mask)
768
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
769
- input_image = init_image.resize(
770
- (process_width, process_height), resample=SAMPLING_MODE
771
- )
772
- if self.inpainting_model:
773
  images = inpaint_func(
774
  prompt=prompt,
 
775
  image=input_image,
776
  width=process_width,
777
  height=process_height,
778
  mask_image=mask_image.resize((process_width, process_height)),
779
  **extra_kwargs,
780
  )["images"]
781
- else:
782
- extra_kwargs["strength"] = strength
783
- if True:
784
- images = inpaint_func(
785
- prompt=prompt,
786
- image=input_image,
787
- mask_image=mask_image.resize((process_width, process_height)),
788
- **extra_kwargs,
789
- )["images"]
790
  else:
791
  if True:
792
  images = text2img(
@@ -798,6 +768,327 @@ class StableDiffusion:
798
  return images
799
 
800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  def get_model(token="", model_choice="", model_path=""):
802
  if "model" not in model:
803
  model_name = ""
 
1
+ import subprocess
2
+ import pip
3
+
4
  import io
5
  import base64
6
  import os
 
89
  # except:
90
  # USE_GLID = False
91
 
92
+ # ******** ORIGINAL ***********
93
+ # try:
94
+ # import onnxruntime
95
+ # onnx_available = True
96
+ # onnx_providers = ["CUDAExecutionProvider", "DmlExecutionProvider", "OpenVINOExecutionProvider", 'CPUExecutionProvider']
97
+ # available_providers = onnxruntime.get_available_providers()
98
+ # onnx_providers = [item for item in onnx_providers if item in available_providers]
99
+ # except:
100
+ # onnx_available = False
101
+ # onnx_providers = []
102
+
103
+
104
+ # try:
105
+ # cuda_available = torch.cuda.is_available()
106
+ # except:
107
+ # cuda_available = False
108
+ # finally:
109
+ # if sys.platform == "darwin":
110
+ # device = "mps" if torch.backends.mps.is_available() else "cpu"
111
+ # elif cuda_available:
112
+ # device = "cuda"
113
+ # else:
114
+ # device = "cpu"
115
+
116
+ # if device != "cuda":
117
+ # import contextlib
118
+
119
+ # autocast = contextlib.nullcontext
120
+
121
+ # with open("config.yaml", "r") as yaml_in:
122
+ # yaml_object = yaml.safe_load(yaml_in)
123
+ # config_json = json.dumps(yaml_object)
124
+
125
+ # ******** ^ ORIGINAL ^ ***********
126
 
127
  try:
128
  cuda_available = torch.cuda.is_available()
 
136
  else:
137
  device = "cpu"
138
 
139
+ import contextlib
 
140
 
141
+ autocast = contextlib.nullcontext
142
 
143
  with open("config.yaml", "r") as yaml_in:
144
  yaml_object = yaml.safe_load(yaml_in)
145
  config_json = json.dumps(yaml_object)
 
146
 
147
 
148
+ # new ^
149
+
150
  def load_html():
151
  body, canvaspy = "", ""
152
  with open("index.html", encoding="utf8") as f:
 
343
  vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
344
  # if device == "cuda" and not args.fp32:
345
  # vae.to(torch.float16)
346
+ vae.to(torch.float16)
347
  if original_checkpoint:
348
  print(f"Converting & Loading {model_path}")
349
  from convert_checkpoint import convert_checkpoint
350
 
351
  pipe = convert_checkpoint(model_path, inpainting=True)
352
+ if device == "cuda":
353
  pipe.to(torch.float16)
354
  inpaint = StableDiffusionInpaintPipeline(
355
  vae=vae,
 
362
  )
363
  else:
364
  print(f"Loading {model_name}")
365
+ if device == "cuda":
366
  inpaint = StableDiffusionInpaintPipeline.from_pretrained(
367
  model_name,
368
  revision="fp16",
 
505
  )["images"]
506
  return images
507
 
 
508
  class StableDiffusion:
509
  def __init__(
510
  self,
 
516
  ):
517
  self.token = token
518
  original_checkpoint = False
519
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
520
+ vae.to(torch.float16)
521
+ if model_path and os.path.exists(model_path):
522
+ if model_path.endswith(".ckpt"):
523
+ original_checkpoint = True
524
+ elif model_path.endswith(".json"):
525
+ model_name = os.path.dirname(model_path)
526
+ else:
527
+ model_name = model_path
528
+ if original_checkpoint:
529
+ print(f"Converting & Loading {model_path}")
530
+ from convert_checkpoint import convert_checkpoint
531
+
532
+ text2img = convert_checkpoint(model_path)
533
+ if device == "cuda" and not args.fp32:
534
+ text2img.to(torch.float16)
 
 
 
 
 
 
 
 
 
 
 
535
  else:
536
+ print(f"Loading {model_name}")
 
 
 
 
 
 
 
537
  if device == "cuda" and not args.fp32:
538
+ text2img = StableDiffusionPipeline.from_pretrained(
539
+ "runwayml/stable-diffusion-v1-5",
540
+ revision="fp16",
541
+ torch_dtype=torch.float16,
542
+ use_auth_token=token,
543
+ vae=vae
 
 
 
 
 
 
 
 
 
 
544
  )
545
  else:
546
+ text2img = StableDiffusionPipeline.from_pretrained(
547
+ model_name, use_auth_token=token,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
548
  )
549
+ if inpainting_model:
550
+ # can reduce vRAM by reusing models except unet
551
+ text2img_unet = text2img.unet
552
+ del text2img.vae
553
+ del text2img.text_encoder
554
+ del text2img.tokenizer
555
+ del text2img.scheduler
556
+ del text2img.safety_checker
557
+ del text2img.feature_extractor
558
+ import gc
559
+
560
+ gc.collect()
561
+ if device == "cuda":
562
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
563
+ "runwayml/stable-diffusion-inpainting",
564
+ revision="fp16",
565
+ torch_dtype=torch.float16,
566
+ use_auth_token=token,
567
+ vae=vae
568
+ ).to(device)
569
  else:
570
+ inpaint = StableDiffusionInpaintPipeline.from_pretrained(
571
+ "runwayml/stable-diffusion-inpainting", use_auth_token=token,
 
 
 
 
 
 
572
  ).to(device)
573
+ text2img_unet.to(device)
574
+ del text2img
575
+ gc.collect()
576
+ text2img = StableDiffusionPipeline(
577
+ vae=inpaint.vae,
578
+ text_encoder=inpaint.text_encoder,
579
+ tokenizer=inpaint.tokenizer,
580
+ unet=text2img_unet,
581
+ scheduler=inpaint.scheduler,
582
+ safety_checker=inpaint.safety_checker,
583
+ feature_extractor=inpaint.feature_extractor,
584
+ )
585
+ else:
586
+ inpaint = StableDiffusionInpaintPipelineLegacy(
587
  vae=text2img.vae,
588
  text_encoder=text2img.text_encoder,
589
  tokenizer=text2img.tokenizer,
 
592
  safety_checker=text2img.safety_checker,
593
  feature_extractor=text2img.feature_extractor,
594
  ).to(device)
595
+ text_encoder = text2img.text_encoder
596
+ tokenizer = text2img.tokenizer
597
+ if os.path.exists("./embeddings"):
598
+ for item in os.listdir("./embeddings"):
599
+ if item.endswith(".bin"):
600
+ load_learned_embed_in_clip(
601
+ os.path.join("./embeddings", item),
602
+ text2img.text_encoder,
603
+ text2img.tokenizer,
604
+ )
605
+ text2img.to(device)
606
+ if device == "mps":
607
+ _ = text2img("", num_inference_steps=1)
608
  scheduler_dict["PLMS"] = text2img.scheduler
609
  scheduler_dict["DDIM"] = prepare_scheduler(
610
  DDIMScheduler(
 
620
  beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
621
  )
622
  )
 
 
 
 
 
 
623
  scheduler_dict["DPM"] = prepare_scheduler(
624
  DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
625
  )
626
  self.safety_checker = text2img.safety_checker
627
+ img2img = StableDiffusionImg2ImgPipeline(
628
+ vae=text2img.vae,
629
+ text_encoder=text2img.text_encoder,
630
+ tokenizer=text2img.tokenizer,
631
+ unet=text2img.unet,
632
+ scheduler=text2img.scheduler,
633
+ safety_checker=text2img.safety_checker,
634
+ feature_extractor=text2img.feature_extractor,
635
+ ).to(device)
636
  save_token(token)
637
  try:
638
  total_memory = torch.cuda.get_device_properties(0).total_memory // (
639
  1024 ** 3
640
  )
641
+ if total_memory <= 5:
642
  inpaint.enable_attention_slicing()
 
 
 
 
643
  except:
644
  pass
645
  self.text2img = text2img
646
  self.inpaint = inpaint
647
  self.img2img = img2img
648
+ self.unified = UnifiedPipeline(
649
+ vae=text2img.vae,
650
+ text_encoder=text2img.text_encoder,
651
+ tokenizer=text2img.tokenizer,
652
+ unet=text2img.unet,
653
+ scheduler=text2img.scheduler,
654
+ safety_checker=text2img.safety_checker,
655
+ feature_extractor=text2img.feature_extractor,
656
+ ).to(device)
 
 
 
657
  self.inpainting_model = inpainting_model
658
 
659
  def run(
 
684
  selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
685
  for item in [text2img, inpaint, img2img, unified]:
686
  item.scheduler = selected_scheduler
687
+ if enable_safety:
688
  item.safety_checker = self.safety_checker
689
  else:
690
  item.safety_checker = lambda images, **kwargs: (images, False)
 
720
  if True:
721
  images = img2img(
722
  prompt=prompt,
723
+ init_image=init_image.resize(
724
  (process_width, process_height), resample=SAMPLING_MODE
725
  ),
726
  strength=strength,
 
730
  if fill_mode == "g_diffuser" and not self.inpainting_model:
731
  mask = 255 - mask
732
  mask = mask[:, :, np.newaxis].repeat(3, axis=2)
733
+ img, mask, out_mask = functbl[fill_mode](img, mask)
734
  extra_kwargs["strength"] = 1.0
735
+ extra_kwargs["out_mask"] = Image.fromarray(out_mask)
736
  inpaint_func = unified
737
  else:
738
  img, mask = functbl[fill_mode](img, mask)
739
  mask = 255 - mask
740
  mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
741
  mask = mask.repeat(8, axis=0).repeat(8, axis=1)
742
+ extra_kwargs["strength"] = strength
743
  inpaint_func = inpaint
744
  init_image = Image.fromarray(img)
745
  mask_image = Image.fromarray(mask)
746
  # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
747
+ if True:
748
+ input_image = init_image.resize(
749
+ (process_width, process_height), resample=SAMPLING_MODE
750
+ )
751
  images = inpaint_func(
752
  prompt=prompt,
753
+ init_image=input_image,
754
  image=input_image,
755
  width=process_width,
756
  height=process_height,
757
  mask_image=mask_image.resize((process_width, process_height)),
758
  **extra_kwargs,
759
  )["images"]
 
 
 
 
 
 
 
 
 
760
  else:
761
  if True:
762
  images = text2img(
 
768
  return images
769
 
770
 
771
+ # class StableDiffusion:
772
+ # def __init__(
773
+ # self,
774
+ # token: str = "",
775
+ # model_name: str = "runwayml/stable-diffusion-v1-5",
776
+ # model_path: str = None,
777
+ # inpainting_model: bool = False,
778
+ # **kwargs,
779
+ # ):
780
+ # self.token = token
781
+ # original_checkpoint = False
782
+ # if device=="cpu" and onnx_available:
783
+ # from diffusers import OnnxStableDiffusionPipeline, OnnxStableDiffusionInpaintPipelineLegacy, OnnxStableDiffusionImg2ImgPipeline
784
+ # text2img = OnnxStableDiffusionPipeline.from_pretrained(
785
+ # model_name,
786
+ # revision="onnx",
787
+ # provider=onnx_providers[0] if onnx_providers else None
788
+ # )
789
+ # inpaint = OnnxStableDiffusionInpaintPipelineLegacy(
790
+ # vae_encoder=text2img.vae_encoder,
791
+ # vae_decoder=text2img.vae_decoder,
792
+ # text_encoder=text2img.text_encoder,
793
+ # tokenizer=text2img.tokenizer,
794
+ # unet=text2img.unet,
795
+ # scheduler=text2img.scheduler,
796
+ # safety_checker=text2img.safety_checker,
797
+ # feature_extractor=text2img.feature_extractor,
798
+ # )
799
+ # img2img = OnnxStableDiffusionImg2ImgPipeline(
800
+ # vae_encoder=text2img.vae_encoder,
801
+ # vae_decoder=text2img.vae_decoder,
802
+ # text_encoder=text2img.text_encoder,
803
+ # tokenizer=text2img.tokenizer,
804
+ # unet=text2img.unet,
805
+ # scheduler=text2img.scheduler,
806
+ # safety_checker=text2img.safety_checker,
807
+ # feature_extractor=text2img.feature_extractor,
808
+ # )
809
+ # else:
810
+ # if model_path and os.path.exists(model_path):
811
+ # if model_path.endswith(".ckpt"):
812
+ # original_checkpoint = True
813
+ # elif model_path.endswith(".json"):
814
+ # model_name = os.path.dirname(model_path)
815
+ # else:
816
+ # model_name = model_path
817
+ # vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse")
818
+ # if device == "cuda" and not args.fp32:
819
+ # vae.to(torch.float16)
820
+ # if original_checkpoint:
821
+ # print(f"Converting & Loading {model_path}")
822
+ # from convert_checkpoint import convert_checkpoint
823
+
824
+ # pipe = convert_checkpoint(model_path)
825
+ # if device == "cuda" and not args.fp32:
826
+ # pipe.to(torch.float16)
827
+ # text2img = StableDiffusionPipeline(
828
+ # vae=vae,
829
+ # text_encoder=pipe.text_encoder,
830
+ # tokenizer=pipe.tokenizer,
831
+ # unet=pipe.unet,
832
+ # scheduler=pipe.scheduler,
833
+ # safety_checker=pipe.safety_checker,
834
+ # feature_extractor=pipe.feature_extractor,
835
+ # )
836
+ # else:
837
+ # print(f"Loading {model_name}")
838
+ # if device == "cuda" and not args.fp32:
839
+ # text2img = StableDiffusionPipeline.from_pretrained(
840
+ # model_name,
841
+ # revision="fp16",
842
+ # torch_dtype=torch.float16,
843
+ # use_auth_token=token,
844
+ # vae=vae,
845
+ # )
846
+ # else:
847
+ # text2img = StableDiffusionPipeline.from_pretrained(
848
+ # model_name, use_auth_token=token, vae=vae
849
+ # )
850
+ # if inpainting_model:
851
+ # # can reduce vRAM by reusing models except unet
852
+ # text2img_unet = text2img.unet
853
+ # del text2img.vae
854
+ # del text2img.text_encoder
855
+ # del text2img.tokenizer
856
+ # del text2img.scheduler
857
+ # del text2img.safety_checker
858
+ # del text2img.feature_extractor
859
+ # import gc
860
+
861
+ # gc.collect()
862
+ # if device == "cuda" and not args.fp32:
863
+ # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
864
+ # "runwayml/stable-diffusion-inpainting",
865
+ # revision="fp16",
866
+ # torch_dtype=torch.float16,
867
+ # use_auth_token=token,
868
+ # vae=vae,
869
+ # ).to(device)
870
+ # else:
871
+ # inpaint = StableDiffusionInpaintPipeline.from_pretrained(
872
+ # "runwayml/stable-diffusion-inpainting",
873
+ # use_auth_token=token,
874
+ # vae=vae,
875
+ # ).to(device)
876
+ # text2img_unet.to(device)
877
+ # text2img = StableDiffusionPipeline(
878
+ # vae=inpaint.vae,
879
+ # text_encoder=inpaint.text_encoder,
880
+ # tokenizer=inpaint.tokenizer,
881
+ # unet=text2img_unet,
882
+ # scheduler=inpaint.scheduler,
883
+ # safety_checker=inpaint.safety_checker,
884
+ # feature_extractor=inpaint.feature_extractor,
885
+ # )
886
+ # else:
887
+ # inpaint = StableDiffusionInpaintPipelineLegacy(
888
+ # vae=text2img.vae,
889
+ # text_encoder=text2img.text_encoder,
890
+ # tokenizer=text2img.tokenizer,
891
+ # unet=text2img.unet,
892
+ # scheduler=text2img.scheduler,
893
+ # safety_checker=text2img.safety_checker,
894
+ # feature_extractor=text2img.feature_extractor,
895
+ # ).to(device)
896
+ # text_encoder = text2img.text_encoder
897
+ # tokenizer = text2img.tokenizer
898
+ # if os.path.exists("./embeddings"):
899
+ # for item in os.listdir("./embeddings"):
900
+ # if item.endswith(".bin"):
901
+ # load_learned_embed_in_clip(
902
+ # os.path.join("./embeddings", item),
903
+ # text2img.text_encoder,
904
+ # text2img.tokenizer,
905
+ # )
906
+ # text2img.to(device)
907
+ # if device == "mps":
908
+ # _ = text2img("", num_inference_steps=1)
909
+ # img2img = StableDiffusionImg2ImgPipeline(
910
+ # vae=text2img.vae,
911
+ # text_encoder=text2img.text_encoder,
912
+ # tokenizer=text2img.tokenizer,
913
+ # unet=text2img.unet,
914
+ # scheduler=text2img.scheduler,
915
+ # safety_checker=text2img.safety_checker,
916
+ # feature_extractor=text2img.feature_extractor,
917
+ # ).to(device)
918
+ # scheduler_dict["PLMS"] = text2img.scheduler
919
+ # scheduler_dict["DDIM"] = prepare_scheduler(
920
+ # DDIMScheduler(
921
+ # beta_start=0.00085,
922
+ # beta_end=0.012,
923
+ # beta_schedule="scaled_linear",
924
+ # clip_sample=False,
925
+ # set_alpha_to_one=False,
926
+ # )
927
+ # )
928
+ # scheduler_dict["K-LMS"] = prepare_scheduler(
929
+ # LMSDiscreteScheduler(
930
+ # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
931
+ # )
932
+ # )
933
+ # scheduler_dict["PNDM"] = prepare_scheduler(
934
+ # PNDMScheduler(
935
+ # beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
936
+ # skip_prk_steps=True
937
+ # )
938
+ # )
939
+ # scheduler_dict["DPM"] = prepare_scheduler(
940
+ # DPMSolverMultistepScheduler.from_config(text2img.scheduler.config)
941
+ # )
942
+ # self.safety_checker = text2img.safety_checker
943
+ # save_token(token)
944
+ # try:
945
+ # total_memory = torch.cuda.get_device_properties(0).total_memory // (
946
+ # 1024 ** 3
947
+ # )
948
+ # if total_memory <= 5 or args.lowvram:
949
+ # inpaint.enable_attention_slicing()
950
+ # inpaint.enable_sequential_cpu_offload()
951
+ # if inpainting_model:
952
+ # text2img.enable_attention_slicing()
953
+ # text2img.enable_sequential_cpu_offload()
954
+ # except:
955
+ # pass
956
+ # self.text2img = text2img
957
+ # self.inpaint = inpaint
958
+ # self.img2img = img2img
959
+ # if True:
960
+ # self.unified = inpaint
961
+ # else:
962
+ # self.unified = UnifiedPipeline(
963
+ # vae=text2img.vae,
964
+ # text_encoder=text2img.text_encoder,
965
+ # tokenizer=text2img.tokenizer,
966
+ # unet=text2img.unet,
967
+ # scheduler=text2img.scheduler,
968
+ # safety_checker=text2img.safety_checker,
969
+ # feature_extractor=text2img.feature_extractor,
970
+ # ).to(device)
971
+ # self.inpainting_model = inpainting_model
972
+
973
+ # def run(
974
+ # self,
975
+ # image_pil,
976
+ # prompt="",
977
+ # negative_prompt="",
978
+ # guidance_scale=7.5,
979
+ # resize_check=True,
980
+ # enable_safety=True,
981
+ # fill_mode="patchmatch",
982
+ # strength=0.75,
983
+ # step=50,
984
+ # enable_img2img=False,
985
+ # use_seed=False,
986
+ # seed_val=-1,
987
+ # generate_num=1,
988
+ # scheduler="",
989
+ # scheduler_eta=0.0,
990
+ # **kwargs,
991
+ # ):
992
+ # text2img, inpaint, img2img, unified = (
993
+ # self.text2img,
994
+ # self.inpaint,
995
+ # self.img2img,
996
+ # self.unified,
997
+ # )
998
+ # selected_scheduler = scheduler_dict.get(scheduler, scheduler_dict["PLMS"])
999
+ # for item in [text2img, inpaint, img2img, unified]:
1000
+ # item.scheduler = selected_scheduler
1001
+ # if enable_safety or self.safety_checker is None:
1002
+ # item.safety_checker = self.safety_checker
1003
+ # else:
1004
+ # item.safety_checker = lambda images, **kwargs: (images, False)
1005
+ # if RUN_IN_SPACE:
1006
+ # step = max(150, step)
1007
+ # image_pil = contain_func(image_pil, (1024, 1024))
1008
+ # width, height = image_pil.size
1009
+ # sel_buffer = np.array(image_pil)
1010
+ # img = sel_buffer[:, :, 0:3]
1011
+ # mask = sel_buffer[:, :, -1]
1012
+ # nmask = 255 - mask
1013
+ # process_width = width
1014
+ # process_height = height
1015
+ # if resize_check:
1016
+ # process_width, process_height = my_resize(width, height)
1017
+ # extra_kwargs = {
1018
+ # "num_inference_steps": step,
1019
+ # "guidance_scale": guidance_scale,
1020
+ # "eta": scheduler_eta,
1021
+ # }
1022
+ # if RUN_IN_SPACE:
1023
+ # generate_num = max(
1024
+ # int(4 * 512 * 512 // process_width // process_height), generate_num
1025
+ # )
1026
+ # if USE_NEW_DIFFUSERS:
1027
+ # extra_kwargs["negative_prompt"] = negative_prompt
1028
+ # extra_kwargs["num_images_per_prompt"] = generate_num
1029
+ # if use_seed:
1030
+ # generator = torch.Generator(text2img.device).manual_seed(seed_val)
1031
+ # extra_kwargs["generator"] = generator
1032
+ # if nmask.sum() < 1 and enable_img2img:
1033
+ # init_image = Image.fromarray(img)
1034
+ # if True:
1035
+ # images = img2img(
1036
+ # prompt=prompt,
1037
+ # image=init_image.resize(
1038
+ # (process_width, process_height), resample=SAMPLING_MODE
1039
+ # ),
1040
+ # strength=strength,
1041
+ # **extra_kwargs,
1042
+ # )["images"]
1043
+ # elif mask.sum() > 0:
1044
+ # if fill_mode == "g_diffuser" and not self.inpainting_model:
1045
+ # mask = 255 - mask
1046
+ # mask = mask[:, :, np.newaxis].repeat(3, axis=2)
1047
+ # img, mask = functbl[fill_mode](img, mask)
1048
+ # extra_kwargs["strength"] = 1.0
1049
+ # extra_kwargs["out_mask"] = Image.fromarray(mask)
1050
+ # inpaint_func = unified
1051
+ # else:
1052
+ # img, mask = functbl[fill_mode](img, mask)
1053
+ # mask = 255 - mask
1054
+ # mask = skimage.measure.block_reduce(mask, (8, 8), np.max)
1055
+ # mask = mask.repeat(8, axis=0).repeat(8, axis=1)
1056
+ # inpaint_func = inpaint
1057
+ # init_image = Image.fromarray(img)
1058
+ # mask_image = Image.fromarray(mask)
1059
+ # # mask_image=mask_image.filter(ImageFilter.GaussianBlur(radius = 8))
1060
+ # input_image = init_image.resize(
1061
+ # (process_width, process_height), resample=SAMPLING_MODE
1062
+ # )
1063
+ # if self.inpainting_model:
1064
+ # images = inpaint_func(
1065
+ # prompt=prompt,
1066
+ # image=input_image,
1067
+ # width=process_width,
1068
+ # height=process_height,
1069
+ # mask_image=mask_image.resize((process_width, process_height)),
1070
+ # **extra_kwargs,
1071
+ # )["images"]
1072
+ # else:
1073
+ # extra_kwargs["strength"] = strength
1074
+ # if True:
1075
+ # images = inpaint_func(
1076
+ # prompt=prompt,
1077
+ # image=input_image,
1078
+ # mask_image=mask_image.resize((process_width, process_height)),
1079
+ # **extra_kwargs,
1080
+ # )["images"]
1081
+ # else:
1082
+ # if True:
1083
+ # images = text2img(
1084
+ # prompt=prompt,
1085
+ # height=process_width,
1086
+ # width=process_height,
1087
+ # **extra_kwargs,
1088
+ # )["images"]
1089
+ # return images
1090
+
1091
+
1092
  def get_model(token="", model_choice="", model_path=""):
1093
  if "model" not in model:
1094
  model_name = ""