mastari commited on
Commit
943fefc
·
1 Parent(s): be56361
Files changed (3) hide show
  1. .DS_Store +0 -0
  2. handler.py +79 -162
  3. requirements.txt +7 -4
.DS_Store ADDED
Binary file (6.15 kB). View file
 
handler.py CHANGED
@@ -1,181 +1,98 @@
1
- import base64, cv2, numpy as np, importlib.util
2
- from typing import Dict, Any
3
-
 
 
 
 
 
4
 
5
  class EndpointHandler:
6
- """
7
- Robust hybrid text-removal handler:
8
- • Uses EasyOCR (pixel-level) if available
9
- • Falls back to EAST detector otherwise
10
- • Expands & merges masks for full caption coverage
11
- • Returns both mask overlay and inpainted (cleaned) image
12
- """
13
-
14
- def __init__(self, path: str = ""):
15
- easyocr_spec = importlib.util.find_spec("easyocr")
16
- if easyocr_spec:
17
- import easyocr
18
- self.reader = easyocr.Reader(["en"], gpu=False)
19
- self.use_easyocr = True
20
- print("[INIT] Using EasyOCR text detector")
21
- else:
22
- model_path = f"{path}/frozen_east_text_detection.pb"
23
- self.net = cv2.dnn.readNet(model_path)
24
- self.use_easyocr = False
25
- print(f"[INIT] Using EAST model from {model_path}")
26
-
27
- # ----------------------------- INFERENCE -----------------------------
28
- def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
29
- inputs = data.get("inputs", data)
30
- image_b64 = inputs.get("image")
31
- if not image_b64:
32
- raise ValueError("Missing 'image' in inputs")
33
-
34
- img = self._decode_image(image_b64)
35
- mask = self._make_mask(img)
36
- cleaned = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
37
 
38
- # visualize mask overlay
39
- vis = img.copy()
40
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
41
- cv2.drawContours(vis, contours, -1, (0, 0, 255), 2)
42
 
43
- return {
44
- "mask_image": self._encode_image(vis),
45
- "cleaned_image": self._encode_image(cleaned),
46
- }
 
 
 
47
 
48
- # ----------------------------- UTILITIES -----------------------------
49
- def _decode_image(self, b64):
50
- data = base64.b64decode(b64)
51
- np_arr = np.frombuffer(data, np.uint8)
52
- return cv2.imdecode(np_arr, cv2.IMREAD_COLOR)
53
 
54
- def _encode_image(self, im):
55
- _, buf = cv2.imencode(".png", im)
56
- return base64.b64encode(buf).decode("utf-8")
 
57
 
58
- # ----------------------------- MASK CREATION -----------------------------
59
  def _make_mask(self, img):
60
  mask = np.zeros(img.shape[:2], np.uint8)
61
  h, w = img.shape[:2]
62
 
63
- if self.use_easyocr:
64
- results = self.reader.readtext(img)
65
- for det in results:
66
- try:
67
- box, text, conf = det
68
- if conf < 0.6:
69
- continue
70
-
71
- pts = np.array(box, np.int32)
72
- x, y, bw, bh = cv2.boundingRect(pts)
73
-
74
- # Skip very small noise
75
- if bw < 0.015 * w or bh < 0.015 * h:
76
- continue
77
-
78
- # Calculate local contrast
79
- roi = img[max(0, y):min(h, y + bh), max(0, x):min(w, x + bw)]
80
- gray = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
81
- contrast = gray.std()
82
-
83
- # Base padding — increase if high contrast or thick captions
84
- pad_scale = 0.03 # 3% of image width
85
- pad = max(int(w * pad_scale), 10)
86
- if contrast > 25:
87
- pad = int(pad * 1.5)
88
-
89
- # Expand more vertically — typical caption boxes have extra height
90
- pad_x = pad
91
- pad_y = int(pad * 1.4)
92
-
93
- x0, y0 = max(0, x - pad_x), max(0, y - pad_y)
94
- x1, y1 = min(w, x + bw + pad_x), min(h, y + bh + pad_y)
95
- cv2.rectangle(mask, (x0, y0), (x1, y1), 255, -1)
96
-
97
- except Exception as e:
98
- print(f"[WARN] Skipped invalid detection: {e}")
99
-
100
- else:
101
- boxes = self._east_boxes(img)
102
- for (x0, y0, x1, y1) in boxes:
103
- pad = 10
104
- cv2.rectangle(
105
- mask,
106
- (max(0, x0 - pad), max(0, y0 - pad)),
107
- (min(w, x1 + pad), min(h, y1 + pad)),
108
- 255,
109
- -1,
110
- )
111
-
112
- # Merge nearby boxes and smooth edges
113
  kernel = np.ones((9, 9), np.uint8)
