AkashKumarave commited on
Commit
9ca6126
·
verified ·
1 Parent(s): 62afffd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -37
app.py CHANGED
@@ -6,41 +6,59 @@ from PIL import Image
6
  import numpy as np
7
  import io
8
  import base64
 
 
 
 
 
 
9
 
10
  # Initialize the segmentation model for background removal
11
- try:
12
- # Using RMBG-1.4 which is specifically designed for background removal
13
- # Added trust_remote_code=True to allow custom code execution
14
- segmenter = pipeline(
15
- "image-segmentation",
16
- model="briaai/RMBG-1.4",
17
- device=0 if torch.cuda.is_available() else -1,
18
- trust_remote_code=True # Add this parameter to allow custom code execution
19
- )
20
- print("Successfully loaded RMBG-1.4 model")
21
- except Exception as e:
22
- print(f"Error loading RMBG model: {e}")
23
- # Fallback to a more standard segmentation model that doesn't require custom code
24
  try:
 
 
25
  segmenter = pipeline(
26
- "image-segmentation",
27
- model="facebook/detr-resnet-50-panoptic",
28
- device=0 if torch.cuda.is_available() else -1
 
29
  )
30
- print("Using fallback model: facebook/detr-resnet-50-panoptic")
31
- except Exception as e2:
32
- print(f"Error loading fallback model: {e2}")
33
- segmenter = None
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def remove_background(input_image):
36
  """Remove background from an image using segmentation."""
37
- if input_image is None:
38
- return None
39
 
 
40
  if segmenter is None:
41
- print("No segmentation model available")
 
 
 
42
  return input_image
43
 
 
 
 
 
44
  try:
45
  # Convert input image to numpy array if it's not already
46
  if isinstance(input_image, str):
@@ -51,32 +69,30 @@ def remove_background(input_image):
51
 
52
  # Check if image is valid
53
  if input_array.size == 0:
54
- print("Empty input image")
55
  return input_image
56
 
57
- print(f"Processing image of shape {input_array.shape}")
58
 
59
  # Run image segmentation
60
  result = segmenter(input_image)
61
- print(f"Segmentation result type: {type(result)}")
62
 
63
  # For the RMBG model, we directly get the mask
64
  if isinstance(result, dict) and 'mask' in result:
65
  # Direct mask from RMBG model
66
  mask_array = np.array(result['mask'])
67
  mask_array = mask_array / 255.0 # Normalize if needed
68
- print("Using RMBG mask")
69
  elif isinstance(result, list) and len(result) > 0:
70
  # Standard segmentation model output - try to create a foreground mask
71
- # This will work differently depending on the model
72
- # For DETR model, we need to identify person/object segments
73
  foreground_classes = ['person', 'animal', 'vehicle', 'object']
74
 
75
  # Initialize an empty mask
76
  if len(input_array.shape) == 3:
77
  mask_array = np.zeros((input_array.shape[0], input_array.shape[1]), dtype=np.float32)
78
  else:
79
- print("Invalid input image shape")
80
  return input_image
81
 
82
  # Combine all foreground segments
@@ -93,9 +109,9 @@ def remove_background(input_image):
93
  (mask_array.shape[1], mask_array.shape[0])))
94
  # Add this segment to the foreground mask
95
  mask_array = np.maximum(mask_array, segment_mask)
96
- print("Created composite mask from segmentation model")
97
  else:
98
- print("Unexpected model output format")
99
  return input_image
100
 
101
  # Create an RGBA image
@@ -111,19 +127,22 @@ def remove_background(input_image):
111
  # For other models, we may need to invert the mask
112
  rgba[:,:,3] = (mask_array * 255).astype(np.uint8)
113
 
114
- print("Successfully created RGBA image")
115
  return Image.fromarray(rgba)
116
  else:
117
- print(f"Unexpected image format: shape {input_array.shape}")
118
  return input_image
119
 
120
  except Exception as e:
121
- print(f"Error in background removal: {e}")
122
  # Return original image if processing failed
123
  return input_image
124
 
125
- # Create Gradio interface
126
- with gr.Blocks(css="footer {visibility: hidden}") as demo:
 
 
 
