Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import sys
|
| 3 |
-
import subprocess
|
| 4 |
import torch
|
| 5 |
import datetime
|
| 6 |
import numpy as np
|
|
@@ -142,23 +144,6 @@ try:
|
|
| 142 |
from hyvideo.commons.infer_state import initialize_infer_state
|
| 143 |
# Import the specific I2V System Prompt from the repo
|
| 144 |
from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt
|
| 145 |
-
|
| 146 |
-
# --- FIX: Force Disable Flash Attention Patch ---
|
| 147 |
-
import hyvideo.models.transformers.modules.attention
|
| 148 |
-
|
| 149 |
-
print("🛠️ Patching Attention Mode to 'torch' (SDPA) to bypass Flash Attn check...")
|
| 150 |
-
|
| 151 |
-
def patched_fallback(attn_mode, infer_state=None, block_idx=None):
|
| 152 |
-
# Always return 'torch' to bypass the flash-attn check
|
| 153 |
-
return "torch"
|
| 154 |
-
|
| 155 |
-
# Patch the source definition in commons
|
| 156 |
-
hyvideo.commons.maybe_fallback_attn_mode = patched_fallback
|
| 157 |
-
|
| 158 |
-
# Patch the reference inside the attention module (crucial for TokenRefiner which imports it)
|
| 159 |
-
hyvideo.models.transformers.modules.attention.maybe_fallback_attn_mode = patched_fallback
|
| 160 |
-
# ------------------------------------------------
|
| 161 |
-
|
| 162 |
except ImportError as e:
|
| 163 |
print(f"CRITICAL ERROR: {e}")
|
| 164 |
sys.exit(1)
|
|
|
|
| 1 |
+
import subprocess
|
| 2 |
+
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 3 |
+
|
| 4 |
import os
|
| 5 |
import sys
|
|
|
|
| 6 |
import torch
|
| 7 |
import datetime
|
| 8 |
import numpy as np
|
|
|
|
| 144 |
from hyvideo.commons.infer_state import initialize_infer_state
|
| 145 |
# Import the specific I2V System Prompt from the repo
|
| 146 |
from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
except ImportError as e:
|
| 148 |
print(f"CRITICAL ERROR: {e}")
|
| 149 |
sys.exit(1)
|