114
  mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=3)
115
  mask = cv2.dilate(mask, kernel, iterations=2)
116
-
117
- # Feather slightly to eliminate border seams
118
- mask = cv2.GaussianBlur(mask, (7, 7), 2)
119
  mask = (mask > 100).astype(np.uint8) * 255
120
 
121
  return mask
122
 
123
- # ----------------------------- EAST FALLBACK -----------------------------
124
- def _east_boxes(self, image, conf_threshold=0.5):
125
- h, w = image.shape[:2]
126
- new_w, new_h = 320, 320
127
- r_w, r_h = w / new_w, h / new_h
128
- blob = cv2.dnn.blobFromImage(
129
- image,
130
- 1.0,
131
- (new_w, new_h),
132
- (123.68, 116.78, 103.94),
133
- swapRB=True,
134
- crop=False,
135
- )
136
- self.net.setInput(blob)
137
- scores, geometry = self.net.forward(
138
- ["feature_fusion/Conv_7/Sigmoid", "feature_fusion/concat_3"]
139
- )
140
- rects, confidences = self._decode(scores, geometry, conf_threshold)
141
- indices = cv2.dnn.NMSBoxes(rects, confidences, conf_threshold, 0.4)
142
- boxes = []
143
- if len(indices) > 0:
144
- for i in indices.flatten():
145
- x0, y0, x1, y1 = rects[i]
146
- boxes.append(
147
- [
148
- max(0, int(x0 * r_w)),
149
- max(0, int(y0 * r_h)),
150
- min(w, int(x1 * r_w)),
151
- min(h, int(y1 * r_h)),
152
- ]
153
- )
154
- return boxes
155
-
156
- def _decode(self, scores, geometry, conf_threshold):
157
- num_rows, num_cols = scores.shape[2:4]
158
- rects, confidences = [], []
159
- for y in range(num_rows):
160
- scores_data = scores[0, 0, y]
161
- x0 = geometry[0, 0, y]
162
- x1 = geometry[0, 1, y]
163
- x2 = geometry[0, 2, y]
164
- x3 = geometry[0, 3, y]
165
- angles = geometry[0, 4, y]
166
- for x in range(num_cols):
167
- if scores_data[x] < conf_threshold:
168
- continue
169
- offset_x, offset_y = x * 4.0, y * 4.0
170
- angle = angles[x]
171
- cos, sin = np.cos(angle), np.sin(angle)
172
- h_ = x0[x] + x2[x]
173
- w_ = x1[x] + x3[x]
174
- end_x = int(offset_x + cos * x1[x] + sin * x2[x])
175
- end_y = int(offset_y - sin * x1[x] + cos * x2[x])
176
- start_x = int(end_x - w_)
177
- start_y = int(end_y - h_)
178
- rects.append((start_x, start_y, end_x, end_y))
179
- confidences.append(float(scores_data[x]))
180
- return rects, confidences
181
 
 
1
+ import base64
2
+ import io
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ from diffusers import StableDiffusionInpaintPipeline
8
+ import easyocr
9
 
10
  class EndpointHandler:
