File size: 3,859 Bytes
498dce5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Custom Inference Handler for Amebo Premium Voice
Enables HuggingFace Inference API and Dedicated Endpoints
"""
import torch
import numpy as np
from transformers import VitsModel, AutoTokenizer
from scipy import signal
from scipy.ndimage import uniform_filter1d
import base64
import io
import soundfile as sf

class EndpointHandler:
    def __init__(self, path="."):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.sample_rate = 16000
        
        # Load MMS-TTS Hausa
        self.model = VitsModel.from_pretrained("facebook/mms-tts-hau").to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-hau")
        self.model.eval()
    
    def add_warmth(self, audio, warmth=0.3, presence=0.2):
        audio = audio.astype(np.float32)
        max_val = np.abs(audio).max()
        if max_val > 0:
            audio = audio / max_val
        
        # Low-mid boost for warmth
        if warmth > 0:
            b_low, a_low = signal.butter(2, 800 / (self.sample_rate / 2), btype='low')
            low_content = signal.filtfilt(b_low, a_low, audio)
            audio = audio + warmth * 0.3 * low_content
        
        # Presence boost for clarity
        if presence > 0:
            b_mid, a_mid = signal.butter(2, [2000 / (self.sample_rate / 2), 
                                              4000 / (self.sample_rate / 2)], btype='band')
            mid_content = signal.filtfilt(b_mid, a_mid, audio)
            audio = audio + presence * 0.2 * mid_content
        
        # Gentle compression
        threshold = 0.5
        ratio = 3.0
        audio_abs = np.abs(audio)
        mask = audio_abs > threshold
        if np.any(mask):
            gain_reduction = np.ones_like(audio)
            gain_reduction[mask] = threshold + (audio_abs[mask] - threshold) / ratio
            gain_reduction[mask] = gain_reduction[mask] / audio_abs[mask]
            audio = audio * gain_reduction
        
        # Smooth transients
        audio = uniform_filter1d(audio, size=3)
        
        # Normalize
        max_val = np.abs(audio).max()
        if max_val > 0:
            audio = audio / max_val * 0.95
        
        return audio.astype(np.float32)
    
    def __call__(self, data):
        """
        Process inference request
        
        Args:
            data: dict with 'inputs' (text) and optional 'parameters'
        
        Returns:
            Audio as base64 encoded WAV or raw bytes
        """
        # Get input text
        inputs = data.get("inputs", "")
        if not inputs:
            return {"error": "No input text provided"}
        
        # Get parameters
        params = data.get("parameters", {})
        warmth = params.get("warmth", 0.3)
        presence = params.get("presence", 0.2)
        return_format = params.get("format", "base64")
        
        # Tokenize
        tokens = self.tokenizer(inputs, return_tensors="pt").to(self.device)
        
        # Generate audio
        with torch.no_grad():
            output = self.model(**tokens).waveform
        
        audio = output.squeeze().cpu().numpy()
        
        # Apply warmth
        audio = self.add_warmth(audio, warmth=warmth, presence=presence)
        
        # Return as base64 WAV
        if return_format == "base64":
            buffer = io.BytesIO()
            sf.write(buffer, audio, self.sample_rate, format="WAV")
            buffer.seek(0)
            audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
            return {
                "audio": audio_base64,
                "sample_rate": self.sample_rate,
                "format": "wav",
                "encoding": "base64"
            }
        else:
            return {
                "audio": audio.tolist(),
                "sample_rate": self.sample_rate
            }