| | import os |
| | import cv2 |
| | import numpy as np |
| | import onnxruntime as ort |
| | import uuid |
| | import base64 |
| | from io import BytesIO |
| | from PIL import Image |
| | from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, Request |
| | from fastapi.responses import FileResponse |
| | from fastapi.staticfiles import StaticFiles |
| | from fastapi.templating import Jinja2Templates |
| | from fastapi.middleware.cors import CORSMiddleware |
| | import gradio as gr |
| | import shutil |
| |
|
| | |
| | API_KEY = os.getenv("API_KEY") |
| |
|
| | |
| | app = FastAPI(title="Background Removal API") |
| |
|
| | |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_credentials=True, |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| | |
| | |
| | TMP_FOLDER = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp") |
| | os.makedirs(TMP_FOLDER, exist_ok=True) |
| | print(f"Created tmp folder at: {TMP_FOLDER}") |
| |
|
| |
|
| | |
| | app.mount("/tmp", StaticFiles(directory=TMP_FOLDER), name="tmp") |
| | templates = Jinja2Templates(directory="templates") |
| |
|
| | |
| | model_path = "BiRefNet-general-resolution_512x512-fp16-epoch_216.onnx" |
| | session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) |
| | input_name = "input_image" |
| | INPUT_SIZE = (512, 512) |
| |
|
| | input_info = session.get_inputs()[0] |
| |
|
| | print("Input name:", input_info.name) |
| | print("Input shape:", input_info.shape) |
| | print("Input type:", input_info.type) |
| |
|
| | |
| | def verify_api_key(api_key: str = Form(...)): |
| | if api_key != API_KEY: |
| | raise HTTPException(status_code=401, detail="Invalid API key") |
| | return api_key |
| |
|
| | def preprocess_image(image): |
| | """Process image from various input types""" |
| | if isinstance(image, str): |
| | img = cv2.imread(image) |
| | elif isinstance(image, np.ndarray): |
| | img = image |
| | else: |
| | nparr = np.frombuffer(image, np.uint8) |
| | img = cv2.imdecode(nparr, cv2.IMREAD_COLOR) |
| | |
| | |
| | original_img = img.copy() |
| | original_shape = img.shape[:2] |
| | |
| | |
| | rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| | |
| | |
| | resized = cv2.resize(rgb, INPUT_SIZE) |
| | |
| | |
| | normalized = resized.astype(np.float32) / 255.0 |
| | normalized = (normalized - 0.5) / 0.5 |
| | transposed = np.transpose(normalized, (2, 0, 1)) |
| | input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32) |
| | |
| | return input_tensor, original_shape, original_img |
| |
|
| |
|
| | def apply_mask(original_img, mask_array, original_shape, output_path): |
| | try: |
| | |
| | mask = np.squeeze(mask_array) |
| | mask = cv2.resize(mask, (original_shape[1], original_shape[0])) |
| | mask = np.clip(mask, 0, 1) |
| |
|
| | |
| | binary_mask = (mask > 0.5).astype(np.uint8) |
| |
|
| | |
| | img = original_img.astype(np.uint8) |
| | masked_img = cv2.bitwise_and(img, img, mask=binary_mask) |
| |
|
| | |
| | alpha = (binary_mask * 255).astype(np.uint8) |
| |
|
| | |
| | bgra = cv2.cvtColor(masked_img, cv2.COLOR_BGR2BGRA) |
| | bgra[:, :, 3] = alpha |
| |
|
| | |
| | os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| |
|
| | |
| | cv2.imwrite(output_path, bgra, [cv2.IMWRITE_PNG_COMPRESSION, 0]) |
| | print(f"Saved masked object image to: {output_path} with size {bgra.shape[:2]}") |
| |
|
| | return bgra, True |
| | except Exception as e: |
| | print(f"Error applying mask: {e}") |
| | return None, False |
| |
|
| |
|
| | |
| | @app.post("/") |
| | async def index_post( |
| | request: Request, |
| | main_photo: UploadFile = File(...), |
| | bg_photo: UploadFile = File(None) |
| | ): |
| | try: |
| | |
| | main_image_data = await main_photo.read() |
| | input_tensor, original_shape, original_img = preprocess_image(main_image_data) |
| | output = session.run(None, {input_name: input_tensor}) |
| | mask = output[0] |
| |
|
| | result_filename = f"{uuid.uuid4()}.png" |
| | output_path = os.path.join(TMP_FOLDER, result_filename) |
| |
|
| | |
| | transparent_img, success = apply_mask(original_img, mask, original_shape, output_path) |
| | final_result_path = output_path |
| |
|
| | |
| | if bg_photo: |
| | bg_image_data = await bg_photo.read() |
| | bg_np = np.frombuffer(bg_image_data, np.uint8) |
| | bg_img = cv2.imdecode(bg_np, cv2.IMREAD_COLOR) |
| | bg_img_resized = cv2.resize(bg_img, (original_shape[1], original_shape[0])) |
| |
|
| | alpha = transparent_img[:, :, 3] / 255.0 |
| | foreground = transparent_img[:, :, :3] |
| |
|
| | blended = (foreground * alpha[..., None] + bg_img_resized * (1 - alpha[..., None])).astype(np.uint8) |
| | final_result_path = os.path.join(TMP_FOLDER, f"bg_replaced_{uuid.uuid4()}.png") |
| | cv2.imwrite(final_result_path, blended) |
| |
|
| | return templates.TemplateResponse("index.html", { |
| | "request": request, |
| | "output_image": os.path.basename(final_result_path) |
| | }) |
| |
|
| | except Exception as e: |
| | import traceback |
| | print("Error in index_post:", str(e)) |
| | print(traceback.format_exc()) |
| | return templates.TemplateResponse("index.html", { |
| | "request": request, |
| | "error": f"Error: {str(e)}" |
| | }) |
| | |
| | @app.post("/remove-background") |
| | async def remove_background(request: Request, api_key: str = Form(...), main_photo: UploadFile = File(...)): |
| | |
| | verify_api_key(api_key) |
| | |
| | try: |
| | |
| | image_data = await main_photo.read() |
| | |
| | |
| | result_filename = f"{uuid.uuid4()}.png" |
| | output_path = os.path.join(TMP_FOLDER, result_filename) |
| | |
| | |
| | os.makedirs(TMP_FOLDER, exist_ok=True) |
| | |
| | |
| | input_tensor, original_shape, original_img = preprocess_image(image_data) |
| | output = session.run(None, {input_name: input_tensor}) |
| | mask = output[0] |
| | |
| | |
| | _, success = apply_mask(original_img, mask, original_shape, output_path) |
| | |
| | if success: |
| | |
| | base_url = str(request.base_url) |
| | if base_url.endswith("/"): |
| | base_url = base_url[:-1] |
| | |
| | |
| | if "hf.space" in base_url: |
| | |
| | full_url = f"{base_url}/tmp/{result_filename}" |
| | else: |
| | |
| | full_url = f"{base_url}/tmp/{result_filename}" |
| | |
| | return { |
| | "status": "success", |
| | "message": "Background removed successfully", |
| | "filename": result_filename, |
| | "image_url": full_url |
| | } |
| | else: |
| | return { |
| | "status": "failure", |
| | "message": "Failed to process image" |
| | } |
| | |
| | except Exception as e: |
| | import traceback |
| | print(f"Error in remove_background: {str(e)}") |
| | print(traceback.format_exc()) |
| | return { |
| | "status": "failure", |
| | "message": f"Error: {str(e)}" |
| | } |
| |
|
| | |
| | def process_image_gradio(image): |
| | |
| | input_tensor, original_shape, original_img = preprocess_image(image) |
| | output = session.run(None, {input_name: input_tensor}) |
| | mask = output[0] |
| | |
| | |
| | filename = f"{uuid.uuid4()}.png" |
| | output_path = os.path.join(TMP_FOLDER, filename) |
| | |
| | |
| | os.makedirs(TMP_FOLDER, exist_ok=True) |
| | |
| | |
| | result_img, success = apply_mask(original_img, mask, original_shape, output_path) |
| | |
| | if success: |
| | |
| | result_pil = Image.fromarray(cv2.cvtColor(result_img, cv2.COLOR_BGRA2RGBA)) |
| | return result_pil |
| | else: |
| | return None |
| |
|
| | |
| | interface = gr.Interface( |
| | fn=process_image_gradio, |
| | inputs=gr.Image(type="numpy"), |
| | outputs=gr.Image(type="pil"), |
| | title="Background Removal", |
| | description="Upload an image to remove its background" |
| | ) |
| |
|
| | |
| | app = gr.mount_gradio_app(app, interface, path="/gradio") |
| |
|
| | |
| |
|
| | |
| | @app.get("/") |
| | async def index_get(request: Request): |
| | return templates.TemplateResponse("index.html", {"request": request}) |
| |
|
| | |
| | |
| | @app.post("/process_image") |
| | async def process_image(request: Request, image: UploadFile = File(...), api_key: str = Form(...)): |
| | |
| | verify_api_key(api_key) |
| | |
| | try: |
| | |
| | image_data = await image.read() |
| | |
| | |
| | result_filename = f"{uuid.uuid4()}.png" |
| | output_path = os.path.join(TMP_FOLDER, result_filename) |
| | |
| | |
| | os.makedirs(TMP_FOLDER, exist_ok=True) |
| | |
| | |
| | input_tensor, original_shape, original_img = preprocess_image(image_data) |
| | output = session.run(None, {input_name: input_tensor}) |
| | mask = output[0] |
| | |
| | |
| | bgra, success = apply_mask(original_img, mask, original_shape, output_path) |
| | |
| | if success: |
| | |
| | with open(output_path, "rb") as img_file: |
| | base64_image = base64.b64encode(img_file.read()).decode('utf-8') |
| | |
| | |
| | return { |
| | "status": "success", |
| | "image_code": base64_image |
| | } |
| | else: |
| | return { |
| | "status": "failure", |
| | "message": "Failed to process image" |
| | } |
| | |
| | except Exception as e: |
| | import traceback |
| | print(f"Error in process_image: {str(e)}") |
| | print(traceback.format_exc()) |
| | return { |
| | "status": "failure", |
| | "message": f"Error: {str(e)}" |
| | } |
| |
|
| | |
| | |
| | @app.get("/download/{filename}") |
| | async def download_file(filename: str): |
| | file_path = os.path.join(TMP_FOLDER, filename) |
| | if os.path.exists(file_path): |
| | return FileResponse( |
| | path=file_path, |
| | filename=filename, |
| | media_type="image/png" |
| | ) |
| | raise HTTPException(status_code=404, detail="File not found") |
| |
|
| | |
| | if __name__ == "__main__": |
| | import uvicorn |
| | |
| | print(f"Current working directory: {os.getcwd()}") |
| | print(f"TMP_FOLDER absolute path: {os.path.abspath(TMP_FOLDER)}") |
| | uvicorn.run(app, host="0.0.0.0", port=7860) |
| |
|
| |
|