File size: 2,773 Bytes
f20a8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a85341
 
f20a8ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import asyncio
import utils

class BaseEngine:
    def __init__(self, name):
        self.lock = asyncio.Lock()
        self.name = name
        self.tts = None
        # Initialize with default, subclass should overwrite or load_model should update it
        self.sample_rate = 24000 

        print(f"Init model {self.name}")
        self.load_model()

    def load_model(self):
        raise NotImplementedError("Subclass must implement abstract method")

    def get_style_safe(self, voice_name: str):
        raise NotImplementedError("Subclass must implement abstract method")
    
    # FIX: Changed from async to sync because it's run in an executor
    # FIX: Fixed typo 'genetrate' -> 'generate'
    def generate(self, chunks: str, voice_name: str, speed: float):
        """
        Should return (audio_float_array, sample_rate)
        This method is CPU blocking, so it stays synchronous.
        """
        raise NotImplementedError("Subclass must implement abstract method")

    # FIX: Added a default preprocessor in case subclass doesn't have one
    def preprocess_text(self, text: str):
        return text
    async def stream_generator(self, text: str, voice_name: str, speed: float, format: str):
        encoder = None

        if format == "wav":
            yield utils.create_wav_header(self.sample_rate)
        elif format == "mp3":
            encoder = utils.create_mp3_encoder(sample_rate=self.sample_rate)

        # Preprocess text and voice
        try:
            voice_name = self.get_style_safe(voice_name)
        except NotImplementedError:
            pass 

        chunks = self.preprocess_text(text)
        
        
        loop = asyncio.get_event_loop()
        
        for i, chunk in enumerate(chunks):
            async with self.lock:
                # Run synchronous generation in executor
                audio_float = await loop.run_in_executor(
                    None, 
                    self.generate, 
                    chunk,
                    voice_name,
                    speed 
                )
            for audio in audio_float:
                if format == "wav":
                    pcm_bytes = utils.float_to_pcm16(audio)
                    yield pcm_bytes
                
                elif format == "mp3":
                    # This now returns 'bytes', so it is safe
                    mp3_bytes = utils.float_to_mp3(audio, encoder)
                    if len(mp3_bytes) > 0:
                        yield mp3_bytes

        # Flush MP3 encoder to get remaining audio frames
        if format == "mp3" and encoder is not None:
            final_data = encoder.flush()
            if len(final_data) > 0:
                yield bytes(final_data) # <--- CRITICAL FIX: Cast to bytes