Spaces:
Running
Running
| import json | |
| import time | |
| from omagent_core.clients.devices.Aaas.schemas import (ConversationEvent, MessageType) | |
| from omagent_core.clients.input_base import InputBase | |
| from omagent_core.engine.http.models.workflow_status import running_status | |
| from omagent_core.engine.orkes.orkes_workflow_client import ( | |
| workflow_client) | |
| from omagent_core.services.connectors.redis import RedisConnector | |
| from omagent_core.utils.logger import logging | |
| from omagent_core.utils.registry import registry | |
| class AaasInput(InputBase): | |
| redis_stream_client: RedisConnector | |
| def read_input(self, workflow_instance_id: str, input_prompt=""): | |
| result = self._parse_workflow_instance_id(workflow_instance_id) | |
| workflow_instance_id = result.get('workflow_instance_id', '') | |
| agent_id = result.get('agent_id', '') | |
| conversation_id = result.get('conversation_id', '') | |
| chat_id = result.get('chat_id', '') | |
| stream_name = f"agent_os:conversation:input:{workflow_instance_id}" | |
| group_name = "OmAaasAgentConsumerGroup" # consumer group name | |
| consumer_name = f"{workflow_instance_id}_agent" # consumer name | |
| poll_interval: int = 1 | |
| if input_prompt is not None: | |
| start_id = self.send_output_message(agent_id, conversation_id, chat_id, input_prompt) | |
| else: | |
| current_timestamp = int(time.time() * 1000) | |
| start_id = f"{current_timestamp}-0" | |
| result = {} | |
| # ensure consumer group exists | |
| try: | |
| self.redis_stream_client._client.xgroup_create( | |
| stream_name, group_name, id="0", mkstream=True | |
| ) | |
| except Exception as e: | |
| logging.debug(f"Consumer group may already exist: {e}") | |
| logging.info( | |
| f"Listening to Redis stream: {stream_name} in group: {group_name} start_id: {start_id}" | |
| ) | |
| data_flag = False | |
| while True: | |
| try: | |
| # logging.info(f"Checking workflow status: {workflow_instance_id}") | |
| workflow_status = workflow_client.get_workflow_status( | |
| workflow_instance_id | |
| ) | |
| if workflow_status.status not in running_status: | |
| logging.info( | |
| f"Workflow {workflow_instance_id} is not running, exiting..." | |
| ) | |
| break | |
| # read new messages from redis stream | |
| messages = self.redis_stream_client._client.xrevrange( | |
| stream_name, max="+", min=start_id, count=1 | |
| ) | |
| logging.info(f"Messages: {messages}") | |
| # Convert byte data to string | |
| messages = [ | |
| ( | |
| message_id, | |
| { | |
| k.decode("utf-8"): v.decode("utf-8") | |
| for k, v in message.items() | |
| }, | |
| ) | |
| for message_id, message in messages | |
| ] | |
| for message_id, message in messages: | |
| data_flag = self.process_message(message, result) | |
| if data_flag: | |
| break | |
| # Sleep for the specified interval before checking for new messages again | |
| # logging.info(f"Sleeping for {poll_interval} seconds, waiting for {stream_name} ...") | |
| time.sleep(poll_interval) | |
| except Exception as e: | |
| logging.error(f"Error while listening to stream: {e}") | |
| time.sleep(poll_interval) # Wait before retrying | |
| return result | |
| def process_message(self, message, result): | |
| logging.info(f"Received message: {message}") | |
| try: | |
| payload = message.get("payload") | |
| messages = [] | |
| for dialong in payload.get('messages', []): | |
| content = [] | |
| for item in dialong.get('contents', []): | |
| content.append({ | |
| 'type': item.get('contentType', 'unknown'), | |
| 'data': item.get('content') | |
| }) | |
| messages.append({ | |
| 'role': dialong.get('role'), | |
| 'content': content | |
| }) | |
| payload['messages'] = messages | |
| """ | |
| { | |
| "agent_id": "string", | |
| "messages": [ | |
| { | |
| "role": "string", | |
| "content": [ | |
| { | |
| "type": "string", | |
| "data": "string" | |
| } | |
| ] | |
| } | |
| ], | |
| "kwargs": {} | |
| } | |
| """ | |
| # check payload data | |
| if not payload: | |
| logging.error("Payload is empty") | |
| return False | |
| try: | |
| payload_data = json.loads(payload) | |
| except json.JSONDecodeError as e: | |
| logging.error(f"Payload is not a valid JSON: {e}") | |
| return False | |
| if "agent_id" not in payload_data: | |
| logging.error("Payload does not contain 'agent_id' key") | |
| return False | |
| if "messages" not in payload_data: | |
| logging.error("Payload does not contain 'messages' key") | |
| return False | |
| if not isinstance(payload_data["messages"], list): | |
| logging.error("'messages' should be a list") | |
| return False | |
| for message in payload_data["messages"]: | |
| if not isinstance(message, dict): | |
| logging.error("Each item in 'messages' should be a dictionary") | |
| return False | |
| if "role" not in message or "content" not in message: | |
| logging.error( | |
| "Each item in 'messages' should contain 'role' and 'content' keys" | |
| ) | |
| return False | |
| if not isinstance(message["content"], list): | |
| logging.error("'content' should be a list") | |
| return False | |
| for content in message["content"]: | |
| if not isinstance(content, dict): | |
| logging.error("Each item in 'content' should be a dictionary") | |
| return False | |
| if "type" not in content or "data" not in content: | |
| logging.error( | |
| "Each item in 'content' should contain 'type' and 'data' keys" | |
| ) | |
| return False | |
| message_data = json.loads(payload) | |
| result.update(message_data) | |
| except Exception as e: | |
| logging.error(f"Error processing message: {e}") | |
| return False | |
| return True | |
| def _parse_workflow_instance_id(data: str): | |
| split_data = data.split('|') | |
| if not split_data: | |
| return {} | |
| result = {} | |
| keys = [ | |
| 'workflow_instance_id', | |
| 'agent_id', | |
| 'conversation_id', | |
| 'chat_id', | |
| ] | |
| for index, value in enumerate(split_data): | |
| if index + 1 <= len(keys): | |
| result.setdefault(keys[index], value) | |
| return result | |
| def _create_output_data( | |
| self, | |
| event='', | |
| conversation_id='', | |
| chat_id='', | |
| agent_id='', | |
| status='', | |
| contentType='', | |
| content='', | |
| type='', | |
| is_finish=True | |
| ): | |
| data = { | |
| 'content': json.dumps({ | |
| 'event': event, | |
| 'data': { | |
| 'conversationId': conversation_id, | |
| 'chatId': chat_id, | |
| 'agentId': agent_id, | |
| 'createTime': None, | |
| 'endTime': None, | |
| 'status': status, | |
| 'contentType': contentType, | |
| 'content': content, | |
| 'type': type, | |
| 'isFinish': is_finish | |
| } | |
| }, ensure_ascii=False) | |
| } | |
| return data | |
| def send_base_message( | |
| self, | |
| event='', | |
| conversation_id='', | |
| chat_id='', | |
| agent_id='', | |
| status='', | |
| contentType='', | |
| content='', | |
| type='', | |
| is_finish=True | |
| ): | |
| stream_name = f"agent_os:conversation:output:{conversation_id}" | |
| group_name = "OmAaasAgentConsumerGroup" # replace with your consumer group name | |
| message = self._create_output_data( | |
| event=event, | |
| conversation_id=conversation_id, | |
| chat_id=chat_id, | |
| agent_id=agent_id, | |
| status=status, | |
| contentType=contentType, | |
| content=content, | |
| type=type, | |
| is_finish=is_finish | |
| ) | |
| message_id = self.send_to_group(stream_name, group_name, message) | |
| return message_id | |
| def send_output_message( | |
| self, | |
| agent_id, | |
| conversation_id, | |
| chat_id, | |
| msg, | |
| ): | |
| return self.send_base_message( | |
| event=ConversationEvent.MESSAGE_DELTA.value, | |
| conversation_id=conversation_id, | |
| chat_id=chat_id, | |
| agent_id=agent_id, | |
| status='completed', | |
| contentType=MessageType.TEXT.value, | |
| content=msg, | |
| type='ask_complete', | |
| is_finish=True | |
| ) | |
| def send_to_group(self, stream_name, group_name, data): | |
| logging.info(f"Stream: {stream_name}, Group: {group_name}, Data: {data}") | |
| message_id = self.redis_stream_client._client.xadd(stream_name, data) | |
| try: | |
| self.redis_stream_client._client.xgroup_create( | |
| stream_name, group_name, id="0" | |
| ) | |
| except Exception as e: | |
| logging.debug(f"Consumer group may already exist: {e}") | |
| return message_id | |