| import os | |
| from typing import Any, Dict, Union | |
| from lora_diffusion import patch_pipe, tune_lora_scale | |
| class LoraStyle: | |
| class LoraPatcher: | |
| def __init__(self, pipe, style: Dict[str, Any]): | |
| self.__style = style | |
| self.pipe = pipe | |
| def patch(self): | |
| patch_pipe(self.pipe, self.__style["path"]) | |
| tune_lora_scale(self.pipe.unet, self.__style["weight"]) | |
| tune_lora_scale(self.pipe.text_encoder, self.__style["weight"]) | |
| def kwargs(self): | |
| return {} | |
| def cleanup(self): | |
| tune_lora_scale(self.pipe.unet, 0.0) | |
| tune_lora_scale(self.pipe.text_encoder, 0.0) | |
| pass | |
| class EmptyLoraPatcher: | |
| def patch(self): | |
| pass | |
| def kwargs(self): | |
| return {} | |
| def cleanup(self): | |
| pass | |
| def load(self, model_dir: str): | |
| self.__styles = { | |
| "nq6akX1CIp": { | |
| "path": model_dir + "/laur_style/nq6akX1CIp/final_lora", | |
| "weight": 0.5, | |
| "negativePrompt": [""], | |
| "type": "custom", | |
| }, | |
| "ghibli": { | |
| "path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin", | |
| "weight": 1, | |
| "negativePrompt": [""], | |
| "type": "custom", | |
| }, | |
| "eQAmnK2kB2": { | |
| "path": model_dir + "/laur_style/eQAmnK2kB2/final_lora", | |
| "weight": 0.5, | |
| "negativePrompt": [""], | |
| "type": "custom", | |
| }, | |
| "to8contrast": { | |
| "path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin", | |
| "weight": 0.5, | |
| "negativePrompt": [""], | |
| "type": "custom", | |
| }, | |
| "jim lee": { | |
| "path": model_dir + "/laur_style/e2j9mz0jqj/final_lora.bin", | |
| "weight": 0.8, | |
| "negativePrompt": [""], | |
| "type": "custom", | |
| }, | |
| } | |
| self.__verify() | |
| def prepend_style_to_prompt(self, prompt: str, key: str) -> str: | |
| if key in self.__styles: | |
| return f"{key} style {prompt}" | |
| return prompt | |
| def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]: | |
| if key in self.__styles: | |
| style = self.__styles[key] | |
| return self.LoraPatcher(pipe, style) | |
| return self.EmptyLoraPatcher() | |
| def __verify(self): | |
| "A method to verify if lora exists within the required path otherwise throw error" | |
| for item in self.__styles.keys(): | |
| if not os.path.exists(self.__styles[item]["path"]): | |
| raise Exception( | |
| "Lora style model " | |
| + item | |
| + " not found at path: " | |
| + self.__styles[item]["path"] | |
| ) | |