""" Simple clothing segmentation API - No Gradio """ import numpy as np from PIL import Image import io import base64 from fastapi import FastAPI, File, UploadFile from fastapi.responses import Response import onnxruntime as ort from huggingface_hub import hf_hub_download app = FastAPI() print("Downloading model...") model_path = hf_hub_download( repo_id="Metal3d/deeplabv3p-resnet50-human", filename="deeplabv3p-resnet50-human.onnx" ) print(f"Model from: {model_path}") session = ort.InferenceSession(model_path) print("Model loaded!") CLOTHING_CLASSES = [5, 9] def preprocess(img): img = img.resize((512, 512)) arr = np.array(img).astype(np.float32) / 127.5 - 1 if len(arr.shape) == 2: arr = np.stack([arr] * 3, axis=-1) elif arr.shape[-1] == 4: arr = arr[:, :, :3] return np.transpose(arr, (2, 0, 1))[np.newaxis, :, :, :] @app.post("/process") async def process(user_image: UploadFile = File(...), fabric_image: UploadFile = File(...)): user = Image.open(io.BytesIO(await user_image.read())).convert("RGB") fabric = Image.open(io.BytesIO(await fabric_image.read())).convert("RGB") input_data = preprocess(user) input_name = session.get_inputs()[0].name output_name = session.get_outputs()[0].name result = session.run([output_name], {input_name: input_data[0]})[0] result = np.argmax(result[0], axis=0) mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255 mask_img = Image.fromarray(mask).resize(user.size, Image.NEAREST) fabric_arr = np.array(fabric.resize(user.size, Image.LANCZOS)) user_arr = np.array(user) mask_arr = np.array(mask_img) / 255.0 output = (fabric_arr * mask_arr[:, :, np.newaxis] + user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8) buf = io.BytesIO() Image.fromarray(output).save(buf, format="PNG") return Response(content=buf.getvalue(), media_type="image/png")