ankara-tryon / app.py
jquenum's picture
Upload app.py with huggingface_hub
9b3c89e verified
"""
Simple clothing segmentation API - No Gradio
"""
import numpy as np
from PIL import Image
import io
import base64
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import Response
import onnxruntime as ort
from huggingface_hub import hf_hub_download
app = FastAPI()
print("Downloading model...")
model_path = hf_hub_download(
repo_id="Metal3d/deeplabv3p-resnet50-human",
filename="deeplabv3p-resnet50-human.onnx"
)
print(f"Model from: {model_path}")
session = ort.InferenceSession(model_path)
print("Model loaded!")
CLOTHING_CLASSES = [5, 9]
def preprocess(img):
img = img.resize((512, 512))
arr = np.array(img).astype(np.float32) / 127.5 - 1
if len(arr.shape) == 2:
arr = np.stack([arr] * 3, axis=-1)
elif arr.shape[-1] == 4:
arr = arr[:, :, :3]
return np.transpose(arr, (2, 0, 1))[np.newaxis, :, :, :]
@app.post("/process")
async def process(user_image: UploadFile = File(...), fabric_image: UploadFile = File(...)):
user = Image.open(io.BytesIO(await user_image.read())).convert("RGB")
fabric = Image.open(io.BytesIO(await fabric_image.read())).convert("RGB")
input_data = preprocess(user)
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
result = session.run([output_name], {input_name: input_data[0]})[0]
result = np.argmax(result[0], axis=0)
mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255
mask_img = Image.fromarray(mask).resize(user.size, Image.NEAREST)
fabric_arr = np.array(fabric.resize(user.size, Image.LANCZOS))
user_arr = np.array(user)
mask_arr = np.array(mask_img) / 255.0
output = (fabric_arr * mask_arr[:, :, np.newaxis] +
user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8)
buf = io.BytesIO()
Image.fromarray(output).save(buf, format="PNG")
return Response(content=buf.getvalue(), media_type="image/png")