Spaces:
Running on Zero
Running on Zero
Commit Β·
26e8bca
1
Parent(s): ef4f0ff
Fix FlashSR CUDA init: bypass FASR class, load SynthesizerTrn on CPU
Browse filesFASR.__init__ calls torch.cuda.is_available() and .to(device),
which initializes CUDA in the main process and violates ZeroGPU
stateless-GPU rule β aborting all subsequent GPU tasks.
Now we load SynthesizerTrn directly on CPU, replicating the same
hyperparams and normalization that FASR uses, without touching CUDA.
This allows FlashSR to run safely in the CPU post-processing step
outside @spaces.GPU, saving GPU quota per segment.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
app.py
CHANGED
|
@@ -514,23 +514,42 @@ FLASHSR_SR_OUT = 48000
|
|
| 514 |
|
| 515 |
|
| 516 |
def _load_flashsr():
|
| 517 |
-
"""Load FlashSR
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
global _FLASHSR_MODEL
|
| 519 |
with _FLASHSR_LOCK:
|
| 520 |
if _FLASHSR_MODEL is not None:
|
| 521 |
return _FLASHSR_MODEL
|
| 522 |
-
print("[FlashSR] Loading
|
| 523 |
from huggingface_hub import hf_hub_download
|
| 524 |
-
from FastAudioSR import
|
|
|
|
| 525 |
ckpt_path = hf_hub_download(
|
| 526 |
repo_id="YatharthS/FlashSR",
|
| 527 |
filename="upsampler.pth",
|
| 528 |
local_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), ".flashsr_cache"),
|
| 529 |
)
|
| 530 |
-
#
|
| 531 |
-
|
| 532 |
-
|
| 533 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
_FLASHSR_MODEL = model
|
| 535 |
return model
|
| 536 |
|
|
@@ -543,16 +562,14 @@ def _apply_flashsr(wav_16k: np.ndarray) -> np.ndarray:
|
|
| 543 |
"""
|
| 544 |
try:
|
| 545 |
model = _load_flashsr()
|
| 546 |
-
|
| 547 |
-
t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0)
|
| 548 |
print(f"[FlashSR] Upsampling {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz β 48kHz (CPU) β¦")
|
| 549 |
with torch.no_grad():
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
|
| 554 |
-
|
| 555 |
-
out = np.array(out, dtype=np.float32).squeeze()
|
| 556 |
print(f"[FlashSR] Done β output shape {out.shape}, sr={FLASHSR_SR_OUT}")
|
| 557 |
return out
|
| 558 |
except Exception as e:
|
|
|
|
| 514 |
|
| 515 |
|
| 516 |
def _load_flashsr():
|
| 517 |
+
"""Load FlashSR SynthesizerTrn on CPU (cached after first call).
|
| 518 |
+
|
| 519 |
+
We bypass the FASR wrapper class because it calls
|
| 520 |
+
``torch.cuda.is_available()`` and ``.to(device)`` in ``__init__``,
|
| 521 |
+
which initialises CUDA in the main process and violates ZeroGPU's
|
| 522 |
+
stateless-GPU rule (aborting all subsequent GPU tasks).
|
| 523 |
+
Instead we instantiate the underlying ``SynthesizerTrn`` directly on CPU.
|
| 524 |
+
"""
|
| 525 |
global _FLASHSR_MODEL
|
| 526 |
with _FLASHSR_LOCK:
|
| 527 |
if _FLASHSR_MODEL is not None:
|
| 528 |
return _FLASHSR_MODEL
|
| 529 |
+
print("[FlashSR] Loading SynthesizerTrn on CPU (bypassing FASR to avoid CUDA init) β¦")
|
| 530 |
from huggingface_hub import hf_hub_download
|
| 531 |
+
from FastAudioSR.speechsr import SynthesizerTrn
|
| 532 |
+
|
| 533 |
ckpt_path = hf_hub_download(
|
| 534 |
repo_id="YatharthS/FlashSR",
|
| 535 |
filename="upsampler.pth",
|
| 536 |
local_dir=os.path.join(os.path.dirname(os.path.abspath(__file__)), ".flashsr_cache"),
|
| 537 |
)
|
| 538 |
+
# Replicate FASR's hps exactly, but stay on CPU
|
| 539 |
+
model = SynthesizerTrn(
|
| 540 |
+
spec_channels=128, # n_mel_channels
|
| 541 |
+
segment_size=9600 // 320, # segment_size // hop_length = 30
|
| 542 |
+
resblock="0",
|
| 543 |
+
resblock_kernel_sizes=[3, 7, 11],
|
| 544 |
+
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
| 545 |
+
upsample_rates=[3],
|
| 546 |
+
upsample_initial_channel=32,
|
| 547 |
+
upsample_kernel_sizes=[3],
|
| 548 |
+
)
|
| 549 |
+
checkpoint_dict = torch.load(ckpt_path, map_location="cpu")["model"]
|
| 550 |
+
model.load_state_dict(checkpoint_dict)
|
| 551 |
+
model.eval()
|
| 552 |
+
print("[FlashSR] SynthesizerTrn loaded on CPU (fp32) β no CUDA touched")
|
| 553 |
_FLASHSR_MODEL = model
|
| 554 |
return model
|
| 555 |
|
|
|
|
| 562 |
"""
|
| 563 |
try:
|
| 564 |
model = _load_flashsr()
|
| 565 |
+
t = torch.from_numpy(wav_16k.astype(np.float32)).unsqueeze(0) # [1, T]
|
|
|
|
| 566 |
print(f"[FlashSR] Upsampling {len(wav_16k)/FLASHSR_SR_IN:.2f}s @ 16kHz β 48kHz (CPU) β¦")
|
| 567 |
with torch.no_grad():
|
| 568 |
+
# SynthesizerTrn.forward expects [B, 1, T] β add channel dim
|
| 569 |
+
out = model(t.unsqueeze(1)) # β [B, 1, T*3]
|
| 570 |
+
out = out.squeeze() # β [T*3]
|
| 571 |
+
out = out / (torch.abs(out).max() + 1e-8) * 0.999 # normalize like FASR.super_resolution
|
| 572 |
+
out = out.cpu().float().numpy()
|
|
|
|
| 573 |
print(f"[FlashSR] Done β output shape {out.shape}, sr={FLASHSR_SR_OUT}")
|
| 574 |
return out
|
| 575 |
except Exception as e:
|