aeb56 commited on
Commit
2900b36
Β·
1 Parent(s): b705945

Monkey-patch transformers to disable flash attention via wrapper script

Browse files
Files changed (1) hide show
  1. app.py +31 -57
app.py CHANGED
@@ -215,80 +215,54 @@ class ChatBot:
215
  logs += f"⏱️ Estimated time: 30-60 minutes\n\n"
216
  yield status_table, logs
217
 
218
- # Create a fake flash_attn package to avoid import errors
219
- # This will fallback to standard PyTorch attention
220
- fake_flash_dir = f"/tmp/flash_attn_{timestamp}"
221
- os.makedirs(fake_flash_dir, exist_ok=True)
222
-
223
- with open(os.path.join(fake_flash_dir, "__init__.py"), 'w') as f:
224
- f.write("""
225
- # Fake flash_attn module that falls back to standard PyTorch attention
226
- import torch
227
 
228
- def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, **kwargs):
229
- '''Fallback to standard PyTorch attention (slower but works without flash-attn)'''
230
- if softmax_scale is None:
231
- softmax_scale = 1.0 / (q.size(-1) ** 0.5)
232
-
233
- # Standard attention: softmax(Q @ K.T) @ V
234
- attn_weights = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale
235
-
236
- if causal:
237
- seq_len = attn_weights.size(-1)
238
- causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=attn_weights.device), diagonal=1).bool()
239
- attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
240
-
241
- attn_weights = torch.softmax(attn_weights, dim=-1)
242
-
243
- if dropout_p > 0:
244
- attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_p)
245
-
246
- output = torch.matmul(attn_weights, v)
247
- return output, None # Return None for attention weights
248
 
249
- def flash_attn_varlen_func(*args, **kwargs):
250
- return flash_attn_func(*args, **kwargs)
 
 
 
 
 
 
 
 
251
 
252
- __version__ = "2.5.0"
 
253
  """)
254
 
255
- # Add fake package to Python path for subprocess
256
- import sys
257
- if f"/tmp" not in sys.path:
258
- sys.path.insert(0, "/tmp")
259
-
260
- # Set PYTHONPATH environment variable so subprocess can find fake flash_attn
261
- env = os.environ.copy()
262
- pythonpath = env.get('PYTHONPATH', '')
263
- env['PYTHONPATH'] = f"/tmp:{pythonpath}" if pythonpath else "/tmp"
264
-
265
- logs += "⚠️ **Note:** Using fallback PyTorch attention (slower than flash-attn)\n\n"
266
  yield status_table, logs
267
 
268
- # Run lm_eval
269
- cmd = [
270
- "lm_eval",
271
- "--model", "hf",
272
- "--model_args", f"pretrained={MODEL_NAME},trust_remote_code=True,dtype=bfloat16,low_cpu_mem_usage=True,parallelize=True",
273
- "--tasks", task_string,
274
- "--batch_size", "1",
275
- "--output_path", output_dir,
276
- "--log_samples"
277
- ]
278
 
279
  status_table = self._create_status_table(tasks_to_run, "πŸ”„ Running")
280
- logs += f"πŸ”„ **Running lm_eval...**\n\nCommand: `{' '.join(cmd)}`\n\n"
281
  logs += "---\n\n### πŸ“œ Live Logs (last 15 lines):\n\n```\n"
282
  yield status_table, logs
283
 
284
- # Run evaluation with custom environment
285
  process = subprocess.Popen(
286
  cmd,
287
  stdout=subprocess.PIPE,
288
  stderr=subprocess.STDOUT,
289
  text=True,
290
- bufsize=1,
291
- env=env # Pass custom environment with PYTHONPATH
292
  )
293
 
294
  output_lines = []
 
215
  logs += f"⏱️ Estimated time: 30-60 minutes\n\n"
216
  yield status_table, logs
217
 
218
+ # Create a wrapper script that disables flash attention before running lm_eval
219
+ wrapper_script = f"/tmp/run_eval_{timestamp}.py"
220
+ with open(wrapper_script, 'w') as f:
221
+ f.write(f"""
222
+ import sys
223
+ import os
 
 
 
224
 
225
+ # Monkey-patch transformers to disable flash attention
226
+ import transformers.modeling_flash_attention_utils as flash_utils
227
+
228
+ def disabled_lazy_import(*args, **kwargs):
229
+ raise ImportError("Flash attention disabled - using eager attention")
230
+
231
+ flash_utils.lazy_import_flash_attention = disabled_lazy_import
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
+ # Now run lm_eval
234
+ sys.argv = [
235
+ 'lm_eval',
236
+ '--model', 'hf',
237
+ '--model_args', 'pretrained={MODEL_NAME},trust_remote_code=True,dtype=bfloat16,low_cpu_mem_usage=True,parallelize=True,attn_implementation=eager',
238
+ '--tasks', '{task_string}',
239
+ '--batch_size', '1',
240
+ '--output_path', '{output_dir}',
241
+ '--log_samples'
242
+ ]
243
 
244
+ from lm_eval.__main__ import cli_evaluate
245
+ cli_evaluate()
246
  """)
247
 
248
+ logs += "⚠️ **Note:** Flash attention disabled, using eager attention (slower but compatible)\n\n"
 
 
 
 
 
 
 
 
 
 
249
  yield status_table, logs
250
 
251
+ # Run lm_eval via wrapper script
252
+ cmd = ["python3", wrapper_script]
 
 
 
 
 
 
 
 
253
 
254
  status_table = self._create_status_table(tasks_to_run, "πŸ”„ Running")
255
+ logs += f"πŸ”„ **Running lm_eval...**\n\nTasks: {task_string}\n\n"
256
  logs += "---\n\n### πŸ“œ Live Logs (last 15 lines):\n\n```\n"
257
  yield status_table, logs
258
 
259
+ # Run evaluation
260
  process = subprocess.Popen(
261
  cmd,
262
  stdout=subprocess.PIPE,
263
  stderr=subprocess.STDOUT,
264
  text=True,
265
+ bufsize=1
 
266
  )
267
 
268
  output_lines = []