Sky-Kim's picture
Initial commit
6ac63e1
using System;
using System.Runtime.InteropServices;
using UnityEngine;
public class TenVADRunner : IDisposable
{
#if UNITY_WEBGL && !UNITY_EDITOR
[DllImport("__Internal")]
private static extern int WebGLTenVad_Create(int hopSize, float threshold);
[DllImport("__Internal")]
private static extern int WebGLTenVad_Process(int instanceId, short[] audioData, int audioDataLength, out float outProbability, out int outFlag);
[DllImport("__Internal")]
private static extern int WebGLTenVad_Destroy(int instanceId);
[DllImport("__Internal")]
private static extern int WebGLTenVad_GetState();
private const int WebGlStateLoading = 0;
private const int WebGlStateError = -1;
private const int WebGlPending = -2;
private int webGlInstanceId;
private readonly int webGlHopSize;
private readonly float webGlThreshold;
#else
private const string DllName = "ten_vad";
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
private static extern int ten_vad_create(out IntPtr handle, UIntPtr hop_size, float threshold);
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
private static extern int ten_vad_process(IntPtr handle, short[] audio_data, UIntPtr audio_data_length, out float out_probability, out int out_flag);
[DllImport(DllName, CallingConvention = CallingConvention.Cdecl)]
private static extern int ten_vad_destroy(ref IntPtr handle);
private IntPtr vadHandle = IntPtr.Zero;
#endif
private bool isDisposed = false;
private bool useFallbackVad = false;
private readonly float fallbackThreshold;
private float fallbackSmoothedProbability;
#if UNITY_WEBGL && !UNITY_EDITOR
private const float FallbackRmsScale = 140f;
#else
private const float FallbackRmsScale = 20f;
#endif
private const float FallbackSmoothing = 0.35f;
public TenVADRunner(UIntPtr hopSize, float threshold)
{
fallbackThreshold = Mathf.Clamp01(threshold);
fallbackSmoothedProbability = 0f;
#if UNITY_WEBGL && !UNITY_EDITOR
webGlHopSize = Math.Max(1, (int)hopSize.ToUInt64());
webGlThreshold = threshold;
TryInitializeWebGlVad();
#else
try
{
int result = ten_vad_create(out vadHandle, hopSize, threshold);
if (result != 0 || vadHandle == IntPtr.Zero)
{
EnableFallback($"Failed to create VAD Handle. (Error Code: {result})");
return;
}
}
catch (DllNotFoundException ex)
{
EnableFallback($"Native VAD library '{DllName}' was not found. {ex.Message}");
}
catch (EntryPointNotFoundException ex)
{
EnableFallback($"Native VAD entry point was not found. {ex.Message}");
}
catch (Exception ex)
{
EnableFallback($"Native VAD initialization failed. {ex.Message}");
}
#endif
}
public int Process(short[] audioData, out float probability, out int flag)
{
if (isDisposed)
{
throw new ObjectDisposedException(nameof(TenVADRunner), "The VAD instance has already been disposed.");
}
if (audioData == null || audioData.Length == 0)
{
probability = 0;
flag = 0;
return -1;
}
if (!useFallbackVad)
{
#if UNITY_WEBGL && !UNITY_EDITOR
if (webGlInstanceId <= 0)
{
if (!TryInitializeWebGlVad())
return ProcessWithFallback(audioData, out probability, out flag);
}
try
{
int result = WebGLTenVad_Process(webGlInstanceId, audioData, audioData.Length, out probability, out flag);
if (result == 0)
return result;
if (result == WebGlPending)
return ProcessWithFallback(audioData, out probability, out flag);
int state = SafeGetWebGlState();
EnableFallback($"WebGL ten_vad processing failed. result={result}, state={state}");
}
catch (Exception ex)
{
EnableFallback($"WebGL ten_vad processing bridge failed. {ex.Message}");
}
#else
if (vadHandle != IntPtr.Zero)
{
try
{
int result = ten_vad_process(vadHandle, audioData, (UIntPtr)audioData.Length, out probability, out flag);
return result;
}
catch (DllNotFoundException ex)
{
EnableFallback($"Native VAD library '{DllName}' disappeared at runtime. {ex.Message}");
}
catch (EntryPointNotFoundException ex)
{
EnableFallback($"Native VAD entry point missing at runtime. {ex.Message}");
}
catch (Exception ex)
{
EnableFallback($"Native VAD processing failed. {ex.Message}");
}
}
#endif
}
return ProcessWithFallback(audioData, out probability, out flag);
}
public void Dispose()
{
if (isDisposed)
return;
#if UNITY_WEBGL && !UNITY_EDITOR
if (!useFallbackVad && webGlInstanceId > 0)
{
try
{
WebGLTenVad_Destroy(webGlInstanceId);
}
catch (Exception)
{
// Ignore disposal failures for browser teardown.
}
webGlInstanceId = 0;
}
#else
if (!useFallbackVad && vadHandle != IntPtr.Zero)
{
try
{
ten_vad_destroy(ref vadHandle);
}
catch (Exception)
{
// Ignore disposal failures for native plugin teardown.
}
vadHandle = IntPtr.Zero;
}
#endif
isDisposed = true;
}
private void EnableFallback(string reason)
{
if (useFallbackVad)
return;
useFallbackVad = true;
#if UNITY_WEBGL && !UNITY_EDITOR
webGlInstanceId = 0;
Debug.LogWarning($"[TenVADRunner] Falling back to RMS VAD on WebGL. Reason: {reason}");
#else
vadHandle = IntPtr.Zero;
Debug.LogWarning($"[TenVADRunner] Falling back to simple RMS VAD. Reason: {reason}");
#endif
}
#if UNITY_WEBGL && !UNITY_EDITOR
private bool TryInitializeWebGlVad()
{
if (webGlInstanceId > 0)
return true;
int createResult;
try
{
createResult = WebGLTenVad_Create(webGlHopSize, webGlThreshold);
}
catch (Exception ex)
{
EnableFallback($"WebGL ten_vad bridge is unavailable. {ex.Message}");
return false;
}
if (createResult > 0)
{
webGlInstanceId = createResult;
return true;
}
if (createResult == WebGlPending || createResult == WebGlStateLoading)
{
return false;
}
int state = SafeGetWebGlState();
if (state == WebGlStateError || createResult < 0)
{
EnableFallback($"WebGL ten_vad initialization failed. createResult={createResult}, state={state}");
}
return false;
}
private int SafeGetWebGlState()
{
try
{
return WebGLTenVad_GetState();
}
catch
{
return WebGlStateError;
}
}
#endif
private int ProcessWithFallback(short[] audioData, out float probability, out int flag)
{
double sumSquares = 0d;
for (int i = 0; i < audioData.Length; i++)
{
float sample = audioData[i] / 32768f;
sumSquares += sample * sample;
}
float rms = (float)Math.Sqrt(sumSquares / audioData.Length);
float rawProbability = Mathf.Clamp01(rms * FallbackRmsScale);
fallbackSmoothedProbability = Mathf.Lerp(fallbackSmoothedProbability, rawProbability, FallbackSmoothing);
probability = fallbackSmoothedProbability;
flag = probability >= fallbackThreshold ? 1 : 0;
return 0;
}
}