Spaces:
Running
on
L40S
Running
on
L40S
Commit
·
575f433
1
Parent(s):
e70b261
remove clear gpu memory
Browse files
app.py
CHANGED
|
@@ -22,31 +22,6 @@ DEFAULT_MAX_SEQUENCE_LENGTH = 512
|
|
| 22 |
GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
|
| 23 |
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
|
| 24 |
|
| 25 |
-
def clear_gpu_memory(*args):
|
| 26 |
-
allocated_before = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
|
| 27 |
-
reserved_before = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
|
| 28 |
-
print(f"Before clearing: Allocated={allocated_before:.2f} GB, Reserved={reserved_before:.2f} GB")
|
| 29 |
-
|
| 30 |
-
deleted_types = []
|
| 31 |
-
for arg in args:
|
| 32 |
-
if arg is not None:
|
| 33 |
-
deleted_types.append(str(type(arg)))
|
| 34 |
-
del arg
|
| 35 |
-
|
| 36 |
-
if deleted_types:
|
| 37 |
-
print(f"Deleted objects of types: {', '.join(deleted_types)}")
|
| 38 |
-
else:
|
| 39 |
-
print("No objects passed to clear_gpu_memory.")
|
| 40 |
-
|
| 41 |
-
gc.collect()
|
| 42 |
-
if DEVICE == "cuda":
|
| 43 |
-
torch.cuda.empty_cache()
|
| 44 |
-
|
| 45 |
-
allocated_after = torch.cuda.memory_allocated(0) / 1024**3 if DEVICE == "cuda" else 0
|
| 46 |
-
reserved_after = torch.cuda.memory_reserved(0) / 1024**3 if DEVICE == "cuda" else 0
|
| 47 |
-
print(f"After clearing: Allocated={allocated_after:.2f} GB, Reserved={reserved_after:.2f} GB")
|
| 48 |
-
print("-" * 20)
|
| 49 |
-
|
| 50 |
CACHED_PIPES = {}
|
| 51 |
def load_bf16_pipeline():
|
| 52 |
"""Loads the original FLUX.1-dev pipeline in BF16 precision."""
|
|
@@ -120,7 +95,7 @@ def load_bnb_4bit_pipeline():
|
|
| 120 |
|
| 121 |
@spaces.GPU(duration=240)
|
| 122 |
def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
|
| 123 |
-
"""Loads original and selected quantized model, generates one image each,
|
| 124 |
if not prompt:
|
| 125 |
return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
|
| 126 |
|
|
@@ -161,12 +136,6 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
|
|
| 161 |
print(f"\n--- Loading {label} Model ---")
|
| 162 |
load_start_time = time.time()
|
| 163 |
try:
|
| 164 |
-
# Ensure previous pipe is cleared *before* loading the next
|
| 165 |
-
# if current_pipe:
|
| 166 |
-
# print(f"--- Clearing memory before loading {label} Model ---")
|
| 167 |
-
# clear_gpu_memory(current_pipe)
|
| 168 |
-
# current_pipe = None
|
| 169 |
-
|
| 170 |
current_pipe = load_func()
|
| 171 |
load_end_time = time.time()
|
| 172 |
print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.")
|
|
@@ -184,22 +153,11 @@ def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm
|
|
| 184 |
|
| 185 |
except Exception as e:
|
| 186 |
print(f"Error during {label} model processing: {e}")
|
| 187 |
-
# Attempt cleanup
|
| 188 |
-
if current_pipe:
|
| 189 |
-
print(f"--- Clearing memory after error with {label} Model ---")
|
| 190 |
-
clear_gpu_memory(current_pipe)
|
| 191 |
-
current_pipe = None
|
| 192 |
# Return error state to Gradio - update all outputs
|
| 193 |
return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
|
| 194 |
|
| 195 |
# No finally block needed here, cleanup happens before next load or after loop
|
| 196 |
|
| 197 |
-
# Final cleanup after the loop finishes successfully
|
| 198 |
-
# if current_pipe:
|
| 199 |
-
# print(f"--- Clearing memory after last model ({label}) ---")
|
| 200 |
-
# clear_gpu_memory(current_pipe)
|
| 201 |
-
# current_pipe = None
|
| 202 |
-
|
| 203 |
if len(results) != len(model_configs):
|
| 204 |
print("Generation did not complete for all models.")
|
| 205 |
# Update all outputs
|
|
@@ -275,7 +233,7 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
|
|
| 275 |
generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
|
| 276 |
|
| 277 |
output_gallery = gr.Gallery(
|
| 278 |
-
label="Generated Images
|
| 279 |
columns=2,
|
| 280 |
height=512,
|
| 281 |
object_fit="contain",
|
|
@@ -324,5 +282,5 @@ with gr.Blocks(title="FLUX Quantization Challenge", theme=gr.themes.Soft()) as d
|
|
| 324 |
|
| 325 |
if __name__ == "__main__":
|
| 326 |
# queue()
|
| 327 |
-
# demo.queue().launch()
|
| 328 |
-
demo.launch()
|
|
|
|
| 22 |
GENERATION_SEED = 0 # could use a random number generator to set this, for more variety
|
| 23 |
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN")
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
CACHED_PIPES = {}
|
| 26 |
def load_bf16_pipeline():
|
| 27 |
"""Loads the original FLUX.1-dev pipeline in BF16 precision."""
|
|
|
|
| 95 |
|
| 96 |
@spaces.GPU(duration=240)
|
| 97 |
def generate_images(prompt, quantization_choice, progress=gr.Progress(track_tqdm=True)):
|
| 98 |
+
"""Loads original and selected quantized model, generates one image each, shuffles results."""
|
| 99 |
if not prompt:
|
| 100 |
return None, {}, gr.update(value="Please enter a prompt.", interactive=False), gr.update(choices=[], value=None)
|
| 101 |
|
|
|
|
| 136 |
print(f"\n--- Loading {label} Model ---")
|
| 137 |
load_start_time = time.time()
|
| 138 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
current_pipe = load_func()
|
| 140 |
load_end_time = time.time()
|
| 141 |
print(f"{label} model loaded in {load_end_time - load_start_time:.2f} seconds.")
|
|
|
|
| 153 |
|
| 154 |
except Exception as e:
|
| 155 |
print(f"Error during {label} model processing: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
# Return error state to Gradio - update all outputs
|
| 157 |
return None, {}, gr.update(value=f"Error processing {label} model: {e}", interactive=False), gr.update(choices=[], value=None)
|
| 158 |
|
| 159 |
# No finally block needed here, cleanup happens before next load or after loop
|
| 160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
if len(results) != len(model_configs):
|
| 162 |
print("Generation did not complete for all models.")
|
| 163 |
# Update all outputs
|
|
|
|
| 233 |
generate_button = gr.Button("Generate & Compare", variant="primary", scale=1)
|
| 234 |
|
| 235 |
output_gallery = gr.Gallery(
|
| 236 |
+
label="Generated Images",
|
| 237 |
columns=2,
|
| 238 |
height=512,
|
| 239 |
object_fit="contain",
|
|
|
|
| 282 |
|
| 283 |
if __name__ == "__main__":
|
| 284 |
# queue()
|
| 285 |
+
# demo.queue().launch()
|
| 286 |
+
demo.launch(share=True)
|