jquenum commited on
Commit
2a73a32
·
verified ·
1 Parent(s): 6f10db2

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +38 -90
app.py CHANGED
@@ -1,121 +1,69 @@
1
  """
2
  HuggingFace Space for Clothing Segmentation
3
- Uses DeepLabV3+ ONNX model to segment human body parts
4
  """
5
 
6
  import os
 
7
  import numpy as np
8
  from PIL import Image
9
- import onnxruntime
10
  import gradio as gr
11
- from gradio import processing_utils
12
 
13
- # Labels for the model
14
- LABELS = [
15
- "background", "unknown", "hair", "unknown", "glasses",
16
- "top-clothes", "unknown", "unknown", "unknown", "bottom-clothes",
17
- "torso-skin", "unknown", "unknown", "face", "left-arm",
18
- "right-arm", "left-leg", "right-leg", "left-foot", "right-foot"
19
- ]
20
-
21
- # Clothing-related classes (top-clothes, bottom-clothes)
22
  CLOTHING_CLASSES = [5, 9] # top-clothes, bottom-clothes
23
 
24
- def load_model():
25
- """Load ONNX model"""
26
- model_path = "deeplabv3p-resnet50-human.onnx"
27
- if not os.path.exists(model_path):
28
- # Try to download from HuggingFace
29
- from huggingface_hub import hf_hub_download
30
- model_path = hf_hub_download(
31
- repo_id="Metal3d/deeplabv3p-resnet50-human",
32
- filename="deeplabv3p-resnet50-human.onnx"
33
- )
34
- session = onnxruntime.InferenceSession(model_path)
35
- return session
36
 
37
- # Load model at startup
38
- print("Loading model...")
39
- model = load_model()
40
  print("Model loaded!")
41
 
42
- def preprocess_image(img: Image.Image) -> np.ndarray:
43
  """Preprocess image for model"""
44
  img = img.resize((512, 512))
45
- img_array = np.array(img).astype(np.float32) / 127.5 - 1
46
- # Ensure 3 channels
47
- if len(img_array.shape) == 2:
48
- img_array = np.stack([img_array] * 3, axis=-1)
49
- elif img_array.shape[-1] == 4:
50
- img_array = img_array[:, :, :3]
51
- return img_array
52
 
53
- def segment_clothing(img: Image.Image) -> Image.Image:
54
- """Segment clothing from image"""
55
- # Preprocess
56
- input_data = preprocess_image(img)
57
- input_data = np.transpose(input_data, (2, 0, 1))[np.newaxis, :, :, :]
58
-
59
- # Run inference
60
- input_name = model.get_inputs()[0].name
61
- output_name = model.get_outputs()[0].name
62
- result = model.run([output_name], {input_name: input_data[0]})[0]
63
 
64
- # Get argmax
 
 
 
65
  result = np.argmax(result[0], axis=0)
 
 
66
 
67
- # Create clothing mask (top + bottom clothes)
68
- clothing_mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255
69
-
70
- # Resize back to original
71
- mask_img = Image.fromarray(clothing_mask).resize(img.size, Image.NEAREST)
72
- return mask_img
73
-
74
- def apply_fabric(user_img: Image.Image, fabric_img: Image.Image) -> Image.Image:
75
- """Apply fabric to segmented clothing area"""
76
- # First segment the clothing
77
- mask = segment_clothing(user_img)
78
-
79
- # Convert to numpy
80
- user_arr = np.array(user_img)
81
  fabric_arr = np.array(fabric_img.resize(user_img.size, Image.LANCZOS))
82
- mask_arr = np.array(mask) / 255.0
 
83
 
84
- # Apply: where mask=1, use fabric; where mask=0, use original
85
- result = (fabric_arr * mask_arr[:, :, np.newaxis] +
86
- user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8)
87
 
88
- return Image.fromarray(result)
89
 
90
- # Gradio Interface
91
  with gr.Blocks() as demo:
92
- gr.Markdown("# Clothing Virtual Try-On")
93
- gr.Markdown("Upload your photo and select a fabric to try it on!")
94
-
95
  with gr.Row():
