Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from fastapi import FastAPI, Request, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import cv2 | |
| import numpy as np | |
| from PIL import Image, UnidentifiedImageError | |
| from io import BytesIO | |
| import base64 | |
| import traceback | |
| from starlette.exceptions import HTTPException as StarletteHTTPException | |
| from rembg import remove as bgrem_remove | |
| # =============================== | |
| # SIMPLE DPT MODEL (DEPTH ESTIMATION) | |
| # =============================== | |
| class SimpleDPT(nn.Module): | |
| def __init__(self, backbone_name='vit_base_patch16_384'): | |
| super(SimpleDPT, self).__init__() | |
| self.backbone = timm.create_model(backbone_name, pretrained=True, features_only=True) | |
| feature_info = self.backbone.feature_info | |
| channels = [f['num_chs'] for f in feature_info] | |
| self.decoder = nn.Sequential( | |
| nn.Conv2d(channels[-1], 256, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(256, 128, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 1, kernel_size=1) | |
| ) | |
| def forward(self, x, target_size): | |
| features = self.backbone(x) | |
| x = features[-1] | |
| depth = self.decoder(x) | |
| depth = nn.functional.interpolate(depth, size=target_size, mode='bilinear', align_corners=False) | |
| return depth | |
| # =============================== | |
| # DEPTH → NORMAL MAP | |
| # =============================== | |
| def depth_to_normal(depth): | |
| dy, dx = np.gradient(depth) | |
| normal = np.dstack((-dx, -dy, np.ones_like(depth))) | |
| n = np.linalg.norm(normal, axis=2, keepdims=True) | |
| normal /= (n + 1e-8) | |
| normal = (normal + 1) / 2 | |
| return normal | |
| # =============================== | |
| # CORE PROCESSING FUNCTION | |
| # =============================== | |
| def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image, k: int = 5): | |
| # Auto-set k for BASE-BOTTOM images | |
| filename = getattr(base_image, "filename", "") or "" | |
| if any(name in filename for name in ["BASE-BOTTOM-2.png"]): | |
| k = 12 | |
| print(f"[DEBUG] base_image filename: {filename}") | |
| # Prepare tensor | |
| img_pil = base_image.convert("RGB") | |
| img_np = np.array(img_pil) | |
| img_resized = img_pil.resize((384, 384)) | |
| img_tensor = torch.from_numpy(np.array(img_resized)).permute(2, 0, 1).unsqueeze(0).float() / 255.0 | |
| mean = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1) | |
| std = torch.as_tensor([0.5, 0.5, 0.5], device=img_tensor.device).view(1, 3, 1, 1) | |
| img_tensor = (img_tensor - mean) / std | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = SimpleDPT(backbone_name='vit_base_patch16_384').to(device) | |
| model.eval() | |
| # Depth inference | |
| with torch.no_grad(): | |
| target_size = img_pil.size[::-1] | |
| depth_map = model(img_tensor.to(device), target_size=target_size) | |
| depth_map = depth_map.squeeze().cpu().numpy() | |
| # Normalize depth | |
| depth_vis = (depth_map - depth_map.min()) / (depth_map.max() - depth_map.min()) | |
| # Normal map | |
| normal_map = depth_to_normal(depth_vis) | |
| # Shading map (CLAHE) | |
| img_lab = cv2.cvtColor(img_np, cv2.COLOR_RGB2LAB) | |
| l_channel, _, _ = cv2.split(img_lab) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| l_clahe = clahe.apply(l_channel) | |
| shading_map = l_clahe / 255.0 | |
| # Tile pattern | |
| pattern_np = np.array(pattern_image.convert("RGB")) | |
| target_h, target_w = img_np.shape[:2] | |
| pattern_h, pattern_w = pattern_np.shape[:2] | |
| pattern_tiled = np.zeros((target_h, target_w, 3), dtype=np.uint8) | |
| for y in range(0, target_h, pattern_h): | |
| for x in range(0, target_w, pattern_w): | |
| end_y = min(y + pattern_h, target_h) | |
| end_x = min(x + pattern_w, target_w) | |
| pattern_tiled[y:end_y, x:end_x] = pattern_np[0:(end_y - y), 0:(end_x - x)] | |
| # Blend pattern | |
| normal_map_loaded = normal_map.astype(np.float32) | |
| shading_map_loaded = np.stack([shading_map] * 3, axis=-1) | |
| alpha = 0.7 | |
| blended_shading = alpha * shading_map_loaded + (1 - alpha) | |
| pattern_folded = pattern_tiled.astype(np.float32) / 255.0 * blended_shading | |
| normal_boost = 0.5 + 0.5 * normal_map_loaded[..., 2:3] | |
| pattern_folded *= normal_boost | |
| pattern_folded = np.clip(pattern_folded, 0, 1) | |
| # ========================================================== | |
| # Background removal with post-processing (no duplicate blur) | |
| # ========================================================== | |
| buf = BytesIO() | |
| base_image.save(buf, format="PNG") | |
| base_bytes = buf.getvalue() | |
| # Get RGBA from bgrem | |
| result_no_bg = bgrem_remove(base_bytes) | |
| mask_img = Image.open(BytesIO(result_no_bg)).convert("RGBA") | |
| # Extract alpha and clean edges | |
| mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0 | |
| # 1. Slightly stronger shrink (balanced) | |
| kernel = np.ones((k, k), np.uint8) # slightly larger kernel | |
| mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255 # slightly stricter threshold | |
| mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion | |
| # 2. Feather edges (blur) | |
| mask_blurred = cv2.GaussianBlur(mask_eroded, (15, 15), sigmaX=3, sigmaY=3) | |
| # 3. Normalize | |
| mask_blurred = mask_blurred.astype(np.float32) / 255.0 | |
| # Final RGBA | |
| mask_stack = np.stack([mask_blurred] * 3, axis=-1) | |
| pattern_final = pattern_folded * mask_stack | |
| pattern_rgb = (pattern_final * 255).astype(np.uint8) | |
| alpha_channel = (mask_blurred * 255).astype(np.uint8) | |
| pattern_rgba = np.dstack((pattern_rgb, alpha_channel)) | |
| # return Image.fromarray(pattern_rgba, mode="RGBA") | |
| return Image.fromarray(pattern_rgba) | |
| # =============================== | |
| # WRAPPER: ACCEPT BYTES OR BASE64 | |
| # =============================== | |
| def process_saree(data, k: int = 5): | |
| if not isinstance(data, (list, tuple)) or len(data) != 2: | |
| raise HTTPException(status_code=422, detail="Expected an array with two elements: [base_blob, pattern_blob]") | |
| try: | |
| base_blob, pattern_blob = data | |
| # Convert base64 to bytes if needed | |
| if isinstance(base_blob, str): | |
| base_blob = base64.b64decode(base_blob.split(",")[-1]) | |
| if isinstance(pattern_blob, str): | |
| pattern_blob = base64.b64decode(pattern_blob.split(",")[-1]) | |
| base_image = Image.open(BytesIO(base_blob)).convert("RGBA") | |
| pattern_image = Image.open(BytesIO(pattern_blob)).convert("RGBA") | |
| except (base64.binascii.Error, UnidentifiedImageError) as e: | |
| raise HTTPException(status_code=422, detail=f"Invalid image data: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail=f"Error reading input images: {str(e)}") | |
| return _process_saree_core(base_image, pattern_image, k=k) | |
| # =============================== | |
| # GRADIO + FASTAPI APP | |
| # =============================== | |
| gradio_iface = gr.Interface( | |
| fn=process_saree, | |
| inputs=gr.Dataframe(headers=["Base Blob", "Pattern Blob"], type="array"), | |
| outputs=gr.Image(type="pil", label="Final Saree Output"), | |
| title="Saree Depth + Pattern Draping", | |
| description="Blob or base64 API compatible" | |
| ) | |
| app = FastAPI() | |
| # Root endpoint | |
| async def root(): | |
| return JSONResponse( | |
| content={ | |
| "message": "Saree Depth + Pattern Draping API", | |
| "endpoints": ["/predict-saree", "/api/predict/", "/gradio"] | |
| } | |
| ) | |
| # Mount Gradio at /gradio | |
| app = gr.mount_gradio_app(app, gradio_iface, path="/gradio") | |
| # =============================== | |
| # Custom API endpoint | |
| # =============================== | |
| async def predict_saree(request: Request): | |
| try: | |
| body = await request.json() | |
| if "data" not in body: | |
| raise HTTPException(status_code=422, detail="Missing 'data' field in request body") | |
| k = body.get("k", 5) # Default to 5 if not provided | |
| result_img = process_saree(body["data"], k=int(k)) | |
| buf = BytesIO() | |
| result_img.save(buf, format="PNG") | |
| base64_img = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return JSONResponse(content={"image_base64": base64_img}) | |
| except HTTPException as e: | |
| return JSONResponse(status_code=e.status_code, content={"error": "Input Error", "details": e.detail}) | |
| except Exception as e: | |
| tb = traceback.format_exc() | |
| return JSONResponse(status_code=500, content={"error": "Processing Error", "details": str(e), "trace": tb}) | |
| # Alias for backward compatibility | |
| async def alias_predict(request: Request): | |
| return await predict_saree(request) | |
| # =============================== | |
| # GLOBAL ERROR HANDLERS | |
| # =============================== | |
| async def http_exception_handler(request: Request, exc: StarletteHTTPException): | |
| if exc.status_code == 404: | |
| return JSONResponse( | |
| status_code=404, | |
| content={ | |
| "error": "Endpoint Not Found", | |
| "details": f"The requested URL {request.url.path} does not exist. " | |
| "Valid endpoints: /predict-saree or /api/predict/." | |
| } | |
| ) | |
| elif exc.status_code == 405: | |
| return JSONResponse( | |
| status_code=405, | |
| content={ | |
| "error": "Method Not Allowed", | |
| "details": f"Method {request.method} not allowed on {request.url.path}" | |
| } | |
| ) | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"error": exc.detail or "HTTP Error"} | |
| ) | |
| async def unhandled_exception_handler(request: Request, exc: Exception): | |
| return JSONResponse( | |
| status_code=500, | |
| content={ | |
| "error": "Internal Server Error", | |
| "details": str(exc) | |
| } | |
| ) | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |