StatusReport commited on
Commit
ca0ae99
·
1 Parent(s): a930359

App: fix app startup issues.

Browse files

See https://huggingface.co/spaces/Lightricks/LTX-2-3/discussions/17 for more details.

Files changed (1) hide show
  1. app.py +76 -0
app.py CHANGED
@@ -64,6 +64,82 @@ try:
64
  except Exception as e:
65
  print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  logging.getLogger().setLevel(logging.INFO)
68
 
69
  MAX_SEED = np.iinfo(np.int32).max
 
64
  except Exception as e:
65
  print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
66
 
67
+ # Disable xformers FA3 dispatch: FA3 kernels are Hopper-only (sm_90a), but
68
+ # xformers' dispatcher gates them on `device_capability >= (9, 0)`, which also
69
+ # matches Blackwell (RTX PRO 6000, the ZeroGPU fleet hardware since 2026-05-12)
70
+ # and crashes at kernel launch with "invalid argument".
71
+ try:
72
+ from xformers.ops.fmha import _set_use_fa3
73
+ _set_use_fa3(False)
74
+ print("[ATTN] xformers FA3 dispatch disabled (Blackwell-incompatible)")
75
+ except Exception as e:
76
+ print(f"[ATTN] FA3 disable FAILED: {type(e).__name__}: {e}")
77
+
78
+ # FUSE/mmap workaround: SafetensorsStateDictLoader.load uses safetensors.safe_open
79
+ # under the hood, which mmap's the file. On bucket FUSE mounts that triggers a
80
+ # page-fault storm and deadlocks loading. Bypass mmap by parsing the safetensors
81
+ # header ourselves and reading each tensor's bytes directly.
82
+ import json
83
+ import struct
84
+
85
+ from ltx_core.loader.primitives import StateDict
86
+ from ltx_core.loader.sft_loader import SafetensorsStateDictLoader
87
+
88
+ _SAFETENSORS_DTYPE_MAP = {
89
+ "F64": torch.float64,
90
+ "F32": torch.float32,
91
+ "F16": torch.float16,
92
+ "BF16": torch.bfloat16,
93
+ "F8_E5M2": torch.float8_e5m2,
94
+ "F8_E4M3": torch.float8_e4m3fn,
95
+ "I64": torch.int64,
96
+ "I32": torch.int32,
97
+ "I16": torch.int16,
98
+ "I8": torch.int8,
99
+ "U8": torch.uint8,
100
+ "BOOL": torch.bool,
101
+ }
102
+
103
+
104
+ def _patched_load(self, path, sd_ops, device=None):
105
+ sd = {}
106
+ size = 0
107
+ dtype = set()
108
+ device = device or torch.device("cpu")
109
+ model_paths = path if isinstance(path, list) else [path]
110
+ for shard_path in model_paths:
111
+ with open(shard_path, "rb") as f:
112
+ header_len = struct.unpack("<Q", f.read(8))[0]
113
+ header = json.loads(f.read(header_len).decode("utf-8"))
114
+ data_base = 8 + header_len
115
+ for name, meta in header.items():
116
+ if name == "__metadata__":
117
+ continue
118
+ expected_name = name if sd_ops is None else sd_ops.apply_to_key(name)
119
+ if expected_name is None:
120
+ continue
121
+ start, end = meta["data_offsets"]
122
+ f.seek(data_base + start)
123
+ buf = f.read(end - start)
124
+ t = torch.frombuffer(
125
+ bytearray(buf), dtype=_SAFETENSORS_DTYPE_MAP[meta["dtype"]]
126
+ ).reshape(meta["shape"])
127
+ t = t.to(device=device, non_blocking=True, copy=False)
128
+ kvs = (
129
+ ((expected_name, t),)
130
+ if sd_ops is None
131
+ else sd_ops.apply_to_key_value(expected_name, t)
132
+ )
133
+ for key, v in kvs:
134
+ size += v.nbytes
135
+ dtype.add(v.dtype)
136
+ sd[key] = v
137
+ return StateDict(sd=sd, device=device, size=size, dtype=dtype)
138
+
139
+
140
+ SafetensorsStateDictLoader.load = _patched_load
141
+ print("[FUSE-PATCH] SafetensorsStateDictLoader.load replaced (chunked-read)")
142
+
143
  logging.getLogger().setLevel(logging.INFO)
144
 
145
  MAX_SEED = np.iinfo(np.int32).max