enoky commited on
Commit
79bdec3
·
verified ·
1 Parent(s): db1a689

add mask dilation

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -21,7 +21,6 @@ def load_models():
21
  print("Loading LaMa Inpainting Model...")
22
  # 2. LaMa Inpainting Model (TorchScript)
23
  # We download the .pt file directly from a repository that hosts the compiled JIT version.
24
- # This avoids dealing with .ckpt files and source code dependencies.
25
  try:
26
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
27
 
@@ -89,13 +88,20 @@ def run_local_lama(image_bgr, mask_float):
89
  image_bgr: HxWx3 uint8 numpy array
90
  mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
91
  """
 
 
 
 
 
 
 
92
  # 1. Resize to be divisible by 8 (LaMa requirement)
93
  h, w = image_bgr.shape[:2]
94
  new_h = (h // 8) * 8
95
  new_w = (w // 8) * 8
96
 
97
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
98
- mask_resized = cv2.resize(mask_float, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
99
 
100
  # 2. Convert to Torch Tensors
101
  # Image: (1, 3, H, W), RGB, 0-1
@@ -104,7 +110,7 @@ def run_local_lama(image_bgr, mask_float):
104
  img_t = img_t[:, [2, 1, 0], :, :]
105
 
106
  # Mask: (1, 1, H, W), 0-1
107
- mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0)
108
  # Binary threshold just in case
109
  mask_t = (mask_t > 0.5).float()
110
 
@@ -112,6 +118,9 @@ def run_local_lama(image_bgr, mask_float):
112
  mask_t = mask_t.to(device)
113
 
114
  # 3. Inference
 
 
 
115
  inpainted_t = lama_model(img_t, mask_t)
116
 
117
  # 4. Post-process
 
21
  print("Loading LaMa Inpainting Model...")
22
  # 2. LaMa Inpainting Model (TorchScript)
23
  # We download the .pt file directly from a repository that hosts the compiled JIT version.
 
24
  try:
25
  model_path = hf_hub_download(repo_id="fashn-ai/LaMa", filename="big-lama.pt")
26
 
 
88
  image_bgr: HxWx3 uint8 numpy array
89
  mask_float: HxW float32 numpy array (1.0 = hole, 0.0 = valid)
90
  """
91
+ # 0. Dilate Mask (Fixes smearing/streaking)
92
+ # We expand the "hole" area (values of 1) to cover the jagged edges
93
+ # created by the pixel shift. This forces LaMa to regenerate the boundary.
94
+ kernel = np.ones((5, 5), np.uint8)
95
+ mask_uint8 = (mask_float * 255).astype(np.uint8)
96
+ mask_dilated = cv2.dilate(mask_uint8, kernel, iterations=1)
97
+
98
  # 1. Resize to be divisible by 8 (LaMa requirement)
99
  h, w = image_bgr.shape[:2]
100
  new_h = (h // 8) * 8
101
  new_w = (w // 8) * 8
102
 
103
  img_resized = cv2.resize(image_bgr, (new_w, new_h))
104
+ mask_resized = cv2.resize(mask_dilated, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
105
 
106
  # 2. Convert to Torch Tensors
107
  # Image: (1, 3, H, W), RGB, 0-1
 
110
  img_t = img_t[:, [2, 1, 0], :, :]
111
 
112
  # Mask: (1, 1, H, W), 0-1
113
+ mask_t = torch.from_numpy(mask_resized).float().unsqueeze(0).unsqueeze(0) / 255.0
114
  # Binary threshold just in case
115
  mask_t = (mask_t > 0.5).float()
116
 
 
118
  mask_t = mask_t.to(device)
119
 
120
  # 3. Inference
121
+ # LaMa expects the image to be masked (zeroed out) in the hole regions for best results
122
+ img_t = img_t * (1 - mask_t)
123
+
124
  inpainted_t = lama_model(img_t, mask_t)
125
 
126
  # 4. Post-process