TorchRik commited on
Commit
0ba4f67
·
verified ·
1 Parent(s): 9ca9730

Upload combined_stable_diffusion.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. combined_stable_diffusion.py +13 -6
combined_stable_diffusion.py CHANGED
@@ -2,11 +2,11 @@ import math
2
  import random
3
 
4
  import torch
5
- from diffusers import DiffusionPipeline
6
  from diffusers.image_processor import VaeImageProcessor
7
  from huggingface_hub import PyTorchModelHubMixin
8
  from PIL import Image
9
-
10
 
11
  class CombinedStableDiffusion(
12
  DiffusionPipeline,
@@ -20,21 +20,28 @@ class CombinedStableDiffusion(
20
  self,
21
  original_unet: torch.nn.Module,
22
  fine_tuned_unet: torch.nn.Module,
23
- scheduler,
24
  vae: torch.nn.Module,
25
- tokenizer=None,
26
- text_encoder=None,
27
  ) -> None:
28
 
29
  super().__init__()
30
 
 
 
 
 
 
 
 
31
  self.register_modules(
32
  tokenizer=tokenizer,
 
33
  original_unet=original_unet,
34
  fine_tuned_unet=fine_tuned_unet,
35
  scheduler=scheduler,
36
  vae=vae,
37
- text_encoder=text_encoder,
38
  )
39
 
40
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
 
2
  import random
3
 
4
  import torch
5
+ from diffusers import DiffusionPipeline, DDPMScheduler
6
  from diffusers.image_processor import VaeImageProcessor
7
  from huggingface_hub import PyTorchModelHubMixin
8
  from PIL import Image
9
+ from transformers import CLIPTextModel, CLIPTokenizer
10
 
11
  class CombinedStableDiffusion(
12
  DiffusionPipeline,
 
20
  self,
21
  original_unet: torch.nn.Module,
22
  fine_tuned_unet: torch.nn.Module,
23
+ scheduler: DDPMScheduler,
24
  vae: torch.nn.Module,
25
+ tokenizer: CLIPTextModel,
26
+ text_encoder: CLIPTokenizer,
27
  ) -> None:
28
 
29
  super().__init__()
30
 
31
+ # print(tokenizer)
32
+ # print(text_encoder)
33
+ # print(original_unet)
34
+ # print(fine_tuned_unet)
35
+ # print(scheduler)
36
+ # print(vae)
37
+ #
38
  self.register_modules(
39
  tokenizer=tokenizer,
40
+ text_encoder=text_encoder,
41
  original_unet=original_unet,
42
  fine_tuned_unet=fine_tuned_unet,
43
  scheduler=scheduler,
44
  vae=vae,
 
45
  )
46
 
47
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)