danielr-ceva commited on
Commit
03cd46e
·
verified ·
1 Parent(s): 4af1e8c

Update run_tflite.py

Browse files
Files changed (1) hide show
  1. run_tflite.py +164 -58
run_tflite.py CHANGED
@@ -1,4 +1,5 @@
1
  import argparse
 
2
  from pathlib import Path
3
  import sys
4
 
@@ -8,12 +9,28 @@ import librosa
8
  from tflite_runtime.interpreter import Interpreter
9
  from tqdm import tqdm
10
 
11
-
12
  TFLITE_DIR = Path('./')
13
 
14
- # ===== STFT / iSTFT params (as in the snippet) =====
15
- WIN_LEN = 320 # 16 kHz: 320
16
- HOP_SIZE = WIN_LEN // 2 # 50% hop
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
 
19
  def vorbis_window(window_len: int) -> np.ndarray:
@@ -29,35 +46,51 @@ def get_wnorm(window_len: int, frame_size: int) -> float:
29
  return 1.0 / (window_len ** 2 / (2 * frame_size))
30
 
31
 
32
- # ---------- Pre/Post processing ----------
33
- _WIN = vorbis_window(WIN_LEN)
34
- _WNORM = get_wnorm(WIN_LEN, HOP_SIZE)
 
 
 
 
 
 
 
 
 
 
 
35
 
36
 
37
- def preprocessing(waveform_16k: np.ndarray) -> np.ndarray:
 
 
 
 
38
  """
39
- waveform_16k: 1D float32 numpy array at 16 kHz, mono, range ~[-1,1]
40
  Returns complex STFT as real/imag split: [B=1, T, F, 2] float32
41
  """
42
  # Librosa returns [F, T]; match original by using center=False here
43
  spec = librosa.stft(
44
- y=waveform_16k.astype(np.float32, copy=False),
45
- n_fft=WIN_LEN,
46
- hop_length=HOP_SIZE,
47
- win_length=WIN_LEN,
48
- window=_WIN,
49
- center=False,
50
- pad_mode="reflect"
51
  ) # [F, T] complex64
52
- spec = (spec.T * _WNORM).astype(np.complex64) # [T, F]
 
53
  spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32) # [T, F, 2]
54
  return spec_ri[None, ...] # [1, T, F, 2]
55
 
56
 
57
- def postprocessing(spec_e: np.ndarray) -> np.ndarray:
58
  """
59
  spec_e: [1, T, F, 2] float32
60
- Returns waveform (1D float32, 16 kHz)
61
  """
62
  # Recreate complex STFT with shape [F, T]
63
  spec_c = spec_e[0].astype(np.float32) # [T, F, 2]
@@ -65,19 +98,26 @@ def postprocessing(spec_e: np.ndarray) -> np.ndarray:
65
 
66
  waveform_e = librosa.istft(
67
  spec,
68
- hop_length=HOP_SIZE,
69
- win_length=WIN_LEN,
70
- window=_WIN,
71
  center=True,
72
  length=None,
73
  ).astype(np.float32)
74
 
75
- waveform_e = waveform_e / _WNORM
76
- waveform_e = np.concatenate([waveform_e[WIN_LEN * 2:], np.zeros(WIN_LEN * 2, dtype=np.float32)])
 
 
 
 
77
  return waveform_e.astype(np.float32)
78
 
79
 
80
- # ---------- Audio utilities ----------
 
 
 
81
  def to_mono(audio: np.ndarray) -> np.ndarray:
82
  if audio.ndim == 1:
83
  return audio
@@ -85,16 +125,22 @@ def to_mono(audio: np.ndarray) -> np.ndarray:
85
  return np.mean(audio, axis=1)
86
 
87
 
88
- def ensure_16k(waveform: np.ndarray, sr: int, target_sr: int = 16000) -> np.ndarray:
89
  if sr == target_sr:
90
  return waveform.astype(np.float32, copy=False)
91
- return librosa.resample(waveform.astype(np.float32, copy=False), orig_sr=sr, target_sr=target_sr)
 
 
92
 
93
 
