import gradio import torch import modules.devices as devices import modules.scripts as scripts import modules.script_callbacks as script_callbacks import modules.sd_unet as sd_unet import modules.shared as shared from ldm.modules.diffusionmodules.util import timestep_embedding as timestep_embedding class DeepShrinkHiresFixAction(): def __init__(self, enable: bool, timestep: float, depth: int, scale: float): self.enable = enable self.timestep = timestep self.depth = depth self.scale = scale pass pass class DeepShrinkHiresFix(scripts.Script): deepShrinkHiresFixActions: list[DeepShrinkHiresFixAction] = [] enableExperimental: bool = False experimentalTimestep: float = 900 experimentalScales: list[float] = [] def __init__(self): pass def title(self): return "Deep Shrink Hires.fix" pass def show(self, is_img2img): return scripts.AlwaysVisible pass def ui(self, is_img2img): with gradio.Accordion(label="Deep Shrink Hires.fix", open=False): with gradio.Row(): Enable_1 = gradio.Checkbox(value=True, label="Enable 1") Timestep_1 = gradio.Number(value=900, label="Timestep 1") Depth_1 = gradio.Number(value=3, label="Block Depth 1", precision=0) Scale_1 = gradio.Number(value=2, label="Scale factor 1") pass with gradio.Row(): Enable_2 = gradio.Checkbox(value=True, label="Enable 2") Timestep_2 = gradio.Number(value=650, label="Timestep 2") Depth_2 = gradio.Number(value=3, label="Block Depth 2", precision=0) Scale_2 = gradio.Number(value=2, label="Scale factor 2") pass with gradio.Accordion(label="Advanced Settings", open=False): with gradio.Row(): Enable_3 = gradio.Checkbox(value=False, label="Enable 3") Timestep_3 = gradio.Number(value=900, label="Timestep 3") Depth_3 = gradio.Number(value=3, label="Block Depth 3", precision=0) Scale_3 = gradio.Number(value=2, label="Scale factor 3") pass with gradio.Row(): Enable_4 = gradio.Checkbox(value=False, label="Enable 4") Timestep_4 = gradio.Number(value=650, label="Timestep 4") Depth_4 = gradio.Number(value=3, label="Block Depth 4", precision=0) Scale_4 = gradio.Number(value=2, label="Scale factor 4") pass with gradio.Row(): Enable_5 = gradio.Checkbox(value=False, label="Enable 5") Timestep_5 = gradio.Number(value=900, label="Timestep 5") Depth_5 = gradio.Number(value=3, label="Block Depth 5", precision=0) Scale_5 = gradio.Number(value=2, label="Scale factor 5") pass with gradio.Row(): Enable_6 = gradio.Checkbox(value=False, label="Enable 6") Timestep_6 = gradio.Number(value=650, label="Timestep 6") Depth_6 = gradio.Number(value=3, label="Block Depth 6", precision=0) Scale_6 = gradio.Number(value=2, label="Scale factor 6") pass with gradio.Row(): Enable_7 = gradio.Checkbox(value=False, label="Enable 7") Timestep_7 = gradio.Number(value=900, label="Timestep 7") Depth_7 = gradio.Number(value=3, label="Block Depth 7", precision=0) Scale_7 = gradio.Number(value=2, label="Scale factor 7") pass with gradio.Row(): Enable_8 = gradio.Checkbox(value=False, label="Enable 8") Timestep_8 = gradio.Number(value=650, label="Timestep 8") Depth_8 = gradio.Number(value=3, label="Block Depth 8", precision=0) Scale_8 = gradio.Number(value=2, label="Scale factor 8") pass pass with gradio.Accordion(label="Experimental Settings", open=False): with gradio.Row(): Enable_Experimental = gradio.Checkbox(value=False, label="Enable Experimental Mode") Timestep_Experimental = gradio.Number(value=900, label="Timestep") Scale_Experimental = gradio.Textbox(value="1,1,1, 1,1,1, 1,1,1, 1,1,1, 2, 1,1,1, 1,1,1, 1,1,1, 1,1,1", label="Scale Factor List") pass pass pass return [Enable_1, Timestep_1, Depth_1, Scale_1, Enable_2, Timestep_2, Depth_2, Scale_2, Enable_3, Timestep_3, Depth_3, Scale_3, Enable_4, Timestep_4, Depth_4, Scale_4, Enable_5, Timestep_5, Depth_5, Scale_5, Enable_6, Timestep_6, Depth_6, Scale_6, Enable_7, Timestep_7, Depth_7, Scale_7, Enable_8, Timestep_8, Depth_8, Scale_8, Enable_Experimental, Timestep_Experimental, Scale_Experimental] pass def process(self, p, *args): del DeepShrinkHiresFix.deepShrinkHiresFixActions[:] for i in range(8): DeepShrinkHiresFix.deepShrinkHiresFixActions.append(DeepShrinkHiresFixAction(args[i*4], args[i*4+1], args[i*4+2], args[i*4+3])) pass del DeepShrinkHiresFix.experimentalScales[:] DeepShrinkHiresFix.enableExperimental = args[8*4] DeepShrinkHiresFix.experimentalTimestep = args[8*4+1] scaleFactorsTexts: str = args[8*4+2] scaleFactorsTextsList = scaleFactorsTexts.split(",") for scaleFactorsText in scaleFactorsTextsList: DeepShrinkHiresFix.experimentalScales.append(float(scaleFactorsText)) pass pass class DeepShrinkHiresFixUNet(sd_unet.SdUnet): def __init__(self, _model): super().__init__() self.model = _model.to(devices.device) pass def forward(self, x, timesteps, context, y=None, **kwargs): assert (y is not None) == ( self.model.num_classes is not None ), "must specify y if and only if the model is class-conditional" hs = [] ss = [] t_emb = timestep_embedding(timesteps, self.model.model_channels, repeat_only=False) emb = self.model.time_embed(t_emb) if self.model.num_classes is not None: assert y.shape[0] == x.shape[0] emb = emb + self.model.label_emb(y) h = x.type(self.model.dtype) depth = 0 block = 0 scale = 1 for module in self.model.input_blocks: for action in DeepShrinkHiresFix.deepShrinkHiresFixActions: if action.enable == True and action.depth == depth and action.timestep < timesteps[0]: h = torch.nn.functional.interpolate(h.float(), scale_factor=1/action.scale, mode="bicubic", align_corners=False).to(h.dtype) break pass pass if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep: h = torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype) scale = DeepShrinkHiresFix.experimentalScales[block] ss.append(scale) pass h = module(h, emb, context) hs.append(h) depth += 1 block += 1 pass for action in DeepShrinkHiresFix.deepShrinkHiresFixActions: if action.enable == True and action.depth == depth and action.timestep < timesteps[0]: h = torch.nn.functional.interpolate(h.float(), scale_factor=1/action.scale, mode="bicubic", align_corners=False).to(h.dtype) break pass pass if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep: h = torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype) scale = DeepShrinkHiresFix.experimentalScales[block] pass h = self.model.middle_block(h, emb, context) for action in DeepShrinkHiresFix.deepShrinkHiresFixActions: if action.enable == True and action.depth == depth and action.timestep < timesteps[0]: h = torch.nn.functional.interpolate(h.float(), scale_factor=action.scale, mode="bicubic", align_corners=False).to(h.dtype) break pass pass block += 1 for module in self.model.output_blocks: depth -= 1 if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep: h = torch.cat([torch.nn.functional.interpolate(h.float(), scale_factor=scale/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype), torch.nn.functional.interpolate(hs.pop().float(), scale_factor=ss.pop()/DeepShrinkHiresFix.experimentalScales[block], mode="bicubic", align_corners=False).to(h.dtype)], dim=1) scale = DeepShrinkHiresFix.experimentalScales[block] pass else: h = torch.cat([h, hs.pop()], dim=1) pass h = module(h, emb, context) for action in DeepShrinkHiresFix.deepShrinkHiresFixActions: if action.enable == True and action.depth == depth and action.timestep < timesteps[0]: h = torch.nn.functional.interpolate(h.float(), scale_factor=action.scale, mode="bicubic", align_corners=False).to(h.dtype) break pass pass block += 1 pass if DeepShrinkHiresFix.enableExperimental and timesteps[0] >= DeepShrinkHiresFix.experimentalTimestep: h = torch.nn.functional.interpolate(h.float(), scale_factor=scale, mode="bicubic", align_corners=False).to(h.dtype) pass h = h.type(x.dtype) if self.model.predict_codebook_ids: return self.model.id_predictor(h) else: return self.model.out(h) pass pass DeepShrinkHiresFixUNetOption = sd_unet.SdUnetOption() DeepShrinkHiresFixUNetOption.label = "Deep Shrink Hires.fix" DeepShrinkHiresFixUNetOption.create_unet = lambda: DeepShrinkHiresFix.DeepShrinkHiresFixUNet(shared.sd_model.model.diffusion_model) pass script_callbacks.on_list_unets(lambda unets: unets.append(DeepShrinkHiresFix.DeepShrinkHiresFixUNetOption))