added combine_masks
Browse files
app.py
CHANGED
|
@@ -12,7 +12,7 @@ import argparse
|
|
| 12 |
# Load configuration and models
|
| 13 |
config = OmegaConf.load("config/inference_config.yaml")
|
| 14 |
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 15 |
-
"
|
| 16 |
)
|
| 17 |
clipaway = CLIPAway(
|
| 18 |
sd_pipe=sd_pipeline,
|
|
@@ -26,11 +26,19 @@ clipaway = CLIPAway(
|
|
| 26 |
)
|
| 27 |
|
| 28 |
def dilate_mask(mask, kernel_size=5, iterations=5):
|
| 29 |
-
mask = mask.convert("L")
|
| 30 |
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 31 |
mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
|
| 32 |
return Image.fromarray(mask)
|
| 33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
def remove_obj(image, uploaded_mask, seed):
|
| 35 |
image_pil, sketched_mask = image["image"], image["mask"]
|
| 36 |
mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
|
|
|
|
| 12 |
# Load configuration and models
|
| 13 |
config = OmegaConf.load("config/inference_config.yaml")
|
| 14 |
sd_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
| 15 |
+
"runwayml/stable-diffusion-inpainting", torch_dtype=torch.float32
|
| 16 |
)
|
| 17 |
clipaway = CLIPAway(
|
| 18 |
sd_pipe=sd_pipeline,
|
|
|
|
| 26 |
)
|
| 27 |
|
| 28 |
def dilate_mask(mask, kernel_size=5, iterations=5):
|
| 29 |
+
mask = mask.convert("L")
|
| 30 |
kernel = np.ones((kernel_size, kernel_size), np.uint8)
|
| 31 |
mask = cv2.dilate(np.array(mask), kernel, iterations=iterations)
|
| 32 |
return Image.fromarray(mask)
|
| 33 |
|
| 34 |
+
def combine_masks(uploaded_mask, sketched_mask):
|
| 35 |
+
if uploaded_mask is not None:
|
| 36 |
+
return uploaded_mask
|
| 37 |
+
elif sketched_mask is not None:
|
| 38 |
+
return sketched_mask
|
| 39 |
+
else:
|
| 40 |
+
raise ValueError("Please provide a mask")
|
| 41 |
+
|
| 42 |
def remove_obj(image, uploaded_mask, seed):
|
| 43 |
image_pil, sketched_mask = image["image"], image["mask"]
|
| 44 |
mask = dilate_mask(combine_masks(uploaded_mask, sketched_mask))
|