danielr-ceva commited on
Commit
0e03407
·
verified ·
1 Parent(s): 3ecf5cb

Upload run_tflite.py

Browse files
Files changed (1) hide show
  1. run_tflite.py +196 -0
run_tflite.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import sys
4
+
5
+ import numpy as np
6
+ import soundfile as sf
7
+ import librosa
8
+ from tflite_runtime.interpreter import Interpreter
9
+ from tqdm import tqdm
10
+
11
+
12
+ TFLITE_DIR = Path('./')
13
+
14
+ # ===== STFT / iSTFT params (as in the snippet) =====
15
+ WIN_LEN = 320 # 16 kHz: 320
16
+ HOP_SIZE = WIN_LEN // 2 # 50% hop
17
+
18
+
19
+ def vorbis_window(window_len: int) -> np.ndarray:
20
+ window_size_h = window_len / 2
21
+ indices = np.arange(window_len)
22
+ sin = np.sin(0.5 * np.pi * (indices + 0.5) / window_size_h)
23
+ window = np.sin(0.5 * np.pi * sin * sin)
24
+ return window.astype(np.float32)
25
+
26
+
27
+ def get_wnorm(window_len: int, frame_size: int) -> float:
28
+ # window_len - #samples of the window; frame_size - hop size
29
+ return 1.0 / (window_len ** 2 / (2 * frame_size))
30
+
31
+
32
+ # ---------- Pre/Post processing ----------
33
+ _WIN = vorbis_window(WIN_LEN)
34
+ _WNORM = get_wnorm(WIN_LEN, HOP_SIZE)
35
+
36
+
37
+ def preprocessing(waveform_16k: np.ndarray) -> np.ndarray:
38
+ """
39
+ waveform_16k: 1D float32 numpy array at 16 kHz, mono, range ~[-1,1]
40
+ Returns complex STFT as real/imag split: [B=1, T, F, 2] float32
41
+ """
42
+ # Librosa returns [F, T]; match original by using center=False here
43
+ spec = librosa.stft(
44
+ y=waveform_16k.astype(np.float32, copy=False),
45
+ n_fft=WIN_LEN,
46
+ hop_length=HOP_SIZE,
47
+ win_length=WIN_LEN,
48
+ window=_WIN,
49
+ center=False,
50
+ pad_mode="reflect"
51
+ ) # [F, T] complex64
52
+ spec = (spec.T * _WNORM).astype(np.complex64) # [T, F]
53
+ spec_ri = np.stack([spec.real, spec.imag], axis=-1).astype(np.float32) # [T, F, 2]
54
+ return spec_ri[None, ...] # [1, T, F, 2]
55
+
56
+
57
+ def postprocessing(spec_e: np.ndarray) -> np.ndarray:
58
+ """
59
+ spec_e: [1, T, F, 2] float32
60
+ Returns waveform (1D float32, 16 kHz)
61
+ """
62
+ # Recreate complex STFT with shape [F, T]
63
+ spec_c = spec_e[0].astype(np.float32) # [T, F, 2]
64
+ spec = (spec_c[..., 0] + 1j * spec_c[..., 1]).T.astype(np.complex64) # [F, T]
65
+
66
+ waveform_e = librosa.istft(
67
+ spec,
68
+ hop_length=HOP_SIZE,
69
+ win_length=WIN_LEN,
70
+ window=_WIN,
71
+ center=True,
72
+ length=None,
73
+ ).astype(np.float32)
74
+
75
+ waveform_e = waveform_e / _WNORM
76
+ waveform_e = np.concatenate([waveform_e[WIN_LEN * 2:], np.zeros(WIN_LEN * 2, dtype=np.float32)])
77
+ return waveform_e.astype(np.float32)
78
+
79
+
80
+ # ---------- Audio utilities ----------
81
+ def to_mono(audio: np.ndarray) -> np.ndarray:
82
+ if audio.ndim == 1:
83
+ return audio
84
+ # Average channels to mono
85
+ return np.mean(audio, axis=1)
86
+
87
+
88
+ def ensure_16k(waveform: np.ndarray, sr: int, target_sr: int = 16000) -> np.ndarray:
89
+ if sr == target_sr:
90
+ return waveform.astype(np.float32, copy=False)
91
+ return librosa.resample(waveform.astype(np.float32, copy=False), orig_sr=sr, target_sr=target_sr)
92
+
93
+
94
+ def resample_back(waveform_16k: np.ndarray, target_sr: int) -> np.ndarray:
95
+ if target_sr == 16000:
96
+ return waveform_16k
97
+ return librosa.resample(waveform_16k.astype(np.float32, copy=False), orig_sr=16000, target_sr=target_sr)
98
+
99
+
100
+ def pcm16_safe(x: np.ndarray) -> np.ndarray:
101
+ x = np.clip(x, -1.0, 1.0)
102
+ return (x * 32767.0).astype(np.int16)
103
+
104
+
105
+ # ---------- Core processing ----------
106
+ def enhance_file(in_path: Path, out_path: Path, model_name: str) -> None:
107
+ # Load audio
108
+ audio, sr_in = sf.read(str(in_path), always_2d=False)
109
+ audio = to_mono(audio)
110
+
111
+ # Convert dtypes and resample to 16k for the model
112
+ audio = audio.astype(np.float32, copy=False)
113
+ audio_16k = ensure_16k(audio, sr_in, 16000)
114
+
115
+ # STFT to frames (streaming)
116
+ spec = preprocessing(audio_16k) # [1, T, F, 2]
117
+ num_frames = spec.shape[1]
118
+
119
+ # New interpreter per file ensures stateful models (RNN/LSTM) start clean
120
+ interpreter = Interpreter(model_path=str(TFLITE_DIR / (model_name + '.tflite')))
121
+ interpreter.allocate_tensors()
122
+ input_details = interpreter.get_input_details()
123
+ output_details = interpreter.get_output_details()
124
+
125
+ # Frame-by-frame inference
126
+ outputs = []
127
+
128
+ for t in tqdm(range(num_frames), desc=f"{in_path.name}", unit="frm", leave=False):
129
+ frame = spec[:, t:t + 1] # [1, 1, F, 2]
130
+ # Some TFLite builds are picky about contiguity/dtype
131
+ frame = np.ascontiguousarray(frame, dtype=np.float32)
132
+
133
+ interpreter.set_tensor(input_details[0]["index"], frame)
134
+ interpreter.invoke()
135
+ y = interpreter.get_tensor(output_details[0]["index"]) # expected [1,1,F,2]
136
+ outputs.append(np.ascontiguousarray(y, dtype=np.float32))
137
+
138
+ # Concatenate along time dimension
139
+ spec_e = np.concatenate(outputs, axis=1).astype(np.float32) # [1, T, F, 2]
140
+
141
+ # iSTFT to waveform (16 kHz), then back to original SR for saving
142
+ enhanced_16k = postprocessing(spec_e)
143
+ enhanced = resample_back(enhanced_16k, sr_in)
144
+
145
+ # Save as 16-bit PCM WAV, mono, original sample rate
146
+ out_path.parent.mkdir(parents=True, exist_ok=True)
147
+ sf.write(str(out_path), pcm16_safe(enhanced), sr_in, subtype="PCM_16")
148
+
149
+
150
+ def main():
151
+ parser = argparse.ArgumentParser(description="Enhance WAV files with a DPDFNet TFLite model (streaming).")
152
+ parser.add_argument("--noisy_dir", type=str, required=True, help="Folder with noisy *.wav files (non-recursive).")
153
+ parser.add_argument("--enhanced_dir", type=str, required=True, help="Output folder for enhanced WAVs.")
154
+ parser.add_argument(
155
+ "--model_name",
156
+ type=str,
157
+ default="dpdfnet8",
158
+ choices=["baseline", "dpdfnet2", "dpdfnet4", "dpdfnet8"],
159
+ help=(
160
+ "Name of the model to use. Options: "
161
+ "'baseline', 'dpdfnet2', 'dpdfnet4', 'dpdfnet8'. "
162
+ "Default is 'dpdfnet8'."
163
+ ),
164
+ )
165
+ args = parser.parse_args()
166
+
167
+ noisy_dir = Path(args.noisy_dir)
168
+ enhanced_dir = Path(args.enhanced_dir)
169
+ model_name = args.model_name
170
+
171
+ if not noisy_dir.is_dir():
172
+ print(f"ERROR: --noisy_dir does not exist or is not a directory: {noisy_dir}", file=sys.stderr)
173
+ sys.exit(1)
174
+
175
+ wavs = sorted(p for p in noisy_dir.glob("*.wav") if p.is_file())
176
+ if not wavs:
177
+ print(f"No .wav files found in {noisy_dir} (non-recursive).")
178
+ sys.exit(0)
179
+
180
+ print(f"Model: {model_name}")
181
+ print(f"Input : {noisy_dir}")
182
+ print(f"Output: {enhanced_dir}")
183
+ print(f"Found {len(wavs)} file(s). Enhancing...\n")
184
+
185
+ for wav in wavs:
186
+ out_path = enhanced_dir / (wav.stem + f'_{model_name}.wav')
187
+ try:
188
+ enhance_file(wav, out_path, model_name)
189
+ except Exception as e:
190
+ print(f"[SKIP] {wav.name} due to error: {e}", file=sys.stderr)
191
+
192
+ print("\nProcessing complete. Outputs saved in:", enhanced_dir)
193
+
194
+
195
+ if __name__ == "__main__":
196
+ main()