AkashKumarave commited on
Commit
901bf30
·
verified ·
1 Parent(s): 9ca6126

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -157
app.py CHANGED
@@ -1,187 +1,115 @@
1
-
2
  import gradio as gr
3
  import torch
4
- from transformers import pipeline
5
- from PIL import Image
6
  import numpy as np
7
- import io
8
- import base64
9
- import sys
10
-
11
- # Configure error logging
12
- import logging
13
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14
- logger = logging.getLogger("BackgroundRemover")
15
 
16
- # Initialize the segmentation model for background removal
17
- segmenter = None
 
 
 
18
 
19
- def init_model():
20
- global segmenter
21
- try:
22
- # Using RMBG-1.4 which is specifically designed for background removal
23
- logger.info("Loading RMBG-1.4 model...")
24
- segmenter = pipeline(
25
- "image-segmentation",
26
- model="briaai/RMBG-1.4",
27
- device=0 if torch.cuda.is_available() else -1,
28
- trust_remote_code=True # Allow custom code execution for the model
29
- )
30
- logger.info("Successfully loaded RMBG-1.4 model")
31
- except Exception as e:
32
- logger.error(f"Error loading RMBG model: {e}")
33
- # Fallback to a more standard segmentation model that doesn't require custom code
34
- try:
35
- logger.info("Attempting to load fallback model...")
36
- segmenter = pipeline(
37
- "image-segmentation",
38
- model="facebook/detr-resnet-50-panoptic",
39
- device=0 if torch.cuda.is_available() else -1
40
- )
41
- logger.info("Using fallback model: facebook/detr-resnet-50-panoptic")
42
- except Exception as e2:
43
- logger.error(f"Error loading fallback model: {e2}")
44
- segmenter = None
45
 
46
- def remove_background(input_image):
47
- """Remove background from an image using segmentation."""
48
- global segmenter
 
 
 
 
 
 
 
 
49
 
50
- # Initialize model if not already done
51
- if segmenter is None:
52
- init_model()
53
 
54
- if segmenter is None:
55
- logger.error("No segmentation model available")
56
- return input_image
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  if input_image is None:
59
- logger.error("No input image provided")
60
  return None
61
 
62
  try:
63
- # Convert input image to numpy array if it's not already
64
- if isinstance(input_image, str):
65
- input_img = Image.open(input_image)
66
- input_array = np.array(input_img)
67
- else:
68
- input_array = np.array(input_image)
69
 
70
- # Check if image is valid
71
- if input_array.size == 0:
72
- logger.error("Empty input image")
73
- return input_image
74
-
75
- logger.info(f"Processing image of shape {input_array.shape}")
76
 
77
- # Run image segmentation
78
- result = segmenter(input_image)
79
- logger.info(f"Segmentation result type: {type(result)}")
80
 
81
- # For the RMBG model, we directly get the mask
82
- if isinstance(result, dict) and 'mask' in result:
83
- # Direct mask from RMBG model
84
- mask_array = np.array(result['mask'])
85
- mask_array = mask_array / 255.0 # Normalize if needed
86
- logger.info("Using RMBG mask")
87
- elif isinstance(result, list) and len(result) > 0:
88
- # Standard segmentation model output - try to create a foreground mask
89
- foreground_classes = ['person', 'animal', 'vehicle', 'object']
90
-
91
- # Initialize an empty mask
92
- if len(input_array.shape) == 3:
93
- mask_array = np.zeros((input_array.shape[0], input_array.shape[1]), dtype=np.float32)
94
- else:
95
- logger.error("Invalid input image shape")
96
- return input_image
97
-
98
- # Combine all foreground segments
99
- for segment in result:
100
- label = segment.get('label', '').lower()
101
- # If it's a foreground class or we don't have specific classes to check
102
- if any(fg_class in label for fg_class in foreground_classes) or not foreground_classes:
103
- segment_mask = segment.get('mask')
104
- if segment_mask is not None:
105
- # Resize mask if needed
106
- segment_mask = np.array(segment_mask)
107
- if segment_mask.shape[:2] != mask_array.shape:
108
- segment_mask = np.array(Image.fromarray(segment_mask).resize(
109
- (mask_array.shape[1], mask_array.shape[0])))
110
- # Add this segment to the foreground mask
111
- mask_array = np.maximum(mask_array, segment_mask)
112
- logger.info("Created composite mask from segmentation model")
113
- else:
114
- logger.error("Unexpected model output format")
115
- return input_image
116
 