94
- def resample_back(waveform_16k: np.ndarray, target_sr: int) -> np.ndarray:
95
- if target_sr == 16000:
96
- return waveform_16k
97
- return librosa.resample(waveform_16k.astype(np.float32, copy=False), orig_sr=16000, target_sr=target_sr)
 
 
 
 
98
 
99
 
100
  def pcm16_safe(x: np.ndarray) -> np.ndarray:
@@ -102,32 +148,72 @@ def pcm16_safe(x: np.ndarray) -> np.ndarray:
102
  return (x * 32767.0).astype(np.int16)
103
 
104
 
105
- # ---------- Core processing ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  def enhance_file(in_path: Path, out_path: Path, model_name: str) -> None:
107
  # Load audio
108
  audio, sr_in = sf.read(str(in_path), always_2d=False)
109
  audio = to_mono(audio)
110
-
111
- # Convert dtypes and resample to 16k for the model
112
  audio = audio.astype(np.float32, copy=False)
113
- audio_16k = ensure_16k(audio, sr_in, 16000)
114
 
115
- # STFT to frames (streaming)
116
- spec = preprocessing(audio_16k) # [1, T, F, 2]
117
- num_frames = spec.shape[1]
118
-
119
- # New interpreter per file ensures stateful models (RNN/LSTM) start clean
120
- interpreter = Interpreter(model_path=str(TFLITE_DIR / (model_name + '.tflite')))
121
- interpreter.allocate_tensors()
122
  input_details = interpreter.get_input_details()
123
  output_details = interpreter.get_output_details()
124
 
 
 
 
 
 
 
 
 
 
 
125
  # Frame-by-frame inference
126
  outputs = []
127
-
128
  for t in tqdm(range(num_frames), desc=f"{in_path.name}", unit="frm", leave=False):
129
- frame = spec[:, t:t + 1] # [1, 1, F, 2]
130
- # Some TFLite builds are picky about contiguity/dtype
131
  frame = np.ascontiguousarray(frame, dtype=np.float32)
132
 
133
  interpreter.set_tensor(input_details[0]["index"], frame)
@@ -138,9 +224,12 @@ def enhance_file(in_path: Path, out_path: Path, model_name: str) -> None:
138
  # Concatenate along time dimension
139
  spec_e = np.concatenate(outputs, axis=1).astype(np.float32) # [1, T, F, 2]
140
 
141
- # iSTFT to waveform (16 kHz), then back to original SR for saving
142
- enhanced_16k = postprocessing(spec_e)
143
- enhanced = resample_back(enhanced_16k, sr_in)
 
 
 
144
 
145
  # Save as 16-bit PCM WAV, mono, original sample rate
146
  out_path.parent.mkdir(parents=True, exist_ok=True)
@@ -148,28 +237,42 @@ def enhance_file(in_path: Path, out_path: Path, model_name: str) -> None:
148
 
149
 
150
  def main():
151
- parser = argparse.ArgumentParser(description="Enhance WAV files with a DPDFNet TFLite model (streaming).")
152
- parser.add_argument("--noisy_dir", type=str, required=True, help="Folder with noisy *.wav files (non-recursive).")
153
- parser.add_argument("--enhanced_dir", type=str, required=True, help="Output folder for enhanced WAVs.")
 
 
 
 
 
 
 
 
 
 
 
 
154
  parser.add_argument(
155
  "--model_name",
156
  type=str,
157
  default="dpdfnet8",
158
- choices=["baseline", "dpdfnet2", "dpdfnet4", "dpdfnet8"],
159
  help=(
160
- "Name of the model to use. Options: "
161
- "'baseline', 'dpdfnet2', 'dpdfnet4', 'dpdfnet8'. "
162
- "Default is 'dpdfnet8'."
163
  ),
164
  )
165
- args = parser.parse_args()
166
 
 
167
  noisy_dir = Path(args.noisy_dir)
168
  enhanced_dir = Path(args.enhanced_dir)
169
  model_name = args.model_name
170
 
171
  if not noisy_dir.is_dir():
172
- print(f"ERROR: --noisy_dir does not exist or is not a directory: {noisy_dir}", file=sys.stderr)
 
 
 
173
  sys.exit(1)
174
 
175
  wavs = sorted(p for p in noisy_dir.glob("*.wav") if p.is_file())
