File size: 1,865 Bytes
943fefc
 
 
 
 
 
286af0e
4ea7620
9f27dd6
943fefc
286af0e
edffe22
286af0e
 
 
943fefc
286af0e
943fefc
 
286af0e
 
 
edffe22
943fefc
 
 
286af0e
943fefc
286af0e
 
 
 
edffe22
943fefc
 
 
 
286af0e
 
 
 
943fefc
286af0e
943fefc
286af0e
 
 
 
 
 
 
 
 
943fefc
286af0e
 
943fefc
286af0e
9f27dd6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import base64
import io
import cv2
import numpy as np
from PIL import Image
import torch
from diffusers import AutoPipelineForInpainting

class EndpointHandler:
    def __init__(self, path=""):
        print("[INIT] Loading Nano Banana SDXL Inpainting pipeline...")

        # Load Nano Banana (SDXL fine-tuned)
        self.pipe = AutoPipelineForInpainting.from_pretrained(
            "SG161222/RealVisXL_V4.0_Nano-Banana",
            torch_dtype=torch.float16,
            variant="fp16"
        ).to("cuda")

        # Default high-level removal instruction
        self.default_prompt = "remove text captions, natural background, realistic restoration"
        print("[READY] Nano Banana model loaded successfully.")

    def _decode_image(self, b64_image):
        img_bytes = base64.b64decode(b64_image)
        img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        return img

    def _encode_image(self, pil_img):
        buf = io.BytesIO()
        pil_img.save(buf, format="PNG")
        return base64.b64encode(buf.getvalue()).decode("utf-8")

    def __call__(self, data):
        if "image" not in data["inputs"]:
            raise ValueError("Missing 'image' field in inputs")

        prompt = data["inputs"].get("prompt", self.default_prompt)

        # Decode base64 → PIL
        img_pil = self._decode_image(data["inputs"]["image"])

        print(f"[PROCESS] Running Nano Banana with prompt: '{prompt}'")

        # Inpaint the whole image (no mask — full generative clean-up)
        result = self.pipe(
            prompt=prompt,
            image=img_pil,
            mask_image=None,
            guidance_scale=3.0,
            strength=0.85,
            num_inference_steps=25
        ).images[0]

        # Encode result back to base64
        cleaned_b64 = self._encode_image(result)

        return {"image": cleaned_b64}