Spaces:
Running
Running
File size: 4,821 Bytes
750bbe6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
#!/usr/bin/env python3
"""
gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR.
"""
from concurrent import futures
import time
import argparse
import signal
import sys
import os
import backend_pb2
import backend_pb2_grpc
import torch
import nemo.collections.asr as nemo_asr
import grpc
def is_float(s):
try:
float(s)
return True
except ValueError:
return False
def is_int(s):
try:
int(s)
return True
except ValueError:
return False
_ONE_DAY_IN_SECONDS = 60 * 60 * 24
MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1'))
class BackendServicer(backend_pb2_grpc.BackendServicer):
def Health(self, request, context):
return backend_pb2.Reply(message=bytes("OK", 'utf-8'))
def LoadModel(self, request, context):
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
if mps_available:
device = "mps"
if not torch.cuda.is_available() and request.CUDA:
return backend_pb2.Result(success=False, message="CUDA is not available")
self.device = device
self.options = {}
for opt in request.Options:
if ":" not in opt:
continue
key, value = opt.split(":", 1)
if is_float(value):
value = float(value)
elif is_int(value):
value = int(value)
elif value.lower() in ["true", "false"]:
value = value.lower() == "true"
self.options[key] = value
model_name = request.Model or "nvidia/parakeet-tdt-0.6b-v3"
try:
print(f"Loading NEMO ASR model from {model_name}", file=sys.stderr)
self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name)
print("NEMO ASR model loaded successfully", file=sys.stderr)
except Exception as err:
print(f"[ERROR] LoadModel failed: {err}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return backend_pb2.Result(success=False, message=str(err))
return backend_pb2.Result(message="Model loaded successfully", success=True)
def AudioTranscription(self, request, context):
result_segments = []
text = ""
try:
audio_path = request.dst
if not audio_path or not os.path.exists(audio_path):
print(f"Error: Audio file not found: {audio_path}", file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")
# NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts
results = self.model.transcribe([audio_path])
if not results or len(results) == 0:
return backend_pb2.TranscriptResult(segments=[], text="")
# Get the transcript text from the first result
text = results[0]
if text:
# Create a single segment with the full transcription
result_segments.append(backend_pb2.TranscriptSegment(
id=0, start=0, end=0, text=text
))
except Exception as err:
print(f"Error in AudioTranscription: {err}", file=sys.stderr)
import traceback
traceback.print_exc(file=sys.stderr)
return backend_pb2.TranscriptResult(segments=[], text="")
return backend_pb2.TranscriptResult(segments=result_segments, text=text)
def serve(address):
server = grpc.server(
futures.ThreadPoolExecutor(max_workers=MAX_WORKERS),
options=[
('grpc.max_message_length', 50 * 1024 * 1024),
('grpc.max_send_message_length', 50 * 1024 * 1024),
('grpc.max_receive_message_length', 50 * 1024 * 1024),
])
backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server)
server.add_insecure_port(address)
server.start()
print("Server started. Listening on: " + address, file=sys.stderr)
def signal_handler(sig, frame):
print("Received termination signal. Shutting down...")
server.stop(0)
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
try:
while True:
time.sleep(_ONE_DAY_IN_SECONDS)
except KeyboardInterrupt:
server.stop(0)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run the gRPC server.")
parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.")
args = parser.parse_args()
serve(args.addr)
|