| import gradio |
| import torch |
|
|
| import modules.devices as devices |
| import modules.scripts as scripts |
| import modules.script_callbacks as script_callbacks |
| import modules.sd_unet as sd_unet |
| import modules.shared as shared |
|
|
| from ldm.modules.diffusionmodules.util import timestep_embedding as timestep_embedding |
|
|
| class DeepShrinkHiresFixAction(): |
| def __init__(self, enable: bool, timestep: float, depth: int, scale: float): |
| self.enable = enable |
| self.timestep = timestep |
| self.depth = depth |
| self.scale = scale |
| pass |
| pass |
|
|
| class DeepShrinkHiresFix(scripts.Script): |
| deepShrinkHiresFixActions: list[DeepShrinkHiresFixAction] = [] |
| enableExperimental: bool = False |
| experimentalTimestep: float = 900 |
| experimentalScales: list[float] = [] |
|
|
| def __init__(self): |
| pass |
|
|
| def title(self): |
| return "Deep Shrink Hires.fix" |
| pass |
|
|
| def show(self, is_img2img): |
| return scripts.AlwaysVisible |
| pass |
|
|
| def ui(self, is_img2img): |
| with gradio.Accordion(label="Deep Shrink Hires.fix", open=False): |
| with gradio.Row(): |
| Enable_1 = gradio.Checkbox(value=True, label="Enable 1") |
| Timestep_1 = gradio.Number(value=900, label="Timestep 1") |
| Depth_1 = gradio.Number(value=3, label="Block Depth 1", precision=0) |
| Scale_1 = gradio.Number(value=2, label="Scale factor 1") |
| pass |
| with gradio.Row(): |
| Enable_2 = gradio.Checkbox(value=True, label="Enable 2") |
| Timestep_2 = gradio.Number(value=650, label="Timestep 2") |
| Depth_2 = gradio.Number(value=3, label="Block Depth 2", precision=0) |
| Scale_2 = gradio.Number(value=2, label="Scale factor 2") |
| pass |
| with gradio.Accordion(label="Advanced Settings", open=False): |
| with gradio.Row(): |
| Enable_3 = gradio.Checkbox(value=False, label="Enable 3") |
| Timestep_3 = gradio.Number(value=900, label="Timestep 3") |
| Depth_3 = gradio.Number(value=3, label="Block Depth 3", precision=0) |
| Scale_3 = gradio.Number(value=2, label="Scale factor 3") |
| pass |
| with gradio.Row(): |
| Enable_4 = gradio.Checkbox(value=False, label="Enable 4") |
| Timestep_4 = gradio.Number(value=650, label="Timestep 4") |
| Depth_4 = gradio.Number(value=3, label="Block Depth 4", precision=0) |
| Scale_4 = gradio.Number(value=2, label="Scale factor 4") |
| pass |
| with gradio.Row(): |
| Enable_5 = gradio.Checkbox(value=False, label="Enable 5") |
| Timestep_5 = gradio.Number(value=900, label="Timestep 5") |
| Depth_5 = gradio.Number(value=3, label="Block Depth 5", precision=0) |
| Scale_5 = gradio.Number(value=2, label="Scale factor 5") |
| pass |
| with gradio.Row(): |
| Enable_6 = gradio.Checkbox(value=False, label="Enable 6") |
| Timestep_6 = gradio.Number(value=650, label="Timestep 6") |
| Depth_6 = gradio.Number(value=3, label="Block Depth 6", precision=0) |
| Scale_6 = gradio.Number(value=2, label="Scale factor 6") |
| pass |
| with gradio.Row(): |
| Enable_7 = gradio.Checkbox(value=False, label="Enable 7") |
| Timestep_7 = gradio.Number(value=900, label="Timestep 7") |
| Depth_7 = gradio.Number(value=3, label="Block Depth 7", precision=0) |
| Scale_7 = gradio.Number(value=2, label="Scale factor 7") |
| pass |
| with gradio.Row(): |
| Enable_8 = gradio.Checkbox(value=False, label="Enable 8") |
| Timestep_8 = gradio.Number(value=650, label="Timestep 8") |
| Depth_8 = gradio.Number(value=3, label="Block Depth 8", precision=0) |
| Scale_8 = gradio.Number(value=2, label="Scale factor 8") |
| pass |
| pass |
| with gradio.Accordion(label="Experimental Settings", open=False): |
| with gradio.Row(): |
| Enable_Experimental = gradio.Checkbox(value=False, label="Enable Experimental Mode") |
| Timestep_Experimental = gradio.Number(value=900, label="Timestep") |
| Scale_Experimental = gradio.Textbox(value="1,1,1, 1,1,1, 1,1,1, 1,1,1, 2, 1,1,1, 1,1,1, 1,1,1, 1,1,1", label="Scale Factor List") |
| pass |
| pass |
| pass |
| return [Enable_1, Timestep_1, Depth_1, Scale_1, Enable_2, Timestep_2, Depth_2, Scale_2, Enable_3, Timestep_3, Depth_3, Scale_3, Enable_4, Timestep_4, Depth_4, Scale_4, |
| Enable_5, Timestep_5, Depth_5, Scale_5, Enable_6, Timestep_6, Depth_6, Scale_6, Enable_7, Timestep_7, Depth_7, Scale_7, Enable_8, Timestep_8, Depth_8, Scale_8, |
| Enable_Experimental, Timestep_Experimental, Scale_Experimental] |
| pass |
|
|
| def process(self, p, *args): |
| del DeepShrinkHiresFix.deepShrinkHiresFixActions[:] |
| for i in range(8): |
| DeepShrinkHiresFix.deepShrinkHiresFixActions.append(DeepShrinkHiresFixAction(args[i*4], args[i*4+1], args[i*4+2], args[i*4+3])) |
| pass |
| del DeepShrinkHiresFix.experimentalScales[:] |
| DeepShrinkHiresFix.enableExperimental = args[8*4] |
| DeepShrinkHiresFix.experimentalTimestep = args[8*4+1] |
| scaleFactorsTexts: str = args[8*4+2] |
| scaleFactorsTextsList = scaleFactorsTexts.split(",") |
| for scaleFactorsText in scaleFactorsTextsList: |
| DeepShrinkHiresFix.experimentalScales.append(float(scaleFactorsText)) |
| pass |
| pass |
|
|
| class DeepShrinkHiresFixUNet(sd_unet.SdUnet): |
| def __init__(self, _model): |
| super().__init__() |
| self.model = _model.to(devices.device) |
| pass |
| def forward(self, x, timesteps, context, y=None, **kwargs): |
| assert (y is not None) == ( |
| self.model.num_classes is not None |
| ), "must specify y if and only if the model is class-conditional" |
| hs = [] |
| ss = [] |
| t_emb = timestep_embedding(timesteps, self.model.model_channels, repeat_only=False) |
| emb = self.model.time_embed(t_emb) |
|
|
| if self.model.num_classes is not None: |
| assert y.shape[0] == x.shape[0] |
| emb = emb + self.model.label_emb(y) |
|
|
| h = x.type(self.model.dtype) |
| depth = 0 |
| block = 0 |
| scale = 1 |
| for module in self.model.input_blocks: |
| for action in DeepShrinkHiresFix.deepShrinkHiresFixActions: |
| if action.enable == True and action.depth == depth and action.timestep < timesteps[0]: |
| h = torch.nn.functional.interpolate(h.float(), scale_factor=1/action.scale, mode="bicubic", align_corners=False).to(h.dtype) |
| break |
| pass |
| pass |
| if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep: |
| h = torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype) |
| scale = DeepShrinkHiresFix.experimentalScales[block] |
| ss.append(scale) |
| pass |
| h = module(h, emb, context) |
| hs.append(h) |
| depth += 1 |
| block += 1 |
| pass |
|
|
| for action in DeepShrinkHiresFix.deepShrinkHiresFixActions: |
| if action.enable == True and action.depth == depth and action.timestep < timesteps[0]: |
| h = torch.nn.functional.interpolate(h.float(), scale_factor=1/action.scale, mode="bicubic", align_corners=False).to(h.dtype) |
| break |
| pass |
| pass |
| if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep: |
| h = torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype) |
| scale = DeepShrinkHiresFix.experimentalScales[block] |
| pass |
|
|
| h = self.model.middle_block(h, emb, context) |
|
|
| for action in DeepShrinkHiresFix.deepShrinkHiresFixActions: |
| if action.enable == True and action.depth == depth and action.timestep < timesteps[0]: |
| h = torch.nn.functional.interpolate(h.float(), scale_factor=action.scale, mode="bicubic", align_corners=False).to(h.dtype) |
| break |
| pass |
| pass |
| block += 1 |
|
|
| for module in self.model.output_blocks: |
| depth -= 1 |
| if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep: |
| h = torch.cat([torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype), |
| torch.nn.functional.interpolate(hs.pop().float(), scale_factor=ss.pop()/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype)], dim=1) |
| scale = DeepShrinkHiresFix.experimentalScales[block] |
| pass |
| else: |
| h = torch.cat([h, hs.pop()], dim=1) |
| pass |
| h = module(h, emb, context) |
| for action in DeepShrinkHiresFix.deepShrinkHiresFixActions: |
| if action.enable == True and action.depth == depth and action.timestep < timesteps[0]: |
| h = torch.nn.functional.interpolate(h.float(), scale_factor=action.scale, mode="bicubic", align_corners=False).to(h.dtype) |
| break |
| pass |
| pass |
| block += 1 |
| pass |
| if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep: |
| h = torch.nn.functional.interpolate(h.float(), scale_factor=scale, mode="bicubic", align_corners=False).to(h.dtype) |
| pass |
| h = h.type(x.dtype) |
| if self.model.predict_codebook_ids: |
| return self.model.id_predictor(h) |
| else: |
| return self.model.out(h) |
| pass |
| pass |
|
|
| DeepShrinkHiresFixUNetOption = sd_unet.SdUnetOption() |
| DeepShrinkHiresFixUNetOption.label = "Deep Shrink Hires.fix" |
| DeepShrinkHiresFixUNetOption.create_unet = lambda: DeepShrinkHiresFix.DeepShrinkHiresFixUNet(shared.sd_model.model.diffusion_model) |
|
|
| pass |
|
|
| script_callbacks.on_list_unets(lambda unets: unets.append(DeepShrinkHiresFix.DeepShrinkHiresFixUNetOption)) |