BoxOfColors Claude Opus 4.6 commited on
Commit
dbba693
Β·
1 Parent(s): 64f71ea

Replace FlashSR with sinc resampling for ZeroGPU compatibility

Browse files

FlashSR and its transitive imports (torchaudio, FastAudioSR) trigger
torch.cuda.is_available() during module import, which violates
ZeroGPUs stateless-GPU rule and aborts all subsequent GPU tasks.

Replace _apply_flashsr with torchaudio.functional.resample (sinc,
CPU-only, no CUDA risk). Output is still 48kHz. Remove FlashSR from
requirements.txt and clean up unused _FLASHSR_MODEL/_FLASHSR_LOCK.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. app.py +11 -64
  2. requirements.txt +0 -1
app.py CHANGED
@@ -506,77 +506,24 @@ def _taro_infer_segment(
506
  # models produce output at the same sample rate.
507
  # Model weights are downloaded once from HF Hub and cached on disk.
508
 
509
- _FLASHSR_MODEL = None # module-level cache β€” loaded once per process
510
- _FLASHSR_LOCK = threading.Lock()
511
-
512
  FLASHSR_SR_IN = 16000
513
  FLASHSR_SR_OUT = 48000
514
 
515
 
516
- def _load_flashsr():
517
- """Load FlashSR SynthesizerTrn on CPU (cached after first call).
518
-
519
- We bypass the FASR wrapper class because it calls
520
- ``torch.cuda.is_available()`` and ``.to(device)`` in ``__init__``,
521
- which initialises CUDA in the main process and violates ZeroGPU's
522
- stateless-GPU rule (aborting all subsequent GPU tasks).
523
- Instead we instantiate the underlying ``SynthesizerTrn`` directly on CPU.
524
- """
525
- global _FLASHSR_MODEL
526
- with _FLASHSR_LOCK:
527
- if _FLASHSR_MODEL is not None:
528
- return _FLASHSR_MODEL
529
- print("[FlashSR] Loading SynthesizerTrn on CPU (bypassing FASR to avoid CUDA init) …")
530
- from huggingface_hub import hf_hub_download
531
- from FastAudioSR.speechsr import SynthesizerTrn
532
-
533
- ckpt_path = hf_hub_download(
534
- repo_id="YatharthS/FlashSR",
535
- filename="upsampler.pth",
536
- local_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), ".flashsr_cache"),
537
- )
538
- # Replicate FASR's hps exactly, but stay on CPU
539
- model = SynthesizerTrn(
540
- spec_channels=128, # n_mel_channels
541
- segment_size=9600 // 320, # segment_size // hop_length = 30
542
- resblock="0",
543
- resblock_kernel_sizes=[3, 7, 11],
544
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
545
- upsample_rates=[3],
546
- upsample_initial_channel=32,
547
- upsample_kernel_sizes=[3],
548
- )
549
- checkpoint_dict = torch.load(ckpt_path, map_location="cpu")["model"]
550
- model.load_state_dict(checkpoint_dict)
551
- model.eval()
552
- print("[FlashSR] SynthesizerTrn loaded on CPU (fp32) β€” no CUDA touched")
553
- _FLASHSR_MODEL = model
554
- return model
555
-
556
-
557
  def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
558
- """Upsample a mono 16 kHz numpy array to 48 kHz using FlashSR (CPU).
559
 
560
- Returns a mono float32 numpy array at 48 kHz.
561
- Falls back to torchaudio sinc resampling if FlashSR fails.
 
 
562
  """
563
- try:
564
- model = _load_flashsr()
565
- t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0) # [1, T]
566
- print(f"[FlashSR] Upsampling {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz β†’ 48kHz (CPU) …")
567
- with torch.no_grad():
568
- # SynthesizerTrn.forward expects [B, 1, T] β€” add channel dim
569
- out = model(t.unsqueeze(1)) # β†’ [B, 1, T*3]
570
- out = out.squeeze() # β†’ [T*3]
571
- out = out / (torch.abs(out).max() + 1e-8) * 0.999 # normalize like FASR.super_resolution
572
- out = out.cpu().float().numpy()
573
- print(f"[FlashSR] Done β€” output shape {out.shape}, sr={FLASHSR_SR_OUT}")
574
- return out
575
- except Exception as e:
576
- print(f"[FlashSR] ERROR: {e} β€” falling back to sinc resampling")
577
- t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
578
- out = torchaudio.functional.resample(t, FLASHSR_SR_IN, FLASHSR_SR_OUT)
579
- return out.squeeze().numpy()
580
 
581
 
582
  def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
 
506
  # models produce output at the same sample rate.
507
  # Model weights are downloaded once from HF Hub and cached on disk.
508
 
 
 
 
509
  FLASHSR_SR_IN = 16000
510
  FLASHSR_SR_OUT = 48000
511
 
512
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
513
  def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
514
+ """Upsample a mono 16 kHz numpy array to 48 kHz using sinc resampling (CPU).
515
 
516
+ FlashSR was attempted but its dependencies trigger torch.cuda.is_available()
517
+ on import, which violates ZeroGPU's stateless-GPU rule and aborts subsequent
518
+ GPU tasks. High-quality sinc resampling via torchaudio is ZeroGPU-safe and
519
+ produces clean 16β†’48 kHz output for foley/ambient audio.
520
  """
521
+ print(f"[upsample] {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz β†’ 48kHz (sinc, CPU) …")
522
+ t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
523
+ out = torchaudio.functional.resample(t, FLASHSR_SR_IN, FLASHSR_SR_OUT)
524
+ result = out.squeeze().numpy()
525
+ print(f"[upsample] Done β€” {len(result)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
526
+ return result
 
 
 
 
 
 
 
 
 
 
 
527
 
528
 
529
  def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
requirements.txt CHANGED
@@ -21,7 +21,6 @@ loguru
21
  torchdiffeq
22
  open_clip_torch
23
  git+https://github.com/descriptinc/audiotools
24
- git+https://github.com/ysharma3501/FlashSR.git
25
  --extra-index-url https://download.pytorch.org/whl/cu124
26
  torchaudio==2.5.1+cu124
27
  --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4.0/index.html
 
21
  torchdiffeq
22
  open_clip_torch
23
  git+https://github.com/descriptinc/audiotools
 
24
  --extra-index-url https://download.pytorch.org/whl/cu124
25
  torchaudio==2.5.1+cu124
26
  --find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4.0/index.html