bg / app.py
Munaf1987's picture
Update app.py
a723b58 verified
import os
import cv2
import numpy as np
import onnxruntime as ort
import uuid
import base64
from io import BytesIO
from PIL import Image
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, Request
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.middleware.cors import CORSMiddleware
import gradio as gr
import shutil
# API key for authentication
API_KEY = os.getenv("API_KEY")
# Initialize FastAPI app
app = FastAPI(title="Background Removal API")
# Add CORS middleware for cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Create tmp folder for ephemeral storage if it doesn't exist
# Make sure this runs BEFORE mounting static files
TMP_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
os.makedirs(TMP_FOLDER, exist_ok=True)
print(f"Created tmp folder at: {TMP_FOLDER}")
# Mount tmp folder as static files too
app.mount("/tmp", StaticFiles(directory=TMP_FOLDER), name="tmp")
templates = Jinja2Templates(directory="templates")
# Load ONNX model
model_path = "BiRefNet-general-resolution_512x512-fp16-epoch_216.onnx"
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
input_name = "input_image"
INPUT_SIZE = (512, 512)
input_info = session.get_inputs()[0]
print("Input name:", input_info.name)
print("Input shape:", input_info.shape)
print("Input type:", input_info.type)
# Authentication dependency
def verify_api_key(api_key: str = Form(...)):
if api_key != API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
return api_key
def preprocess_image(image):
"""Process image from various input types"""
if isinstance(image, str): # Path to image file
img = cv2.imread(image)
elif isinstance(image, np.ndarray): # Already a numpy array
img = image
else: # Bytes or other format
nparr = np.frombuffer(image, np.uint8)
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
# Store original image and its shape
original_img = img.copy()
original_shape = img.shape[:2] # (H, W)
# Convert to RGB for model input
rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
# Resize for model inference (the model requires a specific input size)
resized = cv2.resize(rgb, INPUT_SIZE)
# Normalize for model
normalized = resized.astype(np.float32) / 255.0
normalized = (normalized - 0.5) / 0.5
transposed = np.transpose(normalized, (2, 0, 1))
input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32)
return input_tensor, original_shape, original_img
def apply_mask(original_img, mask_array, original_shape, output_path):
try:
# Get the mask and resize to match original image size
mask = np.squeeze(mask_array)
mask = cv2.resize(mask, (original_shape[1], original_shape[0]))
mask = np.clip(mask, 0, 1)
# Convert mask to binary (0 or 1)
binary_mask = (mask > 0.5).astype(np.uint8)
# Apply binary mask to each channel
img = original_img.astype(np.uint8)
masked_img = cv2.bitwise_and(img, img, mask=binary_mask)
# Create an alpha channel where the mask is 1
alpha = (binary_mask * 255).astype(np.uint8)
# Combine with original image (now only foreground remains)
bgra = cv2.cvtColor(masked_img, cv2.COLOR_BGR2BGRA)
bgra[:, :, 3] = alpha
# Ensure output directory exists
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# Save the image with only the masked object
cv2.imwrite(output_path, bgra, [cv2.IMWRITE_PNG_COMPRESSION, 0])
print(f"Saved masked object image to: {output_path} with size {bgra.shape[:2]}")
return bgra, True
except Exception as e:
print(f"Error applying mask: {e}")
return None, False
# UI endpoint for web interface
@app.post("/")
async def index_post(
request: Request,
main_photo: UploadFile = File(...),
bg_photo: UploadFile = File(None) # Optional background image
):
try:
# Process main photo
main_image_data = await main_photo.read()
input_tensor, original_shape, original_img = preprocess_image(main_image_data)
output = session.run(None, {input_name: input_tensor})
mask = output[0]
result_filename = f"{uuid.uuid4()}.png"
output_path = os.path.join(TMP_FOLDER, result_filename)
# Create alpha-masked image
transparent_img, success = apply_mask(original_img, mask, original_shape, output_path)
final_result_path = output_path
# If background provided, replace it
if bg_photo:
bg_image_data = await bg_photo.read()
bg_np = np.frombuffer(bg_image_data, np.uint8)
bg_img = cv2.imdecode(bg_np, cv2.IMREAD_COLOR)
bg_img_resized = cv2.resize(bg_img, (original_shape[1], original_shape[0]))
alpha = transparent_img[:, :, 3] / 255.0
foreground = transparent_img[:, :, :3]
blended = (foreground * alpha[..., None] + bg_img_resized * (1 - alpha[..., None])).astype(np.uint8)
final_result_path = os.path.join(TMP_FOLDER, f"bg_replaced_{uuid.uuid4()}.png")
cv2.imwrite(final_result_path, blended)
return templates.TemplateResponse("index.html", {
"request": request,
"output_image": os.path.basename(final_result_path)
})
except Exception as e:
import traceback
print("Error in index_post:", str(e))
print(traceback.format_exc())
return templates.TemplateResponse("index.html", {
"request": request,
"error": f"Error: {str(e)}"
})
# API endpoint for background removal
@app.post("/remove-background")
async def remove_background(request: Request, api_key: str = Form(...), main_photo: UploadFile = File(...)):
# Verify API key
verify_api_key(api_key)
try:
# Read and process the uploaded image
image_data = await main_photo.read()
# Generate a unique filename for the processed image
result_filename = f"{uuid.uuid4()}.png"
output_path = os.path.join(TMP_FOLDER, result_filename)
# Ensure tmp directory exists
os.makedirs(TMP_FOLDER, exist_ok=True)
# Process the image
input_tensor, original_shape, original_img = preprocess_image(image_data)
output = session.run(None, {input_name: input_tensor})
mask = output[0]
# Apply mask and save the result
_, success = apply_mask(original_img, mask, original_shape, output_path)
if success:
# Get the base URL - this handles both local and Huggingface deployments
base_url = str(request.base_url)
if base_url.endswith("/"):
base_url = base_url[:-1]
# For Huggingface Spaces, the URL format is slightly different
if "hf.space" in base_url:
# Use the file= format for Huggingface
full_url = f"{base_url}/tmp/{result_filename}"
else:
# Use the direct path for local development
full_url = f"{base_url}/tmp/{result_filename}"
return {
"status": "success",
"message": "Background removed successfully",
"filename": result_filename,
"image_url": full_url
}
else:
return {
"status": "failure",
"message": "Failed to process image"
}
except Exception as e:
import traceback
print(f"Error in remove_background: {str(e)}")
print(traceback.format_exc())
return {
"status": "failure",
"message": f"Error: {str(e)}"
}
# Gradio interface
def process_image_gradio(image):
# Convert gradio image to numpy array
input_tensor, original_shape, original_img = preprocess_image(image)
output = session.run(None, {input_name: input_tensor})
mask = output[0]
# Generate a unique filename
filename = f"{uuid.uuid4()}.png"
output_path = os.path.join(TMP_FOLDER, filename)
# Ensure tmp directory exists
os.makedirs(TMP_FOLDER, exist_ok=True)
# Apply mask and save
result_img, success = apply_mask(original_img, mask, original_shape, output_path)
if success:
# Convert to PIL Image for Gradio
result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGRA2RGBA))
return result_pil
else:
return None
# Create Gradio interface
interface = gr.Interface(
fn=process_image_gradio,
inputs=gr.Image(type="numpy"),
outputs=gr.Image(type="pil"),
title="Background Removal",
description="Upload an image to remove its background"
)
# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, interface, path="/gradio")
# Add these route handlers before the Gradio mounting section
# GET handler for the root path to display the form
@app.get("/")
async def index_get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
# POST handler for the root path to process the form submission
# API endpoint for image processing with base64 response
@app.post("/process_image")
async def process_image(request: Request, image: UploadFile = File(...), api_key: str = Form(...)):
# Verify API key
verify_api_key(api_key)
try:
# Read and process the uploaded image
image_data = await image.read()
# Generate a unique filename for the processed image
result_filename = f"{uuid.uuid4()}.png"
output_path = os.path.join(TMP_FOLDER, result_filename)
# Ensure tmp directory exists
os.makedirs(TMP_FOLDER, exist_ok=True)
# Process the image
input_tensor, original_shape, original_img = preprocess_image(image_data)
output = session.run(None, {input_name: input_tensor})
mask = output[0]
# Apply mask and save the result
bgra, success = apply_mask(original_img, mask, original_shape, output_path)
if success:
# Convert the processed image to base64
with open(output_path, "rb") as img_file:
base64_image = base64.b64encode(img_file.read()).decode('utf-8')
# Return JSON response with base64 image
return {
"status": "success",
"image_code": base64_image
}
else:
return {
"status": "failure",
"message": "Failed to process image"
}
except Exception as e:
import traceback
print(f"Error in process_image: {str(e)}")
print(traceback.format_exc())
return {
"status": "failure",
"message": f"Error: {str(e)}"
}
# Add a download route for the web interface
# Download route for processed images
@app.get("/download/{filename}")
async def download_file(filename: str):
file_path = os.path.join(TMP_FOLDER, filename)
if os.path.exists(file_path):
return FileResponse(
path=file_path,
filename=filename,
media_type="image/png"
)
raise HTTPException(status_code=404, detail="File not found")
# For Huggingface deployment
if __name__ == "__main__":
import uvicorn
# Print current working directory for debugging
print(f"Current working directory: {os.getcwd()}")
print(f"TMP_FOLDER absolute path: {os.path.abspath(TMP_FOLDER)}")
uvicorn.run(app, host="0.0.0.0", port=7860)