NitinBot001 commited on
Commit
85d3d94
·
verified ·
1 Parent(s): e40d29d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +350 -0
app.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from fastapi import FastAPI, File, UploadFile, HTTPException
3
+ from fastapi.responses import Response
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ import uvicorn
6
+ from carvekit.api.high import HiInterface
7
+ from PIL import Image
8
+ import io
9
+ import base64
10
+ import asyncio
11
+ import threading
12
+ import numpy as np
13
+ from typing import Optional
14
+ import logging
15
+
16
+ # Configure logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
+
20
+ # Initialize CarveKit
21
+ try:
22
+ interface = HiInterface(
23
+ object_type="object", # Can be "object" or "hairs-like"
24
+ batch_size_seg=5,
25
+ batch_size_matting=1,
26
+ device='cpu', # Use 'cuda' if GPU is available
27
+ seg_mask_size=640,
28
+ matting_mask_size=2048,
29
+ trimap_prob_threshold=231,
30
+ trimap_kernel_size=30,
31
+ trimap_erosion_iters=5,
32
+ fp16=False
33
+ )
34
+ logger.info("CarveKit interface initialized successfully")
35
+ except Exception as e:
36
+ logger.error(f"Failed to initialize CarveKit: {e}")
37
+ interface = None
38
+
39
+ # Create FastAPI app
40
+ app = FastAPI(
41
+ title="CarveKit Background Remover API",
42
+ description="API for removing backgrounds from images using CarveKit",
43
+ version="1.0.0"
44
+ )
45
+
46
+ # Add CORS middleware
47
+ app.add_middleware(
48
+ CORSMiddleware,
49
+ allow_origins=["*"],
50
+ allow_credentials=True,
51
+ allow_methods=["*"],
52
+ allow_headers=["*"],
53
+ )
54
+
55
+ def process_image_carvekit(image: Image.Image) -> tuple[Optional[Image.Image], str]:
56
+ """Process image with CarveKit to remove background"""
57
+ try:
58
+ if interface is None:
59
+ return None, "CarveKit interface not initialized"
60
+
61
+ if image is None:
62
+ return None, "No image provided"
63
+
64
+ # Convert to RGB if necessary
65
+ if image.mode != 'RGB':
66
+ image = image.convert('RGB')
67
+
68
+ # Process the image
69
+ images_without_bg = interface([image])
70
+
71
+ if images_without_bg and len(images_without_bg) > 0:
72
+ return images_without_bg[0], "Background removed successfully!"
73
+ else:
74
+ return None, "Failed to process image"
75
+
76
+ except Exception as e:
77
+ logger.error(f"Error processing image: {e}")
78
+ return None, f"Error processing image: {str(e)}"
79
+
80
+ # API Endpoints
81
+ @app.get("/")
82
+ async def root():
83
+ """Root endpoint with API information"""
84
+ return {
85
+ "message": "CarveKit Background Remover API",
86
+ "version": "1.0.0",
87
+ "endpoints": {
88
+ "remove_background": "/api/remove-background",
89
+ "remove_background_base64": "/api/remove-background-base64",
90
+ "health": "/health"
91
+ },
92
+ "docs": "/docs",
93
+ "gradio_interface": "/gradio"
94
+ }
95
+
96
+ @app.get("/health")
97
+ async def health_check():
98
+ """Health check endpoint"""
99
+ return {
100
+ "status": "healthy",
101
+ "carvekit_ready": interface is not None
102
+ }
103
+
104
+ @app.post("/api/remove-background")
105
+ async def remove_background_api(file: UploadFile = File(...)):
106
+ """Remove background from uploaded image file"""
107
+ try:
108
+ # Validate file type
109
+ if not file.content_type.startswith('image/'):
110
+ raise HTTPException(status_code=400, detail="File must be an image")
111
+
112
+ # Read and process image
113
+ contents = await file.read()
114
+ image = Image.open(io.BytesIO(contents))
115
+
116
+ # Process with CarveKit
117
+ result_image, message = process_image_carvekit(image)
118
+
119
+ if result_image is None:
120
+ raise HTTPException(status_code=500, detail=message)
121
+
122
+ # Convert result to bytes
123
+ img_byte_arr = io.BytesIO()
124
+ result_image.save(img_byte_arr, format='PNG')
125
+ img_byte_arr.seek(0)
126
+
127
+ return Response(
128
+ content=img_byte_arr.getvalue(),
129
+ media_type="image/png",
130
+ headers={"Content-Disposition": "attachment; filename=result.png"}
131
+ )
132
+
133
+ except HTTPException:
134
+ raise
135
+ except Exception as e:
136
+ logger.error(f"API error: {e}")
137
+ raise HTTPException(status_code=500, detail=str(e))
138
+
139
+ @app.post("/api/remove-background-base64")
140
+ async def remove_background_base64(data: dict):
141
+ """Remove background from base64 encoded image"""
142
+ try:
143
+ if "image" not in data:
144
+ raise HTTPException(status_code=400, detail="Missing 'image' field in request body")
145
+
146
+ # Decode base64 image
147
+ try:
148
+ image_data = base64.b64decode(data["image"])
149
+ image = Image.open(io.BytesIO(image_data))
150
+ except Exception as e:
151
+ raise HTTPException(status_code=400, detail="Invalid base64 image data")
152
+
153
+ # Process with CarveKit
154
+ result_image, message = process_image_carvekit(image)
155
+
156
+ if result_image is None:
157
+ raise HTTPException(status_code=500, detail=message)
158
+
159
+ # Convert result to base64
160
+ img_byte_arr = io.BytesIO()
161
+ result_image.save(img_byte_arr, format='PNG')
162
+ img_byte_arr.seek(0)
163
+ result_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
164
+
165
+ return {
166
+ "success": True,
167
+ "message": message,
168
+ "result": result_base64
169
+ }
170
+
171
+ except HTTPException:
172
+ raise
173
+ except Exception as e:
174
+ logger.error(f"API error: {e}")
175
+ raise HTTPException(status_code=500, detail=str(e))
176
+
177
+ # Gradio Interface Functions
178
+ def remove_background_gradio(image):
179
+ """Gradio interface function"""
180
+ if image is None:
181
+ return None, "Please upload an image first."
182
+
183
+ result_image, message = process_image_carvekit(image)
184
+ return result_image, message
185
+
186
+ # Create Gradio interface
187
+ with gr.Blocks(
188
+ title="CarveKit Background Remover",
189
+ theme=gr.themes.Soft(),
190
+ css="""
191
+ .gradio-container {
192
+ max-width: 1200px !important;
193
+ }
194
+ .api-info {
195
+ background: #f0f0f0;
196
+ padding: 15px;
197
+ border-radius: 10px;
198
+ margin: 10px 0;
199
+ }
200
+ """
201
+ ) as gradio_app:
202
+
203
+ gr.Markdown("# 🎨 CarveKit Background Remover")
204
+ gr.Markdown("Upload an image to automatically remove its background using CarveKit's advanced AI models.")
205
+
206
+ with gr.Tabs():
207
+ with gr.TabItem("🖼️ Web Interface"):
208
+ with gr.Row():
209
+ with gr.Column(scale=1):
210
+ gr.Markdown("### Input")
211
+ input_image = gr.Image(
212
+ label="Upload Image",
213
+ type="pil",
214
+ height=400,
215
+ sources=["upload", "clipboard"]
216
+ )
217
+
218
+ with gr.Row():
219
+ process_btn = gr.Button(
220
+ "🚀 Remove Background",
221
+ variant="primary",
222
+ size="lg"
223
+ )
224
+ clear_btn = gr.Button(
225
+ "🗑️ Clear",
226
+ variant="secondary"
227
+ )
228
+
229
+ with gr.Column(scale=1):
230
+ gr.Markdown("### Result")
231
+ output_image = gr.Image(
232
+ label="Background Removed",
233
+ type="pil",
234
+ height=400
235
+ )
236
+ status_text = gr.Textbox(
237
+ label="Status",
238
+ value="Ready to process images...",
239
+ interactive=False,
240
+ lines=2
241
+ )
242
+
243
+ with gr.TabItem("🔌 API Documentation"):
244
+ gr.Markdown("""
245
+ ## API Endpoints
246
+
247
+ ### 1. File Upload Endpoint
248
+ **POST** `/api/remove-background`
249
+
250
+ Upload an image file to remove its background.
251
+
252
+ **cURL Example:**
253
+ ```bash
254
+ curl -X POST "https://YOUR_SPACE_URL/api/remove-background" \\
255
+ -H "accept: image/png" \\
256
+ -H "Content-Type: multipart/form-data" \\
257
+ -F "file=@your_image.jpg" \\
258
+ --output result.png
259
+ ```
260
+
261
+ **Python Example:**
262
+ ```python
263
+ import requests
264
+
265
+ url = "https://YOUR_SPACE_URL/api/remove-background"
266
+
267
+ with open("your_image.jpg", "rb") as f:
268
+ files = {"file": f}
269
+ response = requests.post(url, files=files)
270
+
271
+ if response.status_code == 200:
272
+ with open("result.png", "wb") as f:
273
+ f.write(response.content)
274
+ ```
275
+
276
+ ### 2. Base64 Endpoint
277
+ **POST** `/api/remove-background-base64`
278
+
279
+ Send base64 encoded image data.
280
+
281
+ **Request Body:**
282
+ ```json
283
+ {
284
+ "image": "base64_encoded_image_data"
285
+ }
286
+ ```
287
+
288
+ **Python Example:**
289
+ ```python
290
+ import requests
291
+ import base64
292
+
293
+ # Read and encode image
294
+ with open("your_image.jpg", "rb") as f:
295
+ image_data = base64.b64encode(f.read()).decode('utf-8')
296
+
297
+ url = "https://YOUR_SPACE_URL/api/remove-background-base64"
298
+ payload = {"image": image_data}
299
+
300
+ response = requests.post(url, json=payload)
301
+ result = response.json()
302
+
303
+ if result["success"]:
304
+ # Decode result
305
+ result_image = base64.b64decode(result["result"])
306
+ with open("result.png", "wb") as f:
307
+ f.write(result_image)
308
+ ```
309
+
310
+ ### 3. Health Check
311
+ **GET** `/health`
312
+
313
+ Check if the service is running properly.
314
+
315
+ ### 4. API Documentation
316
+ **GET** `/docs` - Interactive API documentation (Swagger UI)
317
+ """, elem_classes=["api-info"])
318
+
319
+ # Event handlers
320
+ process_btn.click(
321
+ fn=remove_background_gradio,
322
+ inputs=[input_image],
323
+ outputs=[output_image, status_text]
324
+ )
325
+
326
+ input_image.change(
327
+ fn=remove_background_gradio,
328
+ inputs=[input_image],
329
+ outputs=[output_image, status_text]
330
+ )
331
+
332
+ clear_btn.click(
333
+ fn=lambda: (None, None, "Ready to process images..."),
334
+ outputs=[input_image, output_image, status_text]
335
+ )
336
+
337
+ # Mount Gradio app
338
+ app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
339
+
340
+ def run_server():
341
+ """Run the FastAPI server"""
342
+ uvicorn.run(
343
+ app,
344
+ host="0.0.0.0",
345
+ port=7860,
346
+ log_level="info"
347
+ )
348
+
349
+ if __name__ == "__main__":
350
+ run_server()