nishanth-saka commited on
Commit
414709f
·
verified ·
1 Parent(s): d938499

K as param

Browse files
Files changed (1) hide show
  1. app.py +9 -12
app.py CHANGED
@@ -52,7 +52,7 @@ def depth_to_normal(depth):
52
  # ===============================
53
  # CORE PROCESSING FUNCTION
54
  # ===============================
55
- def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
56
  img_pil = base_image.convert("RGB")
57
  img_np = np.array(img_pil)
58
 
@@ -123,8 +123,7 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
123
  # Extract alpha and clean edges
124
  mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
125
 
126
- # 1. Slightly stronger shrink (balanced)
127
- k = 8
128
  kernel = np.ones((k, k), np.uint8) # slightly larger kernel
129
  mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255 # slightly stricter threshold
130
  mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion
@@ -149,7 +148,7 @@ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
149
  # ===============================
150
  # WRAPPER: ACCEPT BYTES OR BASE64
151
  # ===============================
152
- def process_saree(data):
153
  if not isinstance(data, (list, tuple)) or len(data) != 2:
154
  raise HTTPException(status_code=422, detail="Expected an array with two elements: [base_blob, pattern_blob]")
155
 
@@ -170,7 +169,7 @@ def process_saree(data):
170
  except Exception as e:
171
  raise HTTPException(status_code=400, detail=f"Error reading input images: {str(e)}")
172
 
173
- return _process_saree_core(base_image, pattern_image)
174
 
175
  # ===============================
176
  # GRADIO + FASTAPI APP
@@ -198,31 +197,29 @@ async def root():
198
  # Mount Gradio at /gradio
199
  app = gr.mount_gradio_app(app, gradio_iface, path="/gradio")
200
 
 
201
  # Custom API endpoint
 
202
  @app.post("/predict-saree")
203
  async def predict_saree(request: Request):
204
  try:
205
  body = await request.json()
206
-
207
  if "data" not in body:
208
  raise HTTPException(status_code=422, detail="Missing 'data' field in request body")
209
 
210
- result_img = process_saree(body["data"])
 
211
 
212
  buf = BytesIO()
213
  result_img.save(buf, format="PNG")
214
  base64_img = base64.b64encode(buf.getvalue()).decode("utf-8")
215
-
216
  return JSONResponse(content={"image_base64": base64_img})
217
 
218
  except HTTPException as e:
219
  return JSONResponse(status_code=e.status_code, content={"error": "Input Error", "details": e.detail})
220
  except Exception as e:
221
  tb = traceback.format_exc()
222
- return JSONResponse(
223
- status_code=500,
224
- content={"error": "Processing Error", "details": str(e), "trace": tb}
225
- )
226
 
227
  # Alias for backward compatibility
228
  @app.post("/api/predict/")
 
52
  # ===============================
53
  # CORE PROCESSING FUNCTION
54
  # ===============================
55
+ def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image, k: int = 5):
56
  img_pil = base_image.convert("RGB")
57
  img_np = np.array(img_pil)
58
 
 
123
  # Extract alpha and clean edges
124
  mask_alpha = np.array(mask_img)[:, :, 3].astype(np.float32) / 255.0
125
 
126
+ # 1. Slightly stronger shrink (balanced)
 
127
  kernel = np.ones((k, k), np.uint8) # slightly larger kernel
128
  mask_binary = (mask_alpha > k/100).astype(np.uint8) * 255 # slightly stricter threshold
129
  mask_eroded = cv2.erode(mask_binary, kernel, iterations=3) # balanced erosion
 
148
  # ===============================
149
  # WRAPPER: ACCEPT BYTES OR BASE64
150
  # ===============================
151
+ def process_saree(data, k: int = 5):
152
  if not isinstance(data, (list, tuple)) or len(data) != 2:
153
  raise HTTPException(status_code=422, detail="Expected an array with two elements: [base_blob, pattern_blob]")
154
 
 
169
  except Exception as e:
170
  raise HTTPException(status_code=400, detail=f"Error reading input images: {str(e)}")
171
 
172
+ return _process_saree_core(base_image, pattern_image, k=k)
173
 
174
  # ===============================
175
  # GRADIO + FASTAPI APP
 
197
  # Mount Gradio at /gradio
198
  app = gr.mount_gradio_app(app, gradio_iface, path="/gradio")
199
 
200
+ # ===============================
201
  # Custom API endpoint
202
+ # ===============================
203
  @app.post("/predict-saree")
204
  async def predict_saree(request: Request):
205
  try:
206
  body = await request.json()
 
207
  if "data" not in body:
208
  raise HTTPException(status_code=422, detail="Missing 'data' field in request body")
209
 
210
+ k = body.get("k", 5) # Default to 5 if not provided
211
+ result_img = process_saree(body["data"], k=int(k))
212
 
213
  buf = BytesIO()
214
  result_img.save(buf, format="PNG")
215
  base64_img = base64.b64encode(buf.getvalue()).decode("utf-8")
 
216
  return JSONResponse(content={"image_base64": base64_img})
217
 
218
  except HTTPException as e:
219
  return JSONResponse(status_code=e.status_code, content={"error": "Input Error", "details": e.detail})
220
  except Exception as e:
221
  tb = traceback.format_exc()
222
+ return JSONResponse(status_code=500, content={"error": "Processing Error", "details": str(e), "trace": tb})
 
 
 
223
 
224
  # Alias for backward compatibility
225
  @app.post("/api/predict/")