File size: 1,900 Bytes
b5bff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
// FireEcho Preprocessing — pybind11 bindings (SpeechLib-matched)
// JIT-compiled via torch.utils.cpp_extension.load()

#include <torch/extension.h>

// Forward declarations from fireecho_preproc_cuda.cu
torch::Tensor cuda_stft_impl(
    torch::Tensor audio,
    torch::Tensor window,
    int64_t n_fft,
    int64_t win_length,
    int64_t hop_length,
    double preemph_coeff
);

torch::Tensor cuda_mel_filterbank_impl(
    torch::Tensor power_spec,
    torch::Tensor mel_matrix
);

torch::Tensor cuda_audio_pipeline_impl(
    torch::Tensor audio,
    torch::Tensor window,
    torch::Tensor mel_matrix,
    int64_t n_fft,
    int64_t win_length,
    int64_t hop_length,
    double preemph_coeff
);

torch::Tensor cuda_image_preprocess_impl(
    torch::Tensor image_rgb,
    int64_t crop_size
);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.doc() = "FireEcho CUDA-accelerated preprocessing (Phase 5, SpeechLib-matched)";
    m.def("cuda_stft", &cuda_stft_impl,
          "Batched STFT with per-frame pre-emphasis + 32768 scaling via cuFFT",
          py::arg("audio"), py::arg("window"),
          py::arg("n_fft"), py::arg("win_length"), py::arg("hop_length"),
          py::arg("preemph_coeff") = 0.97);
    m.def("cuda_mel_filterbank", &cuda_mel_filterbank_impl,
          "Mel filterbank with pre-computed SpeechLib matrix + fused clip+log",
          py::arg("power_spec"), py::arg("mel_matrix"));
    m.def("cuda_audio_pipeline", &cuda_audio_pipeline_impl,
          "Full audio pipeline: STFT + mel in single call",
          py::arg("audio"), py::arg("window"), py::arg("mel_matrix"),
          py::arg("n_fft"), py::arg("win_length"), py::arg("hop_length"),
          py::arg("preemph_coeff") = 0.97);
    m.def("cuda_image_preprocess", &cuda_image_preprocess_impl,
          "Fused bicubic resize + normalize [-1,1] + bf16",
          py::arg("image_rgb"), py::arg("crop_size"));
}