import os import io import time import logging import asyncio from contextlib import asynccontextmanager from concurrent.futures import ThreadPoolExecutor import cv2 import torch import numpy as np from fastapi import FastAPI, UploadFile, File, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response, JSONResponse from PIL import Image, ImageFilter from simple_lama_inpainting import SimpleLama from ultralytics import FastSAM from basicsr.archs.rrdbnet_arch import RRDBNet from realesrgan import RealESRGANer from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from rembg import remove as rembg_remove, new_session as rembg_new_session # ───────────────────────────────────────────── # Logging # ──────────────────────────────────────────── logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ───────────────────────────────────────────── # Global model cache + thread pool # ───────────────────────────────────────────── lama_model: SimpleLama | None = None fastsam_model: FastSAM | None = None clipseg_processor = None clipseg_model = None # Real-ESRGAN upsampler instances realesrgan_x2: RealESRGANer | None = None realesrgan_x4: RealESRGANer | None = None # ─── FREE TIER CPU OPTIMIZATIONS ───────────────────────────────────────────── # Real-ESRGAN is massively heavy. On a 2-vCPU Hugging Face space, processing multiple # images at once will thrash the OS context switcher, causing a massive performance collapse. # By forcing 1 internally, and queueing 1 at a time, we ensure maximum throughput. import torch torch.set_num_threads(1) # Dedicated thread pool keeps them OFF the async event loop so FastAPI remains responsive. _executor = ThreadPoolExecutor(max_workers=1) # ───────────────────────────────────────────── # Model weight paths (baked into Docker image) # ───────────────────────────────────────────── WEIGHTS_DIR = os.path.join(os.path.dirname(__file__), "weights") REALESRGAN_X2_PATH = os.path.join(WEIGHTS_DIR, "RealESRGAN_x2plus.pth") REALESRGAN_X4_PATH = os.path.join(WEIGHTS_DIR, "RealESRGAN_x4plus.pth") FASTSAM_MODEL_PATH = os.path.join(WEIGHTS_DIR, "FastSAM-s.pt") # ───────────────────────────────────────────── # Inpainting resolution caps # ───────────────────────────────────────────── FREE_MAX_PX = 720 # free users — fast inference PREMIUM_MAX_PX = 1280 # premium — higher quality inpainting def _load_realesrgan() -> None: """Load Real-ESRGAN models for ×2 and ×4 super-resolution.""" global realesrgan_x2, realesrgan_x4 device = torch.device("cpu") # Hugging Face free tier is CPU-only for scale, path, var in [(2, REALESRGAN_X2_PATH, "realesrgan_x2"), (4, REALESRGAN_X4_PATH, "realesrgan_x4")]: if not os.path.exists(path): logger.warning(f"Real-ESRGAN x{scale} weights not found at {path} — upscaling at x{scale} disabled.") continue try: # Check if weights need wrapping (for models like 4x-UltraSharp that are bare state dicts) try: sd = torch.load(path, map_location="cpu") if "params_ema" not in sd and "params" not in sd: logger.info(f"Wrapping state dict for {path} to be compatible with RealESRGANer...") torch.save({"params_ema": sd}, path) except Exception as e: logger.error(f"Error checking/wrapping weights {path}: {e}") # RRDBNet is the generator backbone of Real-ESRGAN model = RRDBNet( num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=scale ) upsampler = RealESRGANer( scale=scale, model_path=path, model=model, tile=128, # ✨ CPU OPTIMIZATION: 128px tiles fit beautifully in CPU cache, avoiding RAM bottleneck tile_pad=10, pre_pad=0, half=False, # CPU requires float32 device=device, ) globals()[var] = upsampler logger.info(f"✅ Real-ESRGAN x{scale} loaded.") except Exception as e: logger.error(f"Failed to load Real-ESRGAN x{scale}: {e}", exc_info=True) # ───────────────────────────────────────────── # Startup: pre-load all models once # ──────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): global lama_model, fastsam_model, clipseg_processor, clipseg_model logger.info("Loading FastSAM model…") try: # Try local weight path first, then fallback to download if os.path.exists(FASTSAM_MODEL_PATH): logger.info(f"Loading FastSAM from local path: {FASTSAM_MODEL_PATH}") fastsam_model = FastSAM(FASTSAM_MODEL_PATH) else: logger.warning(f"FastSAM weights not found at {FASTSAM_MODEL_PATH}, downloading...") fastsam_model = FastSAM("FastSAM-s.pt") except Exception as e: logger.error(f"FastSAM load failed: {e}") logger.info("Loading CLIPSeg model for text-based removal…") try: clipseg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") except Exception as e: logger.error(f"CLIPSeg load failed: {e}") logger.info("Loading LaMa model…") try: lama_model = SimpleLama() except Exception as e: logger.error(f"LaMa load failed: {e}") logger.info("Loading Real-ESRGAN upscaling models…") _load_realesrgan() logger.info("✅ All models ready.") yield logger.info("Shutting down.") # ───────────────────────────────────────────── # App # ───────────────────────────────────────────── app = FastAPI( title="Vanishly AI Backend", description="High-performance AI inpainting + super-resolution powered by LaMa & Real-ESRGAN.", version="4.0.0", lifespan=lifespan, ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["GET", "POST"], allow_headers=["*"], ) # ───────────────────────────────────────────── # Helpers # ───────────────────────────────────────────── def _fit(img: Image.Image, max_px: int) -> Image.Image: """Downscale so the longest edge ≤ max_px. Never upscales.""" w, h = img.size if max(w, h) <= max_px: return img ratio = max_px / max(w, h) nw = max(2, round(w * ratio) & ~1) nh = max(2, round(h * ratio) & ~1) return img.resize((nw, nh), Image.LANCZOS) def _encode_jpeg(img: Image.Image, quality: int = 88) -> bytes: buf = io.BytesIO() img.convert("RGB").save(buf, format="JPEG", quality=quality, optimize=False, subsampling=2) return buf.getvalue() def _encode_png_fast(img: Image.Image) -> bytes: buf = io.BytesIO() img.save(buf, format="PNG", compress_level=1, optimize=False) return buf.getvalue() def _run_lama(src: Image.Image, msk: Image.Image) -> Image.Image: return lama_model(src, msk) def _run_fastsam(src: Image.Image, px: int, py: int) -> Image.Image | None: results = fastsam_model( src, imgsz=512, conf=0.3, iou=0.7, points=[[px, py]], labels=[1], retina_masks=False, ) if not results or results[0].masks is None: return None mask_array = results[0].masks.data[0].cpu().numpy() mask_img = Image.fromarray((mask_array * 255).astype(np.uint8), mode="L") if mask_img.size != src.size: mask_img = mask_img.resize(src.size, Image.NEAREST) return mask_img def _run_clipseg(src: Image.Image, prompt: str) -> Image.Image | None: if clipseg_processor is None or clipseg_model is None: return None inputs = clipseg_processor(text=[prompt], images=[src], padding="max_length", return_tensors="pt") with torch.no_grad(): outputs = clipseg_model(**inputs) preds = outputs.logits.squeeze() preds = torch.sigmoid(preds) # Lowered threshold from 0.4 to 0.25 for better object detection # Adjust based on results: lower = more aggressive, higher = more conservative mask_np = (preds > 0.25).cpu().numpy().astype(np.uint8) * 255 kernel = np.ones((9, 9), np.uint8) mask_np = cv2.dilate(mask_np, kernel, iterations=2) mask_img = Image.fromarray(mask_np, mode="L") if mask_img.size != src.size: mask_img = mask_img.resize(src.size, Image.NEAREST) return mask_img def _run_superres(src: Image.Image, scale: int) -> Image.Image: """ AI Super-Resolution via Real-ESRGAN. Real-ESRGAN is a GAN-based model trained on millions of real-world photos. It genuinely reconstructs texture detail — hair strands, skin pores, fabric weave, foliage — rather than just interpolating pixels like FSRCNN. tile=256 means large images are processed in overlapping 256px tiles, keeping CPU RAM usage manageable regardless of input size. """ upsampler = realesrgan_x2 if scale == 2 else realesrgan_x4 if upsampler is None: # Fallback if model failed to load logger.warning(f"Real-ESRGAN x{scale} not available, falling back to LANCZOS+sharpen") w, h = src.size upscaled = src.resize((w * scale, h * scale), Image.LANCZOS) return upscaled.filter(ImageFilter.UnsharpMask(radius=2.0, percent=200, threshold=2)) logger.info(f" Real-ESRGAN x{scale}: input={src.size}") # Real-ESRGAN expects a uint8 numpy BGR array (OpenCV format) img_np = np.array(src.convert("RGB")) img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) # Run the GAN-based upscale — produces genuinely sharp, detailed output output_bgr, _ = upsampler.enhance(img_bgr, outscale=scale) # Convert back to PIL output_rgb = cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB) result = Image.fromarray(output_rgb) logger.info(f" Real-ESRGAN x{scale}: output={result.size}") return result # ───────────────────────────────────────────── # Routes # ───────────────────────────────────────────── def _run_remove_bg(src: Image.Image) -> Image.Image: """ Remove background using rembg (u2net ONNX model). Returns an RGBA image with the background made transparent. """ # rembg expects / returns PIL Images when using the PIL interface result = rembg_remove(src) return result @app.get("/", tags=["Health"]) async def root(): return JSONResponse({ "status": "ok", "model": "lama+realesrgan", "ready": lama_model is not None, "upscale_x2": realesrgan_x2 is not None, "upscale_x4": realesrgan_x4 is not None, }) @app.get("/health", tags=["Health"]) async def health(): if lama_model is None or fastsam_model is None: raise HTTPException(status_code=503, detail="Models not yet loaded") return JSONResponse({ "status": "healthy", "upscale_x2": realesrgan_x2 is not None, "upscale_x4": realesrgan_x4 is not None, }) @app.post("/segment", tags=["Segmentation"]) async def segment_image( image: UploadFile = File(...), normX: float = Form(0.5), normY: float = Form(0.5), ): """ Point-tap segmentation via FastSAM. Returns a grayscale PNG mask (white = selected object). """ if fastsam_model is None: raise HTTPException(status_code=503, detail="FastSAM not loaded") if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): raise HTTPException(status_code=415, detail="Unsupported image format") t0 = time.perf_counter() try: img_data = await image.read() src = Image.open(io.BytesIO(img_data)).convert("RGB") src_small = _fit(src, 512) px = int(normX * src_small.width) py = int(normY * src_small.height) logger.info(f"Segmenting @ ({px},{py}) on {src_small.size}") loop = asyncio.get_event_loop() mask_img = await loop.run_in_executor(_executor, _run_fastsam, src_small, px, py) if mask_img is None: mask_img = Image.new("L", src.size, 0) elif mask_img.size != src.size: mask_img = mask_img.resize(src.size, Image.NEAREST) output_bytes = _encode_png_fast(mask_img) logger.info(f"Segmentation done in {time.perf_counter()-t0:.2f}s, {len(output_bytes)//1024}KB") return Response(content=output_bytes, media_type="image/png", headers={"X-Model": "fastsam"}) except Exception as e: logger.error(f"Segmentation failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Segmentation error: {str(e)}") @app.post("/segment_text", tags=["Segmentation"]) async def segment_text( image: UploadFile = File(...), prompt: str = Form(...), ): """ Text-based segmentation via CLIPSeg. Returns a grayscale PNG mask (white = selected object). """ if clipseg_model is None: raise HTTPException(status_code=503, detail="CLIPSeg not loaded") if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): raise HTTPException(status_code=415, detail="Unsupported image format") t0 = time.perf_counter() try: img_data = await image.read() src = Image.open(io.BytesIO(img_data)).convert("RGB") src_small = _fit(src, 512) logger.info(f"Segmenting text '{prompt}' on {src_small.size}") loop = asyncio.get_event_loop() mask_img = await loop.run_in_executor(_executor, _run_clipseg, src_small, prompt) if mask_img is None: mask_img = Image.new("L", src.size, 0) elif mask_img.size != src.size: mask_img = mask_img.resize(src.size, Image.NEAREST) output_bytes = _encode_png_fast(mask_img) logger.info(f"Text Segmentation done in {time.perf_counter()-t0:.2f}s, {len(output_bytes)//1024}KB") return Response(content=output_bytes, media_type="image/png", headers={"X-Model": "clipseg"}) except Exception as e: logger.error(f"Text Segmentation failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Text Segmentation error: {str(e)}") @app.post("/text_erase", tags=["AI Edit"]) async def text_erase( image: UploadFile = File(...), prompt: str = Form(...), premium: str = Form("false"), ): """ AI Text-Based Object Removal — single endpoint. Pipeline: Image + "remove the person" → CLIPSeg mask → LaMa inpaint → Result. Returns the edited image directly — no separate mask needed. premium=true → up to 1280px, JPEG 92. premium=false → capped at 720px, JPEG 85. """ if clipseg_model is None or lama_model is None: raise HTTPException(status_code=503, detail="Models not loaded") if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): raise HTTPException(status_code=415, detail="Unsupported image format") is_premium = premium.lower() in ("true", "1", "yes") max_px = PREMIUM_MAX_PX if is_premium else FREE_MAX_PX resp_q = 92 if is_premium else 85 t0 = time.perf_counter() try: img_data = await image.read() src = Image.open(io.BytesIO(img_data)).convert("RGB") orig_size = src.size # Step 1: Generate mask from text prompt via CLIPSeg src_small = _fit(src, 512) logger.info(f"Text erase '{prompt}' — generating mask on {src_small.size}") loop = asyncio.get_event_loop() mask_img = await loop.run_in_executor(_executor, _run_clipseg, src_small, prompt) if mask_img is None: logger.warning("CLIPSeg returned None — nothing detected for prompt") mask_img = Image.new("L", orig_size, 0) elif mask_img.size != orig_size: mask_img = mask_img.resize(orig_size, Image.NEAREST) # Step 2: High-Resolution Cropped Inpainting msk_np = np.array(mask_img) kernel = np.ones((9, 9), np.uint8) msk_dilated_np = cv2.dilate(msk_np, kernel, iterations=2) msk_dilated = Image.fromarray(msk_dilated_np) coords = cv2.findNonZero(msk_dilated_np) if coords is not None: x, y, w, h = cv2.boundingRect(coords) pad_x = max(128, int(w * 0.5)) pad_y = max(128, int(h * 0.5)) x1 = max(0, x - pad_x) y1 = max(0, y - pad_y) x2 = min(orig_size[0], x + w + pad_x) y2 = min(orig_size[1], y + h + pad_y) crop_src = src.crop((x1, y1, x2, y2)) crop_msk = msk_dilated.crop((x1, y1, x2, y2)) # Process crop at native resolution (cap at 2560 to prevent OOM) inpaint_src = _fit(crop_src, 2560) if crop_msk.size != inpaint_src.size: inpaint_msk = crop_msk.resize(inpaint_src.size, Image.NEAREST) else: inpaint_msk = crop_msk logger.info(f"Text erase inpainting crop {crop_src.size}→{inpaint_src.size} (premium={is_premium})") inpainted_small = await loop.run_in_executor(_executor, _run_lama, inpaint_src, inpaint_msk) if inpainted_small.size != crop_src.size: inpainted_large = inpainted_small.resize(crop_src.size, Image.LANCZOS) else: inpainted_large = inpainted_small pasted_result = src.copy() pasted_result.paste(inpainted_large, (x1, y1)) else: pasted_result = src.copy() # Feather blend for seamless compositing blend_mask = msk_dilated.filter(ImageFilter.GaussianBlur(radius=25)) if blend_mask.size != orig_size: blend_mask = blend_mask.resize(orig_size, Image.LANCZOS) result = Image.composite(pasted_result, src, blend_mask) output_bytes = _encode_png_fast(result) elapsed = time.perf_counter() - t0 logger.info(f"Text erase done in {elapsed:.2f}s, {len(output_bytes)//1024}KB") return Response( content=output_bytes, media_type="image/png", headers={"X-Model": "clipseg+lama", "X-Time": f"{elapsed:.2f}"}, ) except HTTPException: raise except Exception as e: logger.error(f"Text erase failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Text erase error: {str(e)}") @app.post("/inpaint", tags=["Inpaint"]) async def inpaint( image: UploadFile = File(...), mask: UploadFile = File(...), premium: str = Form("false"), ): """ Generative object removal with LaMa. premium=true → up to 1280px, JPEG 92 response. premium=false → capped at 720px, JPEG 85 response. """ if lama_model is None: raise HTTPException(status_code=503, detail="Model not loaded, retry in a moment") for upload, name in [(image, "image"), (mask, "mask")]: if upload.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): raise HTTPException(status_code=415, detail=f"Unsupported type for '{name}'") is_premium = premium.lower() in ("true", "1", "yes") max_px = PREMIUM_MAX_PX if is_premium else FREE_MAX_PX resp_q = 92 if is_premium else 85 t0 = time.perf_counter() try: img_data, mask_data = await asyncio.gather(image.read(), mask.read()) src = Image.open(io.BytesIO(img_data)).convert("RGB") msk = Image.open(io.BytesIO(mask_data)).convert("L") orig_size = src.size # Dilate mask so LaMa overwrites the exact boundary and a bit of the background msk_np = np.array(msk) kernel = np.ones((9, 9), np.uint8) msk_dilated_np = cv2.dilate(msk_np, kernel, iterations=2) msk_dilated = Image.fromarray(msk_dilated_np) # High-Resolution Cropped Inpainting coords = cv2.findNonZero(msk_dilated_np) if coords is not None: x, y, w, h = cv2.boundingRect(coords) # Add generous padding (50% of dimension or 128px, whichever is larger) pad_x = max(128, int(w * 0.5)) pad_y = max(128, int(h * 0.5)) x1 = max(0, x - pad_x) y1 = max(0, y - pad_y) x2 = min(orig_size[0], x + w + pad_x) y2 = min(orig_size[1], y + h + pad_y) crop_src = src.crop((x1, y1, x2, y2)) crop_msk = msk_dilated.crop((x1, y1, x2, y2)) # Process crop at native resolution (cap at 2560 to prevent OOM) inpaint_src = _fit(crop_src, 2560) if crop_msk.size != inpaint_src.size: inpaint_msk = crop_msk.resize(inpaint_src.size, Image.NEAREST) else: inpaint_msk = crop_msk logger.info(f"Inpainting Crop {crop_src.size}→{inpaint_src.size} (premium={is_premium})") loop = asyncio.get_event_loop() inpainted_small = await loop.run_in_executor(_executor, _run_lama, inpaint_src, inpaint_msk) # Restore inpainted crop to its original crop size if inpainted_small.size != crop_src.size: inpainted_large = inpainted_small.resize(crop_src.size, Image.LANCZOS) else: inpainted_large = inpainted_small # Paste the inpainted crop back into the full-size image pasted_result = src.copy() pasted_result.paste(inpainted_large, (x1, y1)) else: pasted_result = src.copy() # Feather blend for seamless compositing onto the untouched original high-res image blend_mask = msk_dilated.filter(ImageFilter.GaussianBlur(radius=25)) if blend_mask.size != orig_size: blend_mask = blend_mask.resize(orig_size, Image.LANCZOS) result = Image.composite(pasted_result, src, blend_mask) # 0% quality loss output_bytes = _encode_png_fast(result) elapsed = time.perf_counter() - t0 logger.info(f"Inpainting done in {elapsed:.2f}s, {len(output_bytes)//1024}KB") return Response( content=output_bytes, media_type="image/png", headers={"X-Model": "lama", "X-Time": f"{elapsed:.2f}"}, ) except HTTPException: raise except Exception as e: logger.error(f"Inpainting failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Inpainting error: {str(e)}") @app.post("/upscale", tags=["Upscale"]) async def upscale_image( image: UploadFile = File(...), scale: str = Form("2"), premium: str = Form("false"), ): """ AI Super-Resolution using Real-ESRGAN — PRO feature only. - scale=2 → 2× super-resolution (High PRO tier) - scale=4 → 4× super-resolution (Max PRO tier) Real-ESRGAN reconstructs genuine texture detail (hair, skin, fabric, etc.) using a GAN trained on millions of real-world photos. Returns a lossless PNG with authentically enhanced image quality. """ is_premium = premium.lower() in ("true", "1", "yes") if not is_premium: raise HTTPException( status_code=403, detail="AI upscaling is a PRO feature. Upgrade to access it." ) if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): raise HTTPException(status_code=415, detail="Unsupported image format") scale_int = int(scale) if scale in ("2", "4") else 2 t0 = time.perf_counter() try: img_data = await image.read() src = Image.open(io.BytesIO(img_data)).convert("RGB") logger.info(f"Upscaling {src.size} × x{scale_int} with Real-ESRGAN") loop = asyncio.get_event_loop() result = await loop.run_in_executor( _executor, _run_superres, src, scale_int ) # Lossless PNG — upscaled PRO image deserves zero compression loss output_bytes = _encode_png_fast(result) elapsed = time.perf_counter() - t0 logger.info( f"Upscaling done in {elapsed:.2f}s | " f"{src.size} → {result.size} | {len(output_bytes)//1024}KB" ) return Response( content=output_bytes, media_type="image/png", headers={ "X-Model": "realesrgan", "X-Scale": str(scale_int), "X-Time": f"{elapsed:.2f}", "X-Output-Width": str(result.width), "X-Output-Height": str(result.height), }, ) except HTTPException: raise except Exception as e: logger.error(f"Upscaling failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Upscaling error: {str(e)}") @app.post("/remove_bg", tags=["AI Edit"]) async def remove_background( image: UploadFile = File(...), ): """ AI Background Removal using rembg (u2net ONNX model). Accepts any image format (PNG, JPEG) and returns a transparent PNG with the background removed. CPU-efficient — no GPU required. """ if image.content_type not in ( "image/png", "image/jpeg", "image/jpg", "application/octet-stream" ): raise HTTPException(status_code=415, detail="Unsupported image format") t0 = time.perf_counter() try: img_data = await image.read() src = Image.open(io.BytesIO(img_data)).convert("RGBA") logger.info(f"Background removal: input {src.size}") loop = asyncio.get_event_loop() result = await loop.run_in_executor(_executor, _run_remove_bg, src) # Always output transparent PNG — JPEG does not support alpha channel output_buf = io.BytesIO() result.save(output_buf, format="PNG", optimize=False) output_bytes = output_buf.getvalue() elapsed = time.perf_counter() - t0 logger.info( f"Background removal done in {elapsed:.2f}s | " f"{src.size} | {len(output_bytes) // 1024}KB" ) return Response( content=output_bytes, media_type="image/png", headers={ "X-Model": "rembg-u2net", "X-Time": f"{elapsed:.2f}", "X-Output-Width": str(result.width), "X-Output-Height": str(result.height), }, ) except HTTPException: raise except Exception as e: logger.error(f"Background removal failed: {e}", exc_info=True) raise HTTPException(status_code=500, detail=f"Background removal error: {str(e)}") if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) uvicorn.run("main:app", host="0.0.0.0", port=port, log_level="info")