Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Error Handling Updated
Browse files
app.py
CHANGED
|
@@ -1,15 +1,15 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
from fastapi import FastAPI, Request
|
| 3 |
from fastapi.responses import JSONResponse
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import timm
|
| 7 |
import cv2
|
| 8 |
import numpy as np
|
| 9 |
-
from PIL import Image
|
| 10 |
from io import BytesIO
|
| 11 |
import base64
|
| 12 |
-
import
|
| 13 |
|
| 14 |
# ===============================
|
| 15 |
# SIMPLE DPT MODEL (DEPTH ESTIMATION)
|
|
@@ -51,7 +51,6 @@ def depth_to_normal(depth):
|
|
| 51 |
# CORE PROCESSING FUNCTION
|
| 52 |
# ===============================
|
| 53 |
def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
|
| 54 |
-
# (Depth estimation + pattern blending logic — unchanged)
|
| 55 |
img_pil = base_image.convert("RGB")
|
| 56 |
img_np = np.array(img_pil)
|
| 57 |
|
|
@@ -146,18 +145,25 @@ def process_saree(data):
|
|
| 146 |
Accepts [base_blob, pattern_blob] as bytes OR base64 strings
|
| 147 |
"""
|
| 148 |
if not isinstance(data, (list, tuple)) or len(data) != 2:
|
| 149 |
-
raise
|
| 150 |
|
| 151 |
-
|
|
|
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
|
| 159 |
-
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
return _process_saree_core(base_image, pattern_image)
|
| 163 |
|
|
@@ -180,15 +186,30 @@ app = gr.mount_gradio_app(app, gradio_iface, path="/")
|
|
| 180 |
# Custom named API endpoint
|
| 181 |
@app.post("/predict-saree")
|
| 182 |
async def predict_saree(request: Request):
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
|
| 191 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
# Run (Hugging Face will call uvicorn automatically)
|
| 194 |
if __name__ == "__main__":
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from fastapi import FastAPI, Request, HTTPException
|
| 3 |
from fastapi.responses import JSONResponse
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import timm
|
| 7 |
import cv2
|
| 8 |
import numpy as np
|
| 9 |
+
from PIL import Image, UnidentifiedImageError
|
| 10 |
from io import BytesIO
|
| 11 |
import base64
|
| 12 |
+
import traceback
|
| 13 |
|
| 14 |
# ===============================
|
| 15 |
# SIMPLE DPT MODEL (DEPTH ESTIMATION)
|
|
|
|
| 51 |
# CORE PROCESSING FUNCTION
|
| 52 |
# ===============================
|
| 53 |
def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
|
|
|
|
| 54 |
img_pil = base_image.convert("RGB")
|
| 55 |
img_np = np.array(img_pil)
|
| 56 |
|
|
|
|
| 145 |
Accepts [base_blob, pattern_blob] as bytes OR base64 strings
|
| 146 |
"""
|
| 147 |
if not isinstance(data, (list, tuple)) or len(data) != 2:
|
| 148 |
+
raise HTTPException(status_code=422, detail="Expected an array with two elements: [base_blob, pattern_blob]")
|
| 149 |
|
| 150 |
+
try:
|
| 151 |
+
base_blob, pattern_blob = data
|
| 152 |
|
| 153 |
+
# Convert base64 to bytes if needed
|
| 154 |
+
if isinstance(base_blob, str):
|
| 155 |
+
base_blob = base64.b64decode(base_blob.split(",")[-1])
|
| 156 |
+
if isinstance(pattern_blob, str):
|
| 157 |
+
pattern_blob = base64.b64decode(pattern_blob.split(",")[-1])
|
| 158 |
|
| 159 |
+
base_image = Image.open(BytesIO(base_blob)).convert("RGBA")
|
| 160 |
+
pattern_image = Image.open(BytesIO(pattern_blob)).convert("RGBA")
|
| 161 |
+
|
| 162 |
+
except (base64.binascii.Error, UnidentifiedImageError) as e:
|
| 163 |
+
raise HTTPException(status_code=422, detail=f"Invalid image data: {str(e)}")
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
raise HTTPException(status_code=400, detail=f"Error reading input images: {str(e)}")
|
| 167 |
|
| 168 |
return _process_saree_core(base_image, pattern_image)
|
| 169 |
|
|
|
|
| 186 |
# Custom named API endpoint
|
| 187 |
@app.post("/predict-saree")
|
| 188 |
async def predict_saree(request: Request):
|
| 189 |
+
try:
|
| 190 |
+
body = await request.json()
|
| 191 |
+
|
| 192 |
+
if "data" not in body:
|
| 193 |
+
raise HTTPException(status_code=422, detail="Missing 'data' field in request body")
|
| 194 |
+
|
| 195 |
+
result_img = process_saree(body["data"])
|
| 196 |
|
| 197 |
+
# Convert output image to base64 PNG
|
| 198 |
+
buf = BytesIO()
|
| 199 |
+
result_img.save(buf, format="PNG")
|
| 200 |
+
base64_img = base64.b64encode(buf.getvalue()).decode("utf-8")
|
| 201 |
|
| 202 |
+
return JSONResponse(content={"image_base64": base64_img})
|
| 203 |
+
|
| 204 |
+
except HTTPException as e:
|
| 205 |
+
return JSONResponse(status_code=e.status_code, content={"error": "Input Error", "details": e.detail})
|
| 206 |
+
|
| 207 |
+
except Exception as e:
|
| 208 |
+
tb = traceback.format_exc()
|
| 209 |
+
return JSONResponse(
|
| 210 |
+
status_code=500,
|
| 211 |
+
content={"error": "Processing Error", "details": str(e), "trace": tb}
|
| 212 |
+
)
|
| 213 |
|
| 214 |
# Run (Hugging Face will call uvicorn automatically)
|
| 215 |
if __name__ == "__main__":
|