import gradio as gr import io import torch import numpy as np from PIL import Image import os import sys # Add current directory to path for model files sys.path.append("/app") # Import model components from briarmbg import BriaRMBG from utilities import preprocess_image, postprocess_image class BackgroundRemover: def __init__(self): self.model = None self.device = None self.load_model() def load_model(self): """Load the RMBG-1.4 model""" try: print("🔄 Loading background removal model...") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = BriaRMBG.from_pretrained("/app") self.model.to(self.device) self.model.eval() print("✅ Model loaded successfully!") except Exception as e: print(f"❌ Error loading model: {e}") self.model = None def remove_background(self, image): """Remove background from image""" if self.model is None: raise Exception("Model not loaded") try: # Convert to RGB if needed input_image = image.convert("RGB") # Preprocess model_input_size = [1024, 1024] orig_im = np.array(input_image) orig_im_size = orig_im.shape[0:2] processed_image = preprocess_image(orig_im, model_input_size).to(self.device) # Inference with torch.no_grad(): result = self.model(processed_image) # Postprocess result_image = postprocess_image(result[0][0], orig_im_size) # Create transparent image pil_mask = Image.fromarray(result_image) no_bg_image = input_image.copy() no_bg_image.putalpha(pil_mask) return no_bg_image except Exception as e: raise Exception(f"Background removal failed: {str(e)}") # Initialize the remover remover = BackgroundRemover() def process_image(image): """Gradio interface function""" try: result = remover.remove_background(image) return result except Exception as e: raise gr.Error(str(e)) # Create Gradio interface demo = gr.Interface( fn=process_image, inputs=gr.Image(type="pil", label="📷 Upload Image"), outputs=gr.Image(type="pil", label="🎨 Background Removed"), title="🎨 Professional Background Remover", description="Upload any image (JPG, PNG, etc) to remove background automatically with AI" ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)