File size: 1,958 Bytes
3ea407b
9b3c89e
3ea407b
 
 
 
9b3c89e
 
 
 
 
2a73a32
3ea407b
9b3c89e
3ea407b
2a73a32
 
 
 
 
9b3c89e
3ea407b
2a73a32
3ea407b
 
9b3c89e
 
2a73a32
3ea407b
2a73a32
 
 
 
 
 
3ea407b
9b3c89e
 
 
 
3ea407b
9b3c89e
2a73a32
 
 
3ea407b
2a73a32
9b3c89e
3ea407b
9b3c89e
 
2a73a32
3ea407b
2a73a32
 
3ea407b
9b3c89e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
Simple clothing segmentation API - No Gradio
"""

import numpy as np
from PIL import Image
import io
import base64
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import Response
import onnxruntime as ort
from huggingface_hub import hf_hub_download

app = FastAPI()

print("Downloading model...")
model_path = hf_hub_download(
    repo_id="Metal3d/deeplabv3p-resnet50-human",
    filename="deeplabv3p-resnet50-human.onnx"
)
print(f"Model from: {model_path}")

session = ort.InferenceSession(model_path)
print("Model loaded!")

CLOTHING_CLASSES = [5, 9]

def preprocess(img):
    img = img.resize((512, 512))
    arr = np.array(img).astype(np.float32) / 127.5 - 1
    if len(arr.shape) == 2:
        arr = np.stack([arr] * 3, axis=-1)
    elif arr.shape[-1] == 4:
        arr = arr[:, :, :3]
    return np.transpose(arr, (2, 0, 1))[np.newaxis, :, :, :]

@app.post("/process")
async def process(user_image: UploadFile = File(...), fabric_image: UploadFile = File(...)):
    user = Image.open(io.BytesIO(await user_image.read())).convert("RGB")
    fabric = Image.open(io.BytesIO(await fabric_image.read())).convert("RGB")
    
    input_data = preprocess(user)
    input_name = session.get_inputs()[0].name
    output_name = session.get_outputs()[0].name
    result = session.run([output_name], {input_name: input_data[0]})[0]
    result = np.argmax(result[0], axis=0)
    mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255
    mask_img = Image.fromarray(mask).resize(user.size, Image.NEAREST)
    
    fabric_arr = np.array(fabric.resize(user.size, Image.LANCZOS))
    user_arr = np.array(user)
    mask_arr = np.array(mask_img) / 255.0
    
    output = (fabric_arr * mask_arr[:, :, np.newaxis] + 
             user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8)
    
    buf = io.BytesIO()
    Image.fromarray(output).save(buf, format="PNG")
    return Response(content=buf.getvalue(), media_type="image/png")