Nekochu commited on
Commit
9ed24c7
·
1 Parent(s): 13f9406

fix meta tensor crash: force low_cpu_mem_usage=False and float32 for CPU

Browse files
Files changed (1) hide show
  1. app.py +9 -0
app.py CHANGED
@@ -339,6 +339,15 @@ try:
339
  torch.backends.cuda.enable_flash_sdp(False)
340
  os.environ["ATTN_BACKEND"] = "sdpa"
341
 
 
 
 
 
 
 
 
 
 
342
  import torchaudio
343
  _orig = torchaudio.load
344
  def _sf(p, *a, **kw):
 
339
  torch.backends.cuda.enable_flash_sdp(False)
340
  os.environ["ATTN_BACKEND"] = "sdpa"
341
 
342
+ import transformers
343
+ _orig_from_pretrained = transformers.AutoModel.from_pretrained
344
+ def _cpu_from_pretrained(*args, **kwargs):
345
+ kwargs['low_cpu_mem_usage'] = False
346
+ kwargs.setdefault('torch_dtype', torch.float32)
347
+ return _orig_from_pretrained(*args, **kwargs)
348
+ transformers.AutoModel.from_pretrained = _cpu_from_pretrained
349
+ log(" Patched AutoModel.from_pretrained: low_cpu_mem_usage=False, dtype=float32")
350
+
351
  import torchaudio
352
  _orig = torchaudio.load
353
  def _sf(p, *a, **kw):