import numpy as np import torch import torch.nn.functional as F from torchvision.transforms.functional import normalize import gradio as gr from PIL import Image from typing import Tuple, Dict, List import cv2 from pathlib import Path from briarmbg import MultiTargetBriaRMBG from briacustom import MultiTargetBriaRMBG, ClothingType, GarmentFeatures class ImageProcessor: def __init__(self, model_path: str = "briaai/RMBG-1.4"): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.net = MultiTargetBriaRMBG.from_pretrained(model_path) self.net.to(self.device) self.model_input_size = (1024, 1024) def preprocess_image(self, image: np.ndarray) -> torch.Tensor: """Prepare image for model input""" # Convert numpy array to PIL if isinstance(image, np.ndarray): image = Image.fromarray(image) # Convert to RGB and resize image = image.convert('RGB') image = image.resize(self.model_input_size, Image.LANCZOS) # Convert to tensor im_tensor = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1) im_tensor = torch.unsqueeze(im_tensor, 0) im_tensor = torch.divide(im_tensor, 255.0) im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) return im_tensor.to(self.device) def postprocess_mask(self, mask: torch.Tensor, target_size: Tuple[int, int]) -> Image.Image: """Convert model output mask to PIL Image""" # Resize mask to original image size mask = F.interpolate(mask, size=target_size, mode='bilinear') mask = torch.squeeze(mask, 0) # Normalize mask values mask = (mask - mask.min()) / (mask.max() - mask.min()) # Convert to PIL Image mask_np = (mask * 255).cpu().data.numpy().astype(np.uint8) return Image.fromarray(mask_np[0]) def process_image(self, image: np.ndarray, mode: str = "background_removal", clothing_options: Dict = None) -> Dict[str, Image.Image]: """Main processing function""" # Get original size orig_image = Image.fromarray(image) if isinstance(image, np.ndarray) else image orig_size = orig_image.size # Preprocess input_tensor = self.preprocess_image(image) # Model inference results = self.net(input_tensor, mode=mode) # Process different outputs based on mode outputs = {} if mode == "background_removal" or mode == "all": # Get foreground mask fg_mask = self.postprocess_mask(results["foreground"], orig_size) # Create transparent background image transparent = Image.new("RGBA", orig_size, (0, 0, 0, 0)) transparent.paste(orig_image, mask=fg_mask) outputs["removed_background"] = transparent # Extract background if requested if mode == "all": bg_mask = self.postprocess_mask(results["background"], orig_size) background = Image.new("RGBA", orig_size, (0, 0, 0, 0)) background.paste(orig_image, mask=bg_mask) outputs["background"] = background if mode == "clothing" or mode == "all": clothing_mask = self.postprocess_mask(results["clothing"], orig_size) if clothing_options: # Apply clothing modifications modified = self.apply_clothing_modifications( orig_image, clothing_mask, clothing_options ) outputs["modified_clothing"] = modified # Extract original clothing clothing = Image.new("RGBA", orig_size, (0, 0, 0, 0)) clothing.paste(orig_image, mask=clothing_mask) outputs["clothing"] = clothing return outputs def apply_clothing_modifications(self, image: Image.Image, mask: Image.Image, options: Dict) -> Image.Image: """Apply clothing modifications based on options""" if "color" in options: image = self.change_clothing_color(image, mask, options["color"]) if "pattern" in options: image = self.apply_pattern(image, mask, options["pattern"]) if "style_transfer" in options: image = self.transfer_clothing_style(image, mask, options["style_transfer"]) return image def create_ui() -> gr.Blocks: """Create the Gradio UI""" processor = ImageProcessor() with gr.Blocks() as app: gr.Markdown("# Advanced Background and Clothing Removal") with gr.Tab("Background Removal"): with gr.Row(): with gr.Column(): input_image = gr.Image(label="Input Image", type="numpy") remove_bg_btn = gr.Button("Remove Background") with gr.Column(): output_image = gr.Image(label="Result", type="pil") remove_bg_btn.click( fn=lambda img: processor.process_image(img)["removed_background"], inputs=[input_image], outputs=[output_image] ) with gr.Tab("Clothing Manipulation"): with gr.Row(): with gr.Column(): cloth_input = gr.Image(label="Input Image", type="numpy") with gr.Accordion("Clothing Options"): color_picker = gr.ColorPicker(label="New Color") pattern_choice = gr.Dropdown( choices=["Stripes", "Dots", "Floral"], label="Pattern" ) style_image = gr.Image(label="Style Reference", type="numpy") process_clothing_btn = gr.Button("Process Clothing") with gr.Column(): cloth_output = gr.Image(label="Modified Clothing", type="pil") def process_clothing(image, color, pattern, style): options = {} if color: options["color"] = color if pattern: options["pattern"] = pattern if style is not None: options["style_transfer"] = style return processor.process_image( image, mode="clothing", clothing_options=options )["modified_clothing"] process_clothing_btn.click( fn=process_clothing, inputs=[cloth_input, color_picker, pattern_choice, style_image], outputs=[cloth_output] ) # Examples section examples_dir = Path("./examples") examples = [ [str(examples_dir / f"example_{i}.jpg")] for i in range(1, 4) if (examples_dir / f"example_{i}.jpg").exists() ] if examples: gr.Examples( examples=examples, inputs=[input_image], outputs=[output_image], fn=lambda img: processor.process_image(img)["removed_background"], cache_examples=True ) return app if __name__ == "__main__": app = create_ui() app.launch( share=False, server_name="0.0.0.0", server_port=7860, debug=True )