Spaces:
Running on Zero
Running on Zero
| """SD1.5 model implementation.""" | |
| import torch | |
| from src.Utilities import util | |
| from src.Model import ModelBase | |
| from src.SD15 import SDClip, SDToken, SDXL | |
| from src.Utilities import Latent | |
| from src.clip import Clip | |
| class sm_SD15(ModelBase.BASE): | |
| """SD1.5 model class.""" | |
| unet_config: dict = { | |
| "context_dim": 768, "model_channels": 320, "use_linear_in_transformer": False, | |
| "adm_in_channels": None, "use_temporal_attention": False, | |
| } | |
| unet_extra_config: dict = {"num_heads": 8, "num_head_channels": -1} | |
| latent_format: Latent.SD15 = Latent.SD15 | |
| def process_clip_state_dict(self, state_dict: dict) -> dict: | |
| """Process CLIP state dict for SD1.5.""" | |
| k = list(state_dict.keys()) | |
| for x in k: | |
| if x.startswith("cond_stage_model.transformer.") and not x.startswith("cond_stage_model.transformer.text_model."): | |
| state_dict[x.replace("cond_stage_model.transformer.", "cond_stage_model.transformer.text_model.")] = state_dict.pop(x) | |
| pos_key = "cond_stage_model.transformer.text_model.embeddings.position_ids" | |
| if pos_key in state_dict and state_dict[pos_key].dtype == torch.float32: | |
| state_dict[pos_key] = state_dict[pos_key].round() | |
| return util.state_dict_prefix_replace(state_dict, {"cond_stage_model.": "clip_l."}, filter_keys=True) | |
| def clip_target(self) -> Clip.ClipTarget: | |
| """Get CLIP target for SD1.5.""" | |
| return Clip.ClipTarget(SDToken.SD1Tokenizer, SDClip.SD1ClipModel) | |
| models = [sm_SD15, SDXL.SDXLRefiner, SDXL.SDXL, SDXL.SSD1B, SDXL.Segmind_Vega, SDXL.KOALA_700M, SDXL.KOALA_1B] | |