@@ -177,13 +280,16 @@ def main():
177
  print(f"No .wav files found in {noisy_dir} (non-recursive).")
178
  sys.exit(0)
179
 
 
180
  print(f"Model: {model_name}")
 
 
181
  print(f"Input : {noisy_dir}")
182
  print(f"Output: {enhanced_dir}")
183
  print(f"Found {len(wavs)} file(s). Enhancing...\n")
184
 
185
  for wav in wavs:
186
- out_path = enhanced_dir / (wav.stem + f'_{model_name}.wav')
187
  try:
188
  enhance_file(wav, out_path, model_name)
189
  except Exception as e:
 
1
  import argparse
2
+ from dataclasses import dataclass
3
  from pathlib import Path
4
  import sys
5
 
 
9
  from tflite_runtime.interpreter import Interpreter
10
  from tqdm import tqdm
11
 
 
12
  TFLITE_DIR = Path('./')
13
 
14
+ # -----------------------------------------------------------------------------
15
+ # Model registry
16
+ # -----------------------------------------------------------------------------
17
+ # Each model declares the sample-rate it expects and the STFT window length
18
+ # used during training/export.
19
+ #
20
+ # 16 kHz models: WIN_LEN=320 (20 ms)
21
+ # 48 kHz models: WIN_LEN=960 (20 ms)
22
+ #
23
+ # Add your new 48 kHz model here (example key: "dpdfnet48k").
24
+ MODEL_CONFIG = {
25
+ # 16 kHz models
26
+ "baseline": {"sr": 16000, "win_len": 320},
27
+ "dpdfnet2": {"sr": 16000, "win_len": 320},
28
+ "dpdfnet4": {"sr": 16000, "win_len": 320},
29
+ "dpdfnet8": {"sr": 16000, "win_len": 320},
30
+
31
+ # 48 kHz models
32
+ "dpdfnet2_48khz_hr": {"sr": 48000, "win_len": 960},
33
+ }
34
 
35
 
36
  def vorbis_window(window_len: int) -> np.ndarray:
 
46
  return 1.0 / (window_len ** 2 / (2 * frame_size))
47
 
48
 
49
+ @dataclass(frozen=True)
50
+ class STFTConfig:
51
+ sr: int
52
+ win_len: int
53
+ hop_size: int
54
+ win: np.ndarray
55
+ wnorm: float
56
+
57
+
58
+ def make_stft_config(sr: int, win_len: int) -> STFTConfig:
59
+ hop_size = win_len // 2 # 50% hop
60
+ win = vorbis_window(win_len)
61
+ wnorm = get_wnorm(win_len, hop_size)
62
+ return STFTConfig(sr=sr, win_len=win_len, hop_size=hop_size, win=win, wnorm=wnorm)
63
 
64
 
65
+ # -----------------------------------------------------------------------------
66
+ # Pre/Post processing
67
+ # -----------------------------------------------------------------------------
68
+
69
+ def preprocessing(waveform: np.ndarray, cfg: STFTConfig) -> np.ndarray:
70
  """
71
+ waveform: 1D float32 numpy array at cfg.sr, mono, range ~[-1,1]
72
  Returns complex STFT as real/imag split: [B=1, T, F, 2] float32
73
  """
74
  # Librosa returns [F, T]; match original by using center=False here
75
  spec = librosa.stft(
76
+ y=waveform.astype(np.float32, copy=False),
77
+ n_fft=cfg.win_len,
78
+ hop_length=cfg.hop_size,
79
+ win_length=cfg.win_len,
80
+ window=cfg.win,
81
+ center=True,
82
+ pad_mode="reflect",
83
  ) # [F, T] complex64
84
+
85
+ spec = (spec.T * cfg.wnorm).astype(np.complex64) # [T, F]
86
  spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32) # [T, F, 2]
87
  return spec_ri[None, ...] # [1, T, F, 2]
88
 
89
 
90
+ def postprocessing(spec_e: np.ndarray, cfg: STFTConfig) -> np.ndarray:
91
  """
92
  spec_e: [1, T, F, 2] float32
93
+ Returns waveform (1D float32, cfg.sr)
94
  """
95
  # Recreate complex STFT with shape [F, T]
96
  spec_c = spec_e[0].astype(np.float32) # [T, F, 2]
 
