VoiceFocus / sdk.py
mariesig's picture
vad (#3)
32dcdfe
import numpy as np
from dotenv import load_dotenv
import aic_sdk as aic
import os
from constants import MODEL_ID
load_dotenv()
class SDKParams:
def __init__(
self,
sample_rate: int = 16000,
enhancement_level: float = 1.0,
allow_variable_frames: bool = False,
num_channels: int = 1,
sync: bool = True,
num_frames: int | None = None,
):
self.sample_rate = sample_rate
self.enhancement_level = enhancement_level
self.allow_variable_frames = allow_variable_frames
self.num_channels = num_channels
self.sync = sync
self.num_frames = num_frames # to be set after processor init
class SDKWrapper:
def __init__(self, model_id: str = MODEL_ID, models_dir: str = "./models"):
if os.getenv("AIC_SDK_KEY") is None:
raise RuntimeError("Missing AIC_SDK_KEY.")
self.sdk_key = os.getenv("AIC_SDK_KEY")
model_path = aic.Model.download(model_id, models_dir)
self.model = aic.Model.from_file(model_path)
def init_processor(self, sdk_params: SDKParams):
optimal_frames = self.model.get_optimal_num_frames(sdk_params.sample_rate)
self.num_frames = sdk_params.num_frames if sdk_params.num_frames else optimal_frames
self.sample_rate = sdk_params.sample_rate
aic_config = aic.ProcessorConfig(
sample_rate=sdk_params.sample_rate,
num_channels=sdk_params.num_channels,
num_frames=self.num_frames,
allow_variable_frames=sdk_params.allow_variable_frames,
)
if sdk_params.sync:
self.processor = aic.Processor(self.model, self.sdk_key, aic_config)
else:
self.processor = aic.ProcessorAsync(self.model, self.sdk_key, aic_config)
self.processor.get_processor_context().set_parameter(
aic.ProcessorParameter.EnhancementLevel, float(sdk_params.enhancement_level)
)
self.enhancement_level = sdk_params.enhancement_level
self.vad_context = self.processor.get_vad_context()
def change_enhancement_level(self, enhancement_level: float):
if not hasattr(self, "processor"):
raise ValueError("Processor not initialized")
self.processor.get_processor_context().set_parameter(
aic.ProcessorParameter.EnhancementLevel, float(enhancement_level)
)
self.enhancement_level = enhancement_level
def _check_shape(self, audio: np.ndarray) -> np.ndarray:
if len(audio.shape) == 1:
audio = audio.reshape(1, -1)
if audio.shape[0] > 2 or len(audio.shape) != 2:
raise ValueError("Expected audio with shape (n, frames)")
return audio
def process_with_vad(
self,
audio: np.ndarray,
) -> tuple[np.ndarray, bool]:
"""
audio_array: 2D NumPy array with shape (num_channels, samples) containing audio data to be enhanced
"""
audio = self._check_shape(audio)
out = np.zeros_like(audio)
vad_per_sample = np.zeros_like(audio, dtype=bool)
vad_overall = False
chunk_size = self.num_frames
n = audio.shape[1]
for i in range(0, n, chunk_size):
chunk = audio[:, i : i + chunk_size]
if chunk.shape[1] < chunk_size:
last = chunk.shape[1]
padded = np.zeros((1, chunk_size), dtype=audio.dtype)
padded[:, :last] = chunk
enhanced = self.processor.process(padded)
out[:, i : i + last] = enhanced[:, :last]
break
enhanced = self.processor.process(chunk)
out[:, i : i + chunk_size] = enhanced[:, :chunk_size]
if self.vad_context.is_speech_detected():
vad_per_sample[:, i : i + chunk_size] = True
if vad_per_sample.mean() > 0.5:
vad_overall = True
return out, vad_overall
def process_chunk(self, audio: np.ndarray) -> np.ndarray:
audio = self._check_shape(audio)
result = self.processor.process(audio)
return result