Spaces:
Running
Running
add mask dilation
Browse files
app.py
CHANGED
|
@@ -21,7 +21,6 @@ def load_models():
|
|
| 21 |
print("Loading LaMa Inpainting Model...")
|
| 22 |
# 2. LaMa Inpainting Model (TorchScript)
|
| 23 |
# We download the .pt file directly from a repository that hosts the compiled JIT version.
|
| 24 |
-
# This avoids dealing with .ckpt files and source code dependencies.
|
| 25 |
try:
|
| 26 |
model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
|
| 27 |
|
|
@@ -89,13 +88,20 @@ def run_local_lama(image_bgr, mask_float):
|
|
| 89 |
image_bgr: HxWx3 uint8 numpy array
|
| 90 |
mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
|
| 91 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
# 1. Resize to be divisible by 8 (LaMa requirement)
|
| 93 |
h, w = image_bgr.shape[:2]
|
| 94 |
new_h = (h // 8) * 8
|
| 95 |
new_w = (w // 8) * 8
|
| 96 |
|
| 97 |
img_resized = cv2.resize(image_bgr, (new_w, new_h))
|
| 98 |
-
mask_resized = cv2.resize(
|
| 99 |
|
| 100 |
# 2. Convert to Torch Tensors
|
| 101 |
# Image: (1, 3, H, W), RGB, 0-1
|
|
@@ -104,7 +110,7 @@ def run_local_lama(image_bgr, mask_float):
|
|
| 104 |
img_t = img_t[:, [2, 1, 0], :, :]
|
| 105 |
|
| 106 |
# Mask: (1, 1, H, W), 0-1
|
| 107 |
-
mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)
|
| 108 |
# Binary threshold just in case
|
| 109 |
mask_t = (mask_t > 0.5).float()
|
| 110 |
|
|
@@ -112,6 +118,9 @@ def run_local_lama(image_bgr, mask_float):
|
|
| 112 |
mask_t = mask_t.to(device)
|
| 113 |
|
| 114 |
# 3. Inference
|
|
|
|
|
|
|
|
|
|
| 115 |
inpainted_t = lama_model(img_t, mask_t)
|
| 116 |
|
| 117 |
# 4. Post-process
|
|
|
|
| 21 |
print("Loading LaMa Inpainting Model...")
|
| 22 |
# 2. LaMa Inpainting Model (TorchScript)
|
| 23 |
# We download the .pt file directly from a repository that hosts the compiled JIT version.
|
|
|
|
| 24 |
try:
|
| 25 |
model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
|
| 26 |
|
|
|
|
| 88 |
image_bgr: HxWx3 uint8 numpy array
|
| 89 |
mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
|
| 90 |
"""
|
| 91 |
+
# 0. Dilate Mask (Fixes smearing/streaking)
|
| 92 |
+
# We expand the "hole" area (values of 1) to cover the jagged edges
|
| 93 |
+
# created by the pixel shift. This forces LaMa to regenerate the boundary.
|
| 94 |
+
kernel = np.ones((5, 5), np.uint8)
|
| 95 |
+
mask_uint8 = (mask_float * 255).astype(np.uint8)
|
| 96 |
+
mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
|
| 97 |
+
|
| 98 |
# 1. Resize to be divisible by 8 (LaMa requirement)
|
| 99 |
h, w = image_bgr.shape[:2]
|
| 100 |
new_h = (h // 8) * 8
|
| 101 |
new_w = (w // 8) * 8
|
| 102 |
|
| 103 |
img_resized = cv2.resize(image_bgr, (new_w, new_h))
|
| 104 |
+
mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
|
| 105 |
|
| 106 |
# 2. Convert to Torch Tensors
|
| 107 |
# Image: (1, 3, H, W), RGB, 0-1
|
|
|
|
| 110 |
img_t = img_t[:, [2, 1, 0], :, :]
|
| 111 |
|
| 112 |
# Mask: (1, 1, H, W), 0-1
|
| 113 |
+
mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
|
| 114 |
# Binary threshold just in case
|
| 115 |
mask_t = (mask_t > 0.5).float()
|
| 116 |
|
|
|
|
| 118 |
mask_t = mask_t.to(device)
|
| 119 |
|
| 120 |
# 3. Inference
|
| 121 |
+
# LaMa expects the image to be masked (zeroed out) in the hole regions for best results
|
| 122 |
+
img_t = img_t * (1 - mask_t)
|
| 123 |
+
|
| 124 |
inpainted_t = lama_model(img_t, mask_t)
|
| 125 |
|
| 126 |
# 4. Post-process
|