Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| from simple_lama import SimpleLama | |
| import pdf2image | |
| import tempfile | |
| import os | |
| from PIL import Image | |
| # Initialize model | |
| print("Initializing LaMa model...") | |
| lama = SimpleLama(device='cpu') | |
| def ensure_rgb(image): | |
| """Convert RGBA/Grayscale to RGB""" | |
| if len(image.shape) == 2: | |
| return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB) | |
| if len(image.shape) == 3 and image.shape[2] == 4: | |
| return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB) | |
| return image | |
| def get_mask_from_dict(image_dict): | |
| """Extract binary mask from Gradio ImageEditor dictionary""" | |
| image = image_dict["background"] | |
| mask = np.zeros(image.shape[:2], dtype=np.uint8) | |
| if image_dict.get("layers"): | |
| for layer in image_dict["layers"]: | |
| if len(layer.shape) == 3 and layer.shape[2] == 4: | |
| alpha = layer[:, :, 3] | |
| mask = cv2.bitwise_or(mask, alpha) | |
| elif len(layer.shape) == 3 and layer.shape[2] == 3: | |
| gray = cv2.cvtColor(layer, cv2.COLOR_RGB2GRAY) | |
| _, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY) | |
| mask = cv2.bitwise_or(mask, thresh) | |
| elif len(layer.shape) == 2: | |
| _, thresh = cv2.threshold(layer, 1, 255, cv2.THRESH_BINARY) | |
| mask = cv2.bitwise_or(mask, thresh) | |
| _, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY) | |
| return mask | |
| def process_image_dict(image_dict): | |
| """Single Image Processing""" | |
| image = ensure_rgb(image_dict["background"]) | |
| mask = get_mask_from_dict(image_dict) | |
| return lama.predict(image, mask) | |
| def process_simple_api(image, mask): | |
| """API Handler""" | |
| image = ensure_rgb(image) | |
| if len(mask.shape) == 3: mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY) | |
| _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY) | |
| return lama.predict(image, mask) | |
| # --- PDF Functions --- | |
| def pdf_preview(pdf_file): | |
| """Convert first page of PDF to image for masking""" | |
| if pdf_file is None: return None | |
| # Convert first page only - Use 300 DPI for high quality preview | |
| images = pdf2image.convert_from_path(pdf_file.name, first_page=1, last_page=1, dpi=300) | |
| if images: | |
| return np.array(images[0]) | |
| return None | |
| import fitz # PyMuPDF | |
| from concurrent.futures import ThreadPoolExecutor | |
| def process_pdf(pdf_file, image_editor_data): | |
| """Nuclear Method: Full Page Rasterization""" | |
| if pdf_file is None or image_editor_data is None: | |
| return None | |
| # 1. Get the mask defined by user on Page 1 | |
| full_mask = get_mask_from_dict(image_editor_data) | |
| # Dilate mask slightly to be safe | |
| kernel = np.ones((5,5), np.uint8) | |
| full_mask = cv2.dilate(full_mask, kernel, iterations=3) | |
| # 2. Convert ALL pages to High-Res Images (300 DPI) | |
| # This "flattens" vector graphics into pixels, solving the color profile mismatch. | |
| print("Rasterizing PDF to Images (300 DPI)...") | |
| try: | |
| pages = pdf2image.convert_from_path(pdf_file.name, dpi=300) | |
| except Exception as e: | |
| print(f"Error converting PDF: {e}") | |
| return None | |
| cleaned_pages = [] | |
| total_pages = len(pages) | |
| print(f"Processing {total_pages} pages...") | |
| for i, page in enumerate(pages): | |
| # Convert PIL to Numpy | |
| img_np = np.array(page) | |
| # Ensure RGB | |
| if len(img_np.shape) == 2: | |
| img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB) | |
| elif len(img_np.shape) == 3 and img_np.shape[2] == 4: | |
| img_np = cv2.cvtColor(img_np, cv2.COLOR_RGBA2RGB) | |
| # Resize mask if page size differs from preview | |
| if img_np.shape[:2] != full_mask.shape[:2]: | |
| current_mask = cv2.resize(full_mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| else: | |
| current_mask = full_mask | |
| # Run AI (Inpainting) | |
| # Since input is now RGB pixels, the AI's RGB output will blend much better. | |
| result = lama.predict(img_np, current_mask) | |
| # Convert back to PIL for PDF saving | |
| cleaned_pages.append(Image.fromarray(result)) | |
| print(f"Processed page {i+1}/{total_pages}") | |
| # 3. Save back to PDF with Max Quality | |
| output_path = tempfile.mktemp(suffix=".pdf") | |
| if cleaned_pages: | |
| cleaned_pages[0].save( | |
| output_path, | |
| save_all=True, | |
| append_images=cleaned_pages[1:], | |
| quality=100, # Max JPEG quality | |
| resolution=300.0, # Maintain High DPI | |
| subsampling=0 # Disable chroma subsampling for sharper colors | |
| ) | |
| return output_path | |
| return None | |
| # --- UI Construction --- | |
| with gr.Blocks(title="AI Watermark Remover") as app: | |
| gr.Markdown("# 💧 AI Watermark Remover (LaMa)") | |
| with gr.Tab("Image Mode"): | |
| with gr.Row(): | |
| input_editor = gr.ImageEditor( | |
| label="Draw Mask", type="numpy", | |
| brush=gr.Brush(colors=["#FF0000"], default_size=20), | |
| interactive=True | |
| ) | |
| ui_output = gr.Image(label="Result") | |
| ui_btn = gr.Button("Remove Watermark", variant="primary") | |
| ui_btn.click(process_image_dict, inputs=input_editor, outputs=ui_output) | |
| with gr.Tab("PDF Mode"): | |
| gr.Markdown("### 1. Upload PDF & Preview Page 1") | |
| with gr.Row(): | |
| pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"]) | |
| preview_btn = gr.Button("Load Preview") | |
| gr.Markdown("### 2. Draw Mask on Page 1 (Applied to ALL pages)") | |
| pdf_editor = gr.ImageEditor( | |
| label="Draw Mask Here", type="numpy", | |
| brush=gr.Brush(colors=["#FF0000"], default_size=20), | |
| interactive=True | |
| ) | |
| gr.Markdown("### 3. Process Full PDF") | |
| pdf_run_btn = gr.Button("Clean Entire PDF", variant="primary") | |
| pdf_output = gr.File(label="Download Cleaned PDF") | |
| # Wiring | |
| preview_btn.click(pdf_preview, inputs=pdf_input, outputs=pdf_editor) | |
| pdf_run_btn.click(process_pdf, inputs=[pdf_input, pdf_editor], outputs=pdf_output) | |
| with gr.Tab("API Mode"): | |
| gr.Markdown("Use this endpoint for API calls: `/predict_api`") | |
| api_image = gr.Image(label="Original", type="numpy") | |
| api_mask = gr.Image(label="Mask", type="numpy") | |
| api_output = gr.Image(label="Result") | |
| api_btn = gr.Button("Run API") | |
| api_btn.click(process_simple_api, inputs=[api_image, api_mask], outputs=api_output, api_name="predict_api") | |
| if __name__ == "__main__": | |
| app.launch() | |