Bapt120 commited on
Commit
6807791
·
verified ·
1 Parent(s): 299e18a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -10
app.py CHANGED
@@ -7,14 +7,6 @@ import spaces
7
  import torch
8
 
9
 
10
- # Install flash-attn for GPU only (after spaces import)
11
- if torch.cuda.is_available():
12
- print("CUDA detected - installing flash-attn for optimal GPU performance...")
13
- subprocess.run(
14
- "pip install flash-attn --no-build-isolation",
15
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
16
- shell=True,
17
- )
18
 
19
  import gradio as gr
20
  from PIL import Image
@@ -29,9 +21,9 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
29
 
30
  # Choose best attention implementation based on device
31
  if device == "cuda":
32
- attn_implementation = "flash_attention_2" # Best for GPU
33
  dtype = torch.bfloat16
34
- print("Using flash_attention_2 for GPU")
35
  else:
36
  attn_implementation = "eager" # Best for CPU
37
  dtype = torch.float32
 
7
  import torch
8
 
9
 
 
 
 
 
 
 
 
 
10
 
11
  import gradio as gr
12
  from PIL import Image
 
21
 
22
  # Choose best attention implementation based on device
23
  if device == "cuda":
24
+ attn_implementation = "sdpa"
25
  dtype = torch.bfloat16
26
+ print("Using sdpa for GPU")
27
  else:
28
  attn_implementation = "eager" # Best for CPU
29
  dtype = torch.float32