VoiceFocus / sdk.py
mariesig
Refactor online streaming functionality and enhance documentation
990f149
raw
history blame
3.18 kB
import numpy as np
from dotenv import load_dotenv
import aic_sdk as aic
import os
from constants import MODEL_ID
load_dotenv()
class SDKWrapper:
def __init__(self, model_id: str = MODEL_ID, models_dir: str = "./models"):
if os.getenv("AIC_SDK_KEY") is None:
raise RuntimeError("Missing AIC_SDK_KEY.")
self.sdk_key = os.getenv("AIC_SDK_KEY")
model_path = aic.Model.download(model_id, models_dir)
self.model = aic.Model.from_file(model_path)
def init_processor(self, sample_rate: int, enhancement_level: float, allow_variable_frames: bool = False, num_frames: int | None = None,num_channels: int = 1, sync: bool = True):
self.processor_sample_rate = sample_rate
processor_optimal_frames = self.model.get_optimal_num_frames(sample_rate)
self.num_frames = num_frames if num_frames else processor_optimal_frames
config = aic.ProcessorConfig(
sample_rate=sample_rate,
num_channels=num_channels,
num_frames=self.num_frames,
allow_variable_frames=allow_variable_frames,
)
if sync:
processor = aic.Processor(self.model, self.sdk_key, config)
else:
processor = aic.ProcessorAsync(self.model, self.sdk_key, config)
processor.get_processor_context().set_parameter(
aic.ProcessorParameter.EnhancementLevel, float(enhancement_level)
)
self.processor = processor
def change_enhancement_level(self, enhancement_level: float):
if not hasattr(self, "processor"):
raise ValueError("Processor not initialized")
self.processor.get_processor_context().set_parameter(
aic.ProcessorParameter.EnhancementLevel, float(enhancement_level)
)
def _check_shape(self, audio: np.ndarray) -> np.ndarray:
if len(audio.shape) == 1:
audio = audio.reshape(1, -1)
if audio.shape[0] > 2 or len(audio.shape) != 2:
raise ValueError("Expected audio with shape (n, frames)")
return audio
def process_sync(
self,
audio: np.ndarray,
) -> np.ndarray:
"""
audio_array: 2D NumPy array with shape (num_channels, samples) containing audio data to be enhanced
"""
audio = self._check_shape(audio)
out = np.zeros_like(audio)
chunk_size = self.num_frames
n = audio.shape[1]
for i in range(0, n, chunk_size):
chunk = audio[:, i : i + chunk_size]
if chunk.shape[1] < chunk_size:
last = chunk.shape[1]
padded = np.zeros((1, chunk_size), dtype=audio.dtype)
padded[:, :last] = chunk
enhanced = self.processor.process(padded)
out[:, i : i + last] = enhanced[:, :last]
break
enhanced = self.processor.process(chunk)
out[:, i : i + chunk_size] = enhanced[:, :chunk_size]
return out
def process_chunk(self, audio: np.ndarray) -> np.ndarray:
audio = self._check_shape(audio)
result = self.processor.process(audio)
return result