Kuangwei Chen commited on
Commit
16ec632
·
1 Parent(s): c828a64

try to fix 没有可用的音频 I/O backend。

Browse files
requirements.txt CHANGED
@@ -4,4 +4,5 @@ torch==2.7.0
4
  torchaudio==2.7.0
5
  transformers==4.57.1
6
  safetensors>=0.4.3
 
7
  gradio==6.5.1
 
4
  torchaudio==2.7.0
5
  transformers==4.57.1
6
  safetensors>=0.4.3
7
+ soundfile>=0.13.1
8
  gradio==6.5.1
weights/tts/modeling_nanotts_global_local.py CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
11
  from typing import Any, Iterator, Optional, Sequence, Union
12
 
13
  import numpy as np
 
14
  import torch
15
  import torch.nn as nn
16
  import torchaudio
@@ -1573,7 +1574,17 @@ class NanoTTSGlobalLocalForCausalLM(NanoTTSPreTrainedModel):
1573
  target_sample_rate: int,
1574
  target_channels: int,
1575
  ) -> tuple[torch.FloatTensor, int]:
1576
- waveform, sample_rate = torchaudio.load(str(reference_audio_path))
 
 
 
 
 
 
 
 
 
 
1577
  waveform = waveform.to(torch.float32)
1578
  if sample_rate != target_sample_rate:
1579
  waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
@@ -1587,6 +1598,25 @@ class NanoTTSGlobalLocalForCausalLM(NanoTTSPreTrainedModel):
1587
  return waveform.mean(dim=0, keepdim=True), sample_rate
1588
  raise ValueError(f"Unsupported reference audio channel conversion: {current_channels} -> {target_channels}")
1589
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1590
  def _decode_local_last_hidden_state(
1591
  self,
1592
  local_inputs_embeds: torch.FloatTensor,
@@ -2193,8 +2223,7 @@ class NanoTTSGlobalLocalForCausalLM(NanoTTSPreTrainedModel):
2193
 
2194
  decoded_sample_rate = decoded_sample_rate or target_sample_rate
2195
  output_path = Path(output_audio_path)
2196
- output_path.parent.mkdir(parents=True, exist_ok=True)
2197
- torchaudio.save(str(output_path), waveform, decoded_sample_rate)
2198
 
2199
  yield {
2200
  "type": "result",
@@ -2428,8 +2457,7 @@ class NanoTTSGlobalLocalForCausalLM(NanoTTSPreTrainedModel):
2428
  assert decoded_sample_rate is not None
2429
 
2430
  output_path = Path(output_audio_path)
2431
- output_path.parent.mkdir(parents=True, exist_ok=True)
2432
- torchaudio.save(str(output_path), waveform, decoded_sample_rate)
2433
 
2434
  if was_training:
2435
  self.train()
 
11
  from typing import Any, Iterator, Optional, Sequence, Union
12
 
13
  import numpy as np
14
+ import soundfile as sf
15
  import torch
16
  import torch.nn as nn
17
  import torchaudio
 
1574
  target_sample_rate: int,
1575
  target_channels: int,
1576
  ) -> tuple[torch.FloatTensor, int]:
1577
+ try:
1578
+ waveform, sample_rate = torchaudio.load(str(reference_audio_path))
1579
+ except RuntimeError as exc:
1580
+ logging.warning(
1581
+ "torchaudio.load failed for %s; falling back to soundfile",
1582
+ reference_audio_path,
1583
+ exc_info=True,
1584
+ )
1585
+ audio_array, sample_rate = sf.read(str(reference_audio_path), dtype="float32", always_2d=True)
1586
+ waveform = torch.from_numpy(audio_array.T).contiguous()
1587
+
1588
  waveform = waveform.to(torch.float32)
1589
  if sample_rate != target_sample_rate:
1590
  waveform = torchaudio.functional.resample(waveform, sample_rate, target_sample_rate)
 
1598
  return waveform.mean(dim=0, keepdim=True), sample_rate
1599
  raise ValueError(f"Unsupported reference audio channel conversion: {current_channels} -> {target_channels}")
1600
 
1601
+ @staticmethod
1602
+ def _save_audio(
1603
+ output_path: Union[str, Path],
1604
+ waveform: torch.Tensor,
1605
+ sample_rate: int,
1606
+ ) -> None:
1607
+ path = Path(output_path)
1608
+ path.parent.mkdir(parents=True, exist_ok=True)
1609
+ try:
1610
+ torchaudio.save(str(path), waveform, sample_rate)
1611
+ except RuntimeError:
1612
+ logging.warning(
1613
+ "torchaudio.save failed for %s; falling back to soundfile",
1614
+ path,
1615
+ exc_info=True,
1616
+ )
1617
+ waveform_np = waveform.detach().cpu().to(torch.float32).numpy().T
1618
+ sf.write(str(path), waveform_np, sample_rate)
1619
+
1620
  def _decode_local_last_hidden_state(
1621
  self,
1622
  local_inputs_embeds: torch.FloatTensor,
 
2223
 
2224
  decoded_sample_rate = decoded_sample_rate or target_sample_rate
2225
  output_path = Path(output_audio_path)
2226
+ self._save_audio(output_path, waveform, decoded_sample_rate)
 
2227
 
2228
  yield {
2229
  "type": "result",
 
2457
  assert decoded_sample_rate is not None
2458
 
2459
  output_path = Path(output_audio_path)
2460
+ self._save_audio(output_path, waveform, decoded_sample_rate)
 
2461
 
2462
  if was_training:
2463
  self.train()