Spaces:
Runtime error
Runtime error
| """ | |
| Simple clothing segmentation API - No Gradio | |
| """ | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import base64 | |
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import Response | |
| import onnxruntime as ort | |
| from huggingface_hub import hf_hub_download | |
| app = FastAPI() | |
| print("Downloading model...") | |
| model_path = hf_hub_download( | |
| repo_id="Metal3d/deeplabv3p-resnet50-human", | |
| filename="deeplabv3p-resnet50-human.onnx" | |
| ) | |
| print(f"Model from: {model_path}") | |
| session = ort.InferenceSession(model_path) | |
| print("Model loaded!") | |
| CLOTHING_CLASSES = [5, 9] | |
| def preprocess(img): | |
| img = img.resize((512, 512)) | |
| arr = np.array(img).astype(np.float32) / 127.5 - 1 | |
| if len(arr.shape) == 2: | |
| arr = np.stack([arr] * 3, axis=-1) | |
| elif arr.shape[-1] == 4: | |
| arr = arr[:, :, :3] | |
| return np.transpose(arr, (2, 0, 1))[np.newaxis, :, :, :] | |
| async def process(user_image: UploadFile = File(...), fabric_image: UploadFile = File(...)): | |
| user = Image.open(io.BytesIO(await user_image.read())).convert("RGB") | |
| fabric = Image.open(io.BytesIO(await fabric_image.read())).convert("RGB") | |
| input_data = preprocess(user) | |
| input_name = session.get_inputs()[0].name | |
| output_name = session.get_outputs()[0].name | |
| result = session.run([output_name], {input_name: input_data[0]})[0] | |
| result = np.argmax(result[0], axis=0) | |
| mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255 | |
| mask_img = Image.fromarray(mask).resize(user.size, Image.NEAREST) | |
| fabric_arr = np.array(fabric.resize(user.size, Image.LANCZOS)) | |
| user_arr = np.array(user) | |
| mask_arr = np.array(mask_img) / 255.0 | |
| output = (fabric_arr * mask_arr[:, :, np.newaxis] + | |
| user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8) | |
| buf = io.BytesIO() | |
| Image.fromarray(output).save(buf, format="PNG") | |
| return Response(content=buf.getvalue(), media_type="image/png") |