127
  gr.Markdown(
128
  """
129
  # Space BG Erase Studio
@@ -142,6 +161,7 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
142
  with gr.Column():
143
  output_image = gr.Image(type="pil", label="Result (Transparent Background)")
144
 
 
145
  submit_btn.click(
146
  fn=remove_background,
147
  inputs=input_image,
 
6
  import numpy as np
7
  import io
8
  import base64
9
+ import sys
10
+
11
+ # Configure error logging
12
+ import logging
13
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
14
+ logger = logging.getLogger("BackgroundRemover")
15
 
16
  # Initialize the segmentation model for background removal
17
+ segmenter = None
18
+
19
+ def init_model():
20
+ global segmenter
 
 
 
 
 
 
 
 
 
21
  try:
22
+ # Using RMBG-1.4 which is specifically designed for background removal
23
+ logger.info("Loading RMBG-1.4 model...")
24
  segmenter = pipeline(
25
+ "image-segmentation",
26
+ model="briaai/RMBG-1.4",
27
+ device=0 if torch.cuda.is_available() else -1,
28
+ trust_remote_code=True # Allow custom code execution for the model
29
  )
30
+ logger.info("Successfully loaded RMBG-1.4 model")
31
+ except Exception as e:
32
+ logger.error(f"Error loading RMBG model: {e}")
33
+ # Fallback to a more standard segmentation model that doesn't require custom code
34
+ try:
35
+ logger.info("Attempting to load fallback model...")
36
+ segmenter = pipeline(
37
+ "image-segmentation",
38
+ model="facebook/detr-resnet-50-panoptic",
39
+ device=0 if torch.cuda.is_available() else -1
40
+ )
41
+ logger.info("Using fallback model: facebook/detr-resnet-50-panoptic")
42
+ except Exception as e2:
43
+ logger.error(f"Error loading fallback model: {e2}")
44
+ segmenter = None
45
 
46
  def remove_background(input_image):
47
  """Remove background from an image using segmentation."""
48
+ global segmenter
 
49
 
50
+ # Initialize model if not already done
51
  if segmenter is None:
52
+ init_model()
53
+
54
+ if segmenter is None:
55
+ logger.error("No segmentation model available")
56
  return input_image
57
 
58
+ if input_image is None:
59
+ logger.error("No input image provided")
60
+ return None
61
+
62
  try:
63
  # Convert input image to numpy array if it's not already
64
  if isinstance(input_image, str):
 
69
 
70
  # Check if image is valid
71
  if input_array.size == 0:
72
+ logger.error("Empty input image")
73
  return input_image
74
 
75
+ logger.info(f"Processing image of shape {input_array.shape}")
76
 
77
  # Run image segmentation
78
  result = segmenter(input_image)
79
+ logger.info(f"Segmentation result type: {type(result)}")
80
 
81
  # For the RMBG model, we directly get the mask
82
  if isinstance(result, dict) and 'mask' in result:
83
  # Direct mask from RMBG model
84
  mask_array = np.array(result['mask'])
85
  mask_array = mask_array / 255.0 # Normalize if needed
86
+ logger.info("Using RMBG mask")
87
  elif isinstance(result, list) and len(result) > 0:
88
  # Standard segmentation model output - try to create a foreground mask
 
 
89
  foreground_classes = ['person', 'animal', 'vehicle', 'object']
90
 
91
  # Initialize an empty mask
92
  if len(input_array.shape) == 3:
93
  mask_array = np.zeros((input_array.shape[0], input_array.shape[1]), dtype=np.float32)
94
  else:
95
+ logger.error("Invalid input image shape")
96
  return input_image
97
 
98
  # Combine all foreground segments
 
109
  (mask_array.shape[1], mask_array.shape[0])))
110
  # Add this segment to the foreground mask
111
  mask_array = np.maximum(mask_array, segment_mask)
112
+ logger.info("Created composite mask from segmentation model")
113
  else:
114
+ logger.error("Unexpected model output format")
115
  return input_image
116
 
117
  # Create an RGBA image
 
127
  # For other models, we may need to invert the mask
128
  rgba[:,:,3] = (mask_array * 255).astype(np.uint8)
129
 
130
+ logger.info("Successfully created RGBA image")
131
  return Image.fromarray(rgba)
132
  else:
133
+ logger.error(f"Unexpected image format: shape {input_array.shape}")
134
  return input_image
135
 
136
  except Exception as e:
137
+ logger.error(f"Error in background removal: {e}")
138
  # Return original image if processing failed
139
  return input_image
140
 
141
+ # Initialize model on startup to avoid lazy loading during request
142
+ init_model()
143
+
144
+ # Create a simpler Gradio interface with minimal components to avoid internal errors
145
+ with gr.Blocks(theme=gr.themes.Default(), css="footer {visibility: hidden}") as demo:
146
  gr.Markdown(
147
  """
148
  # Space BG Erase Studio
 
161
  with gr.Column():
162
  output_image = gr.Image(type="pil", label="Result (Transparent Background)")
163
 
164
+ # Simple click handler to avoid complex API handling
165
  submit_btn.click(
166
  fn=remove_background,
167
  inputs=input_image,