mohammadhakimi commited on
Commit
8c7592a
·
1 Parent(s): ce450be

add blur strength in inputs

Browse files
Files changed (3) hide show
  1. app.py +10 -5
  2. processing/mask.py +3 -3
  3. processing/mask.pyi +2 -2
app.py CHANGED
@@ -45,7 +45,7 @@ stable_diffusion_negative_prompt = ("low quality, unformed, ugly, low resolution
45
  depth_backbone="vitb_rn50_384")
46
 
47
 
48
- @spaces.GPU(duration=120)
49
  def greet(material_exemplar: Image.Image,
50
  raw_cloth: Image.Image = None,
51
  _controlnet_conditioning_scale=0.9,
@@ -55,7 +55,8 @@ def greet(material_exemplar: Image.Image,
55
  occasion=None,
56
  clothing_type=None,
57
  cloth_attrs=None,
58
- mask_padding=5):
 
59
  """
60
  Compute depth map from input_image
61
  """
@@ -91,7 +92,7 @@ def greet(material_exemplar: Image.Image,
91
 
92
  depth_map = create_depth_map(img, model)
93
 
94
- mask = create_mask(raw_cloth, padding=mask_padding)
95
 
96
  """
97
  Process material exemplar and resize all images
@@ -141,9 +142,12 @@ def get_tab(gender: Literal["Male", "Female", "Unisex"]):
141
  guidance_scale = gr.Slider(minimum=0, maximum=15, step=0.1, value=7.5,
142
  label="Guidance Scale",
143
  interactive=True)
144
- mask_padding = gr.Slider(minimum=0, maximum=15, step=1, value=5,
145
  label="Mask Padding",
146
  interactive=True)
 
 
 
147
  with gr.Column():
148
  material_image = gr.Image(type="pil", label="material examplar")
149
  raw_image_input = gr.Image(type="pil", label="raw image")
@@ -189,7 +193,8 @@ def get_tab(gender: Literal["Male", "Female", "Unisex"]):
189
  occasion,
190
  clothing_type_input,
191
  state,
192
- mask_padding],
 
193
  outputs=[raw_image, output_image, mask_image, depth_image, credits])
194
 
195
 
 
45
  depth_backbone="vitb_rn50_384")
46
 
47
 
48
+ @spaces.GPU(duration=180)
49
  def greet(material_exemplar: Image.Image,
50
  raw_cloth: Image.Image = None,
51
  _controlnet_conditioning_scale=0.9,
 
55
  occasion=None,
56
  clothing_type=None,
57
  cloth_attrs=None,
58
+ mask_padding=0,
59
+ blur_mask=3):
60
  """
61
  Compute depth map from input_image
62
  """
 
92
 
93
  depth_map = create_depth_map(img, model)
94
 
95
+ mask = create_mask(raw_cloth, blur=blur_mask, padding=mask_padding)
96
 
97
  """
98
  Process material exemplar and resize all images
 
142
  guidance_scale = gr.Slider(minimum=0, maximum=15, step=0.1, value=7.5,
143
  label="Guidance Scale",
144
  interactive=True)
145
+ mask_padding = gr.Slider(minimum=-15, maximum=15, step=1, value=0,
146
  label="Mask Padding",
147
  interactive=True)
148
+ blur_mask = gr.Slider(minimum=0, maximum=15, step=1, value=3,
149
+ label="Mask Blur",
150
+ interactive=True)
151
  with gr.Column():
152
  material_image = gr.Image(type="pil", label="material examplar")
153
  raw_image_input = gr.Image(type="pil", label="raw image")
 
193
  occasion,
194
  clothing_type_input,
195
  state,
196
+ mask_padding,
197
+ blur_mask],
198
  outputs=[raw_image, output_image, mask_image, depth_image, credits])
199
 
200
 
processing/mask.py CHANGED
@@ -8,7 +8,7 @@ from torchvision.transforms import Compose
8
  from DPT.dpt.transforms import PrepareForNet, NormalizeImage, Resize
9
 
10
 
11
- def create_mask(image, blur=False, padding=0):
12
  rm_bg = remove(np.array(image), post_process_mask=True, only_mask=True)
13
  rm_bg = Image.fromarray((rm_bg * 255).astype(np.uint8))
14
  rm_bg = rm_bg.resize(image.size, resample=Image.BILINEAR)
@@ -17,8 +17,8 @@ def create_mask(image, blur=False, padding=0):
17
 
18
  # Convert mask back to uint8 for PIL compatibility
19
  pil_mask = Image.fromarray((padded_mask * 255).astype(np.uint8))
20
- if blur:
21
- pil_mask = pil_mask.filter(ImageFilter.GaussianBlur(3))
22
  return pil_mask.resize((1024, 1024))
23
 
24
 
 
8
  from DPT.dpt.transforms import PrepareForNet, NormalizeImage, Resize
9
 
10
 
11
+ def create_mask(image, blur=0, padding=0):
12
  rm_bg = remove(np.array(image), post_process_mask=True, only_mask=True)
13
  rm_bg = Image.fromarray((rm_bg * 255).astype(np.uint8))
14
  rm_bg = rm_bg.resize(image.size, resample=Image.BILINEAR)
 
17
 
18
  # Convert mask back to uint8 for PIL compatibility
19
  pil_mask = Image.fromarray((padded_mask * 255).astype(np.uint8))
20
+ if blur > 0:
21
+ pil_mask = pil_mask.filter(ImageFilter.GaussianBlur(blur))
22
  return pil_mask.resize((1024, 1024))
23
 
24
 
processing/mask.pyi CHANGED
@@ -3,11 +3,11 @@ import torch
3
  from PIL.Image import Image
4
 
5
 
6
- def create_mask(image: Image, blur: bool = ..., padding: int = ...) -> Image:
7
  """
8
  Create a mask from the input image.
9
  :param image: The input image.
10
- :param blur: Whether to blur the mask. Default is True.
11
  :param padding: The padding around the object in the mask. Default is 0.
12
  """
13
  ...
 
3
  from PIL.Image import Image
4
 
5
 
6
+ def create_mask(image: Image, blur: int = ..., padding: int = ...) -> Image:
7
  """
8
  Create a mask from the input image.
9
  :param image: The input image.
10
+ :param blur: Blur strength. Default is 0.
11
  :param padding: The padding around the object in the mask. Default is 0.
12
  """
13
  ...