watermark-remover / simple_lama.py
aladhefafalquran
Increase AI context padding
0d3fffb
import os
import cv2
import numpy as np
import onnxruntime
import urllib.request
class SimpleLama:
def __init__(self, device='cpu'):
self.model_url = "https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx"
self.model_path = "big-lama.onnx"
if not os.path.exists(self.model_path):
print(f"Downloading LaMa model to {self.model_path} (this happens once)...")
try:
# Add headers to avoid 403 Forbidden errors if sites block script bots
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
urllib.request.urlretrieve(self.model_url, self.model_path)
print("Download complete.")
except Exception as e:
print(f"Failed to download model: {e}")
raise
print("Loading LaMa model...")
providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == 'cuda' else ['CPUExecutionProvider']
self.session = onnxruntime.InferenceSession(self.model_path, providers=providers)
def pad_img_to_modulo(self, img, mod):
h, w = img.shape[:2]
h = int(np.ceil(h / mod) * mod)
w = int(np.ceil(w / mod) * mod)
return cv2.copyMakeBorder(img, 0, h - img.shape[0], 0, w - img.shape[1], cv2.BORDER_REFLECT)
def predict(self, image, mask):
# 1. Find bounding box of the mask
# If mask is empty, return original
if np.max(mask) == 0:
return image
rows, cols = np.where(mask > 0)
y1, y2 = np.min(rows), np.max(rows)
x1, x2 = np.min(cols), np.max(cols)
# 2. Add padding to context
# Increased padding from 50 to 200 to give the AI more context of the background
# This helps preventing "flat color" patches on gradients or textured backgrounds.
pad = 200
y1 = max(0, y1 - pad)
y2 = min(image.shape[0], y2 + pad)
x1 = max(0, x1 - pad)
x2 = min(image.shape[1], x2 + pad)
# 3. Crop
crop_img = image[y1:y2, x1:x2]
crop_mask = mask[y1:y2, x1:x2]
crop_h, crop_w = crop_img.shape[:2]
# 4. Resize to 512x512 (Model Expectation)
target_size = (512, 512)
img_resized = cv2.resize(crop_img, target_size, interpolation=cv2.INTER_AREA)
mask_resized = cv2.resize(crop_mask, target_size, interpolation=cv2.INTER_NEAREST)
# Prepare for ONNX
img_onnx = img_resized.astype(np.float32) / 255.0
img_onnx = img_onnx.transpose(2, 0, 1) # HWC -> CHW
img_onnx = img_onnx[None, ...] # Add batch dim
mask_onnx = mask_resized.astype(np.float32) / 255.0
mask_onnx = (mask_onnx > 0).astype(np.float32)
if len(mask_onnx.shape) == 2:
mask_onnx = mask_onnx[None, ...]
mask_onnx = mask_onnx[None, ...] # Add batch/channel dim
# Run inference
outputs = self.session.run(None, {'image': img_onnx, 'mask': mask_onnx})
output = outputs[0][0] # Remove batch dim
# Post-process
output = output.transpose(1, 2, 0) # CHW -> HWC
# Check if output is already 0-255 (The model from HF seems to output 0-255 directly)
if output.max() > 2.0:
# Already 0-255
output = output.clip(0, 255).astype(np.uint8)
else:
# It's 0-1, so scale it
output = (output * 255.0).clip(0, 255).astype(np.uint8)
# 5. Resize back to crop size
output_resized = cv2.resize(output, (crop_w, crop_h), interpolation=cv2.INTER_CUBIC)
# 6. Paste back
result = image.copy()
# We only want to paste the part that was masked + blended, but simple paste is okay for now
# Ideally we blend, but LaMa does inpainting so direct paste usually works.
# To avoid seams on the square border, we can use the mask to only update the painted area + some context,
# but let's just paste the whole square context for now as LaMa regenerates the whole context.
result[y1:y2, x1:x2] = output_resized
return result