File size: 2,904 Bytes
4adca93 bb243dd 4adca93 bb243dd 4adca93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import os
from typing import Any, Dict, Union
from lora_diffusion import patch_pipe, tune_lora_scale
class LoraStyle:
class LoraPatcher:
def __init__(self, pipe, style: Dict[str, Any]):
self.__style = style
self.pipe = pipe
def patch(self):
patch_pipe(self.pipe, self.__style["path"])
tune_lora_scale(self.pipe.unet, self.__style["weight"])
tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
def kwargs(self):
return {}
def cleanup(self):
tune_lora_scale(self.pipe.unet, 0.0)
tune_lora_scale(self.pipe.text_encoder, 0.0)
pass
class EmptyLoraPatcher:
def patch(self):
pass
def kwargs(self):
return {}
def cleanup(self):
pass
def load(self, model_dir: str):
self.__styles = {
"nq6akX1CIp": {
"path": model_dir + "/laur_style/nq6akX1CIp/final_lora",
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"ghibli": {
"path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
"weight": 1,
"negativePrompt": [""],
"type": "custom",
},
"eQAmnK2kB2": {
"path": model_dir + "/laur_style/eQAmnK2kB2/final_lora",
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"to8contrast": {
"path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
"weight": 0.5,
"negativePrompt": [""],
"type": "custom",
},
"jim lee": {
"path": model_dir + "/laur_style/e2j9mz0jqj/final_lora.bin",
"weight": 0.8,
"negativePrompt": [""],
"type": "custom",
},
}
self.__verify()
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
if key in self.__styles:
return f"{key} style {prompt}"
return prompt
def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]:
if key in self.__styles:
style = self.__styles[key]
return self.LoraPatcher(pipe, style)
return self.EmptyLoraPatcher()
def __verify(self):
"A method to verify if lora exists within the required path otherwise throw error"
for item in self.__styles.keys():
if not os.path.exists(self.__styles[item]["path"]):
raise Exception(
"Lora style model "
+ item
+ " not found at path: "
+ self.__styles[item]["path"]
)
|