Update handler.py
Browse files- handler.py +4 -6
handler.py
CHANGED
|
@@ -55,17 +55,15 @@ class EndpointHandler():
|
|
| 55 |
|
| 56 |
def __call__(self, data: Dict[str, Any]) -> str:
|
| 57 |
# get inputs
|
| 58 |
-
|
| 59 |
-
data = data.decode('utf-8')
|
| 60 |
-
data = json.loads(data)
|
| 61 |
|
| 62 |
batch_size = data.pop("batch_size", 1)
|
| 63 |
|
| 64 |
-
context = base64.b64decode(
|
| 65 |
context = np.frombuffer(context, dtype="float32")
|
| 66 |
context = np.reshape(context, (batch_size, 77, 768))
|
| 67 |
|
| 68 |
-
unconditional_context = base64.b64decode(
|
| 69 |
unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
|
| 70 |
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
|
| 71 |
|
|
@@ -98,4 +96,4 @@ class EndpointHandler():
|
|
| 98 |
latent_b64 = base64.b64encode(latent.numpy().tobytes())
|
| 99 |
latent_b64str = latent_b64.decode()
|
| 100 |
|
| 101 |
-
return latent_b64str
|
|
|
|
| 55 |
|
| 56 |
def __call__(self, data: Dict[str, Any]) -> str:
|
| 57 |
# get inputs
|
| 58 |
+
contexts = data.pop("inputs", data)
|
|
|
|
|
|
|
| 59 |
|
| 60 |
batch_size = data.pop("batch_size", 1)
|
| 61 |
|
| 62 |
+
context = base64.b64decode(contexts[0])
|
| 63 |
context = np.frombuffer(context, dtype="float32")
|
| 64 |
context = np.reshape(context, (batch_size, 77, 768))
|
| 65 |
|
| 66 |
+
unconditional_context = base64.b64decode(contexts[1])
|
| 67 |
unconditional_context = np.frombuffer(unconditional_context, dtype="float32")
|
| 68 |
unconditional_context = np.reshape(unconditional_context, (batch_size, 77, 768))
|
| 69 |
|
|
|
|
| 96 |
latent_b64 = base64.b64encode(latent.numpy().tobytes())
|
| 97 |
latent_b64str = latent_b64.decode()
|
| 98 |
|
| 99 |
+
return latent_b64str
|