ihabooe commited on
Commit
758eb10
·
verified ·
1 Parent(s): f0b227c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -19
app.py CHANGED
@@ -9,6 +9,7 @@ from PIL import Image
9
  import tempfile
10
  import os
11
  import time
 
12
 
13
  # Load the pre-trained model
14
  print("Loading model...")
@@ -18,6 +19,10 @@ net.to(device)
18
  net.eval()
19
  print(f"Model loaded on {device}")
20
 
 
 
 
 
21
  # Resize the input image for model compatibility
22
  def resize_image(image):
23
  image = image.convert('RGB')
@@ -28,7 +33,7 @@ def resize_image(image):
28
  # Background removal process
29
  def process(image, progress=gr.Progress()):
30
  if image is None:
31
- return None, None
32
 
33
  progress(0, desc="Starting processing...")
34
 
@@ -67,17 +72,26 @@ def process(image, progress=gr.Progress()):
67
  new_im.putalpha(pil_mask)
68
 
69
  progress(0.8, desc="Preparing download...")
70
- # Save the processed image to a temporary file
71
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
72
- new_im.save(temp_file.name, format='PNG')
73
- temp_file.close()
 
 
 
74
 
75
  # Convert to numpy array for display
76
  output_array = np.array(new_im.convert("RGBA"))
77
 
78
  progress(1.0, desc="Done!")
79
- # Return both the image for display and the path for download
80
- return output_array, temp_file.name
 
 
 
 
 
 
81
 
82
  # Gradio interface setup
83
  title = "Background Removal"
@@ -86,17 +100,40 @@ description = r"""Background removal model developed by <a href='https://BRIA.AI
86
  examples = [['./input.jpg']]
87
 
88
  # Create the Gradio interface
89
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  gr.Markdown(f"# {title}")
91
  gr.Markdown(description)
92
 
 
 
 
93
  with gr.Row():
94
  with gr.Column(scale=1):
95
  input_image = gr.Image(type="numpy", label="Upload Image")
96
 
97
  with gr.Column(scale=1):
98
  output_image = gr.Image(type="numpy", label="Result")
99
- download_btn = gr.File(label="Download Processed Image", elem_id="download_button")
 
 
 
100
 
101
  # Set up example images
102
  gr.Examples(examples, inputs=input_image)
@@ -105,18 +142,16 @@ with gr.Blocks() as demo:
105
  input_image.change(
106
  fn=process,
107
  inputs=input_image,
108
- outputs=[output_image, download_btn],
109
  show_progress="full"
110
  )
111
-
112
- # Style to ensure download button is visible
113
- gr.HTML("""
114
- <style>
115
- #download_button {
116
- margin-top: 10px;
117
- }
118
- </style>
119
- """)
120
 
121
  if __name__ == "__main__":
122
  demo.launch(share=False)
 
9
  import tempfile
10
  import os
11
  import time
12
+ import uuid
13
 
14
  # Load the pre-trained model
15
  print("Loading model...")
 
19
  net.eval()
20
  print(f"Model loaded on {device}")
21
 
22
+ # Create output directory if it doesn't exist
23
+ OUTPUT_DIR = "output_images"
24
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
25
+
26
  # Resize the input image for model compatibility
27
  def resize_image(image):
28
  image = image.convert('RGB')
 
33
  # Background removal process
34
  def process(image, progress=gr.Progress()):
35
  if image is None:
36
+ return None, None, gr.update(visible=False)
37
 
38
  progress(0, desc="Starting processing...")
39
 
 
72
  new_im.putalpha(pil_mask)
73
 
74
  progress(0.8, desc="Preparing download...")
75
+ # Generate a unique filename
76
+ unique_id = str(uuid.uuid4())[:8]
77
+ filename = f"background_removed_{unique_id}.png"
78
+ filepath = os.path.join(OUTPUT_DIR, filename)
79
+
80
+ # Save the processed image
81
+ new_im.save(filepath, format='PNG')
82
 
83
  # Convert to numpy array for display
84
  output_array = np.array(new_im.convert("RGBA"))
85
 
86
  progress(1.0, desc="Done!")
87
+ # Return image for display and path for download
88
+ return output_array, filepath, gr.update(visible=True)
89
+
90
+ # Function to handle the download button click
91
+ def download_image(filepath):
92
+ if filepath and os.path.exists(filepath):
93
+ return filepath
94
+ return None
95
 
96
  # Gradio interface setup
97
  title = "Background Removal"
 
100
  examples = [['./input.jpg']]
101
 
102
  # Create the Gradio interface
103
+ with gr.Blocks(css="""
104
+ .download-btn {
105
+ background-color: #4CAF50;
106
+ border: none;
107
+ color: white;
108
+ padding: 10px 24px;
109
+ text-align: center;
110
+ text-decoration: none;
111
+ display: inline-block;
112
+ font-size: 16px;
113
+ margin: 4px 2px;
114
+ cursor: pointer;
115
+ border-radius: 4px;
116
+ }
117
+ .download-btn:hover {
118
+ background-color: #45a049;
119
+ }
120
+ """) as demo:
121
  gr.Markdown(f"# {title}")
122
  gr.Markdown(description)
123
 
124
+ # Store the processed image path
125
+ image_path = gr.State(None)
126
+
127
  with gr.Row():
128
  with gr.Column(scale=1):
129
  input_image = gr.Image(type="numpy", label="Upload Image")
130
 
131
  with gr.Column(scale=1):
132
  output_image = gr.Image(type="numpy", label="Result")
133
+ download_btn = gr.Button("Download Image", elem_id="download_button",
134
+ variant="primary", visible=False,
135
+ elem_classes="download-btn")
136
+ download_output = gr.File(visible=False)
137
 
138
  # Set up example images
139
  gr.Examples(examples, inputs=input_image)
 
142
  input_image.change(
143
  fn=process,
144
  inputs=input_image,
145
+ outputs=[output_image, image_path, download_btn],
146
  show_progress="full"
147
  )
148
+
149
+ # Handle download button click
150
+ download_btn.click(
151
+ fn=download_image,
152
+ inputs=[image_path],
153
+ outputs=[download_output]
154
+ )
 
 
155
 
156
  if __name__ == "__main__":
157
  demo.launch(share=False)