Duskfallcrew commited on
Commit
487dcba
·
verified ·
1 Parent(s): e20a6db

Update app.py

Browse files

GEMINI IS A GOD IT KNOWS HOW TO LOAD A KAGGLE INTERFACE (Lulz.)

Files changed (1) hide show
  1. app.py +19 -11
app.py CHANGED
@@ -14,6 +14,7 @@ from huggingface_hub.utils import validate_repo_id, HFValidationError
14
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
15
  from huggingface_hub.utils import HfHubHTTPError
16
 
 
17
  # ---------------------- UTILITY FUNCTIONS ----------------------
18
  # (download_model, create_model_repo, etc. - All unchanged, but included for completeness)
19
 
@@ -97,13 +98,13 @@ def load_sdxl_checkpoint(checkpoint_path):
97
  unet_state = OrderedDict()
98
 
99
  for key, value in state_dict.items():
100
- if key.startswith("first_stage_model."):
101
  vae_state[key.replace("first_stage_model.", "")] = value.to(torch.float16)
102
- elif key.startswith("condition_model.model.text_encoder."):
103
  text_encoder1_state[key.replace("condition_model.model.text_encoder.", "")] = value.to(torch.float16)
104
- elif key.startswith("condition_model.model.text_encoder_2."):
105
  text_encoder2_state[key.replace("condition_model.model.text_encoder_2.", "")] = value.to(torch.float16)
106
- elif key.startswith("model.diffusion_model."):
107
  unet_state[key.replace("model.diffusion_model.", "")] = value.to(torch.float16)
108
 
109
  return text_encoder1_state, text_encoder2_state, vae_state, unet_state
@@ -115,16 +116,22 @@ def build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, u
115
  if not reference_model_path:
116
  reference_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
117
 
118
- config_text_encoder1 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder")
119
- config_text_encoder2 = CLIPTextConfig.from_pretrained(reference_model_path, subfolder="text_encoder_2")
120
- config_vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae").config
121
- config_unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet").config
 
 
 
 
 
 
122
 
 
123
  text_encoder1 = CLIPTextModel(config_text_encoder1)
124
- text_encoder2 = CLIPTextModelWithProjection(config_text_encoder2) # Correct class
125
- vae = AutoencoderKL(config=config_vae)
126
- unet = UNet2DConditionModel(config=config_unet)
127
 
 
128
  text_encoder1.load_state_dict(text_encoder1_state, strict=False)
129
  text_encoder2.load_state_dict(text_encoder2_state, strict=False)
130
  vae.load_state_dict(vae_state, strict=False)
@@ -135,6 +142,7 @@ def build_diffusers_model(text_encoder1_state, text_encoder2_state, vae_state, u
135
  vae.to(torch.float16).to("cpu")
136
  unet.to(torch.float16).to("cpu")
137
 
 
138
  return text_encoder1, text_encoder2, vae, unet
139
 
140
  def convert_and_save_sdxl_to_diffusers(checkpoint_path_or_url, output_path, reference_model_path):
 
14
  from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
15
  from huggingface_hub.utils import HfHubHTTPError
16
 
17
+
18
  # ---------------------- UTILITY FUNCTIONS ----------------------
19
  # (download_model, create_model_repo, etc. - All unchanged, but included for completeness)
20
 
 
98
  unet_state = OrderedDict()
99
 
100
  for key, value in state_dict.items():
101
+ if key.startswith("first_stage_model."): # VAE
102
  vae_state[key.replace("first_stage_model.", "")] = value.to(torch.float16)
103
+ elif key.startswith("condition_model.model.text_encoder."): # First Text Encoder
104
  text_encoder1_state[key.replace("condition_model.model.text_encoder.", "")] = value.to(torch.float16)
105
+ elif key.startswith("condition_model.model.text_encoder_2."): # Second Text Encoder
106
  text_encoder2_state[key.replace("condition_model.model.text_encoder_2.", "")] = value.to(torch.float16)
107
+ elif key.startswith("model.diffusion_model."): # UNet
108
  unet_state[key.replace("model.diffusion_model.", "")] = value.to(torch.float16)
109
 
110
  return text_encoder1_state, text_encoder2_state, vae_state, unet_state
 
116
  if not reference_model_path:
117
  reference_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
118
 
119
+ # Load configurations from the reference model
120
+ config_text_encoder1 = CLIPTextConfig.from_pretrained(
121
+ reference_model_path, subfolder="text_encoder"
122
+ )
123
+ config_text_encoder2 = CLIPTextConfig.from_pretrained(
124
+ reference_model_path, subfolder="text_encoder_2"
125
+ )
126
+ # Use from_pretrained with subfolder for VAE and UNet
127
+ vae = AutoencoderKL.from_pretrained(reference_model_path, subfolder="vae") # Corrected
128
+ unet = UNet2DConditionModel.from_pretrained(reference_model_path, subfolder="unet") # Corrected
129
 
130
+ # Create instances using the configurations
131
  text_encoder1 = CLIPTextModel(config_text_encoder1)
132
+ text_encoder2 = CLIPTextModelWithProjection(config_text_encoder2) # Use CLIPTextModelWithProjection
 
 
133
 
134
+ # Load state dicts with strict=False
135
  text_encoder1.load_state_dict(text_encoder1_state, strict=False)
136
  text_encoder2.load_state_dict(text_encoder2_state, strict=False)
137
  vae.load_state_dict(vae_state, strict=False)
 
142
  vae.to(torch.float16).to("cpu")
143
  unet.to(torch.float16).to("cpu")
144
 
145
+
146
  return text_encoder1, text_encoder2, vae, unet
147
 
148
  def convert_and_save_sdxl_to_diffusers(checkpoint_path_or_url, output_path, reference_model_path):