File size: 4,230 Bytes
81e51dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a524b41
6e81e30
81e51dc
 
 
 
 
f117bd8
 
 
 
 
 
81e51dc
 
 
 
 
 
 
 
 
 
 
 
41b9b54
 
7062f00
 
 
 
81e51dc
41b9b54
 
 
 
 
 
 
 
 
 
 
 
81e51dc
 
 
7062f00
 
 
 
81e51dc
7062f00
81e51dc
 
 
41b9b54
81e51dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df516aa
81e51dc
 
 
 
6e81e30
 
 
81e51dc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
from typing import Dict, Any
import tempfile
import torchaudio
import soundfile as sf
import re
from num2words import num2words
from f5_tts.model import DiT
from f5_tts.infer.utils_infer import (
    load_vocoder,
    load_model,
    preprocess_ref_audio_text,
    infer_process,
    remove_silence_for_generated_wav,
)
import base64
import io
import numpy as np
from huggingface_hub import hf_hub_download
import traceback

class EndpointHandler:
    def __init__(self, path=""):
        self.vocoder = load_vocoder()
        model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4)
        
        model_path = hf_hub_download(
            repo_id="jpgallegoar/F5-Spanish",
            filename="model_1200000.safetensors"
        )
        
        self.ema_model = load_model(DiT, model_cfg, model_path)

    def traducir_numero_a_texto(self, texto):
        texto_separado = re.sub(r'([A-Za-z])(\d)', r'\1 \2', texto)
        texto_separado = re.sub(r'(\d)([A-Za-z])', r'\1 \2', texto_separado)
        def reemplazar_numero(match):
            numero = match.group()
            return num2words(int(numero), lang='es')
        return re.sub(r'\b\d+\b', reemplazar_numero, texto_separado)

    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        try:
            ref_audio_base64 = data.get("ref_audio")
            if not ref_audio_base64:
                return {
                    "success": False,
                    "error": "Missing required field: 'ref_audio'"
                }

            # Decode base64 audio and write to temp file
            try:
                audio_bytes = base64.b64decode(ref_audio_base64)
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_audio_file:
                    temp_audio_file.write(audio_bytes)
                    temp_audio_path = temp_audio_file.name
            except Exception as e:
                return {
                    "success": False,
                    "error": f"Invalid audio data: {type(e).__name__}: {str(e)}"
                }

            ref_text = data.get("ref_text", "")
            gen_text = data.get("gen_text", "")
            if not gen_text:
                return {
                    "success": False,
                    "error": "Missing required field: 'gen_text'"
                }

            remove_silence = data.get("remove_silence", True)
            cross_fade_duration = data.get("cross_fade_duration", 0.15)
            speed = data.get("speed", 1.0)

            ref_audio, ref_text = preprocess_ref_audio_text(temp_audio_path, ref_text, show_info=print)

            if not gen_text.startswith(" "):
                gen_text = " " + gen_text
            if not gen_text.endswith(". "):
                gen_text += ". "
            gen_text = self.traducir_numero_a_texto(gen_text.lower())

            final_wave, final_sample_rate, _ = infer_process(
                ref_audio,
                ref_text,
                gen_text,
                self.ema_model,
                self.vocoder,
                cross_fade_duration=cross_fade_duration,
                speed=speed,
                show_info=print,
            )

            if remove_silence:
                with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
                    sf.write(f.name, final_wave, final_sample_rate)
                    remove_silence_for_generated_wav(f.name)
                    final_wave, _ = torchaudio.load(f.name)
                final_wave = final_wave.squeeze().cpu().numpy()

            with io.BytesIO() as buffer:
                sf.write(buffer, final_wave, final_sample_rate, format="WAV")
                buffer.seek(0)
                encoded_audio = base64.b64encode(buffer.read()).decode("utf-8")

            return {
                "success": True,
                "audio_base64": encoded_audio
                }

        except Exception as e:
            print("==== Exception Traceback ====")
            traceback.print_exc()
            print("==== End Traceback ====")
            return {
                "success": False,
                "error": f"{type(e).__name__}: {str(e)}"
                }