Spaces:
Running
Running
| import os | |
| import io | |
| import time | |
| import logging | |
| import asyncio | |
| from contextlib import asynccontextmanager | |
| from concurrent.futures import ThreadPoolExecutor | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import Response, JSONResponse | |
| from PIL import Image, ImageFilter | |
| from simple_lama_inpainting import SimpleLama | |
| from ultralytics import FastSAM | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer | |
| from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation | |
| from rembg import remove as rembg_remove, new_session as rembg_new_session | |
| # ───────────────────────────────────────────── | |
| # Logging | |
| # ──────────────────────────────────────────── | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ───────────────────────────────────────────── | |
| # Global model cache + thread pool | |
| # ───────────────────────────────────────────── | |
| lama_model: SimpleLama | None = None | |
| fastsam_model: FastSAM | None = None | |
| clipseg_processor = None | |
| clipseg_model = None | |
| # Real-ESRGAN upsampler instances | |
| realesrgan_x2: RealESRGANer | None = None | |
| realesrgan_x4: RealESRGANer | None = None | |
| # ─── FREE TIER CPU OPTIMIZATIONS ───────────────────────────────────────────── | |
| # Real-ESRGAN is massively heavy. On a 2-vCPU Hugging Face space, processing multiple | |
| # images at once will thrash the OS context switcher, causing a massive performance collapse. | |
| # By forcing 1 internally, and queueing 1 at a time, we ensure maximum throughput. | |
| import torch | |
| torch.set_num_threads(1) | |
| # Dedicated thread pool keeps them OFF the async event loop so FastAPI remains responsive. | |
| _executor = ThreadPoolExecutor(max_workers=1) | |
| # ───────────────────────────────────────────── | |
| # Model weight paths (baked into Docker image) | |
| # ───────────────────────────────────────────── | |
| WEIGHTS_DIR = os.path.join(os.path.dirname(__file__), "weights") | |
| REALESRGAN_X2_PATH = os.path.join(WEIGHTS_DIR, "RealESRGAN_x2plus.pth") | |
| REALESRGAN_X4_PATH = os.path.join(WEIGHTS_DIR, "RealESRGAN_x4plus.pth") | |
| FASTSAM_MODEL_PATH = os.path.join(WEIGHTS_DIR, "FastSAM-s.pt") | |
| # ───────────────────────────────────────────── | |
| # Inpainting resolution caps | |
| # ───────────────────────────────────────────── | |
| FREE_MAX_PX = 720 # free users — fast inference | |
| PREMIUM_MAX_PX = 1280 # premium — higher quality inpainting | |
| def _load_realesrgan() -> None: | |
| """Load Real-ESRGAN models for ×2 and ×4 super-resolution.""" | |
| global realesrgan_x2, realesrgan_x4 | |
| device = torch.device("cpu") # Hugging Face free tier is CPU-only | |
| for scale, path, var in [(2, REALESRGAN_X2_PATH, "realesrgan_x2"), | |
| (4, REALESRGAN_X4_PATH, "realesrgan_x4")]: | |
| if not os.path.exists(path): | |
| logger.warning(f"Real-ESRGAN x{scale} weights not found at {path} — upscaling at x{scale} disabled.") | |
| continue | |
| try: | |
| # Check if weights need wrapping (for models like 4x-UltraSharp that are bare state dicts) | |
| try: | |
| sd = torch.load(path, map_location="cpu") | |
| if "params_ema" not in sd and "params" not in sd: | |
| logger.info(f"Wrapping state dict for {path} to be compatible with RealESRGANer...") | |
| torch.save({"params_ema": sd}, path) | |
| except Exception as e: | |
| logger.error(f"Error checking/wrapping weights {path}: {e}") | |
| # RRDBNet is the generator backbone of Real-ESRGAN | |
| model = RRDBNet( | |
| num_in_ch=3, num_out_ch=3, | |
| num_feat=64, num_block=23, | |
| num_grow_ch=32, scale=scale | |
| ) | |
| upsampler = RealESRGANer( | |
| scale=scale, | |
| model_path=path, | |
| model=model, | |
| tile=128, # ✨ CPU OPTIMIZATION: 128px tiles fit beautifully in CPU cache, avoiding RAM bottleneck | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=False, # CPU requires float32 | |
| device=device, | |
| ) | |
| globals()[var] = upsampler | |
| logger.info(f"✅ Real-ESRGAN x{scale} loaded.") | |
| except Exception as e: | |
| logger.error(f"Failed to load Real-ESRGAN x{scale}: {e}", exc_info=True) | |
| # ───────────────────────────────────────────── | |
| # Startup: pre-load all models once | |
| # ──────────────────────────────────────────── | |
| async def lifespan(app: FastAPI): | |
| global lama_model, fastsam_model, clipseg_processor, clipseg_model | |
| logger.info("Loading FastSAM model…") | |
| try: | |
| # Try local weight path first, then fallback to download | |
| if os.path.exists(FASTSAM_MODEL_PATH): | |
| logger.info(f"Loading FastSAM from local path: {FASTSAM_MODEL_PATH}") | |
| fastsam_model = FastSAM(FASTSAM_MODEL_PATH) | |
| else: | |
| logger.warning(f"FastSAM weights not found at {FASTSAM_MODEL_PATH}, downloading...") | |
| fastsam_model = FastSAM("FastSAM-s.pt") | |
| except Exception as e: | |
| logger.error(f"FastSAM load failed: {e}") | |
| logger.info("Loading CLIPSeg model for text-based removal…") | |
| try: | |
| clipseg_processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| clipseg_model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") | |
| except Exception as e: | |
| logger.error(f"CLIPSeg load failed: {e}") | |
| logger.info("Loading LaMa model…") | |
| try: | |
| lama_model = SimpleLama() | |
| except Exception as e: | |
| logger.error(f"LaMa load failed: {e}") | |
| logger.info("Loading Real-ESRGAN upscaling models…") | |
| _load_realesrgan() | |
| logger.info("✅ All models ready.") | |
| yield | |
| logger.info("Shutting down.") | |
| # ───────────────────────────────────────────── | |
| # App | |
| # ───────────────────────────────────────────── | |
| app = FastAPI( | |
| title="Vanishly AI Backend", | |
| description="High-performance AI inpainting + super-resolution powered by LaMa & Real-ESRGAN.", | |
| version="4.0.0", | |
| lifespan=lifespan, | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["GET", "POST"], | |
| allow_headers=["*"], | |
| ) | |
| # ───────────────────────────────────────────── | |
| # Helpers | |
| # ───────────────────────────────────────────── | |
| def _fit(img: Image.Image, max_px: int) -> Image.Image: | |
| """Downscale so the longest edge ≤ max_px. Never upscales.""" | |
| w, h = img.size | |
| if max(w, h) <= max_px: | |
| return img | |
| ratio = max_px / max(w, h) | |
| nw = max(2, round(w * ratio) & ~1) | |
| nh = max(2, round(h * ratio) & ~1) | |
| return img.resize((nw, nh), Image.LANCZOS) | |
| def _encode_jpeg(img: Image.Image, quality: int = 88) -> bytes: | |
| buf = io.BytesIO() | |
| img.convert("RGB").save(buf, format="JPEG", quality=quality, optimize=False, subsampling=2) | |
| return buf.getvalue() | |
| def _encode_png_fast(img: Image.Image) -> bytes: | |
| buf = io.BytesIO() | |
| img.save(buf, format="PNG", compress_level=1, optimize=False) | |
| return buf.getvalue() | |
| def _run_lama(src: Image.Image, msk: Image.Image) -> Image.Image: | |
| return lama_model(src, msk) | |
| def _run_fastsam(src: Image.Image, px: int, py: int) -> Image.Image | None: | |
| results = fastsam_model( | |
| src, | |
| imgsz=512, | |
| conf=0.3, | |
| iou=0.7, | |
| points=[[px, py]], | |
| labels=[1], | |
| retina_masks=False, | |
| ) | |
| if not results or results[0].masks is None: | |
| return None | |
| mask_array = results[0].masks.data[0].cpu().numpy() | |
| mask_img = Image.fromarray((mask_array * 255).astype(np.uint8), mode="L") | |
| if mask_img.size != src.size: | |
| mask_img = mask_img.resize(src.size, Image.NEAREST) | |
| return mask_img | |
| def _run_clipseg(src: Image.Image, prompt: str) -> Image.Image | None: | |
| if clipseg_processor is None or clipseg_model is None: | |
| return None | |
| inputs = clipseg_processor(text=[prompt], images=[src], padding="max_length", return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = clipseg_model(**inputs) | |
| preds = outputs.logits.squeeze() | |
| preds = torch.sigmoid(preds) | |
| # Lowered threshold from 0.4 to 0.25 for better object detection | |
| # Adjust based on results: lower = more aggressive, higher = more conservative | |
| mask_np = (preds > 0.25).cpu().numpy().astype(np.uint8) * 255 | |
| kernel = np.ones((9, 9), np.uint8) | |
| mask_np = cv2.dilate(mask_np, kernel, iterations=2) | |
| mask_img = Image.fromarray(mask_np, mode="L") | |
| if mask_img.size != src.size: | |
| mask_img = mask_img.resize(src.size, Image.NEAREST) | |
| return mask_img | |
| def _run_superres(src: Image.Image, scale: int) -> Image.Image: | |
| """ | |
| AI Super-Resolution via Real-ESRGAN. | |
| Real-ESRGAN is a GAN-based model trained on millions of real-world photos. | |
| It genuinely reconstructs texture detail — hair strands, skin pores, fabric | |
| weave, foliage — rather than just interpolating pixels like FSRCNN. | |
| tile=256 means large images are processed in overlapping 256px tiles, | |
| keeping CPU RAM usage manageable regardless of input size. | |
| """ | |
| upsampler = realesrgan_x2 if scale == 2 else realesrgan_x4 | |
| if upsampler is None: | |
| # Fallback if model failed to load | |
| logger.warning(f"Real-ESRGAN x{scale} not available, falling back to LANCZOS+sharpen") | |
| w, h = src.size | |
| upscaled = src.resize((w * scale, h * scale), Image.LANCZOS) | |
| return upscaled.filter(ImageFilter.UnsharpMask(radius=2.0, percent=200, threshold=2)) | |
| logger.info(f" Real-ESRGAN x{scale}: input={src.size}") | |
| # Real-ESRGAN expects a uint8 numpy BGR array (OpenCV format) | |
| img_np = np.array(src.convert("RGB")) | |
| img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) | |
| # Run the GAN-based upscale — produces genuinely sharp, detailed output | |
| output_bgr, _ = upsampler.enhance(img_bgr, outscale=scale) | |
| # Convert back to PIL | |
| output_rgb = cv2.cvtColor(output_bgr, cv2.COLOR_BGR2RGB) | |
| result = Image.fromarray(output_rgb) | |
| logger.info(f" Real-ESRGAN x{scale}: output={result.size}") | |
| return result | |
| # ───────────────────────────────────────────── | |
| # Routes | |
| # ───────────────────────────────────────────── | |
| def _run_remove_bg(src: Image.Image) -> Image.Image: | |
| """ | |
| Remove background using rembg (u2net ONNX model). | |
| Returns an RGBA image with the background made transparent. | |
| """ | |
| # rembg expects / returns PIL Images when using the PIL interface | |
| result = rembg_remove(src) | |
| return result | |
| async def root(): | |
| return JSONResponse({ | |
| "status": "ok", | |
| "model": "lama+realesrgan", | |
| "ready": lama_model is not None, | |
| "upscale_x2": realesrgan_x2 is not None, | |
| "upscale_x4": realesrgan_x4 is not None, | |
| }) | |
| async def health(): | |
| if lama_model is None or fastsam_model is None: | |
| raise HTTPException(status_code=503, detail="Models not yet loaded") | |
| return JSONResponse({ | |
| "status": "healthy", | |
| "upscale_x2": realesrgan_x2 is not None, | |
| "upscale_x4": realesrgan_x4 is not None, | |
| }) | |
| async def segment_image( | |
| image: UploadFile = File(...), | |
| normX: float = Form(0.5), | |
| normY: float = Form(0.5), | |
| ): | |
| """ | |
| Point-tap segmentation via FastSAM. | |
| Returns a grayscale PNG mask (white = selected object). | |
| """ | |
| if fastsam_model is None: | |
| raise HTTPException(status_code=503, detail="FastSAM not loaded") | |
| if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): | |
| raise HTTPException(status_code=415, detail="Unsupported image format") | |
| t0 = time.perf_counter() | |
| try: | |
| img_data = await image.read() | |
| src = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| src_small = _fit(src, 512) | |
| px = int(normX * src_small.width) | |
| py = int(normY * src_small.height) | |
| logger.info(f"Segmenting @ ({px},{py}) on {src_small.size}") | |
| loop = asyncio.get_event_loop() | |
| mask_img = await loop.run_in_executor(_executor, _run_fastsam, src_small, px, py) | |
| if mask_img is None: | |
| mask_img = Image.new("L", src.size, 0) | |
| elif mask_img.size != src.size: | |
| mask_img = mask_img.resize(src.size, Image.NEAREST) | |
| output_bytes = _encode_png_fast(mask_img) | |
| logger.info(f"Segmentation done in {time.perf_counter()-t0:.2f}s, {len(output_bytes)//1024}KB") | |
| return Response(content=output_bytes, media_type="image/png", headers={"X-Model": "fastsam"}) | |
| except Exception as e: | |
| logger.error(f"Segmentation failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Segmentation error: {str(e)}") | |
| async def segment_text( | |
| image: UploadFile = File(...), | |
| prompt: str = Form(...), | |
| ): | |
| """ | |
| Text-based segmentation via CLIPSeg. | |
| Returns a grayscale PNG mask (white = selected object). | |
| """ | |
| if clipseg_model is None: | |
| raise HTTPException(status_code=503, detail="CLIPSeg not loaded") | |
| if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): | |
| raise HTTPException(status_code=415, detail="Unsupported image format") | |
| t0 = time.perf_counter() | |
| try: | |
| img_data = await image.read() | |
| src = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| src_small = _fit(src, 512) | |
| logger.info(f"Segmenting text '{prompt}' on {src_small.size}") | |
| loop = asyncio.get_event_loop() | |
| mask_img = await loop.run_in_executor(_executor, _run_clipseg, src_small, prompt) | |
| if mask_img is None: | |
| mask_img = Image.new("L", src.size, 0) | |
| elif mask_img.size != src.size: | |
| mask_img = mask_img.resize(src.size, Image.NEAREST) | |
| output_bytes = _encode_png_fast(mask_img) | |
| logger.info(f"Text Segmentation done in {time.perf_counter()-t0:.2f}s, {len(output_bytes)//1024}KB") | |
| return Response(content=output_bytes, media_type="image/png", headers={"X-Model": "clipseg"}) | |
| except Exception as e: | |
| logger.error(f"Text Segmentation failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Text Segmentation error: {str(e)}") | |
| async def text_erase( | |
| image: UploadFile = File(...), | |
| prompt: str = Form(...), | |
| premium: str = Form("false"), | |
| ): | |
| """ | |
| AI Text-Based Object Removal — single endpoint. | |
| Pipeline: Image + "remove the person" → CLIPSeg mask → LaMa inpaint → Result. | |
| Returns the edited image directly — no separate mask needed. | |
| premium=true → up to 1280px, JPEG 92. | |
| premium=false → capped at 720px, JPEG 85. | |
| """ | |
| if clipseg_model is None or lama_model is None: | |
| raise HTTPException(status_code=503, detail="Models not loaded") | |
| if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): | |
| raise HTTPException(status_code=415, detail="Unsupported image format") | |
| is_premium = premium.lower() in ("true", "1", "yes") | |
| max_px = PREMIUM_MAX_PX if is_premium else FREE_MAX_PX | |
| resp_q = 92 if is_premium else 85 | |
| t0 = time.perf_counter() | |
| try: | |
| img_data = await image.read() | |
| src = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| orig_size = src.size | |
| # Step 1: Generate mask from text prompt via CLIPSeg | |
| src_small = _fit(src, 512) | |
| logger.info(f"Text erase '{prompt}' — generating mask on {src_small.size}") | |
| loop = asyncio.get_event_loop() | |
| mask_img = await loop.run_in_executor(_executor, _run_clipseg, src_small, prompt) | |
| if mask_img is None: | |
| logger.warning("CLIPSeg returned None — nothing detected for prompt") | |
| mask_img = Image.new("L", orig_size, 0) | |
| elif mask_img.size != orig_size: | |
| mask_img = mask_img.resize(orig_size, Image.NEAREST) | |
| # Step 2: High-Resolution Cropped Inpainting | |
| msk_np = np.array(mask_img) | |
| kernel = np.ones((9, 9), np.uint8) | |
| msk_dilated_np = cv2.dilate(msk_np, kernel, iterations=2) | |
| msk_dilated = Image.fromarray(msk_dilated_np) | |
| coords = cv2.findNonZero(msk_dilated_np) | |
| if coords is not None: | |
| x, y, w, h = cv2.boundingRect(coords) | |
| pad_x = max(128, int(w * 0.5)) | |
| pad_y = max(128, int(h * 0.5)) | |
| x1 = max(0, x - pad_x) | |
| y1 = max(0, y - pad_y) | |
| x2 = min(orig_size[0], x + w + pad_x) | |
| y2 = min(orig_size[1], y + h + pad_y) | |
| crop_src = src.crop((x1, y1, x2, y2)) | |
| crop_msk = msk_dilated.crop((x1, y1, x2, y2)) | |
| # Process crop at native resolution (cap at 2560 to prevent OOM) | |
| inpaint_src = _fit(crop_src, 2560) | |
| if crop_msk.size != inpaint_src.size: | |
| inpaint_msk = crop_msk.resize(inpaint_src.size, Image.NEAREST) | |
| else: | |
| inpaint_msk = crop_msk | |
| logger.info(f"Text erase inpainting crop {crop_src.size}→{inpaint_src.size} (premium={is_premium})") | |
| inpainted_small = await loop.run_in_executor(_executor, _run_lama, inpaint_src, inpaint_msk) | |
| if inpainted_small.size != crop_src.size: | |
| inpainted_large = inpainted_small.resize(crop_src.size, Image.LANCZOS) | |
| else: | |
| inpainted_large = inpainted_small | |
| pasted_result = src.copy() | |
| pasted_result.paste(inpainted_large, (x1, y1)) | |
| else: | |
| pasted_result = src.copy() | |
| # Feather blend for seamless compositing | |
| blend_mask = msk_dilated.filter(ImageFilter.GaussianBlur(radius=25)) | |
| if blend_mask.size != orig_size: | |
| blend_mask = blend_mask.resize(orig_size, Image.LANCZOS) | |
| result = Image.composite(pasted_result, src, blend_mask) | |
| output_bytes = _encode_png_fast(result) | |
| elapsed = time.perf_counter() - t0 | |
| logger.info(f"Text erase done in {elapsed:.2f}s, {len(output_bytes)//1024}KB") | |
| return Response( | |
| content=output_bytes, | |
| media_type="image/png", | |
| headers={"X-Model": "clipseg+lama", "X-Time": f"{elapsed:.2f}"}, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Text erase failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Text erase error: {str(e)}") | |
| async def inpaint( | |
| image: UploadFile = File(...), | |
| mask: UploadFile = File(...), | |
| premium: str = Form("false"), | |
| ): | |
| """ | |
| Generative object removal with LaMa. | |
| premium=true → up to 1280px, JPEG 92 response. | |
| premium=false → capped at 720px, JPEG 85 response. | |
| """ | |
| if lama_model is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded, retry in a moment") | |
| for upload, name in [(image, "image"), (mask, "mask")]: | |
| if upload.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): | |
| raise HTTPException(status_code=415, detail=f"Unsupported type for '{name}'") | |
| is_premium = premium.lower() in ("true", "1", "yes") | |
| max_px = PREMIUM_MAX_PX if is_premium else FREE_MAX_PX | |
| resp_q = 92 if is_premium else 85 | |
| t0 = time.perf_counter() | |
| try: | |
| img_data, mask_data = await asyncio.gather(image.read(), mask.read()) | |
| src = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| msk = Image.open(io.BytesIO(mask_data)).convert("L") | |
| orig_size = src.size | |
| # Dilate mask so LaMa overwrites the exact boundary and a bit of the background | |
| msk_np = np.array(msk) | |
| kernel = np.ones((9, 9), np.uint8) | |
| msk_dilated_np = cv2.dilate(msk_np, kernel, iterations=2) | |
| msk_dilated = Image.fromarray(msk_dilated_np) | |
| # High-Resolution Cropped Inpainting | |
| coords = cv2.findNonZero(msk_dilated_np) | |
| if coords is not None: | |
| x, y, w, h = cv2.boundingRect(coords) | |
| # Add generous padding (50% of dimension or 128px, whichever is larger) | |
| pad_x = max(128, int(w * 0.5)) | |
| pad_y = max(128, int(h * 0.5)) | |
| x1 = max(0, x - pad_x) | |
| y1 = max(0, y - pad_y) | |
| x2 = min(orig_size[0], x + w + pad_x) | |
| y2 = min(orig_size[1], y + h + pad_y) | |
| crop_src = src.crop((x1, y1, x2, y2)) | |
| crop_msk = msk_dilated.crop((x1, y1, x2, y2)) | |
| # Process crop at native resolution (cap at 2560 to prevent OOM) | |
| inpaint_src = _fit(crop_src, 2560) | |
| if crop_msk.size != inpaint_src.size: | |
| inpaint_msk = crop_msk.resize(inpaint_src.size, Image.NEAREST) | |
| else: | |
| inpaint_msk = crop_msk | |
| logger.info(f"Inpainting Crop {crop_src.size}→{inpaint_src.size} (premium={is_premium})") | |
| loop = asyncio.get_event_loop() | |
| inpainted_small = await loop.run_in_executor(_executor, _run_lama, inpaint_src, inpaint_msk) | |
| # Restore inpainted crop to its original crop size | |
| if inpainted_small.size != crop_src.size: | |
| inpainted_large = inpainted_small.resize(crop_src.size, Image.LANCZOS) | |
| else: | |
| inpainted_large = inpainted_small | |
| # Paste the inpainted crop back into the full-size image | |
| pasted_result = src.copy() | |
| pasted_result.paste(inpainted_large, (x1, y1)) | |
| else: | |
| pasted_result = src.copy() | |
| # Feather blend for seamless compositing onto the untouched original high-res image | |
| blend_mask = msk_dilated.filter(ImageFilter.GaussianBlur(radius=25)) | |
| if blend_mask.size != orig_size: | |
| blend_mask = blend_mask.resize(orig_size, Image.LANCZOS) | |
| result = Image.composite(pasted_result, src, blend_mask) | |
| # 0% quality loss | |
| output_bytes = _encode_png_fast(result) | |
| elapsed = time.perf_counter() - t0 | |
| logger.info(f"Inpainting done in {elapsed:.2f}s, {len(output_bytes)//1024}KB") | |
| return Response( | |
| content=output_bytes, | |
| media_type="image/png", | |
| headers={"X-Model": "lama", "X-Time": f"{elapsed:.2f}"}, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Inpainting failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Inpainting error: {str(e)}") | |
| async def upscale_image( | |
| image: UploadFile = File(...), | |
| scale: str = Form("2"), | |
| premium: str = Form("false"), | |
| ): | |
| """ | |
| AI Super-Resolution using Real-ESRGAN — PRO feature only. | |
| - scale=2 → 2× super-resolution (High PRO tier) | |
| - scale=4 → 4× super-resolution (Max PRO tier) | |
| Real-ESRGAN reconstructs genuine texture detail (hair, skin, fabric, etc.) | |
| using a GAN trained on millions of real-world photos. | |
| Returns a lossless PNG with authentically enhanced image quality. | |
| """ | |
| is_premium = premium.lower() in ("true", "1", "yes") | |
| if not is_premium: | |
| raise HTTPException( | |
| status_code=403, | |
| detail="AI upscaling is a PRO feature. Upgrade to access it." | |
| ) | |
| if image.content_type not in ("image/png", "image/jpeg", "image/jpg", "application/octet-stream"): | |
| raise HTTPException(status_code=415, detail="Unsupported image format") | |
| scale_int = int(scale) if scale in ("2", "4") else 2 | |
| t0 = time.perf_counter() | |
| try: | |
| img_data = await image.read() | |
| src = Image.open(io.BytesIO(img_data)).convert("RGB") | |
| logger.info(f"Upscaling {src.size} × x{scale_int} with Real-ESRGAN") | |
| loop = asyncio.get_event_loop() | |
| result = await loop.run_in_executor( | |
| _executor, _run_superres, src, scale_int | |
| ) | |
| # Lossless PNG — upscaled PRO image deserves zero compression loss | |
| output_bytes = _encode_png_fast(result) | |
| elapsed = time.perf_counter() - t0 | |
| logger.info( | |
| f"Upscaling done in {elapsed:.2f}s | " | |
| f"{src.size} → {result.size} | {len(output_bytes)//1024}KB" | |
| ) | |
| return Response( | |
| content=output_bytes, | |
| media_type="image/png", | |
| headers={ | |
| "X-Model": "realesrgan", | |
| "X-Scale": str(scale_int), | |
| "X-Time": f"{elapsed:.2f}", | |
| "X-Output-Width": str(result.width), | |
| "X-Output-Height": str(result.height), | |
| }, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Upscaling failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Upscaling error: {str(e)}") | |
| async def remove_background( | |
| image: UploadFile = File(...), | |
| ): | |
| """ | |
| AI Background Removal using rembg (u2net ONNX model). | |
| Accepts any image format (PNG, JPEG) and returns a transparent PNG | |
| with the background removed. CPU-efficient — no GPU required. | |
| """ | |
| if image.content_type not in ( | |
| "image/png", "image/jpeg", "image/jpg", "application/octet-stream" | |
| ): | |
| raise HTTPException(status_code=415, detail="Unsupported image format") | |
| t0 = time.perf_counter() | |
| try: | |
| img_data = await image.read() | |
| src = Image.open(io.BytesIO(img_data)).convert("RGBA") | |
| logger.info(f"Background removal: input {src.size}") | |
| loop = asyncio.get_event_loop() | |
| result = await loop.run_in_executor(_executor, _run_remove_bg, src) | |
| # Always output transparent PNG — JPEG does not support alpha channel | |
| output_buf = io.BytesIO() | |
| result.save(output_buf, format="PNG", optimize=False) | |
| output_bytes = output_buf.getvalue() | |
| elapsed = time.perf_counter() - t0 | |
| logger.info( | |
| f"Background removal done in {elapsed:.2f}s | " | |
| f"{src.size} | {len(output_bytes) // 1024}KB" | |
| ) | |
| return Response( | |
| content=output_bytes, | |
| media_type="image/png", | |
| headers={ | |
| "X-Model": "rembg-u2net", | |
| "X-Time": f"{elapsed:.2f}", | |
| "X-Output-Width": str(result.width), | |
| "X-Output-Height": str(result.height), | |
| }, | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Background removal failed: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"Background removal error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) | |
| uvicorn.run("main:app", host="0.0.0.0", port=port, log_level="info") | |