Spaces:
Sleeping
Sleeping
| use super::aiserver::v1::StreamChatResponse; | |
| use flate2::read::GzDecoder; | |
| use prost::Message; | |
| use std::io::Read; | |
| use super::error::{ChatError, StreamError}; | |
| // 解压gzip数据 | |
| fn decompress_gzip(data: &[u8]) -> Option<Vec<u8>> { | |
| let mut decoder = GzDecoder::new(data); | |
| let mut decompressed = Vec::new(); | |
| match decoder.read_to_end(&mut decompressed) { | |
| Ok(_) => Some(decompressed), | |
| Err(_) => { | |
| // println!("gzip解压失败: {}", e); | |
| None | |
| } | |
| } | |
| } | |
| pub enum StreamMessage { | |
| // 未完成 | |
| Incomplete, | |
| // 调试 | |
| Debug(String), | |
| // 流开始标志 b"\0\0\0\0\0" | |
| StreamStart, | |
| // 消息内容 | |
| Content(Vec<String>), | |
| // 流结束标志 b"\x02\0\0\0\x02{}" | |
| StreamEnd, | |
| } | |
| pub fn parse_stream_data(data: &[u8]) -> Result<StreamMessage, StreamError> { | |
| if data.len() < 5 { | |
| return Err(StreamError::DataLengthLessThan5); | |
| } | |
| // 检查是否为流开始标志 | |
| // if data == b"\0\0\0\0\0" { | |
| // return Ok(StreamMessage::StreamStart); | |
| // } | |
| // 检查是否为流结束标志 | |
| // if data == b"\x02\0\0\0\x02{}" { | |
| // return Ok(StreamMessage::StreamEnd); | |
| // } | |
| let mut messages = Vec::new(); | |
| let mut offset = 0; | |
| while offset + 5 <= data.len() { | |
| // 获取消息类型和长度 | |
| let msg_type = data[offset]; | |
| let msg_len = u32::from_be_bytes([ | |
| data[offset + 1], | |
| data[offset + 2], | |
| data[offset + 3], | |
| data[offset + 4], | |
| ]) as usize; | |
| // 流开始 | |
| if msg_type == 0 && msg_len == 0 { | |
| return Ok(StreamMessage::StreamStart); | |
| } | |
| // 检查剩余数据长度是否足够 | |
| if offset + 5 + msg_len > data.len() { | |
| return Ok(StreamMessage::Incomplete); | |
| } | |
| let msg_data = &data[offset + 5..offset + 5 + msg_len]; | |
| match msg_type { | |
| // 文本消息 | |
| 0 => { | |
| if let Ok(response) = StreamChatResponse::decode(msg_data) { | |
| // crate::debug_println!("[text] StreamChatResponse: {:?}", response); | |
| if !response.text.is_empty() { | |
| messages.push(response.text); | |
| } else { | |
| // println!("[text] StreamChatResponse: {:?}", response); | |
| return Ok(StreamMessage::Debug( | |
| response.filled_prompt.unwrap_or_default(), | |
| // response.is_using_slow_request, | |
| )); | |
| } | |
| } | |
| } | |
| // gzip压缩消息 | |
| 1 => { | |
| if let Some(text) = decompress_gzip(msg_data) { | |
| let response = StreamChatResponse::decode(&text[..]).unwrap_or_default(); | |
| // crate::debug_println!("[gzip] StreamChatResponse: {:?}", response); | |
| if !response.text.is_empty() { | |
| messages.push(response.text); | |
| } else { | |
| // println!("[gzip] StreamChatResponse: {:?}", response); | |
| return Ok(StreamMessage::Debug( | |
| response.filled_prompt.unwrap_or_default(), | |
| // response.is_using_slow_request, | |
| )); | |
| } | |
| } | |
| } | |
| // JSON字符串 | |
| 2 => { | |
| if msg_len == 2 { | |
| return Ok(StreamMessage::StreamEnd); | |
| } | |
| if let Ok(text) = String::from_utf8(msg_data.to_vec()) { | |
| // println!("JSON消息: {}", text); | |
| if let Ok(error) = serde_json::from_str::<ChatError>(&text) { | |
| return Err(StreamError::ChatError(error)); | |
| } | |
| // 未预计 | |
| // messages.push(text); | |
| } | |
| } | |
| // 其他类型暂不处理 | |
| t => eprintln!("收到未知消息类型: {},请尝试联系开发者以获取支持", t), | |
| } | |
| offset += 5 + msg_len; | |
| } | |
| if messages.is_empty() { | |
| Err(StreamError::EmptyMessage) | |
| } else { | |
| Ok(StreamMessage::Content(messages)) | |
| } | |
| } | |
| fn test_parse_stream_data() { | |
| // 使用include_str!加载测试数据文件 | |
| let stream_data = include_str!("../../tests/data/stream_data.txt"); | |
| // 将整个字符串按每两个字符分割成字节 | |
| let bytes: Vec<u8> = stream_data | |
| .as_bytes() | |
| .chunks(2) | |
| .map(|chunk| { | |
| let hex_str = std::str::from_utf8(chunk).unwrap(); | |
| u8::from_str_radix(hex_str, 16).unwrap() | |
| }) | |
| .collect(); | |
| // 辅助函数:找到下一个消息边界 | |
| fn find_next_message_boundary(bytes: &[u8]) -> usize { | |
| if bytes.len() < 5 { | |
| return bytes.len(); | |
| } | |
| let msg_len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize; | |
| 5 + msg_len | |
| } | |
| // 辅助函数:将字节转换为hex字符串 | |
| fn bytes_to_hex(bytes: &[u8]) -> String { | |
| bytes.iter() | |
| .map(|b| format!("{:02X}", b)) | |
| .collect::<Vec<String>>() | |
| .join("") | |
| } | |
| // 多次解析数据 | |
| let mut offset = 0; | |
| while offset < bytes.len() { | |
| let remaining_bytes = &bytes[offset..]; | |
| let msg_boundary = find_next_message_boundary(remaining_bytes); | |
| let current_msg_bytes = &remaining_bytes[..msg_boundary]; | |
| let hex_str = bytes_to_hex(current_msg_bytes); | |
| match parse_stream_data(current_msg_bytes) { | |
| Ok(message) => { | |
| match message { | |
| StreamMessage::Content(messages) => { | |
| print!("消息内容 [hex: {}]:", hex_str); | |
| for msg in messages { | |
| println!(" {}", msg); | |
| } | |
| offset += msg_boundary; | |
| } | |
| StreamMessage::Debug(_) => { | |
| // println!("调试信息 [hex: {}]: {}", hex_str, prompt); | |
| offset += msg_boundary; | |
| } | |
| StreamMessage::StreamEnd => { | |
| println!("流结束 [hex: {}]", hex_str); | |
| break; | |
| } | |
| StreamMessage::StreamStart => { | |
| println!("流开始 [hex: {}]", hex_str); | |
| offset += msg_boundary; | |
| } | |
| StreamMessage::Incomplete => { | |
| println!("数据不完整 [hex: {}]", hex_str); | |
| break; | |
| } | |
| } | |
| } | |
| Err(e) => { | |
| println!("解析错误 [hex: {}]: {}", hex_str, e); | |
| break; | |
| } | |
| } | |
| } | |
| } | |