Tianshuo-Xu commited on
Commit
ce4bbb3
·
1 Parent(s): 46f36ce

Fix float8 noise generation and fix gpu container download cache miss

Browse files
Files changed (2) hide show
  1. app.py +2 -4
  2. src/flux/xflux_pipeline.py +3 -0
app.py CHANGED
@@ -106,11 +106,10 @@ def preload_model_files():
106
  print(f"Warning: Could not pre-download Unicalli_Pro: {e}")
107
  local_dir = None
108
 
109
- # 2. T5 text encoder (includes model-00001-of-00002.safetensors, model-00002-of-00002.safetensors)
110
  try:
111
  snapshot_download(
112
- "XLabs-AI/xflux_text_encoders",
113
- allow_patterns=["*.safetensors", "*.json", "*.txt", "*.safetensors.index.json"],
114
  token=hf_token
115
  )
116
  print("✓ T5 text encoder cached")
@@ -121,7 +120,6 @@ def preload_model_files():
121
  try:
122
  snapshot_download(
123
  "openai/clip-vit-large-patch14",
124
- allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin"],
125
  token=hf_token
126
  )
127
  print("✓ CLIP text encoder cached")
 
106
  print(f"Warning: Could not pre-download Unicalli_Pro: {e}")
107
  local_dir = None
108
 
109
+ # 2. T5 text encoder
110
  try:
111
  snapshot_download(
112
+ "xlabs-ai/xflux_text_encoders",
 
113
  token=hf_token
114
  )
115
  print("✓ T5 text encoder cached")
 
120
  try:
121
  snapshot_download(
122
  "openai/clip-vit-large-patch14",
 
123
  token=hf_token
124
  )
125
  print("✓ CLIP text encoder cached")
src/flux/xflux_pipeline.py CHANGED
@@ -323,6 +323,9 @@ class XFluxPipeline:
323
  else:
324
  # Use model's dtype for efficient inference (fp16/bf16)
325
  inference_dtype = next(self.model.parameters()).dtype
 
 
 
326
 
327
  print(f"Using {inference_dtype} for inference")
328
 
 
323
  else:
324
  # Use model's dtype for efficient inference (fp16/bf16)
325
  inference_dtype = next(self.model.parameters()).dtype
326
+ # PyTorch's torch.randn does not support Float8_e4m3fn
327
+ if getattr(torch, "float8_e4m3fn", None) and inference_dtype == torch.float8_e4m3fn:
328
+ inference_dtype = torch.bfloat16
329
 
330
  print(f"Using {inference_dtype} for inference")
331