danielr-ceva commited on
Commit
1a80d48
·
verified ·
1 Parent(s): 3985ab8

Delete run_tflite.py

Browse files
Files changed (1) hide show
  1. run_tflite.py +0 -298
run_tflite.py DELETED
@@ -1,298 +0,0 @@
1
- import argparse
2
- from dataclasses import dataclass
3
- from pathlib import Path
4
- import sys
5
-
6
- import numpy as np
7
- import soundfile as sf
8
- import librosa
9
- from tflite_runtime.interpreter import Interpreter
10
- from tqdm import tqdm
11
-
12
- TFLITE_DIR = Path('./')
13
-
14
- # -----------------------------------------------------------------------------
15
- # Model registry
16
- # -----------------------------------------------------------------------------
17
- # 16 kHz models: WIN_LEN=320 (20 ms)
18
- # 48 kHz models: WIN_LEN=960 (20 ms)
19
-
20
- MODEL_CONFIG = {
21
- # 16 kHz models
22
- "baseline": {"sr": 16000, "win_len": 320},
23
- "dpdfnet2": {"sr": 16000, "win_len": 320},
24
- "dpdfnet4": {"sr": 16000, "win_len": 320},
25
- "dpdfnet8": {"sr": 16000, "win_len": 320},
26
-
27
- # 48 kHz models
28
- "dpdfnet2_48khz_hr": {"sr": 48000, "win_len": 960},
29
- }
30
-
31
-
32
- def vorbis_window(window_len: int) -> np.ndarray:
33
- window_size_h = window_len / 2
34
- indices = np.arange(window_len)
35
- sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h)
36
- window = np.sin(0.5 * np.pi * sin * sin)
37
- return window.astype(np.float32)
38
-
39
-
40
- def get_wnorm(window_len: int, frame_size: int) -> float:
41
- # window_len - #samples of the window; frame_size - hop size
42
- return 1.0 / (window_len ** 2 / (2 * frame_size))
43
-
44
-
45
- @dataclass(frozen=True)
46
- class STFTConfig:
47
- sr: int
48
- win_len: int
49
- hop_size: int
50
- win: np.ndarray
51
- wnorm: float
52
-
53
-
54
- def make_stft_config(sr: int, win_len: int) -> STFTConfig:
55
- hop_size = win_len // 2 # 50% hop
56
- win = vorbis_window(win_len)
57
- wnorm = get_wnorm(win_len, hop_size)
58
- return STFTConfig(sr=sr, win_len=win_len, hop_size=hop_size, win=win, wnorm=wnorm)
59
-
60
-
61
- # -----------------------------------------------------------------------------
62
- # Pre/Post processing
63
- # -----------------------------------------------------------------------------
64
-
65
- def preprocessing(waveform: np.ndarray, cfg: STFTConfig) -> np.ndarray:
66
- """
67
- waveform: 1D float32 numpy array at cfg.sr, mono, range ~[-1,1]
68
- Returns complex STFT as real/imag split: [B=1, T, F, 2] float32
69
- """
70
- # Librosa returns [F, T]; match original by using center=False here
71
- spec = librosa.stft(
72
- y=waveform.astype(np.float32, copy=False),
73
- n_fft=cfg.win_len,
74
- hop_length=cfg.hop_size,
75
- win_length=cfg.win_len,
76
- window=cfg.win,
77
- center=True,
78
- pad_mode="reflect",
79
- ) # [F, T] complex64
80
-
81
- spec = (spec.T * cfg.wnorm).astype(np.complex64) # [T, F]
82
- spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32) # [T, F, 2]
83
- return spec_ri[None, ...] # [1, T, F, 2]
84
-
85
-
86
- def postprocessing(spec_e: np.ndarray, cfg: STFTConfig) -> np.ndarray:
87
- """
88
- spec_e: [1, T, F, 2] float32
89
- Returns waveform (1D float32, cfg.sr)
90
- """
91
- # Recreate complex STFT with shape [F, T]
92
- spec_c = spec_e[0].astype(np.float32) # [T, F, 2]
93
- spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64) # [F, T]
94
-
95
- waveform_e = librosa.istft(
96
- spec,
97
- hop_length=cfg.hop_size,
98
- win_length=cfg.win_len,
99
- window=cfg.win,
100
- center=True,
101
- length=None,
102
- ).astype(np.float32)
103
-
104
- waveform_e = waveform_e / cfg.wnorm
105
-
106
- # Keep the legacy alignment compensation behavior, scaled by win_len.
107
- waveform_e = np.concatenate(
108
- [waveform_e[cfg.win_len * 2 :], np.zeros(cfg.win_len * 2, dtype=np.float32)]
109
- )
110
- return waveform_e.astype(np.float32)
111
-
112
-
113
- # -----------------------------------------------------------------------------
114
- # Audio utilities
115
- # -----------------------------------------------------------------------------
116
-
117
- def to_mono(audio: np.ndarray) -> np.ndarray:
118
- if audio.ndim == 1:
119
- return audio
120
- # Average channels to mono
121
- return np.mean(audio, axis=1)
122
-
123
-
124
- def ensure_sr(waveform: np.ndarray, sr: int, target_sr: int) -> np.ndarray:
125
- if sr == target_sr:
126
- return waveform.astype(np.float32, copy=False)
127
- return librosa.resample(
128
- waveform.astype(np.float32, copy=False), orig_sr=sr, target_sr=target_sr
129
- )
130
-
131
-
132
- def resample_back(waveform_model_sr: np.ndarray, model_sr: int, target_sr: int) -> np.ndarray:
133
- if target_sr == model_sr:
134
- return waveform_model_sr
135
- return librosa.resample(
136
- waveform_model_sr.astype(np.float32, copy=False),
137
- orig_sr=model_sr,
138
- target_sr=target_sr,
139
- )
140
-
141
-
142
- def pcm16_safe(x: np.ndarray) -> np.ndarray:
143
- x = np.clip(x, -1.0, 1.0)
144
- return (x * 32767.0).astype(np.int16)
145
-
146
-
147
- # -----------------------------------------------------------------------------
148
- # Core processing
149
- # -----------------------------------------------------------------------------
150
-
151
- def _load_model_and_cfg(model_name: str) -> tuple[Interpreter, STFTConfig]:
152
- """Create interpreter and return (interpreter, STFTConfig) for this model."""
153
- if model_name not in MODEL_CONFIG:
154
- raise ValueError(
155
- f"Unknown model '{model_name}'. Add it to MODEL_CONFIG or pass a valid --model_name."
156
- )
157
-
158
- model_path = TFLITE_DIR / f"{model_name}.tflite"
159
- if not model_path.exists():
160
- raise FileNotFoundError(f"TFLite model not found: {model_path}")
161
-
162
- interpreter = Interpreter(model_path=str(model_path))
163
- interpreter.allocate_tensors()
164
-
165
- cfg_dict = MODEL_CONFIG[model_name]
166
- cfg = make_stft_config(sr=int(cfg_dict["sr"]), win_len=int(cfg_dict["win_len"]))
167
-
168
- # Optional sanity-check: infer expected F from model input and compare
169
- try:
170
- input_details = interpreter.get_input_details()
171
- shape = input_details[0].get("shape", None)
172
- # Expect [1, 1, F, 2] (or [1, T, F, 2] for non-streaming)
173
- if shape is not None and len(shape) >= 3:
174
- F = int(shape[-2]) # ... F, 2
175
- expected_F = cfg.win_len // 2 + 1
176
- if F != expected_F:
177
- raise ValueError(
178
- f"Model '{model_name}' input F={F} does not match win_len={cfg.win_len} "
179
- f"(expected F={expected_F}). Update MODEL_CONFIG for this model."
180
- )
181
- except Exception:
182
- # Do not hard-fail on odd/unknown shapes; the runtime error will be informative.
183
- pass
184
-
185
- return interpreter, cfg
186
-
187
-
188
- def enhance_file(in_path: Path, out_path: Path, model_name: str) -> None:
189
- # Load audio
190
- audio, sr_in = sf.read(str(in_path), always_2d=False)
191
- audio = to_mono(audio)
192
- audio = audio.astype(np.float32, copy=False)
193
-
194
- # Load model and its expected SR/STFT config
195
- interpreter, cfg = _load_model_and_cfg(model_name)
196
- input_details = interpreter.get_input_details()
197
- output_details = interpreter.get_output_details()
198
-
199
- # Resample to model SR
200
- audio_model_sr = ensure_sr(audio, sr_in, cfg.sr)
201
-
202
- # Alignment compensation #1
203
- audio_pad = np.pad(audio_model_sr, (0, cfg.win_len), mode='constant', constant_values=0)
204
-
205
- # STFT to frames (streaming)
206
- spec = preprocessing(audio_pad, cfg) # [1, T, F, 2]
207
- num_frames = spec.shape[1]
208
-
209
- # Frame-by-frame inference
210
- outputs = []
211
- for t in tqdm(range(num_frames), desc=f"{in_path.name}", unit="frm", leave=False):
212
- frame = spec[:, t : t + 1] # [1, 1, F, 2]
213
- frame = np.ascontiguousarray(frame, dtype=np.float32)
214
-
215
- interpreter.set_tensor(input_details[0]["index"], frame)
216
- interpreter.invoke()
217
- y = interpreter.get_tensor(output_details[0]["index"]) # expected [1,1,F,2]
218
- outputs.append(np.ascontiguousarray(y, dtype=np.float32))
219
-
220
- # Concatenate along time dimension
221
- spec_e = np.concatenate(outputs, axis=1).astype(np.float32) # [1, T, F, 2]
222
-
223
- # iSTFT to waveform (model SR), then back to original SR for saving
224
- enhanced_model_sr = postprocessing(spec_e, cfg)
225
- enhanced = resample_back(enhanced_model_sr, cfg.sr, sr_in)
226
-
227
- # Alignment compensation #2
228
- enhanced = enhanced[: audio.size]
229
-
230
- # Save as 16-bit PCM WAV, mono, original sample rate
231
- out_path.parent.mkdir(parents=True, exist_ok=True)
232
- sf.write(str(out_path), pcm16_safe(enhanced), sr_in, subtype="PCM_16")
233
-
234
-
235
- def main():
236
- parser = argparse.ArgumentParser(
237
- description="Enhance WAV files with a DPDFNet TFLite model (streaming)."
238
- )
239
- parser.add_argument(
240
- "--noisy_dir",
241
- type=str,
242
- required=True,
243
- help="Folder with noisy *.wav files (non-recursive).",
244
- )
245
- parser.add_argument(
246
- "--enhanced_dir",
247
- type=str,
248
- required=True,
249
- help="Output folder for enhanced WAVs.",
250
- )
251
- parser.add_argument(
252
- "--model_name",
253
- type=str,
254
- default="dpdfnet8",
255
- choices=sorted(MODEL_CONFIG.keys()),
256
- help=(
257
- "Name of the model to use. The script will automatically use the correct "
258
- "sample-rate/STFT settings based on MODEL_CONFIG."
259
- ),
260
- )
261
-
262
- args = parser.parse_args()
263
- noisy_dir = Path(args.noisy_dir)
264
- enhanced_dir = Path(args.enhanced_dir)
265
- model_name = args.model_name
266
-
267
- if not noisy_dir.is_dir():
268
- print(
269
- f"ERROR: --noisy_dir does not exist or is not a directory: {noisy_dir}",
270
- file=sys.stderr,
271
- )
272
- sys.exit(1)
273
-
274
- wavs = sorted(p for p in noisy_dir.glob("*.wav") if p.is_file())
275
- if not wavs:
276
- print(f"No .wav files found in {noisy_dir} (non-recursive).")
277
- sys.exit(0)
278
-
279
- cfg = MODEL_CONFIG.get(model_name, None)
280
- print(f"Model: {model_name}")
281
- if cfg is not None:
282
- print(f"Model SR: {cfg['sr']} Hz | win_len: {cfg['win_len']} | hop: {cfg['win_len']//2}")
283
- print(f"Input : {noisy_dir}")
284
- print(f"Output: {enhanced_dir}")
285
- print(f"Found {len(wavs)} file(s). Enhancing...\n")
286
-
287
- for wav in wavs:
288
- out_path = enhanced_dir / (wav.stem + f"_{model_name}.wav")
289
- try:
290
- enhance_file(wav, out_path, model_name)
291
- except Exception as e:
292
- print(f"[SKIP] {wav.name} due to error: {e}", file=sys.stderr)
293
-
294
- print("\nProcessing complete. Outputs saved in:", enhanced_dir)
295
-
296
-
297
- if __name__ == "__main__":
298
- main()