Spaces:
Running on Zero
Running on Zero
Tianshuo-Xu commited on
Commit ·
c49775d
1
Parent(s): 2292172
Speed up Space by initializing globally and keeping on GPU, remove manual offload
Browse files- app.py +10 -5
- src/flux/xflux_pipeline.py +3 -3
app.py
CHANGED
|
@@ -223,7 +223,7 @@ def init_generator():
|
|
| 223 |
generator = CalligraphyGenerator(
|
| 224 |
model_name="flux-dev",
|
| 225 |
device="cuda",
|
| 226 |
-
offload=
|
| 227 |
intern_vlm_path=intern_vlm_path,
|
| 228 |
checkpoint_path=checkpoint_path,
|
| 229 |
font_descriptions_path='dataset/chirography.json',
|
|
@@ -261,9 +261,14 @@ def parse_font_style(font_style: str) -> str:
|
|
| 261 |
return None
|
| 262 |
|
| 263 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
def _get_generation_duration(text, font, author, num_steps, start_seed, num_images):
|
| 265 |
-
"""Calculate dynamic GPU duration:
|
| 266 |
-
return
|
| 267 |
|
| 268 |
|
| 269 |
@spaces.GPU(duration=_get_generation_duration)
|
|
@@ -273,8 +278,8 @@ def run_generation(text, font, author, num_steps, start_seed, num_images):
|
|
| 273 |
All in one GPU session to avoid redundant loading.
|
| 274 |
"""
|
| 275 |
# Step 1: Load model
|
| 276 |
-
logger.info("
|
| 277 |
-
gen =
|
| 278 |
|
| 279 |
# Step 2: Since we reverted to bf16 load to avoid PyTorch native dtype mix issues, skip wrapping
|
| 280 |
logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.")
|
|
|
|
| 223 |
generator = CalligraphyGenerator(
|
| 224 |
model_name="flux-dev",
|
| 225 |
device="cuda",
|
| 226 |
+
offload=False, # Set to False to let ZeroGPU manage CUDA memory directly instead of manual CPU thrashing
|
| 227 |
intern_vlm_path=intern_vlm_path,
|
| 228 |
checkpoint_path=checkpoint_path,
|
| 229 |
font_descriptions_path='dataset/chirography.json',
|
|
|
|
| 261 |
return None
|
| 262 |
|
| 263 |
|
| 264 |
+
# Initialize the generator globally BEFORE zeroGPU functions so weights are memory-mapped
|
| 265 |
+
logger.info("Initializing generator globally...")
|
| 266 |
+
generator = init_generator()
|
| 267 |
+
|
| 268 |
+
|
| 269 |
def _get_generation_duration(text, font, author, num_steps, start_seed, num_images):
|
| 270 |
+
"""Calculate dynamic GPU duration: 20s base + 2s per step per image"""
|
| 271 |
+
return 20 + int(2 * num_steps * num_images)
|
| 272 |
|
| 273 |
|
| 274 |
@spaces.GPU(duration=_get_generation_duration)
|
|
|
|
| 278 |
All in one GPU session to avoid redundant loading.
|
| 279 |
"""
|
| 280 |
# Step 1: Load model
|
| 281 |
+
logger.info("Models are already globally initialized and managed by ZeroGPU.")
|
| 282 |
+
gen = generator
|
| 283 |
|
| 284 |
# Step 2: Since we reverted to bf16 load to avoid PyTorch native dtype mix issues, skip wrapping
|
| 285 |
logger.info("Model weights decompressed to bfloat16 upon load. Skipping dynamic quantization to ensure stability.")
|
src/flux/xflux_pipeline.py
CHANGED
|
@@ -312,9 +312,9 @@ class XFluxPipeline:
|
|
| 312 |
neg_ip_scale=1.0,
|
| 313 |
is_generation=True,
|
| 314 |
):
|
| 315 |
-
#
|
| 316 |
-
torch.backends.cuda.matmul.allow_tf32 = False
|
| 317 |
-
torch.backends.cudnn.allow_tf32 = False
|
| 318 |
|
| 319 |
# Determine inference dtype from model
|
| 320 |
if hasattr(self.model, '_is_quantized') and self.model._is_quantized:
|
|
|
|
| 312 |
neg_ip_scale=1.0,
|
| 313 |
is_generation=True,
|
| 314 |
):
|
| 315 |
+
# Allow TF32 for much faster inference
|
| 316 |
+
# torch.backends.cuda.matmul.allow_tf32 = False
|
| 317 |
+
# torch.backends.cudnn.allow_tf32 = False
|
| 318 |
|
| 319 |
# Determine inference dtype from model
|
| 320 |
if hasattr(self.model, '_is_quantized') and self.model._is_quantized:
|