Spaces:
Runtime error
Runtime error
File size: 1,958 Bytes
3ea407b 9b3c89e 3ea407b 9b3c89e 2a73a32 3ea407b 9b3c89e 3ea407b 2a73a32 9b3c89e 3ea407b 2a73a32 3ea407b 9b3c89e 2a73a32 3ea407b 2a73a32 3ea407b 9b3c89e 3ea407b 9b3c89e 2a73a32 3ea407b 2a73a32 9b3c89e 3ea407b 9b3c89e 2a73a32 3ea407b 2a73a32 3ea407b 9b3c89e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 | """
Simple clothing segmentation API - No Gradio
"""
import numpy as np
from PIL import Image
import io
import base64
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import Response
import onnxruntime as ort
from huggingface_hub import hf_hub_download
app = FastAPI()
print("Downloading model...")
model_path = hf_hub_download(
repo_id="Metal3d/deeplabv3p-resnet50-human",
filename="deeplabv3p-resnet50-human.onnx"
)
print(f"Model from: {model_path}")
session = ort.InferenceSession(model_path)
print("Model loaded!")
CLOTHING_CLASSES = [5, 9]
def preprocess(img):
img = img.resize((512, 512))
arr = np.array(img).astype(np.float32) / 127.5 - 1
if len(arr.shape) == 2:
arr = np.stack([arr] * 3, axis=-1)
elif arr.shape[-1] == 4:
arr = arr[:, :, :3]
return np.transpose(arr, (2, 0, 1))[np.newaxis, :, :, :]
@app.post("/process")
async def process(user_image: UploadFile = File(...), fabric_image: UploadFile = File(...)):
user = Image.open(io.BytesIO(await user_image.read())).convert("RGB")
fabric = Image.open(io.BytesIO(await fabric_image.read())).convert("RGB")
input_data = preprocess(user)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: input_data[0]})[0]
result = np.argmax(result[0], axis=0)
mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255
mask_img = Image.fromarray(mask).resize(user.size, Image.NEAREST)
fabric_arr = np.array(fabric.resize(user.size, Image.LANCZOS))
user_arr = np.array(user)
mask_arr = np.array(mask_img) / 255.0
output = (fabric_arr * mask_arr[:, :, np.newaxis] +
user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8)
buf = io.BytesIO()
Image.fromarray(output).save(buf, format="PNG")
return Response(content=buf.getvalue(), media_type="image/png") |