| from concurrent import futures |
| import torch |
| from models import build_model |
| from collections import deque |
| import grpc |
| import text_to_speech_pb2 |
| import text_to_speech_pb2_grpc |
| from chat_database import save_chat_entry |
| import fastAPI |
| from providers.audio_provider import get_audio_bytes, dummy_bytes, generate_audio_from_chunks |
| from providers.llm_provider import getResponseWithRAG, getResponse |
|
|
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| MODEL = build_model('kokoro-v0_19.pth', device) |
|
|
| VOICE_NAME = [ |
| 'af', |
| 'af_bella', 'af_sarah', 'am_adam', 'am_michael', |
| 'bf_emma', 'bf_isabella', 'bm_george', 'bm_lewis', |
| 'af_nicole', 'af_sky', |
| ][0] |
|
|
|
|
| VOICEPACK = torch.load( |
| f'voices/{VOICE_NAME}.pt', weights_only=True).to(device) |
|
|
|
|
| class TextToSpeechServicer(text_to_speech_pb2_grpc.TextToSpeechServiceServicer): |
| def ProcessText(self, request_iterator , context): |
| try: |
| global VOICEPACK |
| print("Received new request") |
| parameters = { |
| "processing_active": False, |
| "queue": deque(), |
| "file_number": 0, |
| "session_id": "", |
| "interrupt_seq": 0, |
| "temperature": 1, |
| "activeVoice": "af", |
| "maxTokens": 500, |
| } |
| for request in request_iterator: |
| field = request.WhichOneof('request_data') |
| if field == 'metadata': |
| meta = request.metadata |
| print("Metadata received:") |
| print(" session_id:", meta.session_id) |
| print(" silenceDuration:", meta.silenceDuration) |
| print(" threshold:", meta.threshold) |
| print(" temperature:", meta.temperature) |
| print(" activeVoice:", meta.activeVoice) |
| print(" maxTokens:", meta.maxTokens) |
| print("Metadata : ", request.metadata) |
| if meta.session_id: |
| parameters["session_id"] = meta.session_id |
| if meta.temperature: |
| parameters["temperature"] = meta.temperature |
| if meta.maxTokens: |
| parameters["maxTokens"] = meta.maxTokens |
| if meta.activeVoice: |
| parameters["activeVoice"] = meta.activeVoice |
| VOICEPACK = torch.load( |
| f'voices/{parameters["activeVoice"]}.pt', weights_only=True).to(device) |
| continue |
| elif field == 'text': |
| text = request.text |
| if not text: |
| continue |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| save_chat_entry(parameters["session_id"], "user", text) |
| parameters["queue"].clear() |
| yield text_to_speech_pb2.ProcessTextResponse( |
| buffer=dummy_bytes(), |
| session_id=parameters["session_id"], |
| sequence_id="-2", |
| transcript=text, |
| ) |
| final_response = "" |
| complete_response = "" |
| response = getResponse(text, parameters["session_id"]) |
| for chunk in response: |
| msg = chunk.choices[0].delta.content |
| if msg: |
| final_response += msg |
| complete_response += msg |
| if final_response.endswith(('.', '!', '?')): |
| parameters["file_number"] += 1 |
| parameters["queue"].append( |
| (final_response, parameters["file_number"])) |
| final_response = "" |
| if not parameters["processing_active"]: |
| yield from self.process_queue(parameters) |
|
|
| if final_response: |
| parameters["file_number"] += 1 |
| parameters["queue"].append( |
| (final_response, parameters["file_number"])) |
| if not parameters["processing_active"]: |
| yield from self.process_queue(parameters) |
|
|
| if ("Let me check" in complete_response): |
| final_response = "" |
| complete_response = "" |
| response = getResponseWithRAG( |
| text, parameters["session_id"]) |
| for chunk in response: |
| msg = chunk.choices[0].delta.content |
| if msg: |
| final_response += msg |
| complete_response += msg |
| if final_response.endswith(('.', '!', '?')): |
| parameters["file_number"] += 1 |
| parameters["queue"].append( |
| (final_response, parameters["file_number"])) |
| final_response = "" |
| if not parameters["processing_active"]: |
| yield from self.process_queue(parameters) |
|
|
| if final_response: |
| parameters["file_number"] += 1 |
| parameters["queue"].append( |
| (final_response, parameters["file_number"])) |
| if not parameters["processing_active"]: |
| yield from self.process_queue(parameters) |
|
|
| elif field == 'status': |
| transcript = request.status.transcript |
| played_seq = request.status.played_seq |
| interrupt_seq = request.status.interrupt_seq |
| parameters["interrupt_seq"] = interrupt_seq |
| save_chat_entry( |
| parameters["session_id"], "assistant", transcript) |
| continue |
| else: |
| continue |
| except Exception as e: |
| print("Error in ProcessText:", e) |
|
|
| def process_queue(self, parameters): |
| global VOICEPACK |
| try: |
| while True: |
| if not parameters["queue"]: |
| parameters["processing_active"] = False |
| break |
| parameters["processing_active"] = True |
| sentence, file_number = parameters["queue"].popleft() |
| if file_number <= int(parameters["interrupt_seq"]): |
| continue |
|
|
| combined_audio = generate_audio_from_chunks( |
| sentence, MODEL, VOICEPACK, VOICE_NAME) |
| audio_bytes = get_audio_bytes(combined_audio) |
| |
| yield text_to_speech_pb2.ProcessTextResponse( |
| buffer=audio_bytes, |
| session_id=parameters["session_id"], |
| sequence_id=str(file_number), |
| transcript=sentence, |
| ) |
| except Exception as e: |
| parameters["processing_active"] = False |
| print("Error in process_queue:", e) |
|
|
|
|
| def serve(): |
| print("Starting gRPC server...") |
| server = grpc.server(futures.ThreadPoolExecutor(max_workers=1)) |
| text_to_speech_pb2_grpc.add_TextToSpeechServiceServicer_to_server( |
| TextToSpeechServicer(), server) |
| server.add_insecure_port('[::]:8081') |
| server.start() |
| print("gRPC server is running on port 8081") |
| server.wait_for_termination() |
|
|
|
|
| if __name__ == "__main__": |
| serve() |
|
|