11
+ def __init__(self, path=""):
12
+ print("[INIT] Loading EasyOCR and Stable Diffusion Inpainting model...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
+ # Text detector
15
+ self.reader = easyocr.Reader(["en"], gpu=True)
 
 
16
 
17
+ # SOTA inpainting model
18
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
19
+ "stabilityai/stable-diffusion-2-inpainting",
20
+ torch_dtype=torch.float16,
21
+ ).to("cuda")
22
+
23
+ print("[READY] Handler initialized successfully.")
24
 
25
+ # Decode incoming base64 image → numpy
26
+ def _decode_image(self, b64_image):
27
+ img_bytes = base64.b64decode(b64_image)
28
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
29
+ return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
30
 
31
+ # Encode numpy → base64 PNG
32
+ def _encode_image(self, img):
33
+ _, buffer = cv2.imencode(".png", img)
34
+ return base64.b64encode(buffer).decode("utf-8")
35
 
36
+ # Make mask from detected text boxes
37
  def _make_mask(self, img):
38
  mask = np.zeros(img.shape[:2], np.uint8)
39
  h, w = img.shape[:2]
40
 
41
+ results = self.reader.readtext(img)
42
+ for det in results:
43
+ try:
44
+ box, text, conf = det
45
+ if conf < 0.6:
46
+ continue
47
+
48
+ pts = np.array(box, np.int32)
49
+ x, y, bw, bh = cv2.boundingRect(pts)
50
+ if bw < 0.02 * w or bh < 0.015 * h:
51
+ continue
52
+
53
+ pad_scale = 0.03
54
+ pad = max(int(w * pad_scale), 12)
55
+ pad_x, pad_y = pad, int(pad * 1.4)
56
+ x0, y0 = max(0, x - pad_x), max(0, y - pad_y)
57
+ x1, y1 = min(w, x + bw + pad_x), min(h, y + bh + pad_y)
58
+ cv2.rectangle(mask, (x0, y0), (x1, y1), 255, -1)
59
+ except Exception:
60
+ continue
61
+
62
+ # Merge and feather mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  kernel = np.ones((9, 9), np.uint8)
64
  mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=3)
65
  mask = cv2.dilate(mask, kernel, iterations=2)
66
+ mask = cv2.GaussianBlur(mask, (9, 9), 3)
 
 
67
  mask = (mask > 100).astype(np.uint8) * 255
68
 
69
  return mask
70
 
71
+ def __call__(self, data):
72
+ if "image" not in data["inputs"]:
73
+ raise ValueError("Missing 'image' field in inputs")
74
+
75
+ # Decode input image
76
+ img = self._decode_image(data["inputs"]["image"])
77
+ mask = self._make_mask(img)
78
+
79
+ # Convert to PIL for pipeline
80
+ img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
81
+ mask_pil = Image.fromarray(mask)
82
+
83
+ # Run inpainting (prompt left blank to stay realistic)
84
+ print("[INPAINT] Running Stable Diffusion 2 inpainting...")
85
+ out = self.pipe(prompt="", image=img_pil, mask_image=mask_pil).images[0]
86
+ cleaned = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR)
87
+
88
+ # Optional mask overlay for visualization
89
+ mask_overlay = img.copy()
90
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
91
+ cv2.drawContours(mask_overlay, contours, -1, (0, 0, 255), 2)
92
+
93
+ # Encode results
94
+ return {
95
+ "image": self._encode_image(cleaned),
96
+ "mask_overlay": self._encode_image(mask_overlay),
97
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
requirements.txt CHANGED
@@ -1,7 +1,10 @@
 
 
 
 
 
1
  opencv-python-headless>=4.8.0
2
- numpy>=1.26.0
3
- Pillow
4
- # Optional craft replacement – pure Python, compatible with Py3.11
5
  easyocr>=1.7.1
6
- torch>=2.1.0
 
7
 
 
1
+ torch>=2.1.0
2
+ torchvision
3
+ diffusers>=0.29.0
4
+ transformers>=4.41.0
5
+ accelerate
6
  opencv-python-headless>=4.8.0
 
 
 
7
  easyocr>=1.7.1
8
+ Pillow>=10.2.0
9
+ numpy>=1.26.0
10