primerz commited on
Commit
f6a0f97
·
verified ·
1 Parent(s): 352ce47

Update models.py

Browse files
Files changed (1) hide show
  1. models.py +26 -0
models.py CHANGED
@@ -4,6 +4,7 @@ UPDATED VERSION with proper InstantID pipeline support
4
  """
5
  import torch
6
  import time
 
7
  from diffusers import (
8
  ControlNetModel,
9
  AutoencoderKL,
@@ -208,6 +209,31 @@ def optimize_pipeline(pipe):
208
  print(" [OK] xformers enabled")
209
  except Exception as e:
210
  print(f" [INFO] xformers not available: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
 
212
 
213
  def load_caption_model():
 
4
  """
5
  import torch
6
  import time
7
+ import os
8
  from diffusers import (
9
  ControlNetModel,
10
  AutoencoderKL,
 
209
  print(" [OK] xformers enabled")
210
  except Exception as e:
211
  print(f" [INFO] xformers not available: {e}")
212
+
213
+ # Additional optimizations for memory efficiency
214
+ if hasattr(pipe, 'enable_vae_slicing'):
215
+ pipe.enable_vae_slicing()
216
+ print(" [OK] VAE slicing enabled for memory efficiency")
217
+
218
+ if hasattr(pipe, 'enable_vae_tiling'):
219
+ pipe.enable_vae_tiling()
220
+ print(" [OK] VAE tiling enabled for memory efficiency")
221
+
222
+ # For Zero GPU environments (HuggingFace Spaces)
223
+ if os.environ.get("SPACE_ID"):
224
+ print(" [INFO] Detected HuggingFace Spaces environment")
225
+ if hasattr(pipe, 'enable_model_cpu_offload'):
226
+ # Don't enable CPU offload if using Zero GPU
227
+ # as it conflicts with @spaces.GPU decorator
228
+ print(" [INFO] Zero GPU environment - skipping CPU offload (handled by @spaces.GPU)")
229
+
230
+ # Ensure dtype consistency
231
+ if hasattr(pipe, 'vae') and pipe.vae is not None:
232
+ pipe.vae = pipe.vae.to(dtype=dtype)
233
+ if hasattr(pipe, 'text_encoder') and pipe.text_encoder is not None:
234
+ pipe.text_encoder = pipe.text_encoder.to(dtype=dtype)
235
+ if hasattr(pipe, 'text_encoder_2') and pipe.text_encoder_2 is not None:
236
+ pipe.text_encoder_2 = pipe.text_encoder_2.to(dtype=dtype)
237
 
238
 
239
  def load_caption_model():