98
 
99
  waveform_e = librosa.istft(
100
  spec,
101
+ hop_length=cfg.hop_size,
102
+ win_length=cfg.win_len,
103
+ window=cfg.win,
104
  center=True,
105
  length=None,
106
  ).astype(np.float32)
107
 
108
+ waveform_e = waveform_e / cfg.wnorm
109
+
110
+ # Keep the legacy alignment compensation behavior, scaled by win_len.
111
+ waveform_e = np.concatenate(
112
+ [waveform_e[cfg.win_len * 2 :], np.zeros(cfg.win_len * 2, dtype=np.float32)]
113
+ )
114
  return waveform_e.astype(np.float32)
115
 
116
 
117
+ # -----------------------------------------------------------------------------
118
+ # Audio utilities
119
+ # -----------------------------------------------------------------------------
120
+
121
  def to_mono(audio: np.ndarray) -> np.ndarray:
122
  if audio.ndim == 1:
123
  return audio
 
125
  return np.mean(audio, axis=1)
126
 
127
 
128
+ def ensure_sr(waveform: np.ndarray, sr: int, target_sr: int) -> np.ndarray:
129
  if sr == target_sr:
130
  return waveform.astype(np.float32, copy=False)
131
+ return librosa.resample(
132
+ waveform.astype(np.float32, copy=False), orig_sr=sr, target_sr=target_sr
133
+ )
134
 
135
 
136
+ def resample_back(waveform_model_sr: np.ndarray, model_sr: int, target_sr: int) -> np.ndarray:
137
+ if target_sr == model_sr:
138
+ return waveform_model_sr
139
+ return librosa.resample(
140
+ waveform_model_sr.astype(np.float32, copy=False),
141
+ orig_sr=model_sr,
142
+ target_sr=target_sr,
143
+ )
144
 
145
 
146
  def pcm16_safe(x: np.ndarray) -> np.ndarray:
 
148
  return (x * 32767.0).astype(np.int16)
149
 
150
 
151
+ # -----------------------------------------------------------------------------
152
+ # Core processing
153
+ # -----------------------------------------------------------------------------
154
+
155
+ def _load_model_and_cfg(model_name: str) -> tuple[Interpreter, STFTConfig]:
156
+ """Create interpreter and return (interpreter, STFTConfig) for this model."""
157
+ if model_name not in MODEL_CONFIG:
158
+ raise ValueError(
159
+ f"Unknown model '{model_name}'. Add it to MODEL_CONFIG or pass a valid --model_name."
160
+ )
161
+
162
+ model_path = TFLITE_DIR / f"{model_name}.tflite"
163
+ if not model_path.exists():
164
+ raise FileNotFoundError(f"TFLite model not found: {model_path}")
165
+
166
+ interpreter = Interpreter(model_path=str(model_path))
167
+ interpreter.allocate_tensors()
168
+
169
+ cfg_dict = MODEL_CONFIG[model_name]
170
+ cfg = make_stft_config(sr=int(cfg_dict["sr"]), win_len=int(cfg_dict["win_len"]))
171
+
172
+ # Optional sanity-check: infer expected F from model input and compare
173
+ try:
174
+ input_details = interpreter.get_input_details()
175
+ shape = input_details[0].get("shape", None)
176
+ # Expect [1, 1, F, 2] (or [1, T, F, 2] for non-streaming)
177
+ if shape is not None and len(shape) >= 3:
178
+ F = int(shape[-2]) # ... F, 2
179
+ expected_F = cfg.win_len // 2 + 1
180
+ if F != expected_F:
181
+ raise ValueError(
182
+ f"Model '{model_name}' input F={F} does not match win_len={cfg.win_len} "
183
+ f"(expected F={expected_F}). Update MODEL_CONFIG for this model."
184
+ )
185
+ except Exception:
186
+ # Do not hard-fail on odd/unknown shapes; the runtime error will be informative.
187
+ pass
188
+
189
+ return interpreter, cfg
190
+
191
+
192
  def enhance_file(in_path: Path, out_path: Path, model_name: str) -> None:
193
  # Load audio
194
  audio, sr_in = sf.read(str(in_path), always_2d=False)
195
  audio = to_mono(audio)
 
 
