jquenum commited on
Commit
9b3c89e
·
verified ·
1 Parent(s): 2a73a32

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +21 -31
app.py CHANGED
@@ -1,30 +1,31 @@
1
  """
2
- HuggingFace Space for Clothing Segmentation
3
  """
4
 
5
- import os
6
- from pathlib import Path
7
  import numpy as np
8
  from PIL import Image
9
- import gradio as gr
 
 
 
 
10
  from huggingface_hub import hf_hub_download
11
 
12
- # Clothing classes from the model
13
- CLOTHING_CLASSES = [5, 9] # top-clothes, bottom-clothes
14
 
15
  print("Downloading model...")
16
  model_path = hf_hub_download(
17
  repo_id="Metal3d/deeplabv3p-resnet50-human",
18
  filename="deeplabv3p-resnet50-human.onnx"
19
  )
20
- print(f"Model downloaded to: {model_path}")
21
 
22
- import onnxruntime as ort
23
  session = ort.InferenceSession(model_path)
24
  print("Model loaded!")
25
 
 
 
26
  def preprocess(img):
27
- """Preprocess image for model"""
28
  img = img.resize((512, 512))
29
  arr = np.array(img).astype(np.float32) / 127.5 - 1
30
  if len(arr.shape) == 2:
@@ -33,37 +34,26 @@ def preprocess(img):
33
  arr = arr[:, :, :3]
34
  return np.transpose(arr, (2, 0, 1))[np.newaxis, :, :, :]
35
 
36
- def process(user_img, fabric_img):
37
- """Process images"""
38
- if user_img is None or fabric_img is None:
39
- return None
40
 
41
- input_data = preprocess(user_img)
42
  input_name = session.get_inputs()[0].name
43
  output_name = session.get_outputs()[0].name
44
  result = session.run([output_name], {input_name: input_data[0]})[0]
45
  result = np.argmax(result[0], axis=0)
46
  mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255
47
- mask_img = Image.fromarray(mask).resize(user_img.size, Image.NEAREST)
48
 
49
- fabric_arr = np.array(fabric_img.resize(user_img.size, Image.LANCZOS))
50
- user_arr = np.array(user_img)
51
  mask_arr = np.array(mask_img) / 255.0
52
 
53
  output = (fabric_arr * mask_arr[:, :, np.newaxis] +
54
  user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8)
55
 
56
- return Image.fromarray(output)
57
-
58
- with gr.Blocks() as demo:
59
- gr.Markdown("# 👗 Virtual Try-On")
60
- with gr.Row():
61
- with gr.Column():
62
- user = gr.Image(type="pil", label="Your Photo")
63
- fabric = gr.Image(type="pil", label="Fabric")
64
- with gr.Column():
65
- result = gr.Image(type="pil", label="Result")
66
-
67
- gr.Button("Apply").click(fn=process, inputs=[user, fabric], outputs=result)
68
-
69
- demo.launch(server_port=7860)
 
1
  """
2
+ Simple clothing segmentation API - No Gradio
3
  """
4
 
 
 
5
  import numpy as np
6
  from PIL import Image
7
+ import io
8
+ import base64
9
+ from fastapi import FastAPI, File, UploadFile
10
+ from fastapi.responses import Response
11
+ import onnxruntime as ort
12
  from huggingface_hub import hf_hub_download
13
 
14
+ app = FastAPI()
 
15
 
16
  print("Downloading model...")
17
  model_path = hf_hub_download(
18
  repo_id="Metal3d/deeplabv3p-resnet50-human",
19
  filename="deeplabv3p-resnet50-human.onnx"
20
  )
21
+ print(f"Model from: {model_path}")
22
 
 
23
  session = ort.InferenceSession(model_path)
24
  print("Model loaded!")
25
 
26
+ CLOTHING_CLASSES = [5, 9]
27
+
28
  def preprocess(img):
 
29
  img = img.resize((512, 512))
30
  arr = np.array(img).astype(np.float32) / 127.5 - 1
31
  if len(arr.shape) == 2:
 
34
  arr = arr[:, :, :3]
35
  return np.transpose(arr, (2, 0, 1))[np.newaxis, :, :, :]
36
 
37
+ @app.post("/process")
38
+ async def process(user_image: UploadFile = File(...), fabric_image: UploadFile = File(...)):
39
+ user = Image.open(io.BytesIO(await user_image.read())).convert("RGB")
40
+ fabric = Image.open(io.BytesIO(await fabric_image.read())).convert("RGB")
41
 
42
+ input_data = preprocess(user)
43
  input_name = session.get_inputs()[0].name
44
  output_name = session.get_outputs()[0].name
45
  result = session.run([output_name], {input_name: input_data[0]})[0]
46
  result = np.argmax(result[0], axis=0)
47
  mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255
48
+ mask_img = Image.fromarray(mask).resize(user.size, Image.NEAREST)
49
 
50
+ fabric_arr = np.array(fabric.resize(user.size, Image.LANCZOS))
51
+ user_arr = np.array(user)
52
  mask_arr = np.array(mask_img) / 255.0
53
 
54
  output = (fabric_arr * mask_arr[:, :, np.newaxis] +
55
  user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8)
56
 
57
+ buf = io.BytesIO()
58
+ Image.fromarray(output).save(buf, format="PNG")
59
+ return Response(content=buf.getvalue(), media_type="image/png")