Commit
·
471adc0
1
Parent(s):
63859a4
- handler.py +30 -5
handler.py
CHANGED
|
@@ -20,16 +20,41 @@ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.
|
|
| 20 |
|
| 21 |
class EndpointHandler():
|
| 22 |
def __init__(self, path=""):
|
| 23 |
-
|
| 24 |
|
| 25 |
|
| 26 |
|
| 27 |
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
| 35 |
# """
|
|
|
|
| 20 |
|
| 21 |
class EndpointHandler():
|
| 22 |
def __init__(self, path=""):
|
| 23 |
+
self.stable_diffusion_id = "Lykon/dreamshaper-8"
|
| 24 |
|
| 25 |
|
| 26 |
|
| 27 |
|
| 28 |
|
| 29 |
+
self.pipe = StableDiffusionPipeline.from_pretrained(self.stable_diffusion_id,torch_dtype=dtype,safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker", torch_dtype=dtype)).to(device.type)
|
| 30 |
+
#self.pipe.enable_xformers_memory_efficient_attention()
|
| 31 |
+
#self.pipe.enable_vae_tiling()
|
| 32 |
+
self.generator = torch.Generator(device=device.type).manual_seed(3)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
from typing import Optional
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
from torch.nn import functional as F
|
| 38 |
+
from torch.nn import Conv2d
|
| 39 |
+
from torch.nn.modules.utils import _pair
|
| 40 |
+
|
| 41 |
+
def asymmetricConv2DConvForward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
|
| 42 |
+
self.paddingX = (self._reversed_padding_repeated_twice[0], self._reversed_padding_repeated_twice[1], 0, 0)
|
| 43 |
+
self.paddingY = (0, 0, self._reversed_padding_repeated_twice[2], self._reversed_padding_repeated_twice[3])
|
| 44 |
+
working = F.pad(input, self.paddingX, mode='circular')
|
| 45 |
+
working = F.pad(working, self.paddingY, mode='constant')
|
| 46 |
+
return F.conv2d(working, weight, bias, self.stride, _pair(0), self.dilation, self.groups)
|
| 47 |
+
|
| 48 |
+
targets = [pipe.vae, pipe.text_encoder, pipe.unet,]
|
| 49 |
+
conv_layers = []
|
| 50 |
+
for target in targets:
|
| 51 |
+
for module in target.modules():
|
| 52 |
+
if isinstance(module, torch.nn.Conv2d):
|
| 53 |
+
conv_layers.append(module)
|
| 54 |
+
|
| 55 |
+
for cl in conv_layers:
|
| 56 |
+
cl._conv_forward = asymmetricConv2DConvForward.__get__(cl, Conv2d)
|
| 57 |
+
|
| 58 |
|
| 59 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
| 60 |
# """
|