import gradio as gr from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse import torch import torch.nn as nn import timm import cv2 import numpy as np from PIL import Image, UnidentifiedImageError from io import BytesIO import base64 import traceback from starlette.exceptions import HTTPException as StarletteHTTPException from rembg import remove as bgrem_remove # =============================== # SIMPLE DPT MODEL (DEPTH ESTIMATION) # =============================== class SimpleDPT(nn.Module): def __init__(self, backbone_name='vit_base_patch16_384'): super(SimpleDPT, self).__init__() self.backbone = timm.create_model(backbone_name, pretrained=True, features_only=True) feature_info = self.backbone.feature_info channels = [f['num_chs'] for f in feature_info] self.decoder = nn.Sequential( nn.Conv2d(channels[-1], 256, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(256, 128, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(128, 1, kernel_size=1) ) def forward(self, x, target_size): features = self.backbone(x) x = features[-1] depth = self.decoder(x) depth = nn.functional.interpolate(depth, size=target_size, mode='bilinear', align_corners=False) return depth # =============================== # DEPTH → NORMAL MAP # =============================== def depth_to_normal(depth): dy, dx = np.gradient(depth) normal = np.dstack((-dx, -dy, np.ones_like(depth))) n = np.linalg.norm(normal, axis=2, keepdims=True) normal /= (n + 1e-8) normal = (normal + 1) / 2 return normal # =============================== # CORE PROCESSING FUNCTION # =============================== def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image, k: int = 5): # Auto-set k for BASE-BOTTOM images filename = getattr(base_image, "filename", "") or "" if any(name in filename for name in ["BASE-BOTTOM-2.png"]): k = 12 print(f"[DEBUG] base_image filename: {filename}") # Prepare tensor img_pil = base_image.convert("RGB") img_np = np.array(img_pil) img_resized = img_pil.resize((384, 384)) img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1) std = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1) img_tensor = (img_tensor - mean) / std device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device) model.eval() # Depth inference with torch.no_grad(): target_size = img_pil.size[::-1] depth_map = model(img_tensor.to(device), target_size=target_size) depth_map = depth_map.squeeze().cpu().numpy() # Normalize depth depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) # Normal map normal_map = depth_to_normal(depth_vis) # Shading map (CLAHE) img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB) l_channel, _, _ = cv2.split(img_lab) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) l_clahe = clahe.apply(l_channel) shading_map = l_clahe / 255.0 # Tile pattern pattern_np = np.array(pattern_image.convert("RGB")) target_h, target_w = img_np.shape[:2] pattern_h, pattern_w = pattern_np.shape[:2] pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8) for y in range(0, target_h, pattern_h): for x in range(0, target_w, pattern_w): end_y = min(y + pattern_h, target_h) end_x = min(x + pattern_w, target_w) pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)] # Blend pattern normal_map_loaded = normal_map.astype(np.float32) shading_map_loaded = np.stack([shading_map] * 3, axis=-1) alpha = 0.7 blended_shading = alpha * shading_map_loaded + (1 - alpha) pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3] pattern_folded *= normal_boost pattern_folded = np.clip(pattern_folded, 0, 1) # ========================================================== # Background removal with post-processing (no duplicate blur) # ========================================================== buf = BytesIO() base_image.save(buf, format="PNG") base_bytes = buf.getvalue() # Get RGBA from bgrem result_no_bg = bgrem_remove(base_bytes) mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA") # Extract alpha and clean edges mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0 # 1. Slightly stronger shrink (balanced) kernel = np.ones((k, k), np.uint8) # slightly larger kernel mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255 # slightly stricter threshold mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion # 2. Feather edges (blur) mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3) # 3. Normalize mask_blurred = mask_blurred.astype(np.float32) / 255.0 # Final RGBA mask_stack = np.stack([mask_blurred] * 3, axis=-1) pattern_final = pattern_folded * mask_stack pattern_rgb = (pattern_final * 255).astype(np.uint8) alpha_channel = (mask_blurred * 255).astype(np.uint8) pattern_rgba = np.dstack((pattern_rgb, alpha_channel)) # return Image.fromarray(pattern_rgba, mode="RGBA") return Image.fromarray(pattern_rgba) # =============================== # WRAPPER: ACCEPT BYTES OR BASE64 # =============================== def process_saree(data, k: int = 5): if not isinstance(data, (list, tuple)) or len(data) != 2: raise HTTPException(status_code=422, detail="Expected an array with two elements: [base_blob, pattern_blob]") try: base_blob, pattern_blob = data # Convert base64 to bytes if needed if isinstance(base_blob, str): base_blob = base64.b64decode(base_blob.split(",")[-1]) if isinstance(pattern_blob, str): pattern_blob = base64.b64decode(pattern_blob.split(",")[-1]) base_image = Image.open(BytesIO(base_blob)).convert("RGBA") pattern_image = Image.open(BytesIO(pattern_blob)).convert("RGBA") except (base64.binascii.Error, UnidentifiedImageError) as e: raise HTTPException(status_code=422, detail=f"Invalid image data: {str(e)}") except Exception as e: raise HTTPException(status_code=400, detail=f"Error reading input images: {str(e)}") return _process_saree_core(base_image, pattern_image, k=k) # =============================== # GRADIO + FASTAPI APP # =============================== gradio_iface = gr.Interface( fn=process_saree, inputs=gr.Dataframe(headers=["Base Blob", "Pattern Blob"], type="array"), outputs=gr.Image(type="pil", label="Final Saree Output"), title="Saree Depth + Pattern Draping", description="Blob or base64 API compatible" ) app = FastAPI() # Root endpoint @app.get("/") async def root(): return JSONResponse( content={ "message": "Saree Depth + Pattern Draping API", "endpoints": ["/predict-saree", "/api/predict/", "/gradio"] } ) # Mount Gradio at /gradio app = gr.mount_gradio_app(app, gradio_iface, path="/gradio") # =============================== # Custom API endpoint # =============================== @app.post("/predict-saree") async def predict_saree(request: Request): try: body = await request.json() if "data" not in body: raise HTTPException(status_code=422, detail="Missing 'data' field in request body") k = body.get("k", 5) # Default to 5 if not provided result_img = process_saree(body["data"], k=int(k)) buf = BytesIO() result_img.save(buf, format="PNG") base64_img = base64.b64encode(buf.getvalue()).decode("utf-8") return JSONResponse(content={"image_base64": base64_img}) except HTTPException as e: return JSONResponse(status_code=e.status_code, content={"error": "Input Error", "details": e.detail}) except Exception as e: tb = traceback.format_exc() return JSONResponse(status_code=500, content={"error": "Processing Error", "details": str(e), "trace": tb}) # Alias for backward compatibility @app.post("/api/predict/") async def alias_predict(request: Request): return await predict_saree(request) # =============================== # GLOBAL ERROR HANDLERS # =============================== @app.exception_handler(StarletteHTTPException) async def http_exception_handler(request: Request, exc: StarletteHTTPException): if exc.status_code == 404: return JSONResponse( status_code=404, content={ "error": "Endpoint Not Found", "details": f"The requested URL {request.url.path} does not exist. " "Valid endpoints: /predict-saree or /api/predict/." } ) elif exc.status_code == 405: return JSONResponse( status_code=405, content={ "error": "Method Not Allowed", "details": f"Method {request.method} not allowed on {request.url.path}" } ) return JSONResponse( status_code=exc.status_code, content={"error": exc.detail or "HTTP Error"} ) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): return JSONResponse( status_code=500, content={ "error": "Internal Server Error", "details": str(exc) } ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)