Commit ·
ca0ae99
1
Parent(s): a930359
App: fix app startup issues.
Browse filesSee https://huggingface.co/spaces/Lightricks/LTX-2-3/discussions/17 for more details.
app.py
CHANGED
|
@@ -64,6 +64,82 @@ try:
|
|
| 64 |
except Exception as e:
|
| 65 |
print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
|
| 66 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
logging.getLogger().setLevel(logging.INFO)
|
| 68 |
|
| 69 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 64 |
except Exception as e:
|
| 65 |
print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
|
| 66 |
|
| 67 |
+
# Disable xformers FA3 dispatch: FA3 kernels are Hopper-only (sm_90a), but
|
| 68 |
+
# xformers' dispatcher gates them on `device_capability >= (9, 0)`, which also
|
| 69 |
+
# matches Blackwell (RTX PRO 6000, the ZeroGPU fleet hardware since 2026-05-12)
|
| 70 |
+
# and crashes at kernel launch with "invalid argument".
|
| 71 |
+
try:
|
| 72 |
+
from xformers.ops.fmha import _set_use_fa3
|
| 73 |
+
_set_use_fa3(False)
|
| 74 |
+
print("[ATTN] xformers FA3 dispatch disabled (Blackwell-incompatible)")
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"[ATTN] FA3 disable FAILED: {type(e).__name__}: {e}")
|
| 77 |
+
|
| 78 |
+
# FUSE/mmap workaround: SafetensorsStateDictLoader.load uses safetensors.safe_open
|
| 79 |
+
# under the hood, which mmap's the file. On bucket FUSE mounts that triggers a
|
| 80 |
+
# page-fault storm and deadlocks loading. Bypass mmap by parsing the safetensors
|
| 81 |
+
# header ourselves and reading each tensor's bytes directly.
|
| 82 |
+
import json
|
| 83 |
+
import struct
|
| 84 |
+
|
| 85 |
+
from ltx_core.loader.primitives import StateDict
|
| 86 |
+
from ltx_core.loader.sft_loader import SafetensorsStateDictLoader
|
| 87 |
+
|
| 88 |
+
_SAFETENSORS_DTYPE_MAP = {
|
| 89 |
+
"F64": torch.float64,
|
| 90 |
+
"F32": torch.float32,
|
| 91 |
+
"F16": torch.float16,
|
| 92 |
+
"BF16": torch.bfloat16,
|
| 93 |
+
"F8_E5M2": torch.float8_e5m2,
|
| 94 |
+
"F8_E4M3": torch.float8_e4m3fn,
|
| 95 |
+
"I64": torch.int64,
|
| 96 |
+
"I32": torch.int32,
|
| 97 |
+
"I16": torch.int16,
|
| 98 |
+
"I8": torch.int8,
|
| 99 |
+
"U8": torch.uint8,
|
| 100 |
+
"BOOL": torch.bool,
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _patched_load(self, path, sd_ops, device=None):
|
| 105 |
+
sd = {}
|
| 106 |
+
size = 0
|
| 107 |
+
dtype = set()
|
| 108 |
+
device = device or torch.device("cpu")
|
| 109 |
+
model_paths = path if isinstance(path, list) else [path]
|
| 110 |
+
for shard_path in model_paths:
|
| 111 |
+
with open(shard_path, "rb") as f:
|
| 112 |
+
header_len = struct.unpack("<Q", f.read(8))[0]
|
| 113 |
+
header = json.loads(f.read(header_len).decode("utf-8"))
|
| 114 |
+
data_base = 8 + header_len
|
| 115 |
+
for name, meta in header.items():
|
| 116 |
+
if name == "__metadata__":
|
| 117 |
+
continue
|
| 118 |
+
expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
|
| 119 |
+
if expected_name is None:
|
| 120 |
+
continue
|
| 121 |
+
start, end = meta["data_offsets"]
|
| 122 |
+
f.seek(data_base + start)
|
| 123 |
+
buf = f.read(end - start)
|
| 124 |
+
t = torch.frombuffer(
|
| 125 |
+
bytearray(buf), dtype=_SAFETENSORS_DTYPE_MAP[meta["dtype"]]
|
| 126 |
+
).reshape(meta["shape"])
|
| 127 |
+
t = t.to(device=device, non_blocking=True, copy=False)
|
| 128 |
+
kvs = (
|
| 129 |
+
((expected_name, t),)
|
| 130 |
+
if sd_ops is None
|
| 131 |
+
else sd_ops.apply_to_key_value(expected_name, t)
|
| 132 |
+
)
|
| 133 |
+
for key, v in kvs:
|
| 134 |
+
size += v.nbytes
|
| 135 |
+
dtype.add(v.dtype)
|
| 136 |
+
sd[key] = v
|
| 137 |
+
return StateDict(sd=sd, device=device, size=size, dtype=dtype)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
SafetensorsStateDictLoader.load = _patched_load
|
| 141 |
+
print("[FUSE-PATCH] SafetensorsStateDictLoader.load replaced (chunked-read)")
|
| 142 |
+
|
| 143 |
logging.getLogger().setLevel(logging.INFO)
|
| 144 |
|
| 145 |
MAX_SEED = np.iinfo(np.int32).max
|