nishanth-saka commited on
Commit
2d7a63f
·
verified ·
1 Parent(s): af04a83

custom endpoint added

Browse files
Files changed (1) hide show
  1. app.py +28 -7
app.py CHANGED
@@ -1,4 +1,6 @@
1
  import gradio as gr
 
 
2
  import torch
3
  import torch.nn as nn
4
  import timm
@@ -49,7 +51,7 @@ def depth_to_normal(depth):
49
  # CORE PROCESSING FUNCTION
50
  # ===============================
51
  def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
52
- # (Your existing depth estimation + pattern blending logic unchanged)
53
  img_pil = base_image.convert("RGB")
54
  img_np = np.array(img_pil)
55
 
@@ -160,16 +162,35 @@ def process_saree(data):
160
  return _process_saree_core(base_image, pattern_image)
161
 
162
  # ===============================
163
- # GRADIO INTERFACE
164
  # ===============================
165
- iface = gr.Interface(
166
  fn=process_saree,
167
  inputs=gr.Dataframe(headers=["Base Blob", "Pattern Blob"], type="array"),
168
  outputs=gr.Image(type="pil", label="Final Saree Output"),
169
- title="Saree Depth + Pattern Draping (Blob API Compatible)",
170
- description="Send image blobs (bytes or base64) as array [base, pattern] or use UI for testing."
171
  )
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  if __name__ == "__main__":
174
- # Enable CORS for Vite
175
- iface.launch(server_name="0.0.0.0", share=True)
 
1
  import gradio as gr
2
+ from fastapi import FastAPI, Request
3
+ from fastapi.responses import JSONResponse
4
  import torch
5
  import torch.nn as nn
6
  import timm
 
51
  # CORE PROCESSING FUNCTION
52
  # ===============================
53
  def _process_saree_core(base_image: Image.Image, pattern_image: Image.Image):
54
+ # (Depth estimation + pattern blending logic unchanged)
55
  img_pil = base_image.convert("RGB")
56
  img_np = np.array(img_pil)
57
 
 
162
  return _process_saree_core(base_image, pattern_image)
163
 
164
  # ===============================
165
+ # GRADIO + FASTAPI APP
166
  # ===============================
167
+ gradio_iface = gr.Interface(
168
  fn=process_saree,
169
  inputs=gr.Dataframe(headers=["Base Blob", "Pattern Blob"], type="array"),
170
  outputs=gr.Image(type="pil", label="Final Saree Output"),
171
+ title="Saree Depth + Pattern Draping",
172
+ description="Blob or base64 API compatible"
173
  )
174
 
175
+ app = FastAPI()
176
+
177
+ # Mount Gradio UI at root
178
+ app = gr.mount_gradio_app(app, gradio_iface, path="/")
179
+
180
+ # Custom named API endpoint
181
+ @app.post("/predict-saree")
182
+ async def predict_saree(request: Request):
183
+ body = await request.json()
184
+ result_img = process_saree(body["data"])
185
+
186
+ # Convert output image to base64 PNG
187
+ buf = BytesIO()
188
+ result_img.save(buf, format="PNG")
189
+ base64_img = base64.b64encode(buf.getvalue()).decode("utf-8")
190
+
191
+ return JSONResponse(content={"image_base64": base64_img})
192
+
193
+ # Run (Hugging Face will call uvicorn automatically)
194
  if __name__ == "__main__":
195
+ import uvicorn
196
+ uvicorn.run(app, host="0.0.0.0", port=7860)