bgg / app.py
Munaf1987's picture
Update app.py
8967da4 verified
import os
import cv2
import numpy as np
import onnxruntime as ort
import base64
from io import BytesIO
import gradio as gr
import spaces
import re
# Setup
API_KEY = os.getenv("API_KEY", "demo")
INPUT_SIZE = (512, 512)
MODEL_PATH = "BiRefNet-general-resolution_512x512-fp16-epoch_216.onnx"
# Load ONNX model
assert os.path.exists(MODEL_PATH), f"Model not found: {MODEL_PATH}"
session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"])
input_name = session.get_inputs()[0].name
# Preprocess
def preprocess_image(image: np.ndarray):
original_shape = image.shape[:2]
resized = cv2.resize(image, INPUT_SIZE)
normalized = (resized.astype(np.float32) / 255.0 - 0.5) / 0.5
transposed = np.transpose(normalized, (2, 0, 1))
input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32)
return input_tensor, original_shape, image
# Mask logic
def apply_mask(original_img, mask_array, original_shape):
mask = np.squeeze(mask_array)
# Resize the mask back to the original image size
resized_mask = cv2.resize(mask, (original_shape[1], original_shape[0]))
# Create binary mask
binary_mask = (resized_mask > 0.5).astype(np.uint8)
# Create alpha channel
alpha = (binary_mask * 255).astype(np.uint8)
# Convert to BGRA (preserve original image without resizing)
if original_img.shape[2] == 3:
bgra = cv2.cvtColor(original_img, cv2.COLOR_BGR2BGRA)
else:
bgra = original_img.copy()
bgra[:, :, 3] = alpha
return bgra
# ============ UI ============ #
@spaces.GPU
def remove_background_ui(image, bg=None):
input_tensor, original_shape, original_img = preprocess_image(image)
mask = session.run(None, {input_name: input_tensor})[0]
result = apply_mask(original_img, mask, original_shape)
if bg is not None:
bg_resized = cv2.resize(bg, (original_shape[1], original_shape[0]))
alpha = result[:, :, 3] / 255.0
fg = result[:, :, :3]
blended = (fg * alpha[..., None] + bg_resized * (1 - alpha[..., None])).astype(np.uint8)
return blended
return result
# ============ File API ============ #
@spaces.GPU
def remove_background_api(image_file, api_key=""):
if api_key != API_KEY:
raise gr.Error("❌ Invalid API Key")
np_arr = np.fromfile(image_file, np.uint8)
original_img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
if original_img is None:
raise gr.Error("❌ Unable to decode image.")
if len(original_img.shape) == 2:
original_img = cv2.cvtColor(original_img, cv2.COLOR_GRAY2BGR)
input_tensor, original_shape, _ = preprocess_image(original_img)
mask = session.run(None, {input_name: input_tensor})[0]
result = apply_mask(original_img, mask, original_shape)
success, buffer = cv2.imencode(".png", result)
return f"data:image/png;base64,{base64.b64encode(buffer).decode('utf-8')}"
# ============ Base64 API ============ #
def clean_base64_string(b64_string):
if b64_string.startswith('data:image'):
b64_string = b64_string.split(',', 1)[1]
b64_string = re.sub(r'\s+', '', b64_string)
missing_padding = len(b64_string) % 4
if missing_padding:
b64_string += '=' * (4 - missing_padding)
return b64_string
@spaces.GPU
def remove_background_base64_api(base64_image, api_key=""):
try:
if api_key != API_KEY:
raise gr.Error("❌ Invalid API Key")
if not re.match(r'^(data:image\/[a-zA-Z]+;base64,)?[A-Za-z0-9+/=\s]+$', base64_image):
raise gr.Error("❌ Invalid base64 image format")
cleaned_b64 = clean_base64_string(base64_image)
image_data = base64.b64decode(cleaned_b64)
np_arr = np.frombuffer(image_data, np.uint8)
original_img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED)
if original_img is None:
raise gr.Error("❌ Unable to decode image.")
if len(original_img.shape) == 2:
original_img = cv2.cvtColor(original_img, cv2.COLOR_GRAY2BGR)
input_tensor, original_shape, _ = preprocess_image(original_img)
mask = session.run(None, {input_name: input_tensor})[0]
result = apply_mask(original_img, mask, original_shape)
success, buffer = cv2.imencode('.png', result)
result_base64 = base64.b64encode(buffer).decode('utf-8')
return f"data:image/png;base64,{result_base64}"
except Exception as e:
raise gr.Error(f"❌ Error processing image: {str(e)}")
# ============ Gradio Interfaces ============ #
ui = gr.Interface(
fn=remove_background_ui,
inputs=[
gr.Image(type="numpy", label="Main Image"),
gr.Image(type="numpy", label="Optional Background")
],
outputs=gr.Image(type="numpy", label="Result"),
title="πŸ–ΌοΈ Background Remover",
description="Upload a photo (and optionally a background)."
)
api = gr.Interface(
fn=remove_background_api,
inputs=[
gr.Image(type="filepath", label="Upload Image"),
gr.Text(label="API Key", type="password")
],
outputs=gr.Text(label="Base64 PNG"),
title="πŸ” File API Access",
description="POST to `/run/predict` with file + API key."
)
api_base64 = gr.Interface(
fn=remove_background_base64_api,
inputs=[
gr.Text(label="Base64 Image String"),
gr.Text(label="API Key", type="password")
],
outputs=gr.Text(label="Base64 PNG"),
title="πŸ” Base64 API Access",
description="POST to `/run/predict` with base64 image + API key."
)
# Final Gradio app
demo = gr.TabbedInterface([ui, api, api_base64], ["Web UI", "File API", "Base64 API"])
if __name__ == "__main__":
demo.launch()