import os import cv2 import numpy as np import onnxruntime as ort import base64 from io import BytesIO import gradio as gr import spaces import re # Setup API_KEY = os.getenv("API_KEY", "demo") INPUT_SIZE = (512, 512) MODEL_PATH = "BiRefNet-general-resolution_512x512-fp16-epoch_216.onnx" # Load ONNX model assert os.path.exists(MODEL_PATH), f"Model not found: {MODEL_PATH}" session = ort.InferenceSession(MODEL_PATH, providers=["CPUExecutionProvider"]) input_name = session.get_inputs()[0].name # Preprocess def preprocess_image(image: np.ndarray): original_shape = image.shape[:2] resized = cv2.resize(image, INPUT_SIZE) normalized = (resized.astype(np.float32) / 255.0 - 0.5) / 0.5 transposed = np.transpose(normalized, (2, 0, 1)) input_tensor = np.expand_dims(transposed, axis=0).astype(np.float32) return input_tensor, original_shape, image # Mask logic def apply_mask(original_img, mask_array, original_shape): mask = np.squeeze(mask_array) # Resize the mask back to the original image size resized_mask = cv2.resize(mask, (original_shape[1], original_shape[0])) # Create binary mask binary_mask = (resized_mask > 0.5).astype(np.uint8) # Create alpha channel alpha = (binary_mask * 255).astype(np.uint8) # Convert to BGRA (preserve original image without resizing) if original_img.shape[2] == 3: bgra = cv2.cvtColor(original_img, cv2.COLOR_BGR2BGRA) else: bgra = original_img.copy() bgra[:, :, 3] = alpha return bgra # ============ UI ============ # @spaces.GPU def remove_background_ui(image, bg=None): input_tensor, original_shape, original_img = preprocess_image(image) mask = session.run(None, {input_name: input_tensor})[0] result = apply_mask(original_img, mask, original_shape) if bg is not None: bg_resized = cv2.resize(bg, (original_shape[1], original_shape[0])) alpha = result[:, :, 3] / 255.0 fg = result[:, :, :3] blended = (fg * alpha[..., None] + bg_resized * (1 - alpha[..., None])).astype(np.uint8) return blended return result # ============ File API ============ # @spaces.GPU def remove_background_api(image_file, api_key=""): if api_key != API_KEY: raise gr.Error("❌ Invalid API Key") np_arr = np.fromfile(image_file, np.uint8) original_img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) if original_img is None: raise gr.Error("❌ Unable to decode image.") if len(original_img.shape) == 2: original_img = cv2.cvtColor(original_img, cv2.COLOR_GRAY2BGR) input_tensor, original_shape, _ = preprocess_image(original_img) mask = session.run(None, {input_name: input_tensor})[0] result = apply_mask(original_img, mask, original_shape) success, buffer = cv2.imencode(".png", result) return f"data:image/png;base64,{base64.b64encode(buffer).decode('utf-8')}" # ============ Base64 API ============ # def clean_base64_string(b64_string): if b64_string.startswith('data:image'): b64_string = b64_string.split(',', 1)[1] b64_string = re.sub(r'\s+', '', b64_string) missing_padding = len(b64_string) % 4 if missing_padding: b64_string += '=' * (4 - missing_padding) return b64_string @spaces.GPU def remove_background_base64_api(base64_image, api_key=""): try: if api_key != API_KEY: raise gr.Error("❌ Invalid API Key") if not re.match(r'^(data:image\/[a-zA-Z]+;base64,)?[A-Za-z0-9+/=\s]+$', base64_image): raise gr.Error("❌ Invalid base64 image format") cleaned_b64 = clean_base64_string(base64_image) image_data = base64.b64decode(cleaned_b64) np_arr = np.frombuffer(image_data, np.uint8) original_img = cv2.imdecode(np_arr, cv2.IMREAD_UNCHANGED) if original_img is None: raise gr.Error("❌ Unable to decode image.") if len(original_img.shape) == 2: original_img = cv2.cvtColor(original_img, cv2.COLOR_GRAY2BGR) input_tensor, original_shape, _ = preprocess_image(original_img) mask = session.run(None, {input_name: input_tensor})[0] result = apply_mask(original_img, mask, original_shape) success, buffer = cv2.imencode('.png', result) result_base64 = base64.b64encode(buffer).decode('utf-8') return f"data:image/png;base64,{result_base64}" except Exception as e: raise gr.Error(f"❌ Error processing image: {str(e)}") # ============ Gradio Interfaces ============ # ui = gr.Interface( fn=remove_background_ui, inputs=[ gr.Image(type="numpy", label="Main Image"), gr.Image(type="numpy", label="Optional Background") ], outputs=gr.Image(type="numpy", label="Result"), title="🖼️ Background Remover", description="Upload a photo (and optionally a background)." ) api = gr.Interface( fn=remove_background_api, inputs=[ gr.Image(type="filepath", label="Upload Image"), gr.Text(label="API Key", type="password") ], outputs=gr.Text(label="Base64 PNG"), title="🔐 File API Access", description="POST to `/run/predict` with file + API key." ) api_base64 = gr.Interface( fn=remove_background_base64_api, inputs=[ gr.Text(label="Base64 Image String"), gr.Text(label="API Key", type="password") ], outputs=gr.Text(label="Base64 PNG"), title="🔐 Base64 API Access", description="POST to `/run/predict` with base64 image + API key." ) # Final Gradio app demo = gr.TabbedInterface([ui, api, api_base64], ["Web UI", "File API", "Base64 API"]) if __name__ == "__main__": demo.launch()