maolin.liu
commited on
Commit
·
7ec99bc
1
Parent(s):
ef7f04e
[bugfix]Do not serialize string.
Browse files- consumer/asr.py +4 -4
- consumer/base.py +6 -4
consumer/asr.py
CHANGED
|
@@ -3,7 +3,7 @@ import io
|
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
from pathlib import Path
|
| 6 |
-
from typing import Literal
|
| 7 |
|
| 8 |
from faster_whisper import WhisperModel
|
| 9 |
from pydantic import BaseModel, Field, ValidationError, model_validator
|
|
@@ -58,7 +58,7 @@ class TranscribeConsumer(BasicMessageReceiver):
|
|
| 58 |
def setup_message_sender(self):
|
| 59 |
self.sender = BasicMessageSender()
|
| 60 |
|
| 61 |
-
def send_message(self, message: dict):
|
| 62 |
routing_key = 'transcribe-output'
|
| 63 |
# headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
|
| 64 |
self.sender.send_message(
|
|
@@ -71,11 +71,11 @@ class TranscribeConsumer(BasicMessageReceiver):
|
|
| 71 |
def send_success_message(self, uuid: str, transcribed_text):
|
| 72 |
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
|
| 73 |
transcribed_text=transcribed_text)
|
| 74 |
-
self.send_message(message.
|
| 75 |
|
| 76 |
def send_fail_message(self, uuid: str, error: str):
|
| 77 |
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
|
| 78 |
-
self.send_message(message.
|
| 79 |
|
| 80 |
def consume(self, channel, method, properties, message):
|
| 81 |
body = self.decode_message(message)
|
|
|
|
| 3 |
import logging
|
| 4 |
import os
|
| 5 |
from pathlib import Path
|
| 6 |
+
from typing import Literal, Union
|
| 7 |
|
| 8 |
from faster_whisper import WhisperModel
|
| 9 |
from pydantic import BaseModel, Field, ValidationError, model_validator
|
|
|
|
| 58 |
def setup_message_sender(self):
|
| 59 |
self.sender = BasicMessageSender()
|
| 60 |
|
| 61 |
+
def send_message(self, message: Union[dict, str]):
|
| 62 |
routing_key = 'transcribe-output'
|
| 63 |
# headers = Headers(job_id=f'{uuid.uuid4()}', priority=Priority.NORMAL)
|
| 64 |
self.sender.send_message(
|
|
|
|
| 71 |
def send_success_message(self, uuid: str, transcribed_text):
|
| 72 |
message = TranscribeOutputMessage(uuid=uuid, if_success=True, msg='Transcribe finished.',
|
| 73 |
transcribed_text=transcribed_text)
|
| 74 |
+
self.send_message(message.model_dump_json())
|
| 75 |
|
| 76 |
def send_fail_message(self, uuid: str, error: str):
|
| 77 |
message = TranscribeOutputMessage(uuid=uuid, if_success=False, msg=error)
|
| 78 |
+
self.send_message(message.model_dump_json())
|
| 79 |
|
| 80 |
def consume(self, channel, method, properties, message):
|
| 81 |
body = self.decode_message(message)
|
consumer/base.py
CHANGED
|
@@ -6,7 +6,7 @@ import os
|
|
| 6 |
import ssl
|
| 7 |
import time
|
| 8 |
from enum import Enum
|
| 9 |
-
from typing import Dict, Optional, Literal
|
| 10 |
|
| 11 |
import msgpack
|
| 12 |
import pika
|
|
@@ -140,11 +140,13 @@ class BasicPikaClient:
|
|
| 140 |
class BasicMessageSender(BasicPikaClient):
|
| 141 |
message_encoding_type: Literal['bytes', 'json'] = 'json'
|
| 142 |
|
| 143 |
-
def encode_message(self, body: Dict, encoding_type: str = "bytes"):
|
| 144 |
if encoding_type == "bytes":
|
| 145 |
return msgpack.packb(body)
|
| 146 |
elif encoding_type == "json":
|
| 147 |
-
|
|
|
|
|
|
|
| 148 |
else:
|
| 149 |
raise NotImplementedError
|
| 150 |
|
|
@@ -152,7 +154,7 @@ class BasicMessageSender(BasicPikaClient):
|
|
| 152 |
self,
|
| 153 |
exchange_name: str,
|
| 154 |
routing_key: str,
|
| 155 |
-
body: Dict,
|
| 156 |
headers: Optional[Headers],
|
| 157 |
):
|
| 158 |
body = self.encode_message(body=body, encoding_type=self.message_encoding_type)
|
|
|
|
| 6 |
import ssl
|
| 7 |
import time
|
| 8 |
from enum import Enum
|
| 9 |
+
from typing import Dict, Optional, Literal, Union
|
| 10 |
|
| 11 |
import msgpack
|
| 12 |
import pika
|
|
|
|
| 140 |
class BasicMessageSender(BasicPikaClient):
|
| 141 |
message_encoding_type: Literal['bytes', 'json'] = 'json'
|
| 142 |
|
| 143 |
+
def encode_message(self, body: Union[Dict, str], encoding_type: str = "bytes"):
|
| 144 |
if encoding_type == "bytes":
|
| 145 |
return msgpack.packb(body)
|
| 146 |
elif encoding_type == "json":
|
| 147 |
+
if isinstance(body, dict):
|
| 148 |
+
return json.dumps(body)
|
| 149 |
+
return body
|
| 150 |
else:
|
| 151 |
raise NotImplementedError
|
| 152 |
|
|
|
|
| 154 |
self,
|
| 155 |
exchange_name: str,
|
| 156 |
routing_key: str,
|
| 157 |
+
body: Union[Dict, str],
|
| 158 |
headers: Optional[Headers],
|
| 159 |
):
|
| 160 |
body = self.encode_message(body=body, encoding_type=self.message_encoding_type)
|