File size: 5,209 Bytes
10e72d3
 
da2ee9a
10e72d3
 
 
 
 
da2ee9a
3784ae7
10e72d3
 
 
 
da2ee9a
10e72d3
 
da2ee9a
 
10e72d3
 
da2ee9a
10e72d3
 
 
da2ee9a
10e72d3
 
da2ee9a
10e72d3
 
da2ee9a
10e72d3
 
da2ee9a
10e72d3
 
 
 
da2ee9a
 
 
 
10e72d3
 
 
 
 
 
da2ee9a
 
 
 
 
 
 
 
 
3784ae7
da2ee9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10e72d3
da2ee9a
10e72d3
 
da2ee9a
10e72d3
 
 
 
 
 
da2ee9a
10e72d3
da2ee9a
10e72d3
da2ee9a
10e72d3
 
da2ee9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10e72d3
da2ee9a
 
 
 
 
10e72d3
da2ee9a
10e72d3
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import os
import sys
import re
import numpy as np
import torch
import soundfile as sf
import spaces
from config import models_path, results_path, sample_path, BASE_DIR
from sentence_splitter import PersianSentenceSplitter
from persian_numbers import find_and_normalize_numbers

encoder = None
synthesizer = None
vocoder = None
sentence_splitter = None

def load_models():
    global encoder, synthesizer, vocoder, sentence_splitter

    try:
        sys.path.append(os.path.join(BASE_DIR, 'pmt2'))

        from encoder import inference as encoder_module
        from synthesizer.inference import Synthesizer
        from parallel_wavegan.utils import load_model as vocoder_hifigan

        global encoder
        encoder = encoder_module

        print("Loading encoder model...")
        encoder.load_model(os.path.join(models_path, 'encoder.pt'))

        print("Loading synthesizer model...")
        synthesizer = Synthesizer(os.path.join(models_path, 'synthesizer.pt'))

        print("Loading HiFiGAN vocoder...")
        vocoder = vocoder_hifigan(os.path.join(models_path, 'vocoder_HiFiGAN.pkl'))
        vocoder.remove_weight_norm()
        vocoder = vocoder.eval().to('cuda' if torch.cuda.is_available() else 'cpu')

        sentence_splitter = PersianSentenceSplitter(max_chars=150, min_chars=30)

        print("Models loaded successfully!")
        return True
    except Exception as e:
        import traceback
        print(f"Error loading models: {traceback.format_exc()}")
        return False


def normalize_text_for_synthesis(text: str) -> str:
    text = text.replace('ك', 'ک').replace('ي', 'ی')

    text = text.replace('_', '\u200c')

    text = re.sub(r'\s+', ' ', text)
    text = text.strip()

    text = find_and_normalize_numbers(text)

    return text


def synthesize_segment(text_segment: str, embed: np.ndarray) -> np.ndarray:
    try:
        text_segment = normalize_text_for_synthesis(text_segment)

        specs = synthesizer.synthesize_spectrograms([text_segment], [embed])
        spec = specs[0]

        x = torch.from_numpy(spec.T).to('cuda' if torch.cuda.is_available() else 'cpu')

        with torch.no_grad():
            wav = vocoder.inference(x)

        wav = wav.cpu().numpy()

        if wav.ndim > 1:
            wav = wav.squeeze()

        return wav

    except Exception as e:
        import traceback
        print(f"Error synthesizing segment '{text_segment[:50]}...': {traceback.format_exc()}")
        return None


def add_silence(duration_ms: int = 300) -> np.ndarray:
    sample_rate = synthesizer.sample_rate
    num_samples = int(sample_rate * duration_ms / 1000)
    return np.zeros(num_samples, dtype=np.float32)


@spaces.GPU(duration=120)
def generate_speech(text, reference_audio=None, add_pauses: bool = True):
    if not text or text.strip() == "":
        return None

    try:
        if reference_audio is None:
            ref_wav_path = sample_path
        else:
            ref_wav_path = os.path.join(results_path, "reference_audio.wav")
            sf.write(ref_wav_path, reference_audio[1], reference_audio[0])

        print(f"Using reference audio: {ref_wav_path}")

        wav = synthesizer.load_preprocess_wav(ref_wav_path)

        encoder_wav = encoder.preprocess_wav(wav)
        embed, partial_embeds, _ = encoder.embed_utterance(encoder_wav, return_partials=True)

        text_segments = sentence_splitter.split(text)

        print(f"Split text into {len(text_segments)} segments:")
        for i, segment in enumerate(text_segments, 1):
            print(f"  Segment {i}: {segment[:60]}{'...' if len(segment) > 60 else ''}")

        audio_segments = []
        silence = add_silence(300) if add_pauses else None  # 300ms pause

        for i, segment in enumerate(text_segments):
            print(f"Processing segment {i+1}/{len(text_segments)}...")

            segment_wav = synthesize_segment(segment, embed)

            if segment_wav is not None:
                segment_wav = segment_wav.flatten() if segment_wav.ndim > 1 else segment_wav
                audio_segments.append(segment_wav)

                if add_pauses and i < len(text_segments) - 1:
                    audio_segments.append(silence)
            else:
                print(f"Warning: Failed to synthesize segment {i+1}")

        if not audio_segments:
            print("Error: No audio segments were generated successfully")
            return None

        audio_segments = [seg.flatten() if seg.ndim > 1 else seg for seg in audio_segments]

        final_wav = np.concatenate(audio_segments)

        final_wav = final_wav / np.abs(final_wav).max() * 0.97

        output_filename = f"generated_{abs(hash(text)) % 100000}.wav"
        output_path = os.path.join(results_path, output_filename)
        sf.write(output_path, final_wav, synthesizer.sample_rate)

        print(f"✓ Successfully generated speech: {output_path}")
        print(f"  Total duration: {len(final_wav) / synthesizer.sample_rate:.2f} seconds")

        return output_path

    except Exception as e:
        import traceback
        error_details = traceback.format_exc()
        print(f"Error generating speech: {error_details}")
        return None