Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| import onnxruntime | |
| import urllib.request | |
| class SimpleLama: | |
| def __init__(self, device='cpu'): | |
| self.model_url = "https://huggingface.co/Carve/LaMa-ONNX/resolve/main/lama_fp32.onnx" | |
| self.model_path = "big-lama.onnx" | |
| if not os.path.exists(self.model_path): | |
| print(f"Downloading LaMa model to {self.model_path} (this happens once)...") | |
| try: | |
| # Add headers to avoid 403 Forbidden errors if sites block script bots | |
| opener = urllib.request.build_opener() | |
| opener.addheaders = [('User-agent', 'Mozilla/5.0')] | |
| urllib.request.install_opener(opener) | |
| urllib.request.urlretrieve(self.model_url, self.model_path) | |
| print("Download complete.") | |
| except Exception as e: | |
| print(f"Failed to download model: {e}") | |
| raise | |
| print("Loading LaMa model...") | |
| providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if device == 'cuda' else ['CPUExecutionProvider'] | |
| self.session = onnxruntime.InferenceSession(self.model_path, providers=providers) | |
| def pad_img_to_modulo(self, img, mod): | |
| h, w = img.shape[:2] | |
| h = int(np.ceil(h / mod) * mod) | |
| w = int(np.ceil(w / mod) * mod) | |
| return cv2.copyMakeBorder(img, 0, h - img.shape[0], 0, w - img.shape[1], cv2.BORDER_REFLECT) | |
| def predict(self, image, mask): | |
| # 1. Find bounding box of the mask | |
| # If mask is empty, return original | |
| if np.max(mask) == 0: | |
| return image | |
| rows, cols = np.where(mask > 0) | |
| y1, y2 = np.min(rows), np.max(rows) | |
| x1, x2 = np.min(cols), np.max(cols) | |
| # 2. Add padding to context | |
| # Increased padding from 50 to 200 to give the AI more context of the background | |
| # This helps preventing "flat color" patches on gradients or textured backgrounds. | |
| pad = 200 | |
| y1 = max(0, y1 - pad) | |
| y2 = min(image.shape[0], y2 + pad) | |
| x1 = max(0, x1 - pad) | |
| x2 = min(image.shape[1], x2 + pad) | |
| # 3. Crop | |
| crop_img = image[y1:y2, x1:x2] | |
| crop_mask = mask[y1:y2, x1:x2] | |
| crop_h, crop_w = crop_img.shape[:2] | |
| # 4. Resize to 512x512 (Model Expectation) | |
| target_size = (512, 512) | |
| img_resized = cv2.resize(crop_img, target_size, interpolation=cv2.INTER_AREA) | |
| mask_resized = cv2.resize(crop_mask, target_size, interpolation=cv2.INTER_NEAREST) | |
| # Prepare for ONNX | |
| img_onnx = img_resized.astype(np.float32) / 255.0 | |
| img_onnx = img_onnx.transpose(2, 0, 1) # HWC -> CHW | |
| img_onnx = img_onnx[None, ...] # Add batch dim | |
| mask_onnx = mask_resized.astype(np.float32) / 255.0 | |
| mask_onnx = (mask_onnx > 0).astype(np.float32) | |
| if len(mask_onnx.shape) == 2: | |
| mask_onnx = mask_onnx[None, ...] | |
| mask_onnx = mask_onnx[None, ...] # Add batch/channel dim | |
| # Run inference | |
| outputs = self.session.run(None, {'image': img_onnx, 'mask': mask_onnx}) | |
| output = outputs[0][0] # Remove batch dim | |
| # Post-process | |
| output = output.transpose(1, 2, 0) # CHW -> HWC | |
| # Check if output is already 0-255 (The model from HF seems to output 0-255 directly) | |
| if output.max() > 2.0: | |
| # Already 0-255 | |
| output = output.clip(0, 255).astype(np.uint8) | |
| else: | |
| # It's 0-1, so scale it | |
| output = (output * 255.0).clip(0, 255).astype(np.uint8) | |
| # 5. Resize back to crop size | |
| output_resized = cv2.resize(output, (crop_w, crop_h), interpolation=cv2.INTER_CUBIC) | |
| # 6. Paste back | |
| result = image.copy() | |
| # We only want to paste the part that was masked + blended, but simple paste is okay for now | |
| # Ideally we blend, but LaMa does inpainting so direct paste usually works. | |
| # To avoid seams on the square border, we can use the mask to only update the painted area + some context, | |
| # but let's just paste the whole square context for now as LaMa regenerates the whole context. | |
| result[y1:y2, x1:x2] = output_resized | |
| return result | |