Spaces:
Running on Zero
Running on Zero
Tianshuo-Xu commited on
Commit ·
ce4bbb3
1
Parent(s): 46f36ce
Fix float8 noise generation and fix gpu container download cache miss
Browse files- app.py +2 -4
- src/flux/xflux_pipeline.py +3 -0
app.py
CHANGED
|
@@ -106,11 +106,10 @@ def preload_model_files():
|
|
| 106 |
print(f"Warning: Could not pre-download Unicalli_Pro: {e}")
|
| 107 |
local_dir = None
|
| 108 |
|
| 109 |
-
# 2. T5 text encoder
|
| 110 |
try:
|
| 111 |
snapshot_download(
|
| 112 |
-
"
|
| 113 |
-
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.safetensors.index.json"],
|
| 114 |
token=hf_token
|
| 115 |
)
|
| 116 |
print("✓ T5 text encoder cached")
|
|
@@ -121,7 +120,6 @@ def preload_model_files():
|
|
| 121 |
try:
|
| 122 |
snapshot_download(
|
| 123 |
"openai/clip-vit-large-patch14",
|
| 124 |
-
allow_patterns=["*.safetensors", "*.json", "*.txt", "*.bin"],
|
| 125 |
token=hf_token
|
| 126 |
)
|
| 127 |
print("✓ CLIP text encoder cached")
|
|
|
|
| 106 |
print(f"Warning: Could not pre-download Unicalli_Pro: {e}")
|
| 107 |
local_dir = None
|
| 108 |
|
| 109 |
+
# 2. T5 text encoder
|
| 110 |
try:
|
| 111 |
snapshot_download(
|
| 112 |
+
"xlabs-ai/xflux_text_encoders",
|
|
|
|
| 113 |
token=hf_token
|
| 114 |
)
|
| 115 |
print("✓ T5 text encoder cached")
|
|
|
|
| 120 |
try:
|
| 121 |
snapshot_download(
|
| 122 |
"openai/clip-vit-large-patch14",
|
|
|
|
| 123 |
token=hf_token
|
| 124 |
)
|
| 125 |
print("✓ CLIP text encoder cached")
|
src/flux/xflux_pipeline.py
CHANGED
|
@@ -323,6 +323,9 @@ class XFluxPipeline:
|
|
| 323 |
else:
|
| 324 |
# Use model's dtype for efficient inference (fp16/bf16)
|
| 325 |
inference_dtype = next(self.model.parameters()).dtype
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
print(f"Using {inference_dtype} for inference")
|
| 328 |
|
|
|
|
| 323 |
else:
|
| 324 |
# Use model's dtype for efficient inference (fp16/bf16)
|
| 325 |
inference_dtype = next(self.model.parameters()).dtype
|
| 326 |
+
# PyTorch's torch.randn does not support Float8_e4m3fn
|
| 327 |
+
if getattr(torch, "float8_e4m3fn", None) and inference_dtype == torch.float8_e4m3fn:
|
| 328 |
+
inference_dtype = torch.bfloat16
|
| 329 |
|
| 330 |
print(f"Using {inference_dtype} for inference")
|
| 331 |
|