import base64 import io import logging import os from typing import Any, Dict import numpy as np from PIL import Image import cv2 import torch from diffusers import StableDiffusionXLImg2ImgPipeline, DPMSolverMultistepScheduler # === LOGGING SETUP === logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) HANDLER_VERSION = "v7-debug" # bump this to force redeploy reload # === IMAGE HELPERS === def _decode_base64_image(b64: str) -> Image.Image: """Decode a base64 string into a PIL RGB image.""" logger.debug("[HANDLER] Decoding base64 image (%d chars)", len(b64)) try: img_bytes = base64.b64decode(b64) image = Image.open(io.BytesIO(img_bytes)).convert("RGB") logger.debug("[HANDLER] ✅ Image decoded successfully: %s, mode=%s", image.size, image.mode) return image except Exception as e: logger.exception("[HANDLER] ❌ Failed to decode base64 image") raise ValueError(f"Invalid base64 image data: {e}") from e def _encode_base64_image(img: Image.Image) -> str: """Encode a PIL RGB image into base64 PNG.""" logger.debug("[HANDLER] Encoding image back to base64: %s", img.size) buf = io.BytesIO() img.save(buf, format="PNG") b64 = base64.b64encode(buf.getvalue()).decode("utf-8") logger.debug("[HANDLER] ✅ Image encoded (len=%d)", len(b64)) return b64 # === CANVAS / MASK CREATION === def _build_canvases_and_mask( pil_image: Image.Image, top: int, bottom: int, left: int, right: int, mask_offset: int = 50, blur_radius: int = 101, max_size: int = 1024, ): """Create Telea-filled canvas and soft mask for blending.""" logger.debug( "[HANDLER] Building canvases: top=%d bottom=%d left=%d right=%d blur=%d offset=%d", top, bottom, left, right, blur_radius, mask_offset, ) np_orig = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR) h, w, _ = np_orig.shape new_h, new_w = h + top + bottom, w + left + right logger.debug("[HANDLER] Original size=(%d,%d) → canvas size=(%d,%d)", h, w, new_h, new_w) base_canvas = np.zeros((new_h, new_w, 3), dtype=np.uint8) base_canvas[top : top + h, left : left + w] = np_orig telea_canvas = base_canvas.copy() inpaint_mask = np.zeros((new_h, new_w), dtype=np.uint8) if top > 0: inpaint_mask[:top, :] = 255 if bottom > 0: inpaint_mask[new_h - bottom :, :] = 255 if left > 0: inpaint_mask[:, :left] = 255 if right > 0: inpaint_mask[:, new_w - right :] = 255 if np.any(inpaint_mask): logger.debug("[HANDLER] Running Telea inpaint on new borders...") telea_canvas = cv2.inpaint(telea_canvas, inpaint_mask, 3, cv2.INPAINT_TELEA) else: logger.debug("[HANDLER] No inpainting needed (no new borders)") hard_mask = np.zeros((new_h, new_w), dtype=np.uint8) if top > 0: hard_mask[: top + mask_offset, :] = 255 if bottom > 0: hard_mask[new_h - (bottom + mask_offset) :, :] = 255 if left > 0: hard_mask[:, : left + mask_offset] = 255 if right > 0: hard_mask[:, new_w - (right + mask_offset) :] = 255 if blur_radius % 2 == 0: blur_radius += 1 blur_radius = max(3, blur_radius) logger.debug("[HANDLER] Blurring mask with radius=%d", blur_radius) soft_mask = cv2.GaussianBlur(hard_mask, (blur_radius, blur_radius), 0) scale = 1.0 max_dim = max(new_h, new_w) if max_dim > max_size: scale = max_size / max_dim logger.debug("[HANDLER] Resizing large canvas by scale=%.3f", scale) new_w_resized, new_h_resized = int(new_w * scale), int(new_h * scale) base_canvas = cv2.resize(base_canvas, (new_w_resized, new_h_resized), interpolation=cv2.INTER_LANCZOS4) telea_canvas = cv2.resize(telea_canvas, (new_w_resized, new_h_resized), interpolation=cv2.INTER_LANCZOS4) soft_mask = cv2.resize(soft_mask, (new_w_resized, new_h_resized), interpolation=cv2.INTER_LANCZOS4) base_pil = Image.fromarray(cv2.cvtColor(base_canvas, cv2.COLOR_BGR2RGB)) telea_pil = Image.fromarray(cv2.cvtColor(telea_canvas, cv2.COLOR_BGR2RGB)) blend_mask = soft_mask.astype(np.float32) / 255.0 logger.debug("[HANDLER] ✅ Canvas/mask ready: base=%s telea=%s mask=%s scale=%.3f", base_pil.size, telea_pil.size, blend_mask.shape, scale) return base_pil, telea_pil, blend_mask # === MAIN HANDLER === class EndpointHandler: def __init__(self, path: str = "") -> None: logger.debug("[HANDLER] v%s __init__ path=%s", HANDLER_VERSION, path) self.device = "cuda" if torch.cuda.is_available() else "cpu" logger.debug("[HANDLER] Using device=%s torch=%s", self.device, torch.__version__) model_id = os.environ.get("MODEL_ID", "SG161222/RealVisXL_V4.0") logger.debug("[HANDLER] Loading pipeline: %s", model_id) self.pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( model_id, torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, variant="fp16" if self.device == "cuda" else None, ) self.pipe.scheduler = DPMSolverMultistepScheduler.from_config( self.pipe.scheduler.config, use_karras_sigmas=True ) self.pipe.to(self.device) self.pipe.enable_attention_slicing("max") logger.debug("[HANDLER] ✅ Model loaded successfully") def predict(self, data: Dict[str, Any]) -> Dict[str, Any]: logger.debug("[HANDLER] 📩 Predict called (keys=%s)", list(data.keys())) payload = data.get("inputs", data) logger.debug("[HANDLER] Payload keys: %s", list(payload.keys())) b64_image = payload.get("image") if not b64_image: raise ValueError("Missing 'image' field") # Parameters top, bottom, left, right = ( int(payload.get("top", 0)), int(payload.get("bottom", 0)), int(payload.get("left", 0)), int(payload.get("right", 0)), ) prompt = payload.get("prompt", "") negative_prompt = payload.get("negative_prompt", "") steps = int(payload.get("num_inference_steps", 25)) guidance = float(payload.get("guidance_scale", 6.0)) strength = float(payload.get("strength", 0.85)) seed = payload.get("seed", None) logger.debug( "[HANDLER] Params top=%d bottom=%d left=%d right=%d steps=%d guide=%.2f strength=%.2f seed=%s", top, bottom, left, right, steps, guidance, strength, seed, ) orig_pil = _decode_base64_image(b64_image) base_pil, telea_pil, blend_mask = _build_canvases_and_mask( orig_pil, top, bottom, left, right, 50, 101, 1024 ) generator = torch.Generator(device=self.device).manual_seed(int(seed)) if seed is not None else None if generator is None: logger.debug("[HANDLER] Using random seed") else: logger.debug("[HANDLER] Using manual seed=%s", seed) logger.debug("[HANDLER] 🚀 Starting diffusion inference...") # === SAFE autocast === device_type = "cuda" if torch.cuda.is_available() else "cpu" try: ctx = torch.amp.autocast(device_type=device_type) logger.debug("[HANDLER] Using torch.amp.autocast(%s)", device_type) except Exception as e: logger.warning("[HANDLER] amp.autocast failed (%s), using legacy torch.autocast", e) ctx = torch.autocast(device_type) with ctx: out = self.pipe( prompt=prompt, negative_prompt=negative_prompt, image=telea_pil, strength=strength, guidance_scale=guidance, num_inference_steps=steps, generator=generator, ) result_pil = out.images[0] logger.debug("[HANDLER] ✅ Diffusion complete, result size=%s", result_pil.size) # === BLENDING === logger.debug("[HANDLER] Blending outputs...") res_np = np.array(result_pil).astype(np.float32) / 255.0 base_np = np.array(base_pil.resize(result_pil.size, Image.LANCZOS)).astype(np.float32) / 255.0 mask = cv2.resize(blend_mask, (result_pil.size[0], result_pil.size[1]))[:, :, None] final_np = np.clip(res_np * mask + base_np * (1.0 - mask), 0, 1) final_pil = Image.fromarray((final_np * 255).astype(np.uint8)) b64_out = _encode_base64_image(final_pil) logger.debug("[HANDLER] ✅ Returning base64 image len=%d", len(b64_out)) return {"image": b64_out}