Spaces:
Sleeping
Sleeping
K as param
Browse files
app.py
CHANGED
|
@@ -52,7 +52,7 @@ def depth_to_normal(depth):
|
|
| 52 |
# ===============================
|
| 53 |
# CORE PROCESSING FUNCTION
|
| 54 |
# ===============================
|
| 55 |
-
def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
|
| 56 |
img_pil = base_image.convert("RGB")
|
| 57 |
img_np = np.array(img_pil)
|
| 58 |
|
|
@@ -123,8 +123,7 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
|
|
| 123 |
# Extract alpha and clean edges
|
| 124 |
mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
|
| 125 |
|
| 126 |
-
# 1. Slightly stronger shrink (balanced)
|
| 127 |
-
k = 8
|
| 128 |
kernel = np.ones((k, k), np.uint8) # slightly larger kernel
|
| 129 |
mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255 # slightly stricter threshold
|
| 130 |
mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion
|
|
@@ -149,7 +148,7 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
|
|
| 149 |
# ===============================
|
| 150 |
# WRAPPER: ACCEPT BYTES OR BASE64
|
| 151 |
# ===============================
|
| 152 |
-
def process_saree(data):
|
| 153 |
if not isinstance(data, (list, tuple)) or len(data) != 2:
|
| 154 |
raise HTTPException(status_code=422, detail="Expected an array with two elements: [base_blob, pattern_blob]")
|
| 155 |
|
|
@@ -170,7 +169,7 @@ def process_saree(data):
|
|
| 170 |
except Exception as e:
|
| 171 |
raise HTTPException(status_code=400, detail=f"Error reading input images: {str(e)}")
|
| 172 |
|
| 173 |
-
return _process_saree_core(base_image, pattern_image)
|
| 174 |
|
| 175 |
# ===============================
|
| 176 |
# GRADIO + FASTAPI APP
|
|
@@ -198,31 +197,29 @@ async def root():
|
|
| 198 |
# Mount Gradio at /gradio
|
| 199 |
app = gr.mount_gradio_app(app, gradio_iface, path="/gradio")
|
| 200 |
|
|
|
|
| 201 |
# Custom API endpoint
|
|
|
|
| 202 |
@app.post("/predict-saree")
|
| 203 |
async def predict_saree(request: Request):
|
| 204 |
try:
|
| 205 |
body = await request.json()
|
| 206 |
-
|
| 207 |
if "data" not in body:
|
| 208 |
raise HTTPException(status_code=422, detail="Missing 'data' field in request body")
|
| 209 |
|
| 210 |
-
|
|
|
|
| 211 |
|
| 212 |
buf = BytesIO()
|
| 213 |
result_img.save(buf, format="PNG")
|
| 214 |
base64_img = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 215 |
-
|
| 216 |
return JSONResponse(content={"image_base64": base64_img})
|
| 217 |
|
| 218 |
except HTTPException as e:
|
| 219 |
return JSONResponse(status_code=e.status_code, content={"error": "Input Error", "details": e.detail})
|
| 220 |
except Exception as e:
|
| 221 |
tb = traceback.format_exc()
|
| 222 |
-
return JSONResponse(
|
| 223 |
-
status_code=500,
|
| 224 |
-
content={"error": "Processing Error", "details": str(e), "trace": tb}
|
| 225 |
-
)
|
| 226 |
|
| 227 |
# Alias for backward compatibility
|
| 228 |
@app.post("/api/predict/")
|
|
|
|
| 52 |
# ===============================
|
| 53 |
# CORE PROCESSING FUNCTION
|
| 54 |
# ===============================
|
| 55 |
+
def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image, k: int = 5):
|
| 56 |
img_pil = base_image.convert("RGB")
|
| 57 |
img_np = np.array(img_pil)
|
| 58 |
|
|
|
|
| 123 |
# Extract alpha and clean edges
|
| 124 |
mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
|
| 125 |
|
| 126 |
+
# 1. Slightly stronger shrink (balanced)
|
|
|
|
| 127 |
kernel = np.ones((k, k), np.uint8) # slightly larger kernel
|
| 128 |
mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255 # slightly stricter threshold
|
| 129 |
mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion
|
|
|
|
| 148 |
# ===============================
|
| 149 |
# WRAPPER: ACCEPT BYTES OR BASE64
|
| 150 |
# ===============================
|
| 151 |
+
def process_saree(data, k: int = 5):
|
| 152 |
if not isinstance(data, (list, tuple)) or len(data) != 2:
|
| 153 |
raise HTTPException(status_code=422, detail="Expected an array with two elements: [base_blob, pattern_blob]")
|
| 154 |
|
|
|
|
| 169 |
except Exception as e:
|
| 170 |
raise HTTPException(status_code=400, detail=f"Error reading input images: {str(e)}")
|
| 171 |
|
| 172 |
+
return _process_saree_core(base_image, pattern_image, k=k)
|
| 173 |
|
| 174 |
# ===============================
|
| 175 |
# GRADIO + FASTAPI APP
|
|
|
|
| 197 |
# Mount Gradio at /gradio
|
| 198 |
app = gr.mount_gradio_app(app, gradio_iface, path="/gradio")
|
| 199 |
|
| 200 |
+
# ===============================
|
| 201 |
# Custom API endpoint
|
| 202 |
+
# ===============================
|
| 203 |
@app.post("/predict-saree")
|
| 204 |
async def predict_saree(request: Request):
|
| 205 |
try:
|
| 206 |
body = await request.json()
|
|
|
|
| 207 |
if "data" not in body:
|
| 208 |
raise HTTPException(status_code=422, detail="Missing 'data' field in request body")
|
| 209 |
|
| 210 |
+
k = body.get("k", 5) # Default to 5 if not provided
|
| 211 |
+
result_img = process_saree(body["data"], k=int(k))
|
| 212 |
|
| 213 |
buf = BytesIO()
|
| 214 |
result_img.save(buf, format="PNG")
|
| 215 |
base64_img = base64.b64encode(buf.getvalue()).decode("utf-8")
|
|
|
|
| 216 |
return JSONResponse(content={"image_base64": base64_img})
|
| 217 |
|
| 218 |
except HTTPException as e:
|
| 219 |
return JSONResponse(status_code=e.status_code, content={"error": "Input Error", "details": e.detail})
|
| 220 |
except Exception as e:
|
| 221 |
tb = traceback.format_exc()
|
| 222 |
+
return JSONResponse(status_code=500, content={"error": "Processing Error", "details": str(e), "trace": tb})
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
# Alias for backward compatibility
|
| 225 |
@app.post("/api/predict/")
|