File size: 4,880 Bytes
750bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#!/usr/bin/env python3
"""
gRPC server for OuteTTS (OuteAI TTS) models.
"""
from concurrent import futures

import argparse
import signal
import sys
import os
import asyncio

import backend_pb2
import backend_pb2_grpc

import grpc
import outetts

_ONE_DAY_IN_SECONDS = 60 * 60 * 24

MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))


class BackendServicer(backend_pb2_grpc.BackendServicer):
    def Health(self, request, context):
        return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

    def LoadModel(self, request, context):
        model_name = request.Model
        if os.path.exists(request.ModelFile):
            model_name = request.ModelFile

        self.options = {}
        for opt in request.Options:
            if ":" not in opt:
                continue
            key, value = opt.split(":", 1)
            try:
                if "." in value:
                    value = float(value)
                else:
                    value = int(value)
            except ValueError:
                pass
            self.options[key] = value

        MODELNAME = "OuteAI/OuteTTS-0.3-1B"
        TOKENIZER = "OuteAI/OuteTTS-0.3-1B"
        VERSION = "0.3"
        SPEAKER = "en_male_1"
        for opt in request.Options:
            if opt.startswith("tokenizer:"):
                TOKENIZER = opt.split(":")[1]
                break
            if opt.startswith("version:"):
                VERSION = opt.split(":")[1]
                break
            if opt.startswith("speaker:"):
                SPEAKER = opt.split(":")[1]
                break

        if model_name != "":
            MODELNAME = model_name

        try:
            model_config = outetts.HFModelConfig_v2(
                model_path=MODELNAME,
                tokenizer_path=TOKENIZER
            )
            self.interface = outetts.InterfaceHF(model_version=VERSION, cfg=model_config)

            self.interface.print_default_speakers()
            if request.AudioPath:
                if os.path.isabs(request.AudioPath):
                    self.AudioPath = request.AudioPath
                else:
                    self.AudioPath = os.path.join(request.ModelPath, request.AudioPath)
                self.speaker = self.interface.create_speaker(audio_path=self.AudioPath)
            else:
                self.speaker = self.interface.load_default_speaker(name=SPEAKER)

            if request.ContextSize > 0:
                self.max_tokens = request.ContextSize
            else:
                self.max_tokens = self.options.get("max_new_tokens", 512)

        except Exception as err:
            print("Error:", err, file=sys.stderr)
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        return backend_pb2.Result(message="Model loaded successfully", success=True)

    def TTS(self, request, context):
        try:
            text = request.text if request.text else "Speech synthesis is the artificial production of human speech."
            print("[OuteTTS] generating TTS", file=sys.stderr)
            gen_cfg = outetts.GenerationConfig(
                text=text,
                temperature=self.options.get("temperature", 0.1),
                repetition_penalty=self.options.get("repetition_penalty", 1.1),
                max_length=self.max_tokens,
                speaker=self.speaker,
            )
            output = self.interface.generate(config=gen_cfg)
            print("[OuteTTS] Generated TTS", file=sys.stderr)
            output.save(request.dst)
            print("[OuteTTS] TTS done", file=sys.stderr)
        except Exception as err:
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        return backend_pb2.Result(success=True)


async def serve(address):
    server = grpc.aio.server(
        migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
        options=[
            ('grpc.max_message_length', 50 * 1024 * 1024),
            ('grpc.max_send_message_length', 50 * 1024 * 1024),
            ('grpc.max_receive_message_length', 50 * 1024 * 1024),
        ])
    backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
    server.add_insecure_port(address)

    loop = asyncio.get_event_loop()
    for sig in (signal.SIGINT, signal.SIGTERM):
        loop.add_signal_handler(
            sig, lambda: asyncio.ensure_future(server.stop(5))
        )

    await server.start()
    print("Server started. Listening on: " + address, file=sys.stderr)
    await server.wait_for_termination()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the OuteTTS gRPC server.")
    parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.")
    args = parser.parse_args()
    asyncio.run(serve(args.addr))