Spaces:
Build error
Build error
Commit
·
69f10d4
1
Parent(s):
017604b
refactor: cto for better optimization
Browse files- src/api/nto_api.py +58 -34
src/api/nto_api.py
CHANGED
|
@@ -57,6 +57,7 @@ class NecklaceTryOnIDEntity(BaseModel):
|
|
| 57 |
necklaceImageId: str
|
| 58 |
necklaceCategory: str
|
| 59 |
storename: str
|
|
|
|
| 60 |
api_token: str
|
| 61 |
|
| 62 |
|
|
@@ -128,55 +129,66 @@ async def clothing_try_on_v2(image: UploadFile = File(...), clothing_type: str =
|
|
| 128 |
|
| 129 |
@nto_cto_router.post("/clothingTryOn")
|
| 130 |
async def clothing_try_on(image: UploadFile = File(...),
|
| 131 |
-
|
|
|
|
| 132 |
logger.info("-" * 50)
|
| 133 |
logger.info(">>> CLOTHING TRY ON STARTED <<<")
|
| 134 |
start_time = time.time()
|
| 135 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
try:
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
| 141 |
except Exception as e:
|
| 142 |
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
| 143 |
-
return JSONResponse(status_code=500, content={"error":
|
| 144 |
|
| 145 |
try:
|
|
|
|
| 146 |
actual_image = image.copy()
|
| 147 |
jewellery_mask = Image.fromarray(np.bitwise_and(np.array(mask), np.array(image)))
|
| 148 |
arr_orig = np.array(grayscale(mask))
|
| 149 |
|
| 150 |
-
|
| 151 |
-
image = Image.fromarray(
|
|
|
|
|
|
|
| 152 |
|
|
|
|
| 153 |
arr = arr_orig.copy()
|
| 154 |
mask_y = np.where(arr == arr[arr != 0][0])[0][0]
|
| 155 |
arr[mask_y:, :] = 255
|
| 156 |
-
|
| 157 |
mask = Image.fromarray(arr).resize((512, 512))
|
|
|
|
| 158 |
logger.info(">>> IMAGE PROCESSING COMPLETED <<<")
|
| 159 |
except Exception as e:
|
| 160 |
logger.error(f">>> IMAGE PROCESSING ERROR: {str(e)} <<<")
|
| 161 |
-
return JSONResponse(status_code=500,
|
| 162 |
-
content={"error": f"Error processing image or mask", "code": 500})
|
| 163 |
|
| 164 |
try:
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
mask_bytes_ = base64.b64encode(mask_img_base_64.getvalue()).decode("utf-8")
|
| 169 |
-
image_bytes_ = base64.b64encode(act_img_base_64.getvalue()).decode("utf-8")
|
| 170 |
-
|
| 171 |
-
mask_data_uri = f"data:image/webp;base64,{mask_bytes_}"
|
| 172 |
-
image_data_uri = f"data:image/webp;base64,{image_bytes_}"
|
| 173 |
logger.info(">>> IMAGE ENCODING COMPLETED <<<")
|
| 174 |
except Exception as e:
|
| 175 |
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
| 176 |
-
return JSONResponse(status_code=500,
|
| 177 |
-
content={"error": f"Error encoding images", "code": 500})
|
| 178 |
|
| 179 |
-
|
|
|
|
| 180 |
"mask": mask_data_uri,
|
| 181 |
"image": image_data_uri,
|
| 182 |
"prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
|
|
@@ -185,33 +197,45 @@ async def clothing_try_on(image: UploadFile = File(...),
|
|
| 185 |
}
|
| 186 |
|
| 187 |
try:
|
| 188 |
-
output = replicate_run_cto(
|
| 189 |
logger.info(">>> REPLICATE PROCESSING COMPLETED <<<")
|
| 190 |
except Exception as e:
|
| 191 |
logger.error(f">>> REPLICATE PROCESSING ERROR: {str(e)} <<<")
|
| 192 |
-
return JSONResponse(content={"error":
|
| 193 |
|
| 194 |
try:
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
total_inference_time = round((time.time() - start_time), 2)
|
| 205 |
logger.info(">>> OUTPUT IMAGE PROCESSING COMPLETED <<<")
|
| 206 |
|
| 207 |
response = {
|
| 208 |
-
"output":
|
| 209 |
"code": 200,
|
| 210 |
"inference_time": total_inference_time
|
| 211 |
}
|
| 212 |
except Exception as e:
|
| 213 |
logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(e)} <<<")
|
| 214 |
-
return JSONResponse(status_code=500, content={"error":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
| 217 |
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|
|
|
|
| 57 |
necklaceImageId: str
|
| 58 |
necklaceCategory: str
|
| 59 |
storename: str
|
| 60 |
+
|
| 61 |
api_token: str
|
| 62 |
|
| 63 |
|
|
|
|
| 129 |
|
| 130 |
@nto_cto_router.post("/clothingTryOn")
|
| 131 |
async def clothing_try_on(image: UploadFile = File(...),
|
| 132 |
+
mask: UploadFile = File(...),
|
| 133 |
+
clothing_type: str = Form(...)):
|
| 134 |
logger.info("-" * 50)
|
| 135 |
logger.info(">>> CLOTHING TRY ON STARTED <<<")
|
| 136 |
start_time = time.time()
|
| 137 |
|
| 138 |
+
# Helper function to convert image to base64
|
| 139 |
+
def image_to_base64(img: Image.Image, format="WEBP", quality=85) -> str:
|
| 140 |
+
with BytesIO() as buffer:
|
| 141 |
+
img.save(buffer, format=format, quality=quality)
|
| 142 |
+
return f"data:image/{format.lower()};base64,{base64.b64encode(buffer.getvalue()).decode('utf-8')}"
|
| 143 |
+
|
| 144 |
try:
|
| 145 |
+
# Load images concurrently using asyncio
|
| 146 |
+
image_data, mask_data = await asyncio.gather(
|
| 147 |
+
image.read(),
|
| 148 |
+
mask.read()
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Convert bytes to PIL Images
|
| 152 |
+
image = Image.open(BytesIO(image_data)).convert("RGB")
|
| 153 |
+
mask = Image.open(BytesIO(mask_data)).convert("RGB")
|
| 154 |
logger.info(">>> IMAGES LOADED SUCCESSFULLY <<<")
|
| 155 |
except Exception as e:
|
| 156 |
logger.error(f">>> IMAGE LOADING ERROR: {str(e)} <<<")
|
| 157 |
+
return JSONResponse(status_code=500, content={"error": "Error reading image or mask", "code": 500})
|
| 158 |
|
| 159 |
try:
|
| 160 |
+
# Process images
|
| 161 |
actual_image = image.copy()
|
| 162 |
jewellery_mask = Image.fromarray(np.bitwise_and(np.array(mask), np.array(image)))
|
| 163 |
arr_orig = np.array(grayscale(mask))
|
| 164 |
|
| 165 |
+
# Process image with inpainting
|
| 166 |
+
image = Image.fromarray(
|
| 167 |
+
cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
|
| 168 |
+
).resize((512, 512))
|
| 169 |
|
| 170 |
+
# Process mask
|
| 171 |
arr = arr_orig.copy()
|
| 172 |
mask_y = np.where(arr == arr[arr != 0][0])[0][0]
|
| 173 |
arr[mask_y:, :] = 255
|
|
|
|
| 174 |
mask = Image.fromarray(arr).resize((512, 512))
|
| 175 |
+
|
| 176 |
logger.info(">>> IMAGE PROCESSING COMPLETED <<<")
|
| 177 |
except Exception as e:
|
| 178 |
logger.error(f">>> IMAGE PROCESSING ERROR: {str(e)} <<<")
|
| 179 |
+
return JSONResponse(status_code=500, content={"error": "Error processing image or mask", "code": 500})
|
|
|
|
| 180 |
|
| 181 |
try:
|
| 182 |
+
# Convert images to base64 more efficiently
|
| 183 |
+
mask_data_uri = image_to_base64(mask)
|
| 184 |
+
image_data_uri = image_to_base64(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
logger.info(">>> IMAGE ENCODING COMPLETED <<<")
|
| 186 |
except Exception as e:
|
| 187 |
logger.error(f">>> IMAGE ENCODING ERROR: {str(e)} <<<")
|
| 188 |
+
return JSONResponse(status_code=500, content={"error": "Error encoding images", "code": 500})
|
|
|
|
| 189 |
|
| 190 |
+
# Prepare replicate input
|
| 191 |
+
input_data = {
|
| 192 |
"mask": mask_data_uri,
|
| 193 |
"image": image_data_uri,
|
| 194 |
"prompt": f"Dull {clothing_type}, non-reflective clothing, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple, perfect eyes, perfect face, perfect body, high quality, realistic, photorealistic, high resolution,traditional full sleeve blouse",
|
|
|
|
| 197 |
}
|
| 198 |
|
| 199 |
try:
|
| 200 |
+
output = replicate_run_cto(input_data)
|
| 201 |
logger.info(">>> REPLICATE PROCESSING COMPLETED <<<")
|
| 202 |
except Exception as e:
|
| 203 |
logger.error(f">>> REPLICATE PROCESSING ERROR: {str(e)} <<<")
|
| 204 |
+
return JSONResponse(content={"error": "Error running clothing try on", "code": 500}, status_code=500)
|
| 205 |
|
| 206 |
try:
|
| 207 |
+
async with aiohttp.ClientSession() as session:
|
| 208 |
+
async with session.get(output[0]) as response:
|
| 209 |
+
output_bytes = await response.read()
|
| 210 |
+
|
| 211 |
+
output_image = Image.open(BytesIO(output_bytes)).resize(actual_image.size)
|
| 212 |
+
|
| 213 |
+
# Process final image
|
| 214 |
+
output_array = np.bitwise_and(
|
| 215 |
+
np.array(output_image),
|
| 216 |
+
np.bitwise_not(np.array(Image.fromarray(arr_orig).convert("RGB")))
|
| 217 |
+
)
|
| 218 |
+
result = Image.fromarray(np.bitwise_or(output_array, np.array(jewellery_mask)))
|
| 219 |
+
|
| 220 |
+
# Convert result to base64
|
| 221 |
+
result_base64 = image_to_base64(result)
|
| 222 |
+
|
| 223 |
total_inference_time = round((time.time() - start_time), 2)
|
| 224 |
logger.info(">>> OUTPUT IMAGE PROCESSING COMPLETED <<<")
|
| 225 |
|
| 226 |
response = {
|
| 227 |
+
"output": result_base64,
|
| 228 |
"code": 200,
|
| 229 |
"inference_time": total_inference_time
|
| 230 |
}
|
| 231 |
except Exception as e:
|
| 232 |
logger.error(f">>> OUTPUT IMAGE PROCESSING ERROR: {str(e)} <<<")
|
| 233 |
+
return JSONResponse(status_code=500, content={"error": "Error processing output image", "code": 500})
|
| 234 |
+
finally:
|
| 235 |
+
# Clean up resources
|
| 236 |
+
if 'output_image' in locals(): del output_image
|
| 237 |
+
if 'output_array' in locals(): del output_array
|
| 238 |
+
gc.collect()
|
| 239 |
|
| 240 |
logger.info(f">>> TOTAL INFERENCE TIME: {total_inference_time}s <<<")
|
| 241 |
logger.info(">>> REQUEST COMPLETED SUCCESSFULLY <<<")
|