|
|
import base64 |
|
|
import io |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import torch |
|
|
from diffusers import AutoPipelineForInpainting |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path=""): |
|
|
print("[INIT] Loading Nano Banana SDXL Inpainting pipeline...") |
|
|
|
|
|
|
|
|
self.pipe = AutoPipelineForInpainting.from_pretrained( |
|
|
"SG161222/RealVisXL_V4.0_Nano-Banana", |
|
|
torch_dtype=torch.float16, |
|
|
variant="fp16" |
|
|
).to("cuda") |
|
|
|
|
|
|
|
|
self.default_prompt = "remove text captions, natural background, realistic restoration" |
|
|
print("[READY] Nano Banana model loaded successfully.") |
|
|
|
|
|
def _decode_image(self, b64_image): |
|
|
img_bytes = base64.b64decode(b64_image) |
|
|
img = Image.open(io.BytesIO(img_bytes)).convert("RGB") |
|
|
return img |
|
|
|
|
|
def _encode_image(self, pil_img): |
|
|
buf = io.BytesIO() |
|
|
pil_img.save(buf, format="PNG") |
|
|
return base64.b64encode(buf.getvalue()).decode("utf-8") |
|
|
|
|
|
def __call__(self, data): |
|
|
if "image" not in data["inputs"]: |
|
|
raise ValueError("Missing 'image' field in inputs") |
|
|
|
|
|
prompt = data["inputs"].get("prompt", self.default_prompt) |
|
|
|
|
|
|
|
|
img_pil = self._decode_image(data["inputs"]["image"]) |
|
|
|
|
|
print(f"[PROCESS] Running Nano Banana with prompt: '{prompt}'") |
|
|
|
|
|
|
|
|
result = self.pipe( |
|
|
prompt=prompt, |
|
|
image=img_pil, |
|
|
mask_image=None, |
|
|
guidance_scale=3.0, |
|
|
strength=0.85, |
|
|
num_inference_steps=25 |
|
|
).images[0] |
|
|
|
|
|
|
|
|
cleaned_b64 = self._encode_image(result) |
|
|
|
|
|
return {"image": cleaned_b64} |
|
|
|
|
|
|