[Admin maintenance] Support new ZeroGPU hardware

#4
by multimodalart HF Staff - opened
Files changed (3) hide show
  1. README.md +2 -1
  2. app.py +12 -7
  3. requirements.txt +14 -13
README.md CHANGED
@@ -4,7 +4,8 @@ emoji: 🚀
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.31.0
 
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
4
  colorFrom: blue
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 5.49.1
8
+ python_version: "3.12"
9
  app_file: app.py
10
  pinned: false
11
  license: apache-2.0
app.py CHANGED
@@ -1,15 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
  import numpy as np
4
- import os
5
  import torch
6
  import random
7
- import subprocess
8
- subprocess.run(
9
- "pip install flash-attn --no-build-isolation",
10
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
11
- shell=True,
12
- )
13
 
14
  from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
15
  from PIL import Image
 
1
+ import os
2
+ import ctypes
3
+
4
+ # Preload CUDA 13 runtime so flash-attn's prebuilt cu13 wheel can find libcudart.so.13
5
+ _CUDA_LIBDIR = "/cuda-image/usr/local/cuda-13.0/lib64"
6
+ if os.path.isdir(_CUDA_LIBDIR):
7
+ os.environ["LD_LIBRARY_PATH"] = _CUDA_LIBDIR + os.pathsep + os.environ.get("LD_LIBRARY_PATH", "")
8
+ try:
9
+ ctypes.CDLL(os.path.join(_CUDA_LIBDIR, "libcudart.so.13"), mode=ctypes.RTLD_GLOBAL)
10
+ except OSError:
11
+ pass
12
+
13
  import spaces
14
  import gradio as gr
15
  import numpy as np
 
16
  import torch
17
  import random
 
 
 
 
 
 
18
 
19
  from accelerate import infer_auto_device_map, load_checkpoint_and_dispatch, init_empty_weights
20
  from PIL import Image
requirements.txt CHANGED
@@ -1,17 +1,18 @@
1
  decord==0.6.0
2
  einops==0.8.1
3
- huggingface_hub==0.29.1
4
- matplotlib==3.7.0
5
- numpy==1.24.4
6
- opencv_python==4.7.0.72
7
- pyarrow==11.0.0
8
- PyYAML==6.0.2
9
- Requests==2.32.3
10
- safetensors==0.4.5
11
- scipy==1.10.1
12
- sentencepiece==0.1.99
13
- torch==2.5.1
14
- torchvision==0.20.1
15
  transformers==4.49.0
16
  accelerate>=0.34.0
17
- wandb
 
 
1
  decord==0.6.0
2
  einops==0.8.1
3
+ huggingface_hub
4
+ matplotlib
5
+ numpy
6
+ opencv_python
7
+ pyarrow
8
+ PyYAML
9
+ Requests
10
+ safetensors
11
+ scipy
12
+ sentencepiece
13
+ torch==2.10.0
14
+ torchvision==0.25.0
15
  transformers==4.49.0
16
  accelerate>=0.34.0
17
+ wandb
18
+ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.1/flash_attn-2.8.1+cu13torch2.10cxx11abiTRUE-cp312-cp312-linux_x86_64.whl