Phips commited on
Commit
1c78228
·
verified ·
1 Parent(s): 52ff101

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -23
app.py CHANGED
@@ -114,17 +114,11 @@ MODELS = {
114
  }
115
 
116
  # --- Efficient Model Loading and Caching ---
117
- # Global dictionary to hold models that are already loaded in GPU memory.
118
  LOADED_MODELS_CACHE = {}
119
 
120
  def get_upscaler(model_name: str):
121
- """
122
- Loads a model if it's not already in the cache, and moves it to the GPU.
123
- Returns the cached model.
124
- """
125
  if model_name not in LOADED_MODELS_CACHE:
126
  print(f"Loading model: {model_name}")
127
- # Load the model and immediately move it to the GPU.
128
  LOADED_MODELS_CACHE[model_name] = UpscaleWithModel.from_pretrained(
129
  MODELS[model_name]
130
  ).to("cuda")
@@ -133,40 +127,29 @@ def get_upscaler(model_name: str):
133
  # --- Core Upscaling Function ---
134
  @spaces.GPU
135
  def upscale_image(image, model_selection: str, progress=gr.Progress(track_tqdm=True)):
136
- """
137
- Main function to perform the upscaling. It includes error handling.
138
- """
139
  if image is None:
140
  raise gr.Error("No image uploaded. Please upload an image to upscale.")
141
 
142
  try:
143
  progress(0, desc="Loading image and model...")
144
  original = load_image(image)
145
-
146
- # Get the pre-loaded or newly loaded upscaler from the GPU cache.
147
  upscaler = get_upscaler(model_selection)
148
 
149
  progress(0.5, desc="Upscaling image... (this may take a moment)")
150
- # Perform the upscaling on the GPU.
151
  upscaled_pil_image = upscaler(original, tiling=True, tile_width=1024, tile_height=1024)
152
 
153
  progress(0.9, desc="Saving result...")
154
- # Save the result to a temporary PNG file for lossless download.
155
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
156
  upscaled_pil_image.save(temp_file.name, "PNG")
157
  output_filepath = temp_file.name
158
 
159
- # Return both the images for the slider and the filepath for the download button.
160
  return (original, upscaled_pil_image), output_filepath
161
 
162
  except Exception as e:
163
- # Print the full error to the console for debugging.
164
  print(f"An error occurred: {traceback.format_exc()}")
165
- # Raise a user-friendly error in the Gradio UI.
166
  raise gr.Error(f"An error occurred during processing: {e}")
167
 
168
  def clear_outputs():
169
- """Function to clear all output components."""
170
  return None, None
171
 
172
  # --- Gradio Interface Definition ---
@@ -178,8 +161,6 @@ Tiling is fixed at 1024x1024 for optimal performance. An <a href="https://huggin
178
  </div>
179
  """
180
 
181
- # Best practice for public Spaces: configure automatic cache cleaning.
182
- # This will run every hour and delete any temporary files older than one hour.
183
  with gr.Blocks(delete_cache=(3600, 3600)) as demo:
184
  gr.HTML(title)
185
  with gr.Row():
@@ -197,15 +178,23 @@ with gr.Blocks(delete_cache=(3600, 3600)) as demo:
197
  interactive=False,
198
  label="Compare Original vs. Upscaled",
199
  show_label=True,
 
200
  )
201
- download_output = gr.File(label="Download Upscaled Image (Lossless PNG)")
 
 
 
 
 
 
 
202
 
203
  # --- Event Handling ---
204
  run_button.click(
205
  fn=clear_outputs,
206
  inputs=None,
207
  outputs=[result_slider, download_output],
208
- queue=False # Clearing should be instant, no need to queue.
209
  ).then(
210
  fn=upscale_image,
211
  inputs=[input_image, model_selection],
@@ -213,7 +202,6 @@ with gr.Blocks(delete_cache=(3600, 3600)) as demo:
213
  )
214
 
215
  # --- Pre-load the default model for a faster first-time user experience ---
216
- # This will happen once when the Space starts up.
217
  try:
218
  print("Pre-loading default model...")
219
  get_upscaler("4xNomosWebPhoto_RealPLKSR")
@@ -223,4 +211,6 @@ except Exception as e:
223
 
224
  # Queueing is essential for public-facing apps to handle concurrent users.
225
  demo.queue()
226
- demo.launch(share=False)
 
 
 
114
  }
115
 
116
  # --- Efficient Model Loading and Caching ---
 
117
  LOADED_MODELS_CACHE = {}
118
 
119
  def get_upscaler(model_name: str):
 
 
 
 
120
  if model_name not in LOADED_MODELS_CACHE:
121
  print(f"Loading model: {model_name}")
 
122
  LOADED_MODELS_CACHE[model_name] = UpscaleWithModel.from_pretrained(
123
  MODELS[model_name]
124
  ).to("cuda")
 
127
  # --- Core Upscaling Function ---
128
  @spaces.GPU
129
  def upscale_image(image, model_selection: str, progress=gr.Progress(track_tqdm=True)):
 
 
 
130
  if image is None:
131
  raise gr.Error("No image uploaded. Please upload an image to upscale.")
132
 
133
  try:
134
  progress(0, desc="Loading image and model...")
135
  original = load_image(image)
 
 
136
  upscaler = get_upscaler(model_selection)
137
 
138
  progress(0.5, desc="Upscaling image... (this may take a moment)")
 
139
  upscaled_pil_image = upscaler(original, tiling=True, tile_width=1024, tile_height=1024)
140
 
141
  progress(0.9, desc="Saving result...")
 
142
  with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
143
  upscaled_pil_image.save(temp_file.name, "PNG")
144
  output_filepath = temp_file.name
145
 
 
146
  return (original, upscaled_pil_image), output_filepath
147
 
148
  except Exception as e:
 
149
  print(f"An error occurred: {traceback.format_exc()}")
 
150
  raise gr.Error(f"An error occurred during processing: {e}")
151
 
152
  def clear_outputs():
 
153
  return None, None
154
 
155
  # --- Gradio Interface Definition ---
 
161
  </div>
162
  """
163
 
 
 
164
  with gr.Blocks(delete_cache=(3600, 3600)) as demo:
165
  gr.HTML(title)
166
  with gr.Row():
 
178
  interactive=False,
179
  label="Compare Original vs. Upscaled",
180
  show_label=True,
181
+ show_download_button=False
182
  )
183
+
184
+ # --- THIS IS THE NEW ADDITION ---
185
+ # Add a descriptive note to guide the user about the preview vs. download quality.
186
+ gr.Markdown(
187
+ "<center><i>Note: The slider above shows a web-optimized preview. For the full-quality, lossless PNG, please use the download button below.</i></center>"
188
+ )
189
+
190
+ download_output = gr.File(label="Download Full-Quality Upscaled Image (Lossless PNG)")
191
 
192
  # --- Event Handling ---
193
  run_button.click(
194
  fn=clear_outputs,
195
  inputs=None,
196
  outputs=[result_slider, download_output],
197
+ queue=False
198
  ).then(
199
  fn=upscale_image,
200
  inputs=[input_image, model_selection],
 
202
  )
203
 
204
  # --- Pre-load the default model for a faster first-time user experience ---
 
205
  try:
206
  print("Pre-loading default model...")
207
  get_upscaler("4xNomosWebPhoto_RealPLKSR")
 
211
 
212
  # Queueing is essential for public-facing apps to handle concurrent users.
213
  demo.queue()
214
+ demo.launch(share=False)```
215
+
216
+ Now, the interface will have a clear, centered, and italicized note right between the image slider and the download button, perfectly guiding your users on how to get the best possible quality from your application.