Spaces:
Sleeping
Sleeping
| """ | |
| Lama-Cleaner: Image Inpainting with LaMa | |
| CPU inference for HuggingFace Spaces free tier | |
| Based on https://github.com/Sanster/lama-cleaner | |
| """ | |
| import argparse | |
| import gc | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| # Force CPU | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "" | |
| DEVICE = torch.device("cpu") | |
| # Model info | |
| HF_REPO = "fashn-ai/LaMa" | |
| MODEL_FILE = "big-lama.pt" | |
| CACHE_DIR = Path("models") | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| # Global model (lazy loaded) | |
| MODEL = None | |
| def download_model(): | |
| """Download LaMa model from HuggingFace Hub""" | |
| model_path = CACHE_DIR / MODEL_FILE | |
| if not model_path.exists(): | |
| print(f"Downloading {MODEL_FILE}...") | |
| hf_hub_download( | |
| repo_id=HF_REPO, | |
| filename=MODEL_FILE, | |
| local_dir=CACHE_DIR, | |
| ) | |
| return model_path | |
| def load_model(): | |
| """Load model (lazy loading to save memory)""" | |
| global MODEL | |
| if MODEL is not None: | |
| return MODEL | |
| print("Loading LaMa model...") | |
| model_path = download_model() | |
| MODEL = torch.jit.load(str(model_path), map_location=DEVICE) | |
| MODEL.eval() | |
| gc.collect() | |
| print("Model loaded!") | |
| return MODEL | |
| def norm_img(np_img): | |
| """Normalize image: HWC -> CHW, uint8 -> float32 [0,1] | |
| Matches original lama_cleaner/helper.py norm_img() | |
| """ | |
| if len(np_img.shape) == 2: | |
| np_img = np_img[:, :, np.newaxis] | |
| np_img = np.transpose(np_img, (2, 0, 1)) # HWC -> CHW | |
| np_img = np_img.astype("float32") / 255 | |
| return np_img | |
| def ceil_modulo(x, mod): | |
| """Ceil to nearest multiple of mod""" | |
| if x % mod == 0: | |
| return x | |
| return (x // mod + 1) * mod | |
| def pad_img_to_modulo(img, mod=8): | |
| """Pad image to be divisible by mod | |
| Matches original lama_cleaner/helper.py pad_img_to_modulo() | |
| """ | |
| if len(img.shape) == 2: | |
| img = img[:, :, np.newaxis] | |
| height, width = img.shape[:2] | |
| out_height = ceil_modulo(height, mod) | |
| out_width = ceil_modulo(width, mod) | |
| return np.pad( | |
| img, | |
| ((0, out_height - height), (0, out_width - width), (0, 0)), | |
| mode="symmetric", | |
| ) | |
| def inpaint(image: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| """ | |
| Inpaint image using LaMa model. | |
| Matches original lama_cleaner/model/lama.py forward() | |
| Args: | |
| image: RGB image [H, W, 3] uint8 | |
| mask: Binary mask [H, W] uint8, 255 = area to inpaint, 0 = keep | |
| Returns: | |
| Inpainted RGB image [H, W, 3] uint8 | |
| """ | |
| model = load_model() | |
| orig_h, orig_w = image.shape[:2] | |
| # Ensure image is RGB (3 channels) | |
| if len(image.shape) == 3 and image.shape[2] == 4: | |
| image = image[:, :, :3] | |
| # Pad to mod 8 | |
| pad_image = pad_img_to_modulo(image, mod=8) | |
| pad_mask = pad_img_to_modulo(mask, mod=8) | |
| # Normalize: HWC -> CHW, [0,255] -> [0,1] | |
| image_norm = norm_img(pad_image) | |
| mask_norm = norm_img(pad_mask) | |
| # Binary mask | |
| mask_norm = (mask_norm > 0) * 1 | |
| # Convert to tensor and add batch dimension | |
| image_tensor = torch.from_numpy(image_norm).unsqueeze(0).to(DEVICE) | |
| mask_tensor = torch.from_numpy(mask_norm).unsqueeze(0).to(DEVICE) | |
| # Inference | |
| with torch.no_grad(): | |
| inpainted = model(image_tensor, mask_tensor) | |
| # Convert back to numpy: [1,C,H,W] -> [H,W,C] | |
| result = inpainted[0].permute(1, 2, 0).cpu().numpy() | |
| result = np.clip(result * 255, 0, 255).astype(np.uint8) | |
| # Crop to original size | |
| result = result[:orig_h, :orig_w] | |
| # Result is RGB, convert to BGR for blending | |
| result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) | |
| # Blend: only replace masked area (like original _pad_forward) | |
| mask_blend = mask[:, :, np.newaxis].astype(np.float32) / 255.0 | |
| image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) | |
| blended = result_bgr * mask_blend + image_bgr * (1 - mask_blend) | |
| blended = blended.astype(np.uint8) | |
| # Convert back to RGB for output | |
| result_rgb = cv2.cvtColor(blended, cv2.COLOR_BGR2RGB) | |
| gc.collect() | |
| return result_rgb | |
| def process_image(editor_data, progress=None): | |
| """Process image from Gradio ImageEditor""" | |
| if editor_data is None: | |
| return None, "Please upload an image and draw a mask" | |
| # Extract image and mask from editor data | |
| if isinstance(editor_data, dict): | |
| background = editor_data.get("background") | |
| layers = editor_data.get("layers", []) | |
| composite = editor_data.get("composite") | |
| if background is None: | |
| return None, "Please upload an image" | |
| # Handle background - could be numpy array or file path | |
| if isinstance(background, str): | |
| # File path | |
| background = np.array(Image.open(background).convert("RGB")) | |
| elif isinstance(background, np.ndarray): | |
| # Ensure RGB | |
| if len(background.shape) == 3 and background.shape[2] == 4: | |
| background = cv2.cvtColor(background, cv2.COLOR_RGBA2RGB) | |
| else: | |
| return None, "Invalid image format" | |
| # Get mask from layers | |
| mask = None | |
| if layers and len(layers) > 0: | |
| mask_layer = layers[0] | |
| if isinstance(mask_layer, str): | |
| # File path | |
| mask_img = Image.open(mask_layer) | |
| if mask_img.mode == "RGBA": | |
| mask = np.array(mask_img)[:, :, 3] # Use alpha as mask | |
| else: | |
| mask = np.array(mask_img.convert("L")) | |
| elif isinstance(mask_layer, np.ndarray): | |
| if len(mask_layer.shape) == 3: | |
| if mask_layer.shape[2] == 4: | |
| mask = mask_layer[:, :, 3] # Use alpha as mask | |
| else: | |
| mask = cv2.cvtColor(mask_layer, cv2.COLOR_RGB2GRAY) | |
| else: | |
| mask = mask_layer | |
| if mask is None: | |
| return None, "Please draw a mask on the image" | |
| image = background | |
| else: | |
| return None, "Invalid input format" | |
| # Binarize mask (like original: cv2.threshold) | |
| _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) | |
| # Check if mask has any content | |
| if mask.max() == 0: | |
| return None, "Please draw a mask on the area you want to remove" | |
| # Inpaint | |
| result = inpaint(image, mask) | |
| return result, "Inpainting complete!" | |
| def cli_inpaint(image_path: str, mask_path: str, output_path: str): | |
| """CLI mode for inpainting""" | |
| # Load image (RGB) | |
| image = cv2.imread(image_path) | |
| if image is None: | |
| print(f"Error: Could not load image from {image_path}") | |
| sys.exit(1) | |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
| # Load mask (grayscale) | |
| mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE) | |
| if mask is None: | |
| print(f"Error: Could not load mask from {mask_path}") | |
| sys.exit(1) | |
| # Binarize mask | |
| _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) | |
| print(f"Input image: {image.shape}") | |
| print(f"Mask: {mask.shape}") | |
| # Inpaint | |
| result = inpaint(image, mask) | |
| # Save result (convert to BGR for cv2.imwrite) | |
| result_bgr = cv2.cvtColor(result, cv2.COLOR_RGB2BGR) | |
| cv2.imwrite(output_path, result_bgr) | |
| print(f"Result saved to {output_path}") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Lama-Cleaner: Image Inpainting") | |
| subparsers = parser.add_subparsers(dest="command") | |
| # Inpaint command | |
| inpaint_parser = subparsers.add_parser("inpaint", help="Inpaint an image") | |
| inpaint_parser.add_argument("-i", "--image", required=True, help="Input image path") | |
| inpaint_parser.add_argument("-m", "--mask", required=True, help="Mask image path (white = area to inpaint)") | |
| inpaint_parser.add_argument("-o", "--output", required=True, help="Output image path") | |
| args = parser.parse_args() | |
| if args.command == "inpaint": | |
| cli_inpaint(args.image, args.mask, args.output) | |
| else: | |
| # No command = launch Gradio UI | |
| launch_gradio() | |
| def launch_gradio(): | |
| """Launch Gradio UI""" | |
| import gradio as gr | |
| description = """ | |
| # Lama-Cleaner: Image Inpainting | |
| Remove unwanted objects from your images using LaMa (Large Mask Inpainting). | |
| **How to use:** | |
| 1. Upload an image | |
| 2. Draw over the area you want to remove (use the brush tool) | |
| 3. Click "Remove Object" | |
| """ | |
| with gr.Blocks(title="Lama-Cleaner") as demo: | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_editor = gr.ImageEditor( | |
| label="Draw mask on area to remove", | |
| type="numpy", | |
| brush=gr.Brush(colors=["#FFFFFF"], default_size=30), | |
| eraser=gr.Eraser(default_size=30), | |
| ) | |
| process_btn = gr.Button("Remove Object", variant="primary", size="lg") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Result") | |
| status = gr.Textbox(label="Status", interactive=False) | |
| process_btn.click( | |
| fn=process_image, | |
| inputs=[image_editor], | |
| outputs=[output_image, status], | |
| api_name="inpaint", | |
| ) | |
| gr.Markdown(""" | |
| ## Tips | |
| - Draw a white mask over the area you want to remove | |
| - For best results, extend the mask slightly beyond the object | |
| - LaMa works best for small to medium sized areas | |
| """) | |
| demo.queue().launch() | |
| if __name__ == "__main__": | |
| if len(sys.argv) > 1: | |
| main() | |
| else: | |
| launch_gradio() | |