BoxOfColors Claude Opus 4.6 commited on
Commit
26e8bca
Β·
1 Parent(s): ef4f0ff

Fix FlashSR CUDA init: bypass FASR class, load SynthesizerTrn on CPU

Browse files

FASR.__init__ calls torch.cuda.is_available() and .to(device),
which initializes CUDA in the main process and violates ZeroGPU
stateless-GPU rule β€” aborting all subsequent GPU tasks.

Now we load SynthesizerTrn directly on CPU, replicating the same
hyperparams and normalization that FASR uses, without touching CUDA.
This allows FlashSR to run safely in the CPU post-processing step
outside @spaces.GPU, saving GPU quota per segment.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +32 -15
app.py CHANGED
@@ -514,23 +514,42 @@ FLASHSR_SR_OUT = 48000
514
 
515
 
516
  def _load_flashsr():
517
- """Load FlashSR model (cached after first call). Returns FASR instance."""
 
 
 
 
 
 
 
518
  global _FLASHSR_MODEL
519
  with _FLASHSR_LOCK:
520
  if _FLASHSR_MODEL is not None:
521
  return _FLASHSR_MODEL
522
- print("[FlashSR] Loading model weights from HF Hub …")
523
  from huggingface_hub import hf_hub_download
524
- from FastAudioSR import FASR
 
525
  ckpt_path = hf_hub_download(
526
  repo_id="YatharthS/FlashSR",
527
  filename="upsampler.pth",
528
  local_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), ".flashsr_cache"),
529
  )
530
- # Always load on CPU β€” ZeroGPU forbids CUDA init outside @spaces.GPU.
531
- # FlashSR is tiny (1.72 MB) and fast enough on CPU for post-processing.
532
- model = FASR(ckpt_path)
533
- print("[FlashSR] Model loaded on CPU (fp32)")
 
 
 
 
 
 
 
 
 
 
 
534
  _FLASHSR_MODEL = model
535
  return model
536
 
@@ -543,16 +562,14 @@ def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
543
  """
544
  try:
545
  model = _load_flashsr()
546
- # Keep on CPU β€” no CUDA outside @spaces.GPU in ZeroGPU environment
547
- t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
548
  print(f"[FlashSR] Upsampling {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz β†’ 48kHz (CPU) …")
549
  with torch.no_grad():
550
- out = model.run(t)
551
- # out is a tensor or numpy array β€” normalise to numpy float32 cpu
552
- if isinstance(out, torch.Tensor):
553
- out = out.float().cpu().squeeze().numpy()
554
- else:
555
- out = np.array(out, dtype=np.float32).squeeze()
556
  print(f"[FlashSR] Done β€” output shape {out.shape}, sr={FLASHSR_SR_OUT}")
557
  return out
558
  except Exception as e:
 
514
 
515
 
516
  def _load_flashsr():
517
+ """Load FlashSR SynthesizerTrn on CPU (cached after first call).
518
+
519
+ We bypass the FASR wrapper class because it calls
520
+ ``torch.cuda.is_available()`` and ``.to(device)`` in ``__init__``,
521
+ which initialises CUDA in the main process and violates ZeroGPU's
522
+ stateless-GPU rule (aborting all subsequent GPU tasks).
523
+ Instead we instantiate the underlying ``SynthesizerTrn`` directly on CPU.
524
+ """
525
  global _FLASHSR_MODEL
526
  with _FLASHSR_LOCK:
527
  if _FLASHSR_MODEL is not None:
528
  return _FLASHSR_MODEL
529
+ print("[FlashSR] Loading SynthesizerTrn on CPU (bypassing FASR to avoid CUDA init) …")
530
  from huggingface_hub import hf_hub_download
531
+ from FastAudioSR.speechsr import SynthesizerTrn
532
+
533
  ckpt_path = hf_hub_download(
534
  repo_id="YatharthS/FlashSR",
535
  filename="upsampler.pth",
536
  local_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), ".flashsr_cache"),
537
  )
538
+ # Replicate FASR's hps exactly, but stay on CPU
539
+ model = SynthesizerTrn(
540
+ spec_channels=128, # n_mel_channels
541
+ segment_size=9600 // 320, # segment_size // hop_length = 30
542
+ resblock="0",
543
+ resblock_kernel_sizes=[3, 7, 11],
544
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
545
+ upsample_rates=[3],
546
+ upsample_initial_channel=32,
547
+ upsample_kernel_sizes=[3],
548
+ )
549
+ checkpoint_dict = torch.load(ckpt_path, map_location="cpu")["model"]
550
+ model.load_state_dict(checkpoint_dict)
551
+ model.eval()
552
+ print("[FlashSR] SynthesizerTrn loaded on CPU (fp32) β€” no CUDA touched")
553
  _FLASHSR_MODEL = model
554
  return model
555
 
 
562
  """
563
  try:
564
  model = _load_flashsr()
565
+ t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0) # [1, T]
 
566
  print(f"[FlashSR] Upsampling {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz β†’ 48kHz (CPU) …")
567
  with torch.no_grad():
568
+ # SynthesizerTrn.forward expects [B, 1, T] β€” add channel dim
569
+ out = model(t.unsqueeze(1)) # β†’ [B, 1, T*3]
570
+ out = out.squeeze() # β†’ [T*3]
571
+ out = out / (torch.abs(out).max() + 1e-8) * 0.999 # normalize like FASR.super_resolution
572
+ out = out.cpu().float().numpy()
 
573
  print(f"[FlashSR] Done β€” output shape {out.shape}, sr={FLASHSR_SR_OUT}")
574
  return out
575
  except Exception as e: