File size: 1,865 Bytes
943fefc 286af0e 4ea7620 9f27dd6 943fefc 286af0e edffe22 286af0e 943fefc 286af0e 943fefc 286af0e edffe22 943fefc 286af0e 943fefc 286af0e edffe22 943fefc 286af0e 943fefc 286af0e 943fefc 286af0e 943fefc 286af0e 943fefc 286af0e 9f27dd6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import base64
import io
import cv2
import numpy as np
from PIL import Image
import torch
from diffusers import AutoPipelineForInpainting
class EndpointHandler:
def __init__(self, path=""):
print("[INIT] Loading Nano Banana SDXL Inpainting pipeline...")
# Load Nano Banana (SDXL fine-tuned)
self.pipe = AutoPipelineForInpainting.from_pretrained(
"SG161222/RealVisXL_V4.0_Nano-Banana",
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
# Default high-level removal instruction
self.default_prompt = "remove text captions, natural background, realistic restoration"
print("[READY] Nano Banana model loaded successfully.")
def _decode_image(self, b64_image):
img_bytes = base64.b64decode(b64_image)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
return img
def _encode_image(self, pil_img):
buf = io.BytesIO()
pil_img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode("utf-8")
def __call__(self, data):
if "image" not in data["inputs"]:
raise ValueError("Missing 'image' field in inputs")
prompt = data["inputs"].get("prompt", self.default_prompt)
# Decode base64 → PIL
img_pil = self._decode_image(data["inputs"]["image"])
print(f"[PROCESS] Running Nano Banana with prompt: '{prompt}'")
# Inpaint the whole image (no mask — full generative clean-up)
result = self.pipe(
prompt=prompt,
image=img_pil,
mask_image=None,
guidance_scale=3.0,
strength=0.85,
num_inference_steps=25
).images[0]
# Encode result back to base64
cleaned_b64 = self._encode_image(result)
return {"image": cleaned_b64}
|