Update handler.py
Browse files- handler.py +6 -3
handler.py
CHANGED
|
@@ -27,6 +27,9 @@ class EndpointHandler:
|
|
| 27 |
tensor = cast(torch.Tensor, data["inputs"])
|
| 28 |
parameters = cast(dict, data.get("parameters", {}))
|
| 29 |
do_scaling = cast(bool, parameters.get("do_scaling", True))
|
|
|
|
|
|
|
|
|
|
| 30 |
output_type = cast(str, parameters.get("output_type", "pil"))
|
| 31 |
partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
|
| 32 |
if partial_postprocess and output_type != "pt":
|
|
@@ -34,8 +37,8 @@ class EndpointHandler:
|
|
| 34 |
|
| 35 |
tensor = tensor.to(self.device, self.dtype)
|
| 36 |
|
| 37 |
-
if
|
| 38 |
-
tensor = tensor /
|
| 39 |
|
| 40 |
with torch.no_grad():
|
| 41 |
frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
|
|
@@ -55,4 +58,4 @@ class EndpointHandler:
|
|
| 55 |
elif output_type == "pt":
|
| 56 |
frames = frames
|
| 57 |
|
| 58 |
-
return frames
|
|
|
|
| 27 |
tensor = cast(torch.Tensor, data["inputs"])
|
| 28 |
parameters = cast(dict, data.get("parameters", {}))
|
| 29 |
do_scaling = cast(bool, parameters.get("do_scaling", True))
|
| 30 |
+
scaling_factor = cast(float, parameters.get("scaling_factor", None))
|
| 31 |
+
if do_scaling and scaling_factor is None:
|
| 32 |
+
scaling_factor = self.vae.config.scaling_factor
|
| 33 |
output_type = cast(str, parameters.get("output_type", "pil"))
|
| 34 |
partial_postprocess = cast(bool, parameters.get("partial_postprocess", False))
|
| 35 |
if partial_postprocess and output_type != "pt":
|
|
|
|
| 37 |
|
| 38 |
tensor = tensor.to(self.device, self.dtype)
|
| 39 |
|
| 40 |
+
if scaling_factor is not None:
|
| 41 |
+
tensor = tensor / scaling_factor
|
| 42 |
|
| 43 |
with torch.no_grad():
|
| 44 |
frames = cast(torch.Tensor, self.vae.decode(tensor, return_dict=False)[0])
|
|
|
|
| 58 |
elif output_type == "pt":
|
| 59 |
frames = frames
|
| 60 |
|
| 61 |
+
return frames
|