|
|
|
|
|
|
|
|
"""
|
|
|
流式数据包处理器
|
|
|
|
|
|
处理流式protobuf数据包,支持实时解析和WebSocket推送。
|
|
|
"""
|
|
|
import asyncio
|
|
|
import json
|
|
|
import base64
|
|
|
from typing import AsyncGenerator, List, Dict, Any, Optional
|
|
|
from datetime import datetime
|
|
|
|
|
|
from .logging import logger
|
|
|
from .protobuf_utils import protobuf_to_dict
|
|
|
|
|
|
|
|
|
class StreamProcessor:
|
|
|
"""流式数据包处理器"""
|
|
|
|
|
|
def __init__(self, websocket_manager=None):
|
|
|
self.websocket_manager = websocket_manager
|
|
|
self.active_streams: Dict[str, StreamSession] = {}
|
|
|
|
|
|
async def create_stream_session(self, stream_id: str, message_type: str = "warp.multi_agent.v1.Response") -> 'StreamSession':
|
|
|
"""创建流式会话"""
|
|
|
session = StreamSession(stream_id, message_type, self.websocket_manager)
|
|
|
self.active_streams[stream_id] = session
|
|
|
|
|
|
logger.info(f"创建流式会话: {stream_id}, 消息类型: {message_type}")
|
|
|
return session
|
|
|
|
|
|
async def get_stream_session(self, stream_id: str) -> Optional['StreamSession']:
|
|
|
"""获取流式会话"""
|
|
|
return self.active_streams.get(stream_id)
|
|
|
|
|
|
async def close_stream_session(self, stream_id: str):
|
|
|
"""关闭流式会话"""
|
|
|
if stream_id in self.active_streams:
|
|
|
session = self.active_streams[stream_id]
|
|
|
await session.close()
|
|
|
del self.active_streams[stream_id]
|
|
|
logger.info(f"关闭流式会话: {stream_id}")
|
|
|
|
|
|
async def process_stream_chunk(self, stream_id: str, chunk_data: bytes) -> Dict[str, Any]:
|
|
|
"""处理流式数据块"""
|
|
|
session = await self.get_stream_session(stream_id)
|
|
|
if not session:
|
|
|
raise ValueError(f"流式会话不存在: {stream_id}")
|
|
|
|
|
|
return await session.process_chunk(chunk_data)
|
|
|
|
|
|
async def finalize_stream(self, stream_id: str) -> Dict[str, Any]:
|
|
|
"""完成流式处理"""
|
|
|
session = await self.get_stream_session(stream_id)
|
|
|
if not session:
|
|
|
raise ValueError(f"流式会话不存在: {stream_id}")
|
|
|
|
|
|
result = await session.finalize()
|
|
|
await self.close_stream_session(stream_id)
|
|
|
return result
|
|
|
|
|
|
|
|
|
class StreamSession:
|
|
|
"""流式会话"""
|
|
|
|
|
|
def __init__(self, session_id: str, message_type: str, websocket_manager=None):
|
|
|
self.session_id = session_id
|
|
|
self.message_type = message_type
|
|
|
self.websocket_manager = websocket_manager
|
|
|
|
|
|
self.chunks: List[bytes] = []
|
|
|
self.chunk_count = 0
|
|
|
self.total_size = 0
|
|
|
self.start_time = datetime.now()
|
|
|
|
|
|
self.parsed_chunks: List[Dict] = []
|
|
|
self.complete_message: Optional[Dict] = None
|
|
|
|
|
|
async def process_chunk(self, chunk_data: bytes) -> Dict[str, Any]:
|
|
|
"""处理单个数据块"""
|
|
|
self.chunk_count += 1
|
|
|
self.total_size += len(chunk_data)
|
|
|
self.chunks.append(chunk_data)
|
|
|
|
|
|
logger.debug(f"流式会话 {self.session_id}: 处理数据块 {self.chunk_count}, 大小 {len(chunk_data)} 字节")
|
|
|
|
|
|
chunk_result = {
|
|
|
"chunk_index": self.chunk_count - 1,
|
|
|
"size": len(chunk_data),
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
}
|
|
|
|
|
|
try:
|
|
|
chunk_json = protobuf_to_dict(chunk_data, self.message_type)
|
|
|
chunk_result["json_data"] = chunk_json
|
|
|
chunk_result["parsed_successfully"] = True
|
|
|
|
|
|
self.parsed_chunks.append(chunk_json)
|
|
|
|
|
|
if self.websocket_manager:
|
|
|
await self.websocket_manager.broadcast({
|
|
|
"event": "stream_chunk_parsed",
|
|
|
"stream_id": self.session_id,
|
|
|
"chunk": chunk_result
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
chunk_result["error"] = str(e)
|
|
|
chunk_result["parsed_successfully"] = False
|
|
|
logger.warning(f"数据块解析失败: {e}")
|
|
|
|
|
|
if self.websocket_manager:
|
|
|
await self.websocket_manager.broadcast({
|
|
|
"event": "stream_chunk_error",
|
|
|
"stream_id": self.session_id,
|
|
|
"chunk": chunk_result
|
|
|
})
|
|
|
|
|
|
return chunk_result
|
|
|
|
|
|
async def finalize(self) -> Dict[str, Any]:
|
|
|
"""完成流式处理,尝试拼接完整消息"""
|
|
|
duration = (datetime.now() - self.start_time).total_seconds()
|
|
|
|
|
|
logger.info(f"流式会话 {self.session_id} 完成: {self.chunk_count} 块, 总大小 {self.total_size} 字节, 耗时 {duration:.2f}s")
|
|
|
|
|
|
result = {
|
|
|
"session_id": self.session_id,
|
|
|
"chunk_count": self.chunk_count,
|
|
|
"total_size": self.total_size,
|
|
|
"duration_seconds": duration,
|
|
|
"chunks": []
|
|
|
}
|
|
|
|
|
|
for i, chunk in enumerate(self.chunks):
|
|
|
chunk_info = {
|
|
|
"index": i,
|
|
|
"size": len(chunk),
|
|
|
"hex_preview": chunk[:32].hex() if len(chunk) >= 32 else chunk.hex()
|
|
|
}
|
|
|
|
|
|
if i < len(self.parsed_chunks):
|
|
|
chunk_info["parsed_data"] = self.parsed_chunks[i]
|
|
|
|
|
|
result["chunks"].append(chunk_info)
|
|
|
|
|
|
try:
|
|
|
complete_data = b''.join(self.chunks)
|
|
|
complete_json = protobuf_to_dict(complete_data, self.message_type)
|
|
|
|
|
|
result["complete_message"] = {
|
|
|
"size": len(complete_data),
|
|
|
"json_data": complete_json,
|
|
|
"assembly_successful": True
|
|
|
}
|
|
|
|
|
|
self.complete_message = complete_json
|
|
|
|
|
|
logger.info(f"流式消息拼接成功: {len(complete_data)} 字节")
|
|
|
|
|
|
except Exception as e:
|
|
|
result["complete_message"] = {
|
|
|
"error": str(e),
|
|
|
"assembly_successful": False
|
|
|
}
|
|
|
logger.warning(f"流式消息拼接失败: {e}")
|
|
|
|
|
|
if self.websocket_manager:
|
|
|
await self.websocket_manager.broadcast({
|
|
|
"event": "stream_completed",
|
|
|
"stream_id": self.session_id,
|
|
|
"result": result
|
|
|
})
|
|
|
|
|
|
return result
|
|
|
|
|
|
async def close(self):
|
|
|
"""关闭会话"""
|
|
|
self.chunks.clear()
|
|
|
self.parsed_chunks.clear()
|
|
|
self.complete_message = None
|
|
|
|
|
|
logger.debug(f"流式会话 {self.session_id} 已关闭")
|
|
|
|
|
|
|
|
|
class StreamPacketAnalyzer:
|
|
|
"""流式数据包分析器"""
|
|
|
|
|
|
@staticmethod
|
|
|
def analyze_chunk_patterns(chunks: List[bytes]) -> Dict[str, Any]:
|
|
|
if not chunks:
|
|
|
return {"error": "无数据块"}
|
|
|
|
|
|
analysis = {
|
|
|
"total_chunks": len(chunks),
|
|
|
"size_distribution": {},
|
|
|
"size_stats": {},
|
|
|
"pattern_analysis": {}
|
|
|
}
|
|
|
|
|
|
sizes = [len(chunk) for chunk in chunks]
|
|
|
analysis["size_stats"] = {
|
|
|
"min": min(sizes),
|
|
|
"max": max(sizes),
|
|
|
"avg": sum(sizes) / len(sizes),
|
|
|
"total": sum(sizes)
|
|
|
}
|
|
|
|
|
|
size_ranges = [(0, 100), (100, 500), (500, 1000), (1000, 5000), (5000, float('inf'))]
|
|
|
for start, end in size_ranges:
|
|
|
range_name = f"{start}-{end if end != float('inf') else '∞'}"
|
|
|
count = sum(1 for size in sizes if start <= size < end)
|
|
|
analysis["size_distribution"][range_name] = count
|
|
|
|
|
|
if len(chunks) >= 2:
|
|
|
first_bytes = [chunk[:4].hex() if len(chunk) >= 4 else chunk.hex() for chunk in chunks[:5]]
|
|
|
analysis["pattern_analysis"]["first_bytes_samples"] = first_bytes
|
|
|
|
|
|
if chunks:
|
|
|
common_prefix_len = 0
|
|
|
first_chunk = chunks[0]
|
|
|
for i in range(min(len(first_chunk), 10)):
|
|
|
if all(len(chunk) > i and chunk[i] == first_chunk[i] for chunk in chunks[1:]):
|
|
|
common_prefix_len = i + 1
|
|
|
else:
|
|
|
break
|
|
|
|
|
|
if common_prefix_len > 0:
|
|
|
analysis["pattern_analysis"]["common_prefix_length"] = common_prefix_len
|
|
|
analysis["pattern_analysis"]["common_prefix_hex"] = first_chunk[:common_prefix_len].hex()
|
|
|
|
|
|
return analysis
|
|
|
|
|
|
@staticmethod
|
|
|
def extract_streaming_deltas(parsed_chunks: List[Dict]) -> List[Dict]:
|
|
|
if not parsed_chunks:
|
|
|
return []
|
|
|
|
|
|
deltas = []
|
|
|
previous_content = ""
|
|
|
|
|
|
for i, chunk in enumerate(parsed_chunks):
|
|
|
delta = {
|
|
|
"chunk_index": i,
|
|
|
"timestamp": datetime.now().isoformat()
|
|
|
}
|
|
|
|
|
|
current_content = StreamPacketAnalyzer._extract_text_content(chunk)
|
|
|
|
|
|
if current_content and current_content != previous_content:
|
|
|
if previous_content and current_content.startswith(previous_content):
|
|
|
delta["content_delta"] = current_content[len(previous_content):]
|
|
|
delta["delta_type"] = "append"
|
|
|
else:
|
|
|
delta["content_delta"] = current_content
|
|
|
delta["delta_type"] = "replace"
|
|
|
|
|
|
delta["total_content_length"] = len(current_content)
|
|
|
previous_content = current_content
|
|
|
else:
|
|
|
delta["content_delta"] = ""
|
|
|
delta["delta_type"] = "no_change"
|
|
|
|
|
|
if i > 0:
|
|
|
delta["field_changes"] = StreamPacketAnalyzer._compare_dicts(parsed_chunks[i-1], chunk)
|
|
|
|
|
|
deltas.append(delta)
|
|
|
|
|
|
return deltas
|
|
|
|
|
|
@staticmethod
|
|
|
def _extract_text_content(data: Dict) -> str:
|
|
|
text_paths = [
|
|
|
["content"],
|
|
|
["text"],
|
|
|
["message"],
|
|
|
["agent_output", "text"],
|
|
|
["choices", 0, "delta", "content"],
|
|
|
["choices", 0, "message", "content"]
|
|
|
]
|
|
|
|
|
|
for path in text_paths:
|
|
|
try:
|
|
|
current = data
|
|
|
for key in path:
|
|
|
if isinstance(current, dict) and key in current:
|
|
|
current = current[key]
|
|
|
elif isinstance(current, list) and isinstance(key, int) and 0 <= key < len(current):
|
|
|
current = current[key]
|
|
|
else:
|
|
|
break
|
|
|
else:
|
|
|
if isinstance(current, str):
|
|
|
return current
|
|
|
except Exception:
|
|
|
continue
|
|
|
|
|
|
return ""
|
|
|
|
|
|
@staticmethod
|
|
|
def _compare_dicts(dict1: Dict, dict2: Dict, prefix: str = "") -> List[str]:
|
|
|
changes = []
|
|
|
|
|
|
all_keys = set(dict1.keys()) | set(dict2.keys())
|
|
|
|
|
|
for key in all_keys:
|
|
|
current_path = f"{prefix}.{key}" if prefix else key
|
|
|
|
|
|
if key not in dict1:
|
|
|
changes.append(f"添加: {current_path}")
|
|
|
elif key not in dict2:
|
|
|
changes.append(f"删除: {current_path}")
|
|
|
elif dict1[key] != dict2[key]:
|
|
|
if isinstance(dict1[key], dict) and isinstance(dict2[key], dict):
|
|
|
changes.extend(StreamPacketAnalyzer._compare_dicts(dict1[key], dict2[key], current_path))
|
|
|
else:
|
|
|
changes.append(f"修改: {current_path}")
|
|
|
|
|
|
return changes[:10]
|
|
|
|
|
|
|
|
|
_global_processor: Optional[StreamProcessor] = None
|
|
|
|
|
|
def get_stream_processor() -> StreamProcessor:
|
|
|
global _global_processor
|
|
|
if _global_processor is None:
|
|
|
_global_processor = StreamProcessor()
|
|
|
return _global_processor
|
|
|
|
|
|
|
|
|
def set_websocket_manager(manager):
|
|
|
processor = get_stream_processor()
|
|
|
processor.websocket_manager = manager |