mastari commited on
Commit
286af0e
·
1 Parent(s): 943fefc
Files changed (2) hide show
  1. handler.py +31 -70
  2. requirements.txt +0 -1
handler.py CHANGED
@@ -4,95 +4,56 @@ import cv2
4
  import numpy as np
5
  from PIL import Image
6
  import torch
7
- from diffusers import StableDiffusionInpaintPipeline
8
- import easyocr
9
 
10
  class EndpointHandler:
11
  def __init__(self, path=""):
12
- print("[INIT] Loading EasyOCR and Stable Diffusion Inpainting model...")
13
 
14
- # Text detector
15
- self.reader = easyocr.Reader(["en"], gpu=True)
16
-
17
- # SOTA inpainting model
18
- self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
19
- "stabilityai/stable-diffusion-2-inpainting",
20
  torch_dtype=torch.float16,
 
21
  ).to("cuda")
22
 
23
- print("[READY] Handler initialized successfully.")
 
 
24
 
25
- # Decode incoming base64 image → numpy
26
  def _decode_image(self, b64_image):
27
  img_bytes = base64.b64decode(b64_image)
28
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
29
- return cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
30
-
31
- # Encode numpy → base64 PNG
32
- def _encode_image(self, img):
33
- _, buffer = cv2.imencode(".png", img)
34
- return base64.b64encode(buffer).decode("utf-8")
35
-
36
- # Make mask from detected text boxes
37
- def _make_mask(self, img):
38
- mask = np.zeros(img.shape[:2], np.uint8)
39
- h, w = img.shape[:2]
40
-
41
- results = self.reader.readtext(img)
42
- for det in results:
43
- try:
44
- box, text, conf = det
45
- if conf < 0.6:
46
- continue
47
 
48
- pts = np.array(box, np.int32)
49
- x, y, bw, bh = cv2.boundingRect(pts)
50
- if bw < 0.02 * w or bh < 0.015 * h:
51
- continue
52
-
53
- pad_scale = 0.03
54
- pad = max(int(w * pad_scale), 12)
55
- pad_x, pad_y = pad, int(pad * 1.4)
56
- x0, y0 = max(0, x - pad_x), max(0, y - pad_y)
57
- x1, y1 = min(w, x + bw + pad_x), min(h, y + bh + pad_y)
58
- cv2.rectangle(mask, (x0, y0), (x1, y1), 255, -1)
59
- except Exception:
60
- continue
61
-
62
- # Merge and feather mask
63
- kernel = np.ones((9, 9), np.uint8)
64
- mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel, iterations=3)
65
- mask = cv2.dilate(mask, kernel, iterations=2)
66
- mask = cv2.GaussianBlur(mask, (9, 9), 3)
67
- mask = (mask > 100).astype(np.uint8) * 255
68
-
69
- return mask
70
 
71
  def __call__(self, data):
72
  if "image" not in data["inputs"]:
73
  raise ValueError("Missing 'image' field in inputs")
74
 
75
- # Decode input image
76
- img = self._decode_image(data["inputs"]["image"])
77
- mask = self._make_mask(img)
 
78
 
79
- # Convert to PIL for pipeline
80
- img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
81
- mask_pil = Image.fromarray(mask)
82
 
83
- # Run inpainting (prompt left blank to stay realistic)
84
- print("[INPAINT] Running Stable Diffusion 2 inpainting...")
85
- out = self.pipe(prompt="", image=img_pil, mask_image=mask_pil).images[0]
86
- cleaned = cv2.cvtColor(np.array(out), cv2.COLOR_RGB2BGR)
 
 
 
 
 
87
 
88
- # Optional mask overlay for visualization
89
- mask_overlay = img.copy()
90
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
91
- cv2.drawContours(mask_overlay, contours, -1, (0, 0, 255), 2)
92
 
93
- # Encode results
94
- return {
95
- "image": self._encode_image(cleaned),
96
- "mask_overlay": self._encode_image(mask_overlay),
97
- }
98
 
 
4
  import numpy as np
5
  from PIL import Image
6
  import torch
7
+ from diffusers import AutoPipelineForInpainting
 
8
 
9
  class EndpointHandler:
10
  def __init__(self, path=""):
11
+ print("[INIT] Loading Nano Banana SDXL Inpainting pipeline...")
12
 
13
+ # Load Nano Banana (SDXL fine-tuned)
14
+ self.pipe = AutoPipelineForInpainting.from_pretrained(
15
+ "SG161222/RealVisXL_V4.0_Nano-Banana",
 
 
 
16
  torch_dtype=torch.float16,
17
+ variant="fp16"
18
  ).to("cuda")
19
 
20
+ # Default high-level removal instruction
21
+ self.default_prompt = "remove text captions, natural background, realistic restoration"
22
+ print("[READY] Nano Banana model loaded successfully.")
23
 
 
24
  def _decode_image(self, b64_image):
25
  img_bytes = base64.b64decode(b64_image)
26
  img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
27
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ def _encode_image(self, pil_img):
30
+ buf = io.BytesIO()
31
+ pil_img.save(buf, format="PNG")
32
+ return base64.b64encode(buf.getvalue()).decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  def __call__(self, data):
35
  if "image" not in data["inputs"]:
36
  raise ValueError("Missing 'image' field in inputs")
37
 
38
+ prompt = data["inputs"].get("prompt", self.default_prompt)
39
+
40
+ # Decode base64 → PIL
41
+ img_pil = self._decode_image(data["inputs"]["image"])
42
 
43
+ print(f"[PROCESS] Running Nano Banana with prompt: '{prompt}'")
 
 
44
 
45
+ # Inpaint the whole image (no mask full generative clean-up)
46
+ result = self.pipe(
47
+ prompt=prompt,
48
+ image=img_pil,
49
+ mask_image=None,
50
+ guidance_scale=3.0,
51
+ strength=0.85,
52
+ num_inference_steps=25
53
+ ).images[0]
54
 
55
+ # Encode result back to base64
56
+ cleaned_b64 = self._encode_image(result)
 
 
57
 
58
+ return {"image": cleaned_b64}
 
 
 
 
59
 
requirements.txt CHANGED
@@ -4,7 +4,6 @@ diffusers>=0.29.0
4
  transformers>=4.41.0
5
  accelerate
6
  opencv-python-headless>=4.8.0
7
- easyocr>=1.7.1
8
  Pillow>=10.2.0
9
  numpy>=1.26.0
10
 
 
4
  transformers>=4.41.0
5
  accelerate
6
  opencv-python-headless>=4.8.0
 
7
  Pillow>=10.2.0
8
  numpy>=1.26.0
9