117
- # Create an RGBA image
118
- if len(input_array.shape) == 3 and input_array.shape[2] >= 3:
119
- rgba = np.zeros((input_array.shape[0], input_array.shape[1], 4), dtype=np.uint8)
120
- rgba[:,:,:3] = input_array[:,:,:3] # Copy RGB channels
121
-
122
- # Apply mask to alpha channel
123
- if 'briaai/RMBG' in str(segmenter.model):
124
- # For RMBG model, use the mask directly
125
- rgba[:,:,3] = (mask_array * 255).astype(np.uint8)
126
- else:
127
- # For other models, we may need to invert the mask
128
- rgba[:,:,3] = (mask_array * 255).astype(np.uint8)
129
-
130
- logger.info("Successfully created RGBA image")
131
- return Image.fromarray(rgba)
132
- else:
133
- logger.error(f"Unexpected image format: shape {input_array.shape}")
134
- return input_image
135
 
 
 
 
 
136
  except Exception as e:
137
- logger.error(f"Error in background removal: {e}")
138
- # Return original image if processing failed
139
- return input_image
140
-
141
- # Initialize model on startup to avoid lazy loading during request
142
- init_model()
143
 
144
- # Create a simpler Gradio interface with minimal components to avoid internal errors
145
- with gr.Blocks(theme=gr.themes.Default(), css="footer {visibility: hidden}") as demo:
146
- gr.Markdown(
147
- """
148
- # Space BG Erase Studio
149
-
150
- Upload an image and the AI will remove its background, giving you a transparent PNG.
151
-
152
- Powered by Hugging Face Transformers.
153
- """
154
- )
155
 
156
  with gr.Row():
157
- with gr.Column():
158
- input_image = gr.Image(type="pil", label="Upload Image")
159
- submit_btn = gr.Button("Remove Background", variant="primary")
160
-
161
- with gr.Column():
162
- output_image = gr.Image(type="pil", label="Result (Transparent Background)")
163
 
164
- # Simple click handler to avoid complex API handling
165
  submit_btn.click(
166
  fn=remove_background,
167
  inputs=input_image,
168
  outputs=output_image
169
  )
170
-
171
- gr.Markdown(
172
- """
173
- ## How it works
174
-
175
- This app uses a machine learning model specifically designed for background removal.
176
- The result is a transparent PNG with only your subject visible.
177
-
178
- ## Tips for best results
179
-
180
- - Use images where the subject is clearly visible
181
- - Good lighting helps the AI separate the subject from background
182
- - The process may take a few seconds depending on image size
183
- """
184
- )
185
 
 
186
  if __name__ == "__main__":
187
- demo.launch()
 
 
1
  import gradio as gr
2
  import torch
3
+ import torch.nn.functional as F
 
4
  import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import os
 
 
 
 
 
8
 
9
+ # Ensure models directory is accessible
10
+ try:
11
+ from models.isnet import ISNetGT
12
+ except ImportError:
13
+ raise ImportError("Could not import ISNetGT from models.isnet. Ensure models/isnet.py is in the Space.")
14
 
