prithivMLmods commited on
Commit
08077e1
Β·
verified Β·
1 Parent(s): 740cd8f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -1
app.py CHANGED
@@ -11,6 +11,67 @@ import spaces
11
  import numpy as np
12
  import torch
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  # ──────────────────────────────────────────────
15
  # Paths
16
  # ──────────────────────────────────────────────
@@ -460,4 +521,4 @@ Checkpoints are auto-downloaded (~50 GB) from `nvidia/Lyra-2.0` on HuggingFace a
460
 
461
  if __name__ == "__main__":
462
  demo = build_app()
463
- demo.launch(css=CSS, ssr_mode=False)
 
11
  import numpy as np
12
  import torch
13
 
14
+ # ──────────────────────────────────────────────
15
+ # flash_attn install β€” must happen before any
16
+ # lyra_2 import that pulls in flash_attn.
17
+ # We pick the pre-built wheel that matches the
18
+ # running torch+CUDA version automatically.
19
+ # ──────────────────────────────────────────────
20
+
21
+ def _install_flash_attn():
22
+ """
23
+ Install flash-attn from the pre-built wheels hosted on GitHub.
24
+ Matches the wheel to the running torch + CUDA version so the
25
+ .so symbols line up β€” which is exactly what caused the
26
+ 'undefined symbol: _ZN3c104cuda...' error.
27
+ """
28
+ try:
29
+ import flash_attn # already installed and importable β†’ done
30
+ return
31
+ except ImportError:
32
+ pass
33
+
34
+ import torch, platform
35
+
36
+ torch_ver = torch.__version__.split("+")[0].replace(".", "") # e.g. "240"
37
+ cuda_ver = torch.version.cuda.replace(".", "") # e.g. "121"
38
+ py_ver = f"cp{sys.version_info.major}{sys.version_info.minor}" # e.g. "cp310"
39
+ arch = platform.machine() # "x86_64"
40
+
41
+ # Official pre-built wheel index from the flash-attn GitHub releases.
42
+ # Pattern: flash_attn-<fa_ver>+pt<torch>cu<cuda>-<py>-<py>-linux_<arch>.whl
43
+ # We try the newest FA2 release first then fall back to pip --no-build-isolation.
44
+ wheel_url = (
45
+ f"https://github.com/Dao-AILab/flash-attention/releases/download/"
46
+ f"v2.7.4.post1/"
47
+ f"flash_attn-2.7.4.post1+pt{torch_ver}cu{cuda_ver}-{py_ver}-{py_ver}"
48
+ f"-linux_{arch}.whl"
49
+ )
50
+
51
+ print(f"[Lyra] Installing flash-attn wheel: {wheel_url}")
52
+ result = subprocess.run(
53
+ [sys.executable, "-m", "pip", "install", wheel_url, "--no-deps", "-q"],
54
+ capture_output=True, text=True,
55
+ )
56
+ if result.returncode != 0:
57
+ print(f"[Lyra] Pre-built wheel not found for this env, "
58
+ f"falling back to pip install flash-attn --no-build-isolation ...")
59
+ # This compiles from source β€” slow (~20 min) but always works.
60
+ subprocess.run(
61
+ [sys.executable, "-m", "pip", "install", "flash-attn",
62
+ "--no-build-isolation", "-q"],
63
+ check=True,
64
+ )
65
+
66
+ try:
67
+ import flash_attn
68
+ print(f"[Lyra] flash_attn {flash_attn.__version__} ready.")
69
+ except ImportError as e:
70
+ raise RuntimeError(f"flash_attn install succeeded but import still fails: {e}")
71
+
72
+
73
+ _install_flash_attn()
74
+
75
  # ──────────────────────────────────────────────
76
  # Paths
77
  # ──────────────────────────────────────────────
 
521
 
522
  if __name__ == "__main__":
523
  demo = build_app()
524
+ demo.launch(css=CSS)