aladhefafalquran commited on
Commit
24169f7
·
1 Parent(s): a6ff18b

Add PDF support

Browse files
Files changed (3) hide show
  1. app.py +105 -30
  2. packages.txt +1 -0
  3. requirements.txt +2 -1
app.py CHANGED
@@ -2,25 +2,30 @@ import gradio as gr
2
  import numpy as np
3
  import cv2
4
  from simple_lama import SimpleLama
 
 
 
 
5
 
6
  # Initialize model
7
  print("Initializing LaMa model...")
8
  lama = SimpleLama(device='cpu')
9
 
10
- def process_image_dict(image_dict):
11
- """UI Handler: Processes the dictionary from the ImageEditor"""
12
- image = image_dict["background"]
13
-
14
- # Ensure image is RGB (3 channels)
15
  if len(image.shape) == 3 and image.shape[2] == 4:
16
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
17
-
 
 
 
 
18
  mask = np.zeros(image.shape[:2], dtype=np.uint8)
19
 
20
- # Combine layers
21
  if image_dict.get("layers"):
22
  for layer in image_dict["layers"]:
23
- # Check shape to avoid crashes
24
  if len(layer.shape) == 3 and layer.shape[2] == 4:
25
  alpha = layer[:, :, 3]
26
  mask = cv2.bitwise_or(mask, alpha)
@@ -29,33 +34,85 @@ def process_image_dict(image_dict):
29
  _, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
30
  mask = cv2.bitwise_or(mask, thresh)
31
  elif len(layer.shape) == 2:
32
- # Grayscale layer
33
  _, thresh = cv2.threshold(layer, 1, 255, cv2.THRESH_BINARY)
34
  mask = cv2.bitwise_or(mask, thresh)
35
-
36
  _, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
 
 
 
 
 
 
37
  return lama.predict(image, mask)
38
 
39
  def process_simple_api(image, mask):
40
- """API Handler: Accepts two separate images (Original and Mask)"""
41
- # Ensure image is RGB
42
- if len(image.shape) == 3 and image.shape[2] == 4:
43
- image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
 
44
 
45
- # Ensure mask is single channel
46
- if len(mask.shape) == 3:
47
- mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
 
 
48
 
