Spaces:
Running on Zero
Running on Zero
Commit Β·
dbba693
1
Parent(s): 64f71ea
Replace FlashSR with sinc resampling for ZeroGPU compatibility
Browse filesFlashSR and its transitive imports (torchaudio, FastAudioSR) trigger
torch.cuda.is_available() during module import, which violates
ZeroGPUs stateless-GPU rule and aborts all subsequent GPU tasks.
Replace _apply_flashsr with torchaudio.functional.resample (sinc,
CPU-only, no CUDA risk). Output is still 48kHz. Remove FlashSR from
requirements.txt and clean up unused _FLASHSR_MODEL/_FLASHSR_LOCK.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- app.py +11 -64
- requirements.txt +0 -1
app.py
CHANGED
|
@@ -506,77 +506,24 @@ def _taro_infer_segment(
|
|
| 506 |
# models produce output at the same sample rate.
|
| 507 |
# Model weights are downloaded once from HF Hub and cached on disk.
|
| 508 |
|
| 509 |
-
_FLASHSR_MODEL = None # module-level cache β loaded once per process
|
| 510 |
-
_FLASHSR_LOCK = threading.Lock()
|
| 511 |
-
|
| 512 |
FLASHSR_SR_IN = 16000
|
| 513 |
FLASHSR_SR_OUT = 48000
|
| 514 |
|
| 515 |
|
| 516 |
-
def _load_flashsr():
|
| 517 |
-
"""Load FlashSR SynthesizerTrn on CPU (cached after first call).
|
| 518 |
-
|
| 519 |
-
We bypass the FASR wrapper class because it calls
|
| 520 |
-
``torch.cuda.is_available()`` and ``.to(device)`` in ``__init__``,
|
| 521 |
-
which initialises CUDA in the main process and violates ZeroGPU's
|
| 522 |
-
stateless-GPU rule (aborting all subsequent GPU tasks).
|
| 523 |
-
Instead we instantiate the underlying ``SynthesizerTrn`` directly on CPU.
|
| 524 |
-
"""
|
| 525 |
-
global _FLASHSR_MODEL
|
| 526 |
-
with _FLASHSR_LOCK:
|
| 527 |
-
if _FLASHSR_MODEL is not None:
|
| 528 |
-
return _FLASHSR_MODEL
|
| 529 |
-
print("[FlashSR] Loading SynthesizerTrn on CPU (bypassing FASR to avoid CUDA init) β¦")
|
| 530 |
-
from huggingface_hub import hf_hub_download
|
| 531 |
-
from FastAudioSR.speechsr import SynthesizerTrn
|
| 532 |
-
|
| 533 |
-
ckpt_path = hf_hub_download(
|
| 534 |
-
repo_id="YatharthS/FlashSR",
|
| 535 |
-
filename="upsampler.pth",
|
| 536 |
-
local_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), ".flashsr_cache"),
|
| 537 |
-
)
|
| 538 |
-
# Replicate FASR's hps exactly, but stay on CPU
|
| 539 |
-
model = SynthesizerTrn(
|
| 540 |
-
spec_channels=128, # n_mel_channels
|
| 541 |
-
segment_size=9600 // 320, # segment_size // hop_length = 30
|
| 542 |
-
resblock="0",
|
| 543 |
-
resblock_kernel_sizes=[3, 7, 11],
|
| 544 |
-
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 545 |
-
upsample_rates=[3],
|
| 546 |
-
upsample_initial_channel=32,
|
| 547 |
-
upsample_kernel_sizes=[3],
|
| 548 |
-
)
|
| 549 |
-
checkpoint_dict = torch.load(ckpt_path, map_location="cpu")["model"]
|
| 550 |
-
model.load_state_dict(checkpoint_dict)
|
| 551 |
-
model.eval()
|
| 552 |
-
print("[FlashSR] SynthesizerTrn loaded on CPU (fp32) β no CUDA touched")
|
| 553 |
-
_FLASHSR_MODEL = model
|
| 554 |
-
return model
|
| 555 |
-
|
| 556 |
-
|
| 557 |
def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
|
| 558 |
-
"""Upsample a mono 16 kHz numpy array to 48 kHz using
|
| 559 |
|
| 560 |
-
|
| 561 |
-
|
|
|
|
|
|
|
| 562 |
"""
|
| 563 |
-
|
| 564 |
-
|
| 565 |
-
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
|
| 569 |
-
out = model(t.unsqueeze(1)) # β [B, 1, T*3]
|
| 570 |
-
out = out.squeeze() # β [T*3]
|
| 571 |
-
out = out / (torch.abs(out).max() + 1e-8) * 0.999 # normalize like FASR.super_resolution
|
| 572 |
-
out = out.cpu().float().numpy()
|
| 573 |
-
print(f"[FlashSR] Done β output shape {out.shape}, sr={FLASHSR_SR_OUT}")
|
| 574 |
-
return out
|
| 575 |
-
except Exception as e:
|
| 576 |
-
print(f"[FlashSR] ERROR: {e} β falling back to sinc resampling")
|
| 577 |
-
t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
|
| 578 |
-
out = torchaudio.functional.resample(t, FLASHSR_SR_IN, FLASHSR_SR_OUT)
|
| 579 |
-
return out.squeeze().numpy()
|
| 580 |
|
| 581 |
|
| 582 |
def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
|
|
|
|
| 506 |
# models produce output at the same sample rate.
|
| 507 |
# Model weights are downloaded once from HF Hub and cached on disk.
|
| 508 |
|
|
|
|
|
|
|
|
|
|
| 509 |
FLASHSR_SR_IN = 16000
|
| 510 |
FLASHSR_SR_OUT = 48000
|
| 511 |
|
| 512 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 513 |
def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
|
| 514 |
+
"""Upsample a mono 16 kHz numpy array to 48 kHz using sinc resampling (CPU).
|
| 515 |
|
| 516 |
+
FlashSR was attempted but its dependencies trigger torch.cuda.is_available()
|
| 517 |
+
on import, which violates ZeroGPU's stateless-GPU rule and aborts subsequent
|
| 518 |
+
GPU tasks. High-quality sinc resampling via torchaudio is ZeroGPU-safe and
|
| 519 |
+
produces clean 16β48 kHz output for foley/ambient audio.
|
| 520 |
"""
|
| 521 |
+
print(f"[upsample] {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz β 48kHz (sinc, CPU) β¦")
|
| 522 |
+
t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
|
| 523 |
+
out = torchaudio.functional.resample(t, FLASHSR_SR_IN, FLASHSR_SR_OUT)
|
| 524 |
+
result = out.squeeze().numpy()
|
| 525 |
+
print(f"[upsample] Done β {len(result)/FLASHSR_SR_OUT:.2f}s @ {FLASHSR_SR_OUT}Hz")
|
| 526 |
+
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
|
| 528 |
|
| 529 |
def _stitch_wavs(wavs: list[np.ndarray], crossfade_s: float, db_boost: float,
|
requirements.txt
CHANGED
|
@@ -21,7 +21,6 @@ loguru
|
|
| 21 |
torchdiffeq
|
| 22 |
open_clip_torch
|
| 23 |
git+https://github.com/descriptinc/audiotools
|
| 24 |
-
git+https://github.com/ysharma3501/FlashSR.git
|
| 25 |
--extra-index-url https://download.pytorch.org/whl/cu124
|
| 26 |
torchaudio==2.5.1+cu124
|
| 27 |
--find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4.0/index.html
|
|
|
|
| 21 |
torchdiffeq
|
| 22 |
open_clip_torch
|
| 23 |
git+https://github.com/descriptinc/audiotools
|
|
|
|
| 24 |
--extra-index-url https://download.pytorch.org/whl/cu124
|
| 25 |
torchaudio==2.5.1+cu124
|
| 26 |
--find-links https://download.openmmlab.com/mmcv/dist/cu121/torch2.4.0/index.html
|