Spaces:
Running on Zero
Running on Zero
File size: 1,623 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 | """SD1.5 model implementation."""
import torch
from src.Utilities import util
from src.Model import ModelBase
from src.SD15 import SDClip, SDToken, SDXL
from src.Utilities import Latent
from src.clip import Clip
class sm_SD15(ModelBase.BASE):
"""SD1.5 model class."""
unet_config: dict = {
"context_dim": 768, "model_channels": 320, "use_linear_in_transformer": False,
"adm_in_channels": None, "use_temporal_attention": False,
}
unet_extra_config: dict = {"num_heads": 8, "num_head_channels": -1}
latent_format: Latent.SD15 = Latent.SD15
def process_clip_state_dict(self, state_dict: dict) -> dict:
"""Process CLIP state dict for SD1.5."""
k = list(state_dict.keys())
for x in k:
if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."):
state_dict[x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")] = state_dict.pop(x)
pos_key = "cond_stage_model.transformer.text_model.embeddings.position_ids"
if pos_key in state_dict and state_dict[pos_key].dtype == torch.float32:
state_dict[pos_key] = state_dict[pos_key].round()
return util.state_dict_prefix_replace(state_dict, {"cond_stage_model.": "clip_l."}, filter_keys=True)
def clip_target(self) -> Clip.ClipTarget:
"""Get CLIP target for SD1.5."""
return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel)
models = [sm_SD15, SDXL.SDXLRefiner, SDXL.SDXL, SDXL.SSD1B, SDXL.Segmind_Vega, SDXL.KOALA_700M, SDXL.KOALA_1B]
|