H-Liu1997 commited on
Commit
bb7e158
·
1 Parent(s): 3042363

fix: patch flash_attention with SDPA fallback for T4 (no flash-attn)

Browse files
Files changed (1) hide show
  1. model_manager.py +55 -1
model_manager.py CHANGED
@@ -87,11 +87,65 @@ class ModelManager:
87
 
88
  print("ModelManager initialized successfully")
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def _load_models(self, model_name):
91
  """Load VAE and diffusion models from HF Hub"""
92
  torch.set_float32_matmul_precision("high")
93
 
94
- print(f"Loading model from HF Hub: {model_name}")
 
 
 
 
 
 
 
 
95
  from transformers import AutoModel
96
 
97
  hf_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
 
87
 
88
  print("ModelManager initialized successfully")
89
 
90
+ def _patch_attention_sdpa(self, model_name):
91
+ """Patch flash_attention() to include SDPA fallback for GPUs without flash-attn (e.g., T4)."""
92
+ import glob
93
+ import os
94
+
95
+ hf_cache = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
96
+ patterns = [
97
+ os.path.join(
98
+ hf_cache, "hub", "models--" + model_name.replace("/", "--"),
99
+ "snapshots", "*", "ldf_models", "tools", "attention.py",
100
+ ),
101
+ os.path.join(
102
+ hf_cache, "modules", "transformers_modules", model_name,
103
+ "*", "ldf_models", "tools", "attention.py",
104
+ ),
105
+ ]
106
+
107
+ target = ' assert q.device.type == "cuda" and q.size(-1) <= 256'
108
+ sdpa_fallback = target + "\n" + (
109
+ "\n"
110
+ " # SDPA fallback when flash-attn is not available (e.g., T4 GPU)\n"
111
+ " if not FLASH_ATTN_2_AVAILABLE and not FLASH_ATTN_3_AVAILABLE:\n"
112
+ " if q_lens is not None or k_lens is not None:\n"
113
+ ' warnings.warn("Padding mask disabled with scaled_dot_product_attention")\n'
114
+ " q = q.transpose(1, 2).to(dtype)\n"
115
+ " k = k.transpose(1, 2).to(dtype)\n"
116
+ " v = v.transpose(1, 2).to(dtype)\n"
117
+ " out = torch.nn.functional.scaled_dot_product_attention(\n"
118
+ " q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p\n"
119
+ " )\n"
120
+ " return out.transpose(1, 2).contiguous()\n"
121
+ )
122
+
123
+ for pattern in patterns:
124
+ for filepath in glob.glob(pattern):
125
+ with open(filepath, "r") as f:
126
+ content = f.read()
127
+ if "SDPA fallback" in content:
128
+ print(f"Already patched: {filepath}")
129
+ continue
130
+ if target in content:
131
+ content = content.replace(target, sdpa_fallback, 1)
132
+ with open(filepath, "w") as f:
133
+ f.write(content)
134
+ print(f"Patched with SDPA fallback: {filepath}")
135
+
136
  def _load_models(self, model_name):
137
  """Load VAE and diffusion models from HF Hub"""
138
  torch.set_float32_matmul_precision("high")
139
 
140
+ # Pre-download model files to hub cache
141
+ print(f"Downloading model from HF Hub: {model_name}")
142
+ from huggingface_hub import snapshot_download
143
+ snapshot_download(model_name)
144
+
145
+ # Patch flash_attention with SDPA fallback for T4 (no flash-attn)
146
+ self._patch_attention_sdpa(model_name)
147
+
148
+ print("Loading model...")
149
  from transformers import AutoModel
150
 
151
  hf_model = AutoModel.from_pretrained(model_name, trust_remote_code=True)