model-sd-multi / util /lora_style.py
jayparmr's picture
Update util/lora_style.py
bb243dd
import os
from typing import Any, Dict, Union
from lora_diffusion import patch_pipe, tune_lora_scale
class LoraStyle:
class LoraPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
def patch(self):
patch_pipe(self.pipe, self.__style["path"])
tune_lora_scale(self.pipe.unet, self.__style["weight"])
tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
def kwargs(self):
return {}
def cleanup(self):
tune_lora_scale(self.pipe.unet, 0.0)
tune_lora_scale(self.pipe.text_encoder, 0.0)
pass
class EmptyLoraPatcher:
def patch(self):
pass
def kwargs(self):
return {}
def cleanup(self):
pass
def load(self, model_dir: str):
self.__styles = {
"nq6akX1CIp": {
"path": model_dir + "/laur_style/nq6akX1CIp/final_lora",
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"ghibli": {
"path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
"weight": 1,
"negativePrompt": [""],
"type": "custom",
},
"eQAmnK2kB2": {
"path": model_dir + "/laur_style/eQAmnK2kB2/final_lora",
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"to8contrast": {
"path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"jim lee": {
"path": model_dir + "/laur_style/e2j9mz0jqj/final_lora.bin",
"weight": 0.8,
"negativePrompt": [""],
"type": "custom",
},
}
self.__verify()
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
if key in self.__styles:
return f"{key} style {prompt}"
return prompt
def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]:
if key in self.__styles:
style = self.__styles[key]
return self.LoraPatcher(pipe, style)
return self.EmptyLoraPatcher()
def __verify(self):
"A method to verify if lora exists within the required path otherwise throw error"
for item in self.__styles.keys():
if not os.path.exists(self.__styles[item]["path"]):
raise Exception(
"Lora style model "
+ item
+ " not found at path: "
+ self.__styles[item]["path"]
)