File size: 5,941 Bytes
750bbe6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
#!/usr/bin/env python3
"""
This is an extra gRPC server of LocalAI for WhisperX transcription
with speaker diarization, word-level timestamps, and forced alignment.
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import backend_pb2
import backend_pb2_grpc

import grpc


_ONE_DAY_IN_SECONDS = 60 * 60 * 24

# If MAX_WORKERS are specified in the environment use it, otherwise default to 1
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))

# Implement the BackendServicer class with the service methods
class BackendServicer(backend_pb2_grpc.BackendServicer):
    """
    BackendServicer is the class that implements the gRPC service
    """
    def Health(self, request, context):
        return backend_pb2.Reply(message=bytes("OK", 'utf-8'))

    def LoadModel(self, request, context):
        import whisperx
        import torch

        device = "cpu"
        if request.CUDA:
            device = "cuda"
        mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
        if mps_available:
            device = "mps"

        try:
            print("Preparing WhisperX model, please wait", file=sys.stderr)
            compute_type = "float16" if device != "cpu" else "int8"
            self.model = whisperx.load_model(
                request.Model,
                device,
                compute_type=compute_type,
            )
            self.device = device
            self.model_name = request.Model

            # Store HF token for diarization if available
            self.hf_token = os.environ.get("HF_TOKEN", None)
            self.diarize_pipeline = None

            # Cache for alignment models keyed by language code
            self.align_cache = {}

            print(f"WhisperX model loaded: {request.Model} on {device}", file=sys.stderr)
        except Exception as err:
            return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}")
        return backend_pb2.Result(message="Model loaded successfully", success=True)

    def _get_align_model(self, language_code):
        """Load or return cached alignment model for a given language."""
        import whisperx
        if language_code not in self.align_cache:
            model_a, metadata = whisperx.load_align_model(
                language_code=language_code,
                device=self.device,
            )
            self.align_cache[language_code] = (model_a, metadata)
        return self.align_cache[language_code]

    def AudioTranscription(self, request, context):
        import whisperx

        resultSegments = []
        text = ""
        try:
            audio = whisperx.load_audio(request.dst)

            # Transcribe
            transcript = self.model.transcribe(
                audio,
                batch_size=16,
                language=request.language if request.language else None,
            )

            # Align for word-level timestamps
            model_a, metadata = self._get_align_model(transcript["language"])
            transcript = whisperx.align(
                transcript["segments"],
                model_a,
                metadata,
                audio,
                self.device,
                return_char_alignments=False,
            )

            # Diarize if requested and HF token is available
            if request.diarize and self.hf_token:
                if self.diarize_pipeline is None:
                    self.diarize_pipeline = whisperx.DiarizationPipeline(
                        use_auth_token=self.hf_token,
                        device=self.device,
                    )
                diarize_segments = self.diarize_pipeline(audio)
                transcript = whisperx.assign_word_speakers(diarize_segments, transcript)

            # Build result segments
            for idx, seg in enumerate(transcript["segments"]):
                seg_text = seg.get("text", "")
                start = int(seg.get("start", 0))
                end = int(seg.get("end", 0))
                speaker = seg.get("speaker", "")

                resultSegments.append(backend_pb2.TranscriptSegment(
                    id=idx,
                    start=start,
                    end=end,
                    text=seg_text,
                    speaker=speaker,
                ))
                text += seg_text

        except Exception as err:
            print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr)
            return backend_pb2.TranscriptResult(segments=[], text="")

        return backend_pb2.TranscriptResult(segments=resultSegments, text=text)

def serve(address):
    server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
        options=[
            ('grpc.max_message_length', 50 * 1024 * 1024),  # 50MB
            ('grpc.max_send_message_length', 50 * 1024 * 1024),  # 50MB
            ('grpc.max_receive_message_length', 50 * 1024 * 1024),  # 50MB
        ])
    backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
    server.add_insecure_port(address)
    server.start()
    print("Server started. Listening on: " + address, file=sys.stderr)

    # Define the signal handler function
    def signal_handler(sig, frame):
        print("Received termination signal. Shutting down...")
        server.stop(0)
        sys.exit(0)

    # Set the signal handlers for SIGINT and SIGTERM
    signal.signal(signal.SIGINT, signal_handler)
    signal.signal(signal.SIGTERM, signal_handler)

    try:
        while True:
            time.sleep(_ONE_DAY_IN_SECONDS)
    except KeyboardInterrupt:
        server.stop(0)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run the gRPC server.")
    parser.add_argument(
        "--addr", default="localhost:50051", help="The address to bind the server to."
    )
    args = parser.parse_args()

    serve(args.addr)