Spaces:
Runtime error
Runtime error
Fix a bug of preprocess
Browse files
Rodin.py
CHANGED
|
@@ -6,6 +6,7 @@ import random
|
|
| 6 |
import base64
|
| 7 |
import io
|
| 8 |
from PIL import Image
|
|
|
|
| 9 |
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
| 10 |
from constant import *
|
| 11 |
|
|
@@ -98,6 +99,24 @@ def rodin_update(prompt, task_uuid, token, settings):
|
|
| 98 |
response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
|
| 99 |
return response.json()
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
class Generator:
|
| 103 |
def __init__(self, user_id, password) -> None:
|
|
@@ -109,22 +128,17 @@ class Generator:
|
|
| 109 |
return prompt, cache_image_base64
|
| 110 |
print("Preprocessing image...")
|
| 111 |
|
| 112 |
-
image_file =
|
| 113 |
-
if
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
if
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
if not (prompt and task_uuid):
|
| 124 |
-
prompt = preprocess_response.get('prompt', 'Default prompt if none returned')
|
| 125 |
-
processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None)
|
| 126 |
-
finally:
|
| 127 |
-
image_file.close()
|
| 128 |
|
| 129 |
return prompt, processed_image
|
| 130 |
|
|
|
|
| 6 |
import base64
|
| 7 |
import io
|
| 8 |
from PIL import Image
|
| 9 |
+
from io import BytesIO
|
| 10 |
from requests_toolbelt.multipart.encoder import MultipartEncoder
|
| 11 |
from constant import *
|
| 12 |
|
|
|
|
| 99 |
response = requests.post(f"{BASE_URL}/task/rodin_update", data={"uuid": task_uuid, "prompt": prompt, "settings": settings}, headers=headers)
|
| 100 |
return response.json()
|
| 101 |
|
| 102 |
+
def load_image(img_path):
|
| 103 |
+
image = Image.open(img_path)
|
| 104 |
+
|
| 105 |
+
# 按比例缩小图像到长度为1024
|
| 106 |
+
width, height = image.size
|
| 107 |
+
if width > height:
|
| 108 |
+
scale = 1024 / width
|
| 109 |
+
else:
|
| 110 |
+
scale = 1024 / height
|
| 111 |
+
new_width = int(width * scale)
|
| 112 |
+
new_height = int(height * scale)
|
| 113 |
+
resized_image = image.resize((new_width, new_height))
|
| 114 |
+
|
| 115 |
+
# 将 PIL.Image 对象转换为字节流
|
| 116 |
+
byte_io = BytesIO()
|
| 117 |
+
resized_image.save(byte_io, format='PNG')
|
| 118 |
+
image_bytes = byte_io.getvalue()
|
| 119 |
+
return image_bytes
|
| 120 |
|
| 121 |
class Generator:
|
| 122 |
def __init__(self, user_id, password) -> None:
|
|
|
|
| 128 |
return prompt, cache_image_base64
|
| 129 |
print("Preprocessing image...")
|
| 130 |
|
| 131 |
+
image_file = load_image(image_path)
|
| 132 |
+
if prompt and task_uuid:
|
| 133 |
+
preprocess_response = rodin_preprocess_image(generate_prompt=False, image=image_file, name="images.png", token=self.token)
|
| 134 |
+
else:
|
| 135 |
+
preprocess_response = rodin_preprocess_image(generate_prompt=True, image=image_file, name="images.png", token=self.token)
|
| 136 |
+
if 'error' in preprocess_response:
|
| 137 |
+
print("Error in image preprocessing:", preprocess_response['error'])
|
| 138 |
+
else:
|
| 139 |
+
if not (prompt and task_uuid):
|
| 140 |
+
prompt = preprocess_response.get('prompt', 'Default prompt if none returned')
|
| 141 |
+
processed_image = "data:image/png;base64," + preprocess_response.get('processed_image', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
return prompt, processed_image
|
| 144 |
|