| | import asyncio |
| |
|
| | from ten import TenEnv, Data |
| |
|
| | from amazon_transcribe.auth import StaticCredentialResolver |
| | from amazon_transcribe.client import TranscribeStreamingClient |
| | from amazon_transcribe.handlers import TranscriptResultStreamHandler |
| | from amazon_transcribe.model import ( |
| | TranscriptEvent, |
| | TranscriptResultStream, |
| | StartStreamTranscriptionEventStream, |
| | ) |
| |
|
| | from .transcribe_config import TranscribeConfig |
| |
|
| | DATA_OUT_TEXT_DATA_PROPERTY_TEXT = "text" |
| | DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL = "is_final" |
| | DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID = "stream_id" |
| | DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT = "end_of_segment" |
| |
|
| | def create_and_send_data(ten: TenEnv, text_result: str, is_final: bool, stream_id: int = 0): |
| | stable_data = Data.create("text_data") |
| | stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_IS_FINAL, is_final) |
| | stable_data.set_property_string(DATA_OUT_TEXT_DATA_PROPERTY_TEXT, text_result) |
| | stable_data.set_property_int(DATA_OUT_TEXT_DATA_PROPERTY_STREAM_ID, stream_id) |
| | stable_data.set_property_bool(DATA_OUT_TEXT_DATA_PROPERTY_END_OF_SEGMENT, is_final) |
| | ten.send_data(stable_data) |
| |
|
| |
|
| | class AsyncTranscribeWrapper: |
| | def __init__( |
| | self, |
| | config: TranscribeConfig, |
| | queue: asyncio.Queue, |
| | ten: TenEnv, |
| | loop: asyncio.BaseEventLoop, |
| | ): |
| | self.queue = queue |
| | self.ten = ten |
| | self.stopped = False |
| | self.config = config |
| | self.loop = loop |
| | self.stream = None |
| | self.handler = None |
| | self.event_handler_task = None |
| |
|
| | if config.access_key and config.secret_key: |
| | ten.log_info(f"init trascribe client with access key: {config.access_key}") |
| | self.transcribe_client = TranscribeStreamingClient( |
| | region=config.region, |
| | credential_resolver=StaticCredentialResolver( |
| | access_key_id=config.access_key, secret_access_key=config.secret_key |
| | ), |
| | ) |
| | else: |
| | ten.log_info( |
| | "init trascribe client without access key, using default credentials provider chain." |
| | ) |
| |
|
| | self.transcribe_client = TranscribeStreamingClient(region=config.region) |
| |
|
| | asyncio.set_event_loop(self.loop) |
| | self.reset_stream() |
| |
|
| | def reset_stream(self): |
| | self.stream = None |
| | self.handler = None |
| | self.event_handler_task = None |
| |
|
| | async def cleanup(self): |
| | if self.stream: |
| | await self.stream.input_stream.end_stream() |
| | self.ten.log_info("cleanup: stream ended.") |
| |
|
| | if self.event_handler_task: |
| | await self.event_handler_task |
| | self.ten.log_info("cleanup: event handler ended.") |
| |
|
| | self.reset_stream() |
| |
|
| | async def create_stream(self, stream_id) -> bool: |
| | try: |
| | self.stream = await self.get_transcribe_stream() |
| | self.handler = TranscribeEventHandler(self.stream.output_stream, self.ten, stream_id) |
| | self.event_handler_task = asyncio.create_task(self.handler.handle_events()) |
| | except Exception as e: |
| | self.ten.log_error(str(e)) |
| | return False |
| |
|
| | return True |
| |
|
| | async def send_frame(self) -> None: |
| | while not self.stopped: |
| | try: |
| | pcm_frame = await asyncio.wait_for(self.queue.get(), timeout=3.0) |
| |
|
| | if pcm_frame is None: |
| | self.ten.log_warn("send_frame: exit due to None value got.") |
| | return |
| |
|
| | frame_buf = pcm_frame.get_buf() |
| | if not frame_buf: |
| | self.ten.log_warn("send_frame: empty pcm_frame detected.") |
| | continue |
| | stream_id = pcm_frame.get_property_int("stream_id") |
| | if not self.stream: |
| | self.ten.log_info("lazy init stream.") |
| | if not await self.create_stream(stream_id): |
| | continue |
| |
|
| | await self.stream.input_stream.send_audio_event(audio_chunk=frame_buf) |
| | self.queue.task_done() |
| | except asyncio.TimeoutError: |
| | if self.stream: |
| | await self.cleanup() |
| | self.ten.log_info( |
| | "send_frame: no data for 10s, will close current stream and create a new one when receving new frame." |
| | ) |
| | else: |
| | self.ten.log_info("send_frame: waiting for pcm frame.") |
| | except IOError as e: |
| | self.ten.log_error(f"Error in send_frame: {e}") |
| | except Exception as e: |
| | self.ten.log_error(f"Error in send_frame: {e}") |
| | raise e |
| |
|
| | self.ten.log_info("send_frame: exit due to self.stopped == True") |
| |
|
| | async def transcribe_loop(self) -> None: |
| | try: |
| | await self.send_frame() |
| | except Exception as e: |
| | self.ten.log_error(str(e)) |
| | finally: |
| | await self.cleanup() |
| |
|
| | async def get_transcribe_stream(self) -> StartStreamTranscriptionEventStream: |
| | stream = await self.transcribe_client.start_stream_transcription( |
| | language_code=self.config.lang_code, |
| | media_sample_rate_hz=self.config.sample_rate, |
| | media_encoding=self.config.media_encoding, |
| | ) |
| | return stream |
| |
|
| | def run(self) -> None: |
| | self.loop.run_until_complete(self.transcribe_loop()) |
| | self.loop.close() |
| | self.ten.log_info("async_transcribe_wrapper: thread completed.") |
| |
|
| | def stop(self) -> None: |
| | self.stopped = True |
| |
|
| |
|
| | class TranscribeEventHandler(TranscriptResultStreamHandler): |
| | def __init__(self, transcript_result_stream: TranscriptResultStream, ten: TenEnv, stream_id: int = 0): |
| | super().__init__(transcript_result_stream) |
| | self.ten = ten |
| | self.stream_id = stream_id |
| |
|
| | async def handle_transcript_event(self, transcript_event: TranscriptEvent) -> None: |
| | results = transcript_event.transcript.results |
| | text_result = "" |
| |
|
| | is_final = True |
| |
|
| | for result in results: |
| | if result.is_partial: |
| | is_final = False |
| | |
| |
|
| | for alt in result.alternatives: |
| | text_result += alt.transcript |
| |
|
| | if not text_result: |
| | return |
| |
|
| | self.ten.log_info(f"got transcript: [{text_result}], is_final: [{is_final}]") |
| |
|
| | create_and_send_data(ten=self.ten, text_result=text_result, is_final=is_final, stream_id=self.stream_id) |
| |
|