File size: 7,968 Bytes
57d19ed 893f4a7 4d0b07a 57d19ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 | import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from PIL import Image
from typing import Tuple, Dict, List
import cv2
from pathlib import Path
from briarmbg import MultiTargetBriaRMBG
from briacustom import MultiTargetBriaRMBG, ClothingType, GarmentFeatures
class ImageProcessor:
def __init__(self, model_path: str = "briaai/RMBG-1.4"):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.net = MultiTargetBriaRMBG.from_pretrained(model_path)
self.net.to(self.device)
self.model_input_size = (1024, 1024)
def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
"""Prepare image for model input"""
# Convert numpy array to PIL
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Convert to RGB and resize
image = image.convert('RGB')
image = image.resize(self.model_input_size, Image.LANCZOS)
# Convert to tensor
im_tensor = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1)
im_tensor = torch.unsqueeze(im_tensor, 0)
im_tensor = torch.divide(im_tensor, 255.0)
im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
return im_tensor.to(self.device)
def postprocess_mask(self, mask: torch.Tensor, target_size: Tuple[int, int]) -> Image.Image:
"""Convert model output mask to PIL Image"""
# Resize mask to original image size
mask = F.interpolate(mask, size=target_size, mode='bilinear')
mask = torch.squeeze(mask, 0)
# Normalize mask values
mask = (mask - mask.min()) / (mask.max() - mask.min())
# Convert to PIL Image
mask_np = (mask * 255).cpu().data.numpy().astype(np.uint8)
return Image.fromarray(mask_np[0])
def process_image(self,
image: np.ndarray,
mode: str = "background_removal",
clothing_options: Dict = None) -> Dict[str, Image.Image]:
"""Main processing function"""
# Get original size
orig_image = Image.fromarray(image) if isinstance(image, np.ndarray) else image
orig_size = orig_image.size
# Preprocess
input_tensor = self.preprocess_image(image)
# Model inference
results = self.net(input_tensor, mode=mode)
# Process different outputs based on mode
outputs = {}
if mode == "background_removal" or mode == "all":
# Get foreground mask
fg_mask = self.postprocess_mask(results["foreground"], orig_size)
# Create transparent background image
transparent = Image.new("RGBA", orig_size, (0, 0, 0, 0))
transparent.paste(orig_image, mask=fg_mask)
outputs["removed_background"] = transparent
# Extract background if requested
if mode == "all":
bg_mask = self.postprocess_mask(results["background"], orig_size)
background = Image.new("RGBA", orig_size, (0, 0, 0, 0))
background.paste(orig_image, mask=bg_mask)
outputs["background"] = background
if mode == "clothing" or mode == "all":
clothing_mask = self.postprocess_mask(results["clothing"], orig_size)
if clothing_options:
# Apply clothing modifications
modified = self.apply_clothing_modifications(
orig_image,
clothing_mask,
clothing_options
)
outputs["modified_clothing"] = modified
# Extract original clothing
clothing = Image.new("RGBA", orig_size, (0, 0, 0, 0))
clothing.paste(orig_image, mask=clothing_mask)
outputs["clothing"] = clothing
return outputs
def apply_clothing_modifications(self,
image: Image.Image,
mask: Image.Image,
options: Dict) -> Image.Image:
"""Apply clothing modifications based on options"""
if "color" in options:
image = self.change_clothing_color(image, mask, options["color"])
if "pattern" in options:
image = self.apply_pattern(image, mask, options["pattern"])
if "style_transfer" in options:
image = self.transfer_clothing_style(image, mask, options["style_transfer"])
return image
def create_ui() -> gr.Blocks:
"""Create the Gradio UI"""
processor = ImageProcessor()
with gr.Blocks() as app:
gr.Markdown("# Advanced Background and Clothing Removal")
with gr.Tab("Background Removal"):
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
remove_bg_btn = gr.Button("Remove Background")
with gr.Column():
output_image = gr.Image(label="Result", type="pil")
remove_bg_btn.click(
fn=lambda img: processor.process_image(img)["removed_background"],
inputs=[input_image],
outputs=[output_image]
)
with gr.Tab("Clothing Manipulation"):
with gr.Row():
with gr.Column():
cloth_input = gr.Image(label="Input Image", type="numpy")
with gr.Accordion("Clothing Options"):
color_picker = gr.ColorPicker(label="New Color")
pattern_choice = gr.Dropdown(
choices=["Stripes", "Dots", "Floral"],
label="Pattern"
)
style_image = gr.Image(label="Style Reference", type="numpy")
process_clothing_btn = gr.Button("Process Clothing")
with gr.Column():
cloth_output = gr.Image(label="Modified Clothing", type="pil")
def process_clothing(image, color, pattern, style):
options = {}
if color:
options["color"] = color
if pattern:
options["pattern"] = pattern
if style is not None:
options["style_transfer"] = style
return processor.process_image(
image,
mode="clothing",
clothing_options=options
)["modified_clothing"]
process_clothing_btn.click(
fn=process_clothing,
inputs=[cloth_input, color_picker, pattern_choice, style_image],
outputs=[cloth_output]
)
# Examples section
examples_dir = Path("./examples")
examples = [
[str(examples_dir / f"example_{i}.jpg")]
for i in range(1, 4)
if (examples_dir / f"example_{i}.jpg").exists()
]
if examples:
gr.Examples(
examples=examples,
inputs=[input_image],
outputs=[output_image],
fn=lambda img: processor.process_image(img)["removed_background"],
cache_examples=True
)
return app
if __name__ == "__main__":
app = create_ui()
app.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
debug=True
) |