K00B404 commited on
Commit
57d19ed
·
verified ·
1 Parent(s): c2fe20e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torchvision.transforms.functional import normalize
5
+ import gradio as gr
6
+ from PIL import Image
7
+ from typing import Tuple, Dict, List
8
+ import cv2
9
+ from pathlib import Path
10
+
11
+ from advanced_briarmb import MultiTargetBriaRMBG, ClothingType, GarmentFeatures
12
+
13
+ class ImageProcessor:
14
+ def __init__(self, model_path: str = "briaai/RMBG-1.4"):
15
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ self.net = MultiTargetBriaRMBG.from_pretrained(model_path)
17
+ self.net.to(self.device)
18
+ self.model_input_size = (1024, 1024)
19
+
20
+ def preprocess_image(self, image: np.ndarray) -> torch.Tensor:
21
+ """Prepare image for model input"""
22
+ # Convert numpy array to PIL
23
+ if isinstance(image, np.ndarray):
24
+ image = Image.fromarray(image)
25
+
26
+ # Convert to RGB and resize
27
+ image = image.convert('RGB')
28
+ image = image.resize(self.model_input_size, Image.LANCZOS)
29
+
30
+ # Convert to tensor
31
+ im_tensor = torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1)
32
+ im_tensor = torch.unsqueeze(im_tensor, 0)
33
+ im_tensor = torch.divide(im_tensor, 255.0)
34
+ im_tensor = normalize(im_tensor, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
35
+
36
+ return im_tensor.to(self.device)
37
+
38
+ def postprocess_mask(self, mask: torch.Tensor, target_size: Tuple[int, int]) -> Image.Image:
39
+ """Convert model output mask to PIL Image"""
40
+ # Resize mask to original image size
41
+ mask = F.interpolate(mask, size=target_size, mode='bilinear')
42
+ mask = torch.squeeze(mask, 0)
43
+
44
+ # Normalize mask values
45
+ mask = (mask - mask.min()) / (mask.max() - mask.min())
46
+
47
+ # Convert to PIL Image
48
+ mask_np = (mask * 255).cpu().data.numpy().astype(np.uint8)
49
+ return Image.fromarray(mask_np[0])
50
+
51
+ def process_image(self,
52
+ image: np.ndarray,
53
+ mode: str = "background_removal",
54
+ clothing_options: Dict = None) -> Dict[str, Image.Image]:
55
+ """Main processing function"""
56
+ # Get original size
57
+ orig_image = Image.fromarray(image) if isinstance(image, np.ndarray) else image
58
+ orig_size = orig_image.size
59
+
60
+ # Preprocess
61
+ input_tensor = self.preprocess_image(image)
62
+
63
+ # Model inference
64
+ results = self.net(input_tensor, mode=mode)
65
+
66
+ # Process different outputs based on mode
67
+ outputs = {}
68
+
69
+ if mode == "background_removal" or mode == "all":
70
+ # Get foreground mask
71
+ fg_mask = self.postprocess_mask(results["foreground"], orig_size)
72
+
73
+ # Create transparent background image
74
+ transparent = Image.new("RGBA", orig_size, (0, 0, 0, 0))
75
+ transparent.paste(orig_image, mask=fg_mask)
76
+ outputs["removed_background"] = transparent
77
+
78
+ # Extract background if requested
79
+ if mode == "all":
80
+ bg_mask = self.postprocess_mask(results["background"], orig_size)
81
+ background = Image.new("RGBA", orig_size, (0, 0, 0, 0))
82
+ background.paste(orig_image, mask=bg_mask)
83
+ outputs["background"] = background
84
+
85
+ if mode == "clothing" or mode == "all":
86
+ clothing_mask = self.postprocess_mask(results["clothing"], orig_size)
87
+
88
+ if clothing_options:
89
+ # Apply clothing modifications
90
+ modified = self.apply_clothing_modifications(
91
+ orig_image,
92
+ clothing_mask,
93
+ clothing_options
94
+ )
95
+ outputs["modified_clothing"] = modified
96
+
97
+ # Extract original clothing
98
+ clothing = Image.new("RGBA", orig_size, (0, 0, 0, 0))
99
+ clothing.paste(orig_image, mask=clothing_mask)
100
+ outputs["clothing"] = clothing
101
+
102
+ return outputs
103
+
104
+ def apply_clothing_modifications(self,
105
+ image: Image.Image,
106
+ mask: Image.Image,
107
+ options: Dict) -> Image.Image:
108
+ """Apply clothing modifications based on options"""
109
+ if "color" in options:
110
+ image = self.change_clothing_color(image, mask, options["color"])
111
+
112
+ if "pattern" in options:
113
+ image = self.apply_pattern(image, mask, options["pattern"])
114
+
115
+ if "style_transfer" in options:
116
+ image = self.transfer_clothing_style(image, mask, options["style_transfer"])
117
+
118
+ return image
119
+
120
+ def create_ui() -> gr.Blocks:
121
+ """Create the Gradio UI"""
122
+ processor = ImageProcessor()
123
+
124
+ with gr.Blocks() as app:
125
+ gr.Markdown("# Advanced Background and Clothing Removal")
126
+
127
+ with gr.Tab("Background Removal"):
128
+ with gr.Row():
129
+ with gr.Column():
130
+ input_image = gr.Image(label="Input Image", type="numpy")
131
+ remove_bg_btn = gr.Button("Remove Background")
132
+
133
+ with gr.Column():
134
+ output_image = gr.Image(label="Result", type="pil")
135
+
136
+ remove_bg_btn.click(
137
+ fn=lambda img: processor.process_image(img)["removed_background"],
138
+ inputs=[input_image],
139
+ outputs=[output_image]
140
+ )
141
+
142
+ with gr.Tab("Clothing Manipulation"):
143
+ with gr.Row():
144
+ with gr.Column():
145
+ cloth_input = gr.Image(label="Input Image", type="numpy")
146
+
147
+ with gr.Accordion("Clothing Options"):
148
+ color_picker = gr.ColorPicker(label="New Color")
149
+ pattern_choice = gr.Dropdown(
150
+ choices=["Stripes", "Dots", "Floral"],
151
+ label="Pattern"
152
+ )
153
+ style_image = gr.Image(label="Style Reference", type="numpy")
154
+
155
+ process_clothing_btn = gr.Button("Process Clothing")
156
+
157
+ with gr.Column():
158
+ cloth_output = gr.Image(label="Modified Clothing", type="pil")
159
+
160
+ def process_clothing(image, color, pattern, style):
161
+ options = {}
162
+ if color:
163
+ options["color"] = color
164
+ if pattern:
165
+ options["pattern"] = pattern
166
+ if style is not None:
167
+ options["style_transfer"] = style
168
+
169
+ return processor.process_image(
170
+ image,
171
+ mode="clothing",
172
+ clothing_options=options
173
+ )["modified_clothing"]
174
+
175
+ process_clothing_btn.click(
176
+ fn=process_clothing,
177
+ inputs=[cloth_input, color_picker, pattern_choice, style_image],
178
+ outputs=[cloth_output]
179
+ )
180
+
181
+ # Examples section
182
+ examples_dir = Path("./examples")
183
+ examples = [
184
+ [str(examples_dir / f"example_{i}.jpg")]
185
+ for i in range(1, 4)
186
+ if (examples_dir / f"example_{i}.jpg").exists()
187
+ ]
188
+
189
+ if examples:
190
+ gr.Examples(
191
+ examples=examples,
192
+ inputs=[input_image],
193
+ outputs=[output_image],
194
+ fn=lambda img: processor.process_image(img)["removed_background"],
195
+ cache_examples=True
196
+ )
197
+
198
+ return app
199
+
200
+ if __name__ == "__main__":
201
+ app = create_ui()
202
+ app.launch(
203
+ share=False,
204
+ server_name="0.0.0.0",
205
+ server_port=7860,
206
+ debug=True
207
+ )