import os import cv2 import numpy as np import onnxruntime as ort import uuid import base64 from io import BytesIO from PIL import Image from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, Request from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from fastapi.middleware.cors import CORSMiddleware import gradio as gr import shutil # API key for authentication API_KEY = os.getenv("API_KEY") # Initialize FastAPI app app = FastAPI(title="Background Removal API") # Add CORS middleware for cross-origin requests app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Create tmp folder for ephemeral storage if it doesn't exist # Make sure this runs BEFORE mounting static files TMP_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") os.makedirs(TMP_FOLDER, exist_ok=True) print(f"Created tmp folder at: {TMP_FOLDER}") # Mount tmp folder as static files too app.mount("/tmp", StaticFiles(directory=TMP_FOLDER), name="tmp") templates = Jinja2Templates(directory="templates") # Load ONNX model model_path = "BiRefNet-general-resolution_512x512-fp16-epoch_216.onnx" session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) input_name = "input_image" INPUT_SIZE = (512, 512) input_info = session.get_inputs()[0] print("Input name:", input_info.name) print("Input shape:", input_info.shape) print("Input type:", input_info.type) # Authentication dependency def verify_api_key(api_key: str = Form(...)): if api_key != API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") return api_key def preprocess_image(image): """Process image from various input types""" if isinstance(image, str): # Path to image file img = cv2.imread(image) elif isinstance(image, np.ndarray): # Already a numpy array img = image else: # Bytes or other format nparr = np.frombuffer(image, np.uint8) img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) # Store original image and its shape original_img = img.copy() original_shape = img.shape[:2] # (H, W) # Convert to RGB for model input rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # Resize for model inference (the model requires a specific input size) resized = cv2.resize(rgb, INPUT_SIZE) # Normalize for model normalized = resized.astype(np.float32) / 255.0 normalized = (normalized - 0.5) / 0.5 transposed = np.transpose(normalized, (2, 0, 1)) input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32) return input_tensor, original_shape, original_img def apply_mask(original_img, mask_array, original_shape, output_path): try: # Get the mask and resize to match original image size mask = np.squeeze(mask_array) mask = cv2.resize(mask, (original_shape[1], original_shape[0])) mask = np.clip(mask, 0, 1) # Convert mask to binary (0 or 1) binary_mask = (mask > 0.5).astype(np.uint8) # Apply binary mask to each channel img = original_img.astype(np.uint8) masked_img = cv2.bitwise_and(img, img, mask=binary_mask) # Create an alpha channel where the mask is 1 alpha = (binary_mask * 255).astype(np.uint8) # Combine with original image (now only foreground remains) bgra = cv2.cvtColor(masked_img, cv2.COLOR_BGR2BGRA) bgra[:, :, 3] = alpha # Ensure output directory exists os.makedirs(os.path.dirname(output_path), exist_ok=True) # Save the image with only the masked object cv2.imwrite(output_path, bgra, [cv2.IMWRITE_PNG_COMPRESSION, 0]) print(f"Saved masked object image to: {output_path} with size {bgra.shape[:2]}") return bgra, True except Exception as e: print(f"Error applying mask: {e}") return None, False # UI endpoint for web interface @app.post("/") async def index_post( request: Request, main_photo: UploadFile = File(...), bg_photo: UploadFile = File(None) # Optional background image ): try: # Process main photo main_image_data = await main_photo.read() input_tensor, original_shape, original_img = preprocess_image(main_image_data) output = session.run(None, {input_name: input_tensor}) mask = output[0] result_filename = f"{uuid.uuid4()}.png" output_path = os.path.join(TMP_FOLDER, result_filename) # Create alpha-masked image transparent_img, success = apply_mask(original_img, mask, original_shape, output_path) final_result_path = output_path # If background provided, replace it if bg_photo: bg_image_data = await bg_photo.read() bg_np = np.frombuffer(bg_image_data, np.uint8) bg_img = cv2.imdecode(bg_np, cv2.IMREAD_COLOR) bg_img_resized = cv2.resize(bg_img, (original_shape[1], original_shape[0])) alpha = transparent_img[:, :, 3] / 255.0 foreground = transparent_img[:, :, :3] blended = (foreground * alpha[..., None] + bg_img_resized * (1 - alpha[..., None])).astype(np.uint8) final_result_path = os.path.join(TMP_FOLDER, f"bg_replaced_{uuid.uuid4()}.png") cv2.imwrite(final_result_path, blended) return templates.TemplateResponse("index.html", { "request": request, "output_image": os.path.basename(final_result_path) }) except Exception as e: import traceback print("Error in index_post:", str(e)) print(traceback.format_exc()) return templates.TemplateResponse("index.html", { "request": request, "error": f"Error: {str(e)}" }) # API endpoint for background removal @app.post("/remove-background") async def remove_background(request: Request, api_key: str = Form(...), main_photo: UploadFile = File(...)): # Verify API key verify_api_key(api_key) try: # Read and process the uploaded image image_data = await main_photo.read() # Generate a unique filename for the processed image result_filename = f"{uuid.uuid4()}.png" output_path = os.path.join(TMP_FOLDER, result_filename) # Ensure tmp directory exists os.makedirs(TMP_FOLDER, exist_ok=True) # Process the image input_tensor, original_shape, original_img = preprocess_image(image_data) output = session.run(None, {input_name: input_tensor}) mask = output[0] # Apply mask and save the result _, success = apply_mask(original_img, mask, original_shape, output_path) if success: # Get the base URL - this handles both local and Huggingface deployments base_url = str(request.base_url) if base_url.endswith("/"): base_url = base_url[:-1] # For Huggingface Spaces, the URL format is slightly different if "hf.space" in base_url: # Use the file= format for Huggingface full_url = f"{base_url}/tmp/{result_filename}" else: # Use the direct path for local development full_url = f"{base_url}/tmp/{result_filename}" return { "status": "success", "message": "Background removed successfully", "filename": result_filename, "image_url": full_url } else: return { "status": "failure", "message": "Failed to process image" } except Exception as e: import traceback print(f"Error in remove_background: {str(e)}") print(traceback.format_exc()) return { "status": "failure", "message": f"Error: {str(e)}" } # Gradio interface def process_image_gradio(image): # Convert gradio image to numpy array input_tensor, original_shape, original_img = preprocess_image(image) output = session.run(None, {input_name: input_tensor}) mask = output[0] # Generate a unique filename filename = f"{uuid.uuid4()}.png" output_path = os.path.join(TMP_FOLDER, filename) # Ensure tmp directory exists os.makedirs(TMP_FOLDER, exist_ok=True) # Apply mask and save result_img, success = apply_mask(original_img, mask, original_shape, output_path) if success: # Convert to PIL Image for Gradio result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGRA2RGBA)) return result_pil else: return None # Create Gradio interface interface = gr.Interface( fn=process_image_gradio, inputs=gr.Image(type="numpy"), outputs=gr.Image(type="pil"), title="Background Removal", description="Upload an image to remove its background" ) # Mount Gradio app to FastAPI app = gr.mount_gradio_app(app, interface, path="/gradio") # Add these route handlers before the Gradio mounting section # GET handler for the root path to display the form @app.get("/") async def index_get(request: Request): return templates.TemplateResponse("index.html", {"request": request}) # POST handler for the root path to process the form submission # API endpoint for image processing with base64 response @app.post("/process_image") async def process_image(request: Request, image: UploadFile = File(...), api_key: str = Form(...)): # Verify API key verify_api_key(api_key) try: # Read and process the uploaded image image_data = await image.read() # Generate a unique filename for the processed image result_filename = f"{uuid.uuid4()}.png" output_path = os.path.join(TMP_FOLDER, result_filename) # Ensure tmp directory exists os.makedirs(TMP_FOLDER, exist_ok=True) # Process the image input_tensor, original_shape, original_img = preprocess_image(image_data) output = session.run(None, {input_name: input_tensor}) mask = output[0] # Apply mask and save the result bgra, success = apply_mask(original_img, mask, original_shape, output_path) if success: # Convert the processed image to base64 with open(output_path, "rb") as img_file: base64_image = base64.b64encode(img_file.read()).decode('utf-8') # Return JSON response with base64 image return { "status": "success", "image_code": base64_image } else: return { "status": "failure", "message": "Failed to process image" } except Exception as e: import traceback print(f"Error in process_image: {str(e)}") print(traceback.format_exc()) return { "status": "failure", "message": f"Error: {str(e)}" } # Add a download route for the web interface # Download route for processed images @app.get("/download/{filename}") async def download_file(filename: str): file_path = os.path.join(TMP_FOLDER, filename) if os.path.exists(file_path): return FileResponse( path=file_path, filename=filename, media_type="image/png" ) raise HTTPException(status_code=404, detail="File not found") # For Huggingface deployment if __name__ == "__main__": import uvicorn # Print current working directory for debugging print(f"Current working directory: {os.getcwd()}") print(f"TMP_FOLDER absolute path: {os.path.abspath(TMP_FOLDER)}") uvicorn.run(app, host="0.0.0.0", port=7860)