File size: 7,968 Bytes
57d19ed
 
 
 
 
 
 
 
 
893f4a7
4d0b07a
57d19ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.transforms.functional import normalize
import gradio as gr
from PIL import Image
from typing import Tuple, Dict, List
import cv2
from pathlib import Path
from briarmbg import MultiTargetBriaRMBG
from briacustom import MultiTargetBriaRMBG, ClothingType, GarmentFeatures

class ImageProcessor:
    def __init__(self, model_path: str = "briaai/RMBG-1.4"):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = MultiTargetBriaRMBG.from_pretrained(model_path)
        self.net.to(self.device)
        self.model_input_size = (1024, 1024)
        
    def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
        """Prepare image for model input"""
        # Convert numpy array to PIL
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)
        
        # Convert to RGB and resize
        image = image.convert('RGB')
        image = image.resize(self.model_input_size, Image.LANCZOS)
        
        # Convert to tensor
        im_tensor = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1)
        im_tensor = torch.unsqueeze(im_tensor, 0)
        im_tensor = torch.divide(im_tensor, 255.0)
        im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
        
        return im_tensor.to(self.device)

    def postprocess_mask(self, mask: torch.Tensor, target_size: Tuple[int, int]) -> Image.Image:
        """Convert model output mask to PIL Image"""
        # Resize mask to original image size
        mask = F.interpolate(mask, size=target_size, mode='bilinear')
        mask = torch.squeeze(mask, 0)
        
        # Normalize mask values
        mask = (mask - mask.min()) / (mask.max() - mask.min())
        
        # Convert to PIL Image
        mask_np = (mask * 255).cpu().data.numpy().astype(np.uint8)
        return Image.fromarray(mask_np[0])

    def process_image(self, 
                     image: np.ndarray,
                     mode: str = "background_removal",
                     clothing_options: Dict = None) -> Dict[str, Image.Image]:
        """Main processing function"""
        # Get original size
        orig_image = Image.fromarray(image) if isinstance(image, np.ndarray) else image
        orig_size = orig_image.size
        
        # Preprocess
        input_tensor = self.preprocess_image(image)
        
        # Model inference
        results = self.net(input_tensor, mode=mode)
        
        # Process different outputs based on mode
        outputs = {}
        
        if mode == "background_removal" or mode == "all":
            # Get foreground mask
            fg_mask = self.postprocess_mask(results["foreground"], orig_size)
            
            # Create transparent background image
            transparent = Image.new("RGBA", orig_size, (0, 0, 0, 0))
            transparent.paste(orig_image, mask=fg_mask)
            outputs["removed_background"] = transparent
            
            # Extract background if requested
            if mode == "all":
                bg_mask = self.postprocess_mask(results["background"], orig_size)
                background = Image.new("RGBA", orig_size, (0, 0, 0, 0))
                background.paste(orig_image, mask=bg_mask)
                outputs["background"] = background
        
        if mode == "clothing" or mode == "all":
            clothing_mask = self.postprocess_mask(results["clothing"], orig_size)
            
            if clothing_options:
                # Apply clothing modifications
                modified = self.apply_clothing_modifications(
                    orig_image,
                    clothing_mask,
                    clothing_options
                )
                outputs["modified_clothing"] = modified
            
            # Extract original clothing
            clothing = Image.new("RGBA", orig_size, (0, 0, 0, 0))
            clothing.paste(orig_image, mask=clothing_mask)
            outputs["clothing"] = clothing
            
        return outputs

    def apply_clothing_modifications(self,
                                   image: Image.Image,
                                   mask: Image.Image,
                                   options: Dict) -> Image.Image:
        """Apply clothing modifications based on options"""
        if "color" in options:
            image = self.change_clothing_color(image, mask, options["color"])
        
        if "pattern" in options:
            image = self.apply_pattern(image, mask, options["pattern"])
            
        if "style_transfer" in options:
            image = self.transfer_clothing_style(image, mask, options["style_transfer"])
            
        return image

def create_ui() -> gr.Blocks:
    """Create the Gradio UI"""
    processor = ImageProcessor()
    
    with gr.Blocks() as app:
        gr.Markdown("# Advanced Background and Clothing Removal")
        
        with gr.Tab("Background Removal"):
            with gr.Row():
                with gr.Column():
                    input_image = gr.Image(label="Input Image", type="numpy")
                    remove_bg_btn = gr.Button("Remove Background")
                
                with gr.Column():
                    output_image = gr.Image(label="Result", type="pil")
            
            remove_bg_btn.click(
                fn=lambda img: processor.process_image(img)["removed_background"],
                inputs=[input_image],
                outputs=[output_image]
            )
        
        with gr.Tab("Clothing Manipulation"):
            with gr.Row():
                with gr.Column():
                    cloth_input = gr.Image(label="Input Image", type="numpy")
                    
                    with gr.Accordion("Clothing Options"):
                        color_picker = gr.ColorPicker(label="New Color")
                        pattern_choice = gr.Dropdown(
                            choices=["Stripes", "Dots", "Floral"],
                            label="Pattern"
                        )
                        style_image = gr.Image(label="Style Reference", type="numpy")
                    
                    process_clothing_btn = gr.Button("Process Clothing")
                
                with gr.Column():
                    cloth_output = gr.Image(label="Modified Clothing", type="pil")
            
            def process_clothing(image, color, pattern, style):
                options = {}
                if color:
                    options["color"] = color
                if pattern:
                    options["pattern"] = pattern
                if style is not None:
                    options["style_transfer"] = style
                    
                return processor.process_image(
                    image,
                    mode="clothing",
                    clothing_options=options
                )["modified_clothing"]
            
            process_clothing_btn.click(
                fn=process_clothing,
                inputs=[cloth_input, color_picker, pattern_choice, style_image],
                outputs=[cloth_output]
            )
        
        # Examples section
        examples_dir = Path("./examples")
        examples = [
            [str(examples_dir / f"example_{i}.jpg")]
            for i in range(1, 4)
            if (examples_dir / f"example_{i}.jpg").exists()
        ]
        
        if examples:
            gr.Examples(
                examples=examples,
                inputs=[input_image],
                outputs=[output_image],
                fn=lambda img: processor.process_image(img)["removed_background"],
                cache_examples=True
            )
    
    return app

if __name__ == "__main__":
    app = create_ui()
    app.launch(
        share=False,
        server_name="0.0.0.0",
        server_port=7860,
        debug=True
    )