196
  audio = audio.astype(np.float32, copy=False)
 
197
 
198
+ # Load model and its expected SR/STFT config
199
+ interpreter, cfg = _load_model_and_cfg(model_name)
 
 
 
 
 
200
  input_details = interpreter.get_input_details()
201
  output_details = interpreter.get_output_details()
202
 
203
+ # Resample to model SR
204
+ audio_model_sr = ensure_sr(audio, sr_in, cfg.sr)
205
+
206
+ # Alignment compensation #1
207
+ audio_pad = np.pad(audio_model_sr, (0, cfg.win_len), mode='constant', constant_values=0)
208
+
209
+ # STFT to frames (streaming)
210
+ spec = preprocessing(audio_pad, cfg) # [1, T, F, 2]
211
+ num_frames = spec.shape[1]
212
+
213
  # Frame-by-frame inference
214
  outputs = []
 
215
  for t in tqdm(range(num_frames), desc=f"{in_path.name}", unit="frm", leave=False):
216
+ frame = spec[:, t : t + 1] # [1, 1, F, 2]
 
217
  frame = np.ascontiguousarray(frame, dtype=np.float32)
218
 
219
  interpreter.set_tensor(input_details[0]["index"], frame)
 
224
  # Concatenate along time dimension
225
  spec_e = np.concatenate(outputs, axis=1).astype(np.float32) # [1, T, F, 2]
226
 
227
+ # iSTFT to waveform (model SR), then back to original SR for saving
228
+ enhanced_model_sr = postprocessing(spec_e, cfg)
229
+ enhanced = resample_back(enhanced_model_sr, cfg.sr, sr_in)
230
+
231
+ # Alignment compensation #2
232
+ enhanced = enhanced[: audio.size]
233
 
234
  # Save as 16-bit PCM WAV, mono, original sample rate
235
  out_path.parent.mkdir(parents=True, exist_ok=True)
 
237
 
238
 
239
  def main():
240
+ parser = argparse.ArgumentParser(
241
+ description="Enhance WAV files with a DPDFNet TFLite model (streaming)."
242
+ )
243
+ parser.add_argument(
244
+ "--noisy_dir",
245
+ type=str,
246
+ required=True,
247
+ help="Folder with noisy *.wav files (non-recursive).",
248
+ )
249
+ parser.add_argument(
250
+ "--enhanced_dir",
251
+ type=str,
252
+ required=True,
253
+ help="Output folder for enhanced WAVs.",
254
+ )
255
  parser.add_argument(
256
  "--model_name",
257
  type=str,
258
  default="dpdfnet8",
259
+ choices=sorted(MODEL_CONFIG.keys()),
260
  help=(
261
+ "Name of the model to use. The script will automatically use the correct "
262
+ "sample-rate/STFT settings based on MODEL_CONFIG."
 
263
  ),
264
  )
 
265
 
266
+ args = parser.parse_args()
267
  noisy_dir = Path(args.noisy_dir)
268
  enhanced_dir = Path(args.enhanced_dir)
269
  model_name = args.model_name
270
 
271
  if not noisy_dir.is_dir():
272
+ print(
273
+ f"ERROR: --noisy_dir does not exist or is not a directory: {noisy_dir}",
274
+ file=sys.stderr,
275
+ )
276
  sys.exit(1)
277
 
278
  wavs = sorted(p for p in noisy_dir.glob("*.wav") if p.is_file())
 
280
  print(f"No .wav files found in {noisy_dir} (non-recursive).")
281
  sys.exit(0)
282
 
283
+ cfg = MODEL_CONFIG.get(model_name, None)
284
  print(f"Model: {model_name}")
285
+ if cfg is not None:
286
+ print(f"Model SR: {cfg['sr']} Hz | win_len: {cfg['win_len']} | hop: {cfg['win_len']//2}")
287
  print(f"Input : {noisy_dir}")
288
  print(f"Output: {enhanced_dir}")
289
  print(f"Found {len(wavs)} file(s). Enhancing...\n")
290
 
291
  for wav in wavs:
292
+ out_path = enhanced_dir / (wav.stem + f"_{model_name}.wav")
293
  try:
294
  enhance_file(wav, out_path, model_name)
295
  except Exception as e: