clone / src /core /adapter.py
tanbushi's picture
update
82f9be0
"""
Redis适配器 - 处理消息发送和接收
"""
import redis
import json
import threading
import queue
import time
import logging
from dataclasses import dataclass
from datetime import datetime
from typing import Optional, Callable, Dict, Any
from .entities import EntityInfo
from .messages import Message
# 配置日志
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# 类型定义
MessageCallback = Callable[[Message], None]
@dataclass
class RedisConfig:
"""Redis连接配置"""
host: str
port: int
db: int
password: Optional[str] = None
def to_connection_params(self) -> Dict[str, Any]:
"""获取连接参数"""
params = {
'host': self.host,
'port': self.port,
'db': self.db
}
if self.password:
params['password'] = self.password
return params
class RedisAdapter:
"""Redis适配器 - 处理消息发送和接收"""
def __init__(self, entity_info: EntityInfo):
self.entity_info = entity_info
self.redis_config = RedisConfig(
host=entity_info.redis_host,
port=entity_info.redis_port,
db=entity_info.redis_db
)
# Redis连接
self.redis_client: Optional[redis.Redis] = None
# 发送相关
self.send_queue = queue.Queue()
self.send_thread = None
self.send_running = False
# 接收相关
self.receive_queue = queue.Queue()
self.receive_thread = None
self.receive_running = False
self.message_callback: Optional[MessageCallback] = None
# 连接相关
self.pubsub: Optional[redis.client.PubSub] = None
self.connected = False
def connect(self) -> bool:
"""连接到Redis服务器"""
try:
self.redis_client = redis.Redis(**self.redis_config.to_connection_params())
self.redis_client.ping() # 测试连接
self.connected = True
logger.info(f"Entity {self.entity_info.id} connected to Redis")
return True
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
return False
def disconnect(self):
"""断开Redis连接"""
self.stop()
if self.redis_client:
self.redis_client.close()
self.connected = False
logger.info(f"Entity {self.entity_info.id} disconnected from Redis")
def start(self) -> bool:
"""启动适配器(发送和接收线程)"""
if not self.connect():
return False
self.send_running = True
self.receive_running = True
# 启动发送线程
self.send_thread = threading.Thread(target=self._send_worker, daemon=True)
self.send_thread.start()
# 启动接收线程
self.receive_thread = threading.Thread(target=self._receive_worker, daemon=True)
self.receive_thread.start()
logger.info(f"RedisAdapter started for entity {self.entity_info.id}")
return True
def stop(self):
"""停止适配器"""
self.send_running = False
self.receive_running = False
# 等待线程结束
if self.send_thread and self.send_thread.is_alive():
self.send_thread.join(timeout=5)
if self.receive_thread and self.receive_thread.is_alive():
self.receive_thread.join(timeout=5)
logger.info(f"RedisAdapter stopped for entity {self.entity_info.id}")
def send_message(self, receiver_id: str, content: str) -> bool:
"""发送消息(异步)"""
try:
message = Message(
sender_id=self.entity_info.id,
receiver_id=receiver_id,
timestamp=datetime.now(),
content=content
)
self.send_queue.put(message)
logger.debug(f"Message queued for {receiver_id}")
return True
except Exception as e:
logger.error(f"Failed to queue message: {e}")
return False
def register_callback(self, callback: MessageCallback):
"""注册消息接收回调函数"""
self.message_callback = callback
def _send_worker(self):
"""发送工作线程"""
logger.info("Send worker thread started")
while self.send_running:
try:
# 从队列获取消息(超时1秒)
message = self.send_queue.get(timeout=1)
# 这里需要获取接收者的Redis连接信息
# 当前简化处理,使用同一个Redis实例
target_channel = message.receiver_id
# 发送到Redis
if self.redis_client:
self.redis_client.publish(
target_channel,
json.dumps(message.to_dict())
)
logger.debug(f"Message sent to channel {target_channel}")
self.send_queue.task_done()
except queue.Empty:
continue
except Exception as e:
logger.error(f"Error in send worker: {e}")
logger.info("Send worker thread stopped")
def _receive_worker(self):
"""接收工作线程"""
logger.info("Receive worker thread started")
if not self.redis_client:
logger.error("Redis client not available for receiving")
return
try:
# 创建pubsub对象
self.pubsub = self.redis_client.pubsub()
self.pubsub.subscribe(self.entity_info.channel)
while self.receive_running:
try:
# 获取消息(超时1秒)
message = self.pubsub.get_message(timeout=1)
if message and message['type'] == 'message':
# 解析消息数据
message_data = json.loads(message['data'].decode('utf-8'))
received_message = Message.from_dict(message_data)
# 放入接收队列
self.receive_queue.put(received_message)
# 处理接收队列中的消息
self._process_receive_queue()
except Exception as e:
logger.error(f"Error processing received message: {e}")
except Exception as e:
logger.error(f"Error in receive worker: {e}")
finally:
if self.pubsub:
self.pubsub.close()
logger.info("Receive worker thread stopped")
def _process_receive_queue(self):
"""处理接收队列中的消息"""
try:
while not self.receive_queue.empty():
message = self.receive_queue.get_nowait()
if self.message_callback:
try:
self.message_callback(message)
logger.debug(f"Message delivered to callback")
except Exception as e:
logger.error(f"Error in message callback: {e}")
self.receive_queue.task_done()
except queue.Empty:
pass
except Exception as e:
logger.error(f"Error processing receive queue: {e}")