96
  with gr.Column():
97
- user_image = gr.Image(type="pil", label="Your Photo")
98
- fabric_image = gr.Image(type="pil", label="Fabric")
99
-
100
  with gr.Column():
101
- output_image = gr.Image(type="pil", label="Result")
102
-
103
- with gr.Row():
104
- submit_btn = gr.Button("Apply Fabric", variant="primary")
105
-
106
- submit_btn.click(
107
- fn=apply_fabric,
108
- inputs=[user_image, fabric_image],
109
- outputs=output_image
110
- )
111
 
112
- gr.Examples(
113
- examples=[
114
- ["https://example.com/person.jpg", "https://example.com/fabric.jpg"],
115
- ],
116
- inputs=[user_image, fabric_image],
117
- outputs=output_image,
118
- )
119
 
120
- # Launch space
121
  demo.launch(server_port=7860)
 
1
  """
2
  HuggingFace Space for Clothing Segmentation
 
3
  """
4
 
5
  import os
6
+ from pathlib import Path
7
  import numpy as np
8
  from PIL import Image
 
9
  import gradio as gr
10
+ from huggingface_hub import hf_hub_download
11
 
12
+ # Clothing classes from the model
 
 
 
 
 
 
 
 
13
  CLOTHING_CLASSES = [5, 9] # top-clothes, bottom-clothes
14
 
15
+ print("Downloading model...")
16
+ model_path = hf_hub_download(
17
+ repo_id="Metal3d/deeplabv3p-resnet50-human",
18
+ filename="deeplabv3p-resnet50-human.onnx"
19
+ )
20
+ print(f"Model downloaded to: {model_path}")
 
 
 
 
 
 
21
 
22
+ import onnxruntime as ort
23
+ session = ort.InferenceSession(model_path)
 
24
  print("Model loaded!")
25
 
26
+ def preprocess(img):
27
  """Preprocess image for model"""
28
  img = img.resize((512, 512))
29
+ arr = np.array(img).astype(np.float32) / 127.5 - 1
30
+ if len(arr.shape) == 2:
31
+ arr = np.stack([arr] * 3, axis=-1)
32
+ elif arr.shape[-1] == 4:
33
+ arr = arr[:, :, :3]
34
+ return np.transpose(arr, (2, 0, 1))[np.newaxis, :, :, :]
 
35
 
36
+ def process(user_img, fabric_img):
37
+ """Process images"""
38
+ if user_img is None or fabric_img is None:
39
+ return None
 
 
 
 
 
 
40
 
41
+ input_data = preprocess(user_img)
42
+ input_name = session.get_inputs()[0].name
43
+ output_name = session.get_outputs()[0].name
44
+ result = session.run([output_name], {input_name: input_data[0]})[0]
45
  result = np.argmax(result[0], axis=0)
46
+ mask = np.isin(result, CLOTHING_CLASSES).astype(np.uint8) * 255
47
+ mask_img = Image.fromarray(mask).resize(user_img.size, Image.NEAREST)
48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  fabric_arr = np.array(fabric_img.resize(user_img.size, Image.LANCZOS))
50
+ user_arr = np.array(user_img)
51
+ mask_arr = np.array(mask_img) / 255.0
52
 
53
+ output = (fabric_arr * mask_arr[:, :, np.newaxis] +
54
+ user_arr * (1 - mask_arr[:, :, np.newaxis])).astype(np.uint8)
 
55
 
56
+ return Image.fromarray(output)
57
 
 
58
  with gr.Blocks() as demo:
59
+ gr.Markdown("# 👗 Virtual Try-On")
 
 
60
  with gr.Row():
61
  with gr.Column():
62
+ user = gr.Image(type="pil", label="Your Photo")
63
+ fabric = gr.Image(type="pil", label="Fabric")
 
64
  with gr.Column():
65
+ result = gr.Image(type="pil", label="Result")
 
 
 
 
 
 
 
 
 
66
 
67
+ gr.Button("Apply").click(fn=process, inputs=[user, fabric], outputs=result)
 
 
 
 
 
 
68
 
 
69
  demo.launch(server_port=7860)