| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torchvision.transforms.functional import normalize |
| | import gradio as gr |
| | from PIL import Image |
| | from typing import Tuple, Dict, List |
| | import cv2 |
| | from pathlib import Path |
| | from briarmbg import MultiTargetBriaRMBG |
| | from briacustom import MultiTargetBriaRMBG, ClothingType, GarmentFeatures |
| |
|
| | class ImageProcessor: |
| | def __init__(self, model_path: str = "briaai/RMBG-1.4"): |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | self.net = MultiTargetBriaRMBG.from_pretrained(model_path) |
| | self.net.to(self.device) |
| | self.model_input_size = (1024, 1024) |
| | |
| | def preprocess_image(self, image: np.ndarray) -> torch.Tensor: |
| | """Prepare image for model input""" |
| | |
| | if isinstance(image, np.ndarray): |
| | image = Image.fromarray(image) |
| | |
| | |
| | image = image.convert('RGB') |
| | image = image.resize(self.model_input_size, Image.LANCZOS) |
| | |
| | |
| | im_tensor = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1) |
| | im_tensor = torch.unsqueeze(im_tensor, 0) |
| | im_tensor = torch.divide(im_tensor, 255.0) |
| | im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) |
| | |
| | return im_tensor.to(self.device) |
| |
|
| | def postprocess_mask(self, mask: torch.Tensor, target_size: Tuple[int, int]) -> Image.Image: |
| | """Convert model output mask to PIL Image""" |
| | |
| | mask = F.interpolate(mask, size=target_size, mode='bilinear') |
| | mask = torch.squeeze(mask, 0) |
| | |
| | |
| | mask = (mask - mask.min()) / (mask.max() - mask.min()) |
| | |
| | |
| | mask_np = (mask * 255).cpu().data.numpy().astype(np.uint8) |
| | return Image.fromarray(mask_np[0]) |
| |
|
| | def process_image(self, |
| | image: np.ndarray, |
| | mode: str = "background_removal", |
| | clothing_options: Dict = None) -> Dict[str, Image.Image]: |
| | """Main processing function""" |
| | |
| | orig_image = Image.fromarray(image) if isinstance(image, np.ndarray) else image |
| | orig_size = orig_image.size |
| | |
| | |
| | input_tensor = self.preprocess_image(image) |
| | |
| | |
| | results = self.net(input_tensor, mode=mode) |
| | |
| | |
| | outputs = {} |
| | |
| | if mode == "background_removal" or mode == "all": |
| | |
| | fg_mask = self.postprocess_mask(results["foreground"], orig_size) |
| | |
| | |
| | transparent = Image.new("RGBA", orig_size, (0, 0, 0, 0)) |
| | transparent.paste(orig_image, mask=fg_mask) |
| | outputs["removed_background"] = transparent |
| | |
| | |
| | if mode == "all": |
| | bg_mask = self.postprocess_mask(results["background"], orig_size) |
| | background = Image.new("RGBA", orig_size, (0, 0, 0, 0)) |
| | background.paste(orig_image, mask=bg_mask) |
| | outputs["background"] = background |
| | |
| | if mode == "clothing" or mode == "all": |
| | clothing_mask = self.postprocess_mask(results["clothing"], orig_size) |
| | |
| | if clothing_options: |
| | |
| | modified = self.apply_clothing_modifications( |
| | orig_image, |
| | clothing_mask, |
| | clothing_options |
| | ) |
| | outputs["modified_clothing"] = modified |
| | |
| | |
| | clothing = Image.new("RGBA", orig_size, (0, 0, 0, 0)) |
| | clothing.paste(orig_image, mask=clothing_mask) |
| | outputs["clothing"] = clothing |
| | |
| | return outputs |
| |
|
| | def apply_clothing_modifications(self, |
| | image: Image.Image, |
| | mask: Image.Image, |
| | options: Dict) -> Image.Image: |
| | """Apply clothing modifications based on options""" |
| | if "color" in options: |
| | image = self.change_clothing_color(image, mask, options["color"]) |
| | |
| | if "pattern" in options: |
| | image = self.apply_pattern(image, mask, options["pattern"]) |
| | |
| | if "style_transfer" in options: |
| | image = self.transfer_clothing_style(image, mask, options["style_transfer"]) |
| | |
| | return image |
| |
|
| | def create_ui() -> gr.Blocks: |
| | """Create the Gradio UI""" |
| | processor = ImageProcessor() |
| | |
| | with gr.Blocks() as app: |
| | gr.Markdown("# Advanced Background and Clothing Removal") |
| | |
| | with gr.Tab("Background Removal"): |
| | with gr.Row(): |
| | with gr.Column(): |
| | input_image = gr.Image(label="Input Image", type="numpy") |
| | remove_bg_btn = gr.Button("Remove Background") |
| | |
| | with gr.Column(): |
| | output_image = gr.Image(label="Result", type="pil") |
| | |
| | remove_bg_btn.click( |
| | fn=lambda img: processor.process_image(img)["removed_background"], |
| | inputs=[input_image], |
| | outputs=[output_image] |
| | ) |
| | |
| | with gr.Tab("Clothing Manipulation"): |
| | with gr.Row(): |
| | with gr.Column(): |
| | cloth_input = gr.Image(label="Input Image", type="numpy") |
| | |
| | with gr.Accordion("Clothing Options"): |
| | color_picker = gr.ColorPicker(label="New Color") |
| | pattern_choice = gr.Dropdown( |
| | choices=["Stripes", "Dots", "Floral"], |
| | label="Pattern" |
| | ) |
| | style_image = gr.Image(label="Style Reference", type="numpy") |
| | |
| | process_clothing_btn = gr.Button("Process Clothing") |
| | |
| | with gr.Column(): |
| | cloth_output = gr.Image(label="Modified Clothing", type="pil") |
| | |
| | def process_clothing(image, color, pattern, style): |
| | options = {} |
| | if color: |
| | options["color"] = color |
| | if pattern: |
| | options["pattern"] = pattern |
| | if style is not None: |
| | options["style_transfer"] = style |
| | |
| | return processor.process_image( |
| | image, |
| | mode="clothing", |
| | clothing_options=options |
| | )["modified_clothing"] |
| | |
| | process_clothing_btn.click( |
| | fn=process_clothing, |
| | inputs=[cloth_input, color_picker, pattern_choice, style_image], |
| | outputs=[cloth_output] |
| | ) |
| | |
| | |
| | examples_dir = Path("./examples") |
| | examples = [ |
| | [str(examples_dir / f"example_{i}.jpg")] |
| | for i in range(1, 4) |
| | if (examples_dir / f"example_{i}.jpg").exists() |
| | ] |
| | |
| | if examples: |
| | gr.Examples( |
| | examples=examples, |
| | inputs=[input_image], |
| | outputs=[output_image], |
| | fn=lambda img: processor.process_image(img)["removed_background"], |
| | cache_examples=True |
| | ) |
| | |
| | return app |
| |
|
| | if __name__ == "__main__": |
| | app = create_ui() |
| | app.launch( |
| | share=False, |
| | server_name="0.0.0.0", |
| | server_port=7860, |
| | debug=True |
| | ) |