Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import Response | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| from carvekit.api.high import HiInterface | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import asyncio | |
| import threading | |
| import numpy as np | |
| from typing import Optional | |
| import logging | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize CarveKit with proper cache handling | |
| import os | |
| interface = None | |
| def initialize_carvekit(): | |
| """Initialize CarveKit with proper error handling and cache setup""" | |
| global interface | |
| try: | |
| # Set cache directory | |
| cache_dir = os.environ.get('CARVEKIT_CACHE_DIR', '/app/.cache/carvekit') | |
| os.makedirs(cache_dir, exist_ok=True) | |
| # Set environment variable for CarveKit | |
| os.environ['CARVEKIT_CACHE_DIR'] = cache_dir | |
| # Import CarveKit after setting up cache | |
| from carvekit.api.high import HiInterface | |
| interface = HiInterface( | |
| object_type="object", # Can be "object" or "hairs-like" | |
| batch_size_seg=5, | |
| batch_size_matting=1, | |
| device='cpu', # Use 'cuda' if GPU is available | |
| seg_mask_size=640, | |
| matting_mask_size=2048, | |
| trimap_prob_threshold=231, | |
| trimap_kernel_size=30, | |
| trimap_erosion_iters=5, | |
| fp16=False | |
| ) | |
| logger.info("CarveKit interface initialized successfully") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to initialize CarveKit: {e}") | |
| interface = None | |
| return False | |
| # Try to initialize CarveKit | |
| carvekit_ready = initialize_carvekit() | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="CarveKit Background Remover API", | |
| description="API for removing backgrounds from images using CarveKit", | |
| version="1.0.0" | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def process_image_carvekit(image: Image.Image) -> tuple[Optional[Image.Image], str]: | |
| """Process image with CarveKit to remove background""" | |
| try: | |
| if interface is None: | |
| return None, "CarveKit interface not initialized" | |
| if image is None: | |
| return None, "No image provided" | |
| # Convert to RGB if necessary | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Process the image | |
| images_without_bg = interface([image]) | |
| if images_without_bg and len(images_without_bg) > 0: | |
| return images_without_bg[0], "Background removed successfully!" | |
| else: | |
| return None, "Failed to process image" | |
| except Exception as e: | |
| logger.error(f"Error processing image: {e}") | |
| return None, f"Error processing image: {str(e)}" | |
| # API Endpoints | |
| async def root(): | |
| """Root endpoint with API information""" | |
| return { | |
| "message": "CarveKit Background Remover API", | |
| "version": "1.0.0", | |
| "endpoints": { | |
| "remove_background": "/api/remove-background", | |
| "remove_background_base64": "/api/remove-background-base64", | |
| "health": "/health" | |
| }, | |
| "docs": "/docs", | |
| "gradio_interface": "/gradio" | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "carvekit_ready": interface is not None, | |
| "carvekit_initialized": carvekit_ready | |
| } | |
| async def remove_background_api(file: UploadFile = File(...)): | |
| """Remove background from uploaded image file""" | |
| try: | |
| # Validate file type | |
| if not file.content_type.startswith('image/'): | |
| raise HTTPException(status_code=400, detail="File must be an image") | |
| # Read and process image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| # Process with CarveKit | |
| result_image, message = process_image_carvekit(image) | |
| if result_image is None: | |
| raise HTTPException(status_code=500, detail=message) | |
| # Convert result to bytes | |
| img_byte_arr = io.BytesIO() | |
| result_image.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| return Response( | |
| content=img_byte_arr.getvalue(), | |
| media_type="image/png", | |
| headers={"Content-Disposition": "attachment; filename=result.png"} | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"API error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def remove_background_base64(data: dict): | |
| """Remove background from base64 encoded image""" | |
| try: | |
| if "image" not in data: | |
| raise HTTPException(status_code=400, detail="Missing 'image' field in request body") | |
| # Decode base64 image | |
| try: | |
| image_data = base64.b64decode(data["image"]) | |
| image = Image.open(io.BytesIO(image_data)) | |
| except Exception as e: | |
| raise HTTPException(status_code=400, detail="Invalid base64 image data") | |
| # Process with CarveKit | |
| result_image, message = process_image_carvekit(image) | |
| if result_image is None: | |
| raise HTTPException(status_code=500, detail=message) | |
| # Convert result to base64 | |
| img_byte_arr = io.BytesIO() | |
| result_image.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| result_base64 = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8') | |
| return { | |
| "success": True, | |
| "message": message, | |
| "result": result_base64 | |
| } | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"API error: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Gradio Interface Functions | |
| def remove_background_gradio(image): | |
| """Gradio interface function""" | |
| if image is None: | |
| return None, "Please upload an image first." | |
| result_image, message = process_image_carvekit(image) | |
| return result_image, message | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="CarveKit Background Remover", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container { | |
| max-width: 1200px !important; | |
| } | |
| .api-info { | |
| background: #f0f0f0; | |
| padding: 15px; | |
| border-radius: 10px; | |
| margin: 10px 0; | |
| } | |
| """ | |
| ) as gradio_app: | |
| gr.Markdown("# 🎨 CarveKit Background Remover") | |
| gr.Markdown("Upload an image to automatically remove its background using CarveKit's advanced AI models.") | |
| with gr.Tabs(): | |
| with gr.TabItem("🖼️ Web Interface"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Input") | |
| input_image = gr.Image( | |
| label="Upload Image", | |
| type="pil", | |
| height=400, | |
| sources=["upload", "clipboard"] | |
| ) | |
| with gr.Row(): | |
| process_btn = gr.Button( | |
| "🚀 Remove Background", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| clear_btn = gr.Button( | |
| "🗑️ Clear", | |
| variant="secondary" | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Result") | |
| output_image = gr.Image( | |
| label="Background Removed", | |
| type="pil", | |
| height=400 | |
| ) | |
| status_text = gr.Textbox( | |
| label="Status", | |
| value="Ready to process images...", | |
| interactive=False, | |
| lines=2 | |
| ) | |
| with gr.TabItem("🔌 API Documentation"): | |
| gr.Markdown(""" | |
| ## API Endpoints | |
| ### 1. File Upload Endpoint | |
| **POST** `/api/remove-background` | |
| Upload an image file to remove its background. | |
| **cURL Example:** | |
| ```bash | |
| curl -X POST "https://YOUR_SPACE_URL/api/remove-background" \\ | |
| -H "accept: image/png" \\ | |
| -H "Content-Type: multipart/form-data" \\ | |
| -F "file=@your_image.jpg" \\ | |
| --output result.png | |
| ``` | |
| **Python Example:** | |
| ```python | |
| import requests | |
| url = "https://YOUR_SPACE_URL/api/remove-background" | |
| with open("your_image.jpg", "rb") as f: | |
| files = {"file": f} | |
| response = requests.post(url, files=files) | |
| if response.status_code == 200: | |
| with open("result.png", "wb") as f: | |
| f.write(response.content) | |
| ``` | |
| ### 2. Base64 Endpoint | |
| **POST** `/api/remove-background-base64` | |
| Send base64 encoded image data. | |
| **Request Body:** | |
| ```json | |
| { | |
| "image": "base64_encoded_image_data" | |
| } | |
| ``` | |
| **Python Example:** | |
| ```python | |
| import requests | |
| import base64 | |
| # Read and encode image | |
| with open("your_image.jpg", "rb") as f: | |
| image_data = base64.b64encode(f.read()).decode('utf-8') | |
| url = "https://YOUR_SPACE_URL/api/remove-background-base64" | |
| payload = {"image": image_data} | |
| response = requests.post(url, json=payload) | |
| result = response.json() | |
| if result["success"]: | |
| # Decode result | |
| result_image = base64.b64decode(result["result"]) | |
| with open("result.png", "wb") as f: | |
| f.write(result_image) | |
| ``` | |
| ### 3. Health Check | |
| **GET** `/health` | |
| Check if the service is running properly. | |
| ### 4. API Documentation | |
| **GET** `/docs` - Interactive API documentation (Swagger UI) | |
| """, elem_classes=["api-info"]) | |
| # Event handlers | |
| process_btn.click( | |
| fn=remove_background_gradio, | |
| inputs=[input_image], | |
| outputs=[output_image, status_text] | |
| ) | |
| input_image.change( | |
| fn=remove_background_gradio, | |
| inputs=[input_image], | |
| outputs=[output_image, status_text] | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, None, "Ready to process images..."), | |
| outputs=[input_image, output_image, status_text] | |
| ) | |
| # Mount Gradio app | |
| app = gr.mount_gradio_app(app, gradio_app, path="/gradio") | |
| def run_server(): | |
| """Run the FastAPI server""" | |
| uvicorn.run( | |
| app, | |
| host="0.0.0.0", | |
| port=7860, | |
| log_level="info" | |
| ) | |
| if __name__ == "__main__": | |
| run_server() |