15
+ # Define model loading function
16
+ def load_model(model_path="isnet-general-use.pth"):
17
+ if not os.path.exists(model_path):
18
+ raise FileNotFoundError(f"Model file {model_path} not found. Upload it to the Space root directory.")
19
+
20
+ model = ISNetGT()
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ model.load_state_dict(torch.load(model_path, map_location=device))
23
+ model.to(device).eval()
24
+ return model, device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Image preprocessing function
27
+ def preprocess_image(image, target_size=(1024, 1024)):
28
+ # Convert PIL Image to numpy array
29
+ image = np.array(image)
30
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
31
+
32
+ # Resize image while preserving aspect ratio
33
+ h, w = image.shape[:2]
34
+ scale = min(target_size[0] / h, target_size[1] / w)
35
+ new_h, new_w = int(h * scale), int(w * scale)
36
+ image_resized = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)
37
 
38
+ # Pad to target size
39
+ padded_image = np.zeros((target_size[0], target_size[1], 3), dtype=np.uint8)
40
+ padded_image[:new_h, :new_w] = image_resized
41
 
42
+ # Normalize and convert to tensor
43
+ image_tensor = torch.from_numpy(padded_image).permute(2, 0, 1).float() / 255.0
44
+ image_tensor = image_tensor.unsqueeze(0) # Add batch dimension
45
 
46
+ return image_tensor, (new_h, new_w), (h, w)
47
+
48
+ # Inference function
49
+ def inference(model, image_tensor, device):
50
+ image_tensor = image_tensor.to(device)
51
+ with torch.no_grad():
52
+ output = model(image_tensor)[0] # Get segmentation output
53
+ output = F.interpolate(output, size=image_tensor.shape[2:], mode='bilinear', align_corners=True)
54
+ output = torch.sigmoid(output).cpu().numpy()[0, 0] # Convert to probability map
55
+ return output
56
+
57
+ # Post-processing function
58
+ def postprocess_output(output, original_size, resized_size):
59
+ # Resize mask to resized image size, then to original size
60
+ mask = cv2.resize(output, resized_size[::-1], interpolation=cv2.INTER_LANCZOS4)
61
+ mask = cv2.resize(mask, original_size[::-1], interpolation=cv2.INTER_LANCZOS4)
62
+ mask = (mask > 0.5).astype(np.uint8) * 255 # Binarize mask
63
+ return mask
64
+
65
+ # Background removal function
66
+ def remove_background(input_image):
67
  if input_image is None:
 
68
  return None
69
 
70
  try:
71
+ # Load model
72
+ model, device = load_model()
 
 
 
 
73
 
74
+ # Preprocess image
75
+ image_tensor, resized_size, original_size = preprocess_image(input_image)
 
 
 
 
76
 
77
+ # Run inference
78
+ mask = inference(model, image_tensor, device)
 
79
 
80
+ # Post-process mask
81
+ mask = postprocess_output(mask, original_size, resized_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Apply mask to create transparent image
84
+ input_array = np.array(input_image)
85
+ alpha = mask
86
+ rgba = np.zeros((input_array.shape[0], input_array.shape[1], 4), dtype=np.uint8)
87
+ rgba[..., :3] = input_array
88
+ rgba[..., 3] = alpha
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
+ # Convert to PIL Image
91
+ output_image = Image.fromarray(rgba, mode='RGBA')
92
+ return output_image
93
+
94
  except Exception as e:
95
+ return f"Error: {str(e)}"
 
 
 
 
 
96
 
97
+ # Set up Gradio Blocks interface
98
+ with gr.Blocks(title="DIS Background Remover") as demo:
99
+ gr.Markdown("## DIS Background Remover")
100
+ gr.Markdown("Upload an image to remove its background using the IS-Net model from xuebinqin/DIS.")
 
 
 
 
 
 
 
101
 
102
  with gr.Row():
103
+ input_image = gr.Image(type="pil", label="Upload Image")
104
+ output_image = gr.Image(type="pil", label="Image with Background Removed")
 
 
 
 
105
 
106
+ submit_btn = gr.Button("Remove Background")
107
  submit_btn.click(
108
  fn=remove_background,
109
  inputs=input_image,
110
  outputs=output_image
111
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ # Launch the app
114
  if __name__ == "__main__":
115
+ demo.launch()