49
- # Binarize mask
50
- _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- return lama.predict(image, mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
- # Build App
55
  with gr.Blocks(title="AI Watermark Remover") as app:
56
  gr.Markdown("# 💧 AI Watermark Remover (LaMa)")
57
 
58
- with gr.Tab("UI Mode"):
59
  with gr.Row():
60
  input_editor = gr.ImageEditor(
61
  label="Draw Mask", type="numpy",
@@ -64,18 +121,36 @@ with gr.Blocks(title="AI Watermark Remover") as app:
64
  )
65
  ui_output = gr.Image(label="Result")
66
  ui_btn = gr.Button("Remove Watermark", variant="primary")
67
-
68
  ui_btn.click(process_image_dict, inputs=input_editor, outputs=ui_output)
69
 
70
- with gr.Tab("API Mode (Simple)"):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  gr.Markdown("Use this endpoint for API calls: `/predict_api`")
72
- api_image = gr.Image(label="Original Image", type="numpy")
73
- api_mask = gr.Image(label="Mask Image (Black/White)", type="numpy")
74
  api_output = gr.Image(label="Result")
75
  api_btn = gr.Button("Run API")
76
-
77
- # This creates the endpoint "/predict_api"
78
  api_btn.click(process_simple_api, inputs=[api_image, api_mask], outputs=api_output, api_name="predict_api")
79
 
80
  if __name__ == "__main__":
81
- app.launch()
 
2
  import numpy as np
3
  import cv2
4
  from simple_lama import SimpleLama
5
+ import pdf2image
6
+ import tempfile
7
+ import os
8
+ from PIL import Image
9
 
10
  # Initialize model
11
  print("Initializing LaMa model...")
12
  lama = SimpleLama(device='cpu')
13
 
14
+ def ensure_rgb(image):
15
+ """Convert RGBA/Grayscale to RGB"""
16
+ if len(image.shape) == 2:
17
+ return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
 
18
  if len(image.shape) == 3 and image.shape[2] == 4:
19
+ return cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
20
+ return image
21
+
22
+ def get_mask_from_dict(image_dict):
23
+ """Extract binary mask from Gradio ImageEditor dictionary"""
24
+ image = image_dict["background"]
25
  mask = np.zeros(image.shape[:2], dtype=np.uint8)
26
 
 
27
  if image_dict.get("layers"):
28
  for layer in image_dict["layers"]:
 
29
  if len(layer.shape) == 3 and layer.shape[2] == 4:
30
  alpha = layer[:, :, 3]
31
  mask = cv2.bitwise_or(mask, alpha)
 
34
  _, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
35
  mask = cv2.bitwise_or(mask, thresh)
36
  elif len(layer.shape) == 2:
 
37
  _, thresh = cv2.threshold(layer, 1, 255, cv2.THRESH_BINARY)
38
  mask = cv2.bitwise_or(mask, thresh)
39
+
40
  _, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
41
+ return mask
42
+
43
+ def process_image_dict(image_dict):
44
+ """Single Image Processing"""
45
+ image = ensure_rgb(image_dict["background"])
46
+ mask = get_mask_from_dict(image_dict)
47
  return lama.predict(image, mask)
48
 
49
  def process_simple_api(image, mask):
50
+ """API Handler"""
51
+ image = ensure_rgb(image)
52
+ if len(mask.shape) == 3: mask = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
53
+ _, mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
54
+ return lama.predict(image, mask)
55
 
56
+ # --- PDF Functions ---
57
+
58
+ def pdf_preview(pdf_file):
59
+ """Convert first page of PDF to image for masking"""
60
+ if pdf_file is None: return None
61
 
62
+ # Convert first page only
63
+ images = pdf2image.convert_from_path(pdf_file.name, first_page=1, last_page=1)
64
+ if images:
65
+ return np.array(images[0])
66
+ return None
67
+
68
+ def process_pdf(pdf_file, image_editor_data):
69
+ """Process all pages in PDF using the mask from the editor"""
70
+ if pdf_file is None or image_editor_data is None:
71
+ return None
72
+
73
+ # 1. Get the mask defined by user on Page 1
74
+ # We ignore the background image here and just want the mask layer
75
+ mask = get_mask_from_dict(image_editor_data)
76
 
77
+ # 2. Convert all PDF pages to images
78
+ print("Converting PDF to images...")
79
+ pages = pdf2image.convert_from_path(pdf_file.name)
80
+
81
+ cleaned_pages = []
82
+ print(f"Processing {len(pages)} pages...")
83
+
84
+ for i, page in enumerate(pages):
85
+ # Convert PIL to Numpy
86
+ img_np = np.array(page)
87
+ img_np = ensure_rgb(img_np)
88
+
89
+ # Resize mask if page sizes differ (simple safety check)
90
+ if img_np.shape[:2] != mask.shape[:2]:
91
+ current_mask = cv2.resize(mask, (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST)
92
+ else:
93
+ current_mask = mask
94
+
95
+ # Run AI
96
+ result = lama.predict(img_np, current_mask)
97
+
98
+ # Convert back to PIL
99
+ cleaned_pages.append(Image.fromarray(result))
100
+
101
+ # 3. Save back to PDF
102
+ output_path = tempfile.mktemp(suffix=".pdf")
103
+ if cleaned_pages:
104
+ cleaned_pages[0].save(output_path, save_all=True, append_images=cleaned_pages[1:])
105
+ return output_path
106
+
107
+ return None
108
+
109
+
110
+ # --- UI Construction ---
111
 
 
112
  with gr.Blocks(title="AI Watermark Remover") as app:
113
  gr.Markdown("# 💧 AI Watermark Remover (LaMa)")
114
 
115
+ with gr.Tab("Image Mode"):
116
  with gr.Row():
117
  input_editor = gr.ImageEditor(
118
  label="Draw Mask", type="numpy",
 
121
  )
122
  ui_output = gr.Image(label="Result")
123
  ui_btn = gr.Button("Remove Watermark", variant="primary")
 
124
  ui_btn.click(process_image_dict, inputs=input_editor, outputs=ui_output)
125
 
126
+ with gr.Tab("PDF Mode"):
127
+ gr.Markdown("### 1. Upload PDF & Preview Page 1")
128
+ with gr.Row():
129
+ pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
130
+ preview_btn = gr.Button("Load Preview")
131
+
132
+ gr.Markdown("### 2. Draw Mask on Page 1 (Applied to ALL pages)")
133
+ pdf_editor = gr.ImageEditor(
134
+ label="Draw Mask Here", type="numpy",
135
+ brush=gr.Brush(colors=["#FF0000"], default_size=20),
136
+ interactive=True
137
+ )
138
+
139
+ gr.Markdown("### 3. Process Full PDF")
140
+ pdf_run_btn = gr.Button("Clean Entire PDF", variant="primary")
141
+ pdf_output = gr.File(label="Download Cleaned PDF")
142
+
143
+ # Wiring
144
+ preview_btn.click(pdf_preview, inputs=pdf_input, outputs=pdf_editor)
145
+ pdf_run_btn.click(process_pdf, inputs=[pdf_input, pdf_editor], outputs=pdf_output)
146
+
147
+ with gr.Tab("API Mode"):
148
  gr.Markdown("Use this endpoint for API calls: `/predict_api`")
149
+ api_image = gr.Image(label="Original", type="numpy")
150
+ api_mask = gr.Image(label="Mask", type="numpy")
151
  api_output = gr.Image(label="Result")
152
  api_btn = gr.Button("Run API")
 
 
153
  api_btn.click(process_simple_api, inputs=[api_image, api_mask], outputs=api_output, api_name="predict_api")
154
 
155
  if __name__ == "__main__":
156
+ app.launch()
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ poppler-utils
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  opencv-python-headless
2
  numpy
3
  onnxruntime
4
- gradio
 
 
1
  opencv-python-headless
2
  numpy
3
  onnxruntime
4
+ gradio
5
+ pdf2image