import os from typing import Any, Dict, Union from lora_diffusion import patch_pipe, tune_lora_scale class LoraStyle: class LoraPatcher: def __init__(self, pipe, style: Dict[str, Any]): self.__style = style self.pipe = pipe def patch(self): patch_pipe(self.pipe, self.__style["path"]) tune_lora_scale(self.pipe.unet, self.__style["weight"]) tune_lora_scale(self.pipe.text_encoder, self.__style["weight"]) def kwargs(self): return {} def cleanup(self): tune_lora_scale(self.pipe.unet, 0.0) tune_lora_scale(self.pipe.text_encoder, 0.0) pass class EmptyLoraPatcher: def patch(self): pass def kwargs(self): return {} def cleanup(self): pass def load(self, model_dir: str): self.__styles = { "nq6akX1CIp": { "path": model_dir + "/laur_style/nq6akX1CIp/final_lora", "weight": 0.5, "negativePrompt": [""], "type": "custom", }, "ghibli": { "path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin", "weight": 1, "negativePrompt": [""], "type": "custom", }, "eQAmnK2kB2": { "path": model_dir + "/laur_style/eQAmnK2kB2/final_lora", "weight": 0.5, "negativePrompt": [""], "type": "custom", }, "to8contrast": { "path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin", "weight": 0.5, "negativePrompt": [""], "type": "custom", }, "jim lee": { "path": model_dir + "/laur_style/e2j9mz0jqj/final_lora.bin", "weight": 0.8, "negativePrompt": [""], "type": "custom", }, } self.__verify() def prepend_style_to_prompt(self, prompt: str, key: str) -> str: if key in self.__styles: return f"{key} style {prompt}" return prompt def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]: if key in self.__styles: style = self.__styles[key] return self.LoraPatcher(pipe, style) return self.EmptyLoraPatcher() def __verify(self): "A method to verify if lora exists within the required path otherwise throw error" for item in self.__styles.keys(): if not os.path.exists(self.__styles[item]["path"]): raise Exception( "Lora style model " + item + " not found at path: " + self.__styles[item]["path"] )