|
|
""" |
|
|
Custom Inference Handler for Amebo Premium Voice |
|
|
Enables HuggingFace Inference API and Dedicated Endpoints |
|
|
""" |
|
|
import torch |
|
|
import numpy as np |
|
|
from transformers import VitsModel, AutoTokenizer |
|
|
from scipy import signal |
|
|
from scipy.ndimage import uniform_filter1d |
|
|
import base64 |
|
|
import io |
|
|
import soundfile as sf |
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path="."): |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.sample_rate = 16000 |
|
|
|
|
|
|
|
|
self.model = VitsModel.from_pretrained("facebook/mms-tts-hau").to(self.device) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau") |
|
|
self.model.eval() |
|
|
|
|
|
def add_warmth(self, audio, warmth=0.3, presence=0.2): |
|
|
audio = audio.astype(np.float32) |
|
|
max_val = np.abs(audio).max() |
|
|
if max_val > 0: |
|
|
audio = audio / max_val |
|
|
|
|
|
|
|
|
if warmth > 0: |
|
|
b_low, a_low = signal.butter(2, 800 / (self.sample_rate / 2), btype='low') |
|
|
low_content = signal.filtfilt(b_low, a_low, audio) |
|
|
audio = audio + warmth * 0.3 * low_content |
|
|
|
|
|
|
|
|
if presence > 0: |
|
|
b_mid, a_mid = signal.butter(2, [2000 / (self.sample_rate / 2), |
|
|
4000 / (self.sample_rate / 2)], btype='band') |
|
|
mid_content = signal.filtfilt(b_mid, a_mid, audio) |
|
|
audio = audio + presence * 0.2 * mid_content |
|
|
|
|
|
|
|
|
threshold = 0.5 |
|
|
ratio = 3.0 |
|
|
audio_abs = np.abs(audio) |
|
|
mask = audio_abs > threshold |
|
|
if np.any(mask): |
|
|
gain_reduction = np.ones_like(audio) |
|
|
gain_reduction[mask] = threshold + (audio_abs[mask] - threshold) / ratio |
|
|
gain_reduction[mask] = gain_reduction[mask] / audio_abs[mask] |
|
|
audio = audio * gain_reduction |
|
|
|
|
|
|
|
|
audio = uniform_filter1d(audio, size=3) |
|
|
|
|
|
|
|
|
max_val = np.abs(audio).max() |
|
|
if max_val > 0: |
|
|
audio = audio / max_val * 0.95 |
|
|
|
|
|
return audio.astype(np.float32) |
|
|
|
|
|
def __call__(self, data): |
|
|
""" |
|
|
Process inference request |
|
|
|
|
|
Args: |
|
|
data: dict with 'inputs' (text) and optional 'parameters' |
|
|
|
|
|
Returns: |
|
|
Audio as base64 encoded WAV or raw bytes |
|
|
""" |
|
|
|
|
|
inputs = data.get("inputs", "") |
|
|
if not inputs: |
|
|
return {"error": "No input text provided"} |
|
|
|
|
|
|
|
|
params = data.get("parameters", {}) |
|
|
warmth = params.get("warmth", 0.3) |
|
|
presence = params.get("presence", 0.2) |
|
|
return_format = params.get("format", "base64") |
|
|
|
|
|
|
|
|
tokens = self.tokenizer(inputs, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
output = self.model(**tokens).waveform |
|
|
|
|
|
audio = output.squeeze().cpu().numpy() |
|
|
|
|
|
|
|
|
audio = self.add_warmth(audio, warmth=warmth, presence=presence) |
|
|
|
|
|
|
|
|
if return_format == "base64": |
|
|
buffer = io.BytesIO() |
|
|
sf.write(buffer, audio, self.sample_rate, format="WAV") |
|
|
buffer.seek(0) |
|
|
audio_base64 = base64.b64encode(buffer.read()).decode("utf-8") |
|
|
return { |
|
|
"audio": audio_base64, |
|
|
"sample_rate": self.sample_rate, |
|
|
"format": "wav", |
|
|
"encoding": "base64" |
|
|
} |
|
|
else: |
|
|
return { |
|
|
"audio": audio.tolist(), |
|
|
"sample_rate": self.sample_rate |
|
|
} |
|
|
|