Spaces:
Sleeping
Sleeping
File size: 6,991 Bytes
2887ce2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
use super::aiserver::v1::StreamChatResponse;
use flate2::read::GzDecoder;
use prost::Message;
use std::io::Read;
use super::error::{ChatError, StreamError};
// 解压gzip数据
fn decompress_gzip(data: &[u8]) -> Option<Vec<u8>> {
let mut decoder = GzDecoder::new(data);
let mut decompressed = Vec::new();
match decoder.read_to_end(&mut decompressed) {
Ok(_) => Some(decompressed),
Err(_) => {
// println!("gzip解压失败: {}", e);
None
}
}
}
pub enum StreamMessage {
// 未完成
Incomplete,
// 调试
Debug(String),
// 流开始标志 b"\0\0\0\0\0"
StreamStart,
// 消息内容
Content(Vec<String>),
// 流结束标志 b"\x02\0\0\0\x02{}"
StreamEnd,
}
pub fn parse_stream_data(data: &[u8]) -> Result<StreamMessage, StreamError> {
if data.len() < 5 {
return Err(StreamError::DataLengthLessThan5);
}
// 检查是否为流开始标志
// if data == b"\0\0\0\0\0" {
// return Ok(StreamMessage::StreamStart);
// }
// 检查是否为流结束标志
// if data == b"\x02\0\0\0\x02{}" {
// return Ok(StreamMessage::StreamEnd);
// }
let mut messages = Vec::new();
let mut offset = 0;
while offset + 5 <= data.len() {
// 获取消息类型和长度
let msg_type = data[offset];
let msg_len = u32::from_be_bytes([
data[offset + 1],
data[offset + 2],
data[offset + 3],
data[offset + 4],
]) as usize;
// 流开始
if msg_type == 0 && msg_len == 0 {
return Ok(StreamMessage::StreamStart);
}
// 检查剩余数据长度是否足够
if offset + 5 + msg_len > data.len() {
return Ok(StreamMessage::Incomplete);
}
let msg_data = &data[offset + 5..offset + 5 + msg_len];
match msg_type {
// 文本消息
0 => {
if let Ok(response) = StreamChatResponse::decode(msg_data) {
// crate::debug_println!("[text] StreamChatResponse: {:?}", response);
if !response.text.is_empty() {
messages.push(response.text);
} else {
// println!("[text] StreamChatResponse: {:?}", response);
return Ok(StreamMessage::Debug(
response.filled_prompt.unwrap_or_default(),
// response.is_using_slow_request,
));
}
}
}
// gzip压缩消息
1 => {
if let Some(text) = decompress_gzip(msg_data) {
let response = StreamChatResponse::decode(&text[..]).unwrap_or_default();
// crate::debug_println!("[gzip] StreamChatResponse: {:?}", response);
if !response.text.is_empty() {
messages.push(response.text);
} else {
// println!("[gzip] StreamChatResponse: {:?}", response);
return Ok(StreamMessage::Debug(
response.filled_prompt.unwrap_or_default(),
// response.is_using_slow_request,
));
}
}
}
// JSON字符串
2 => {
if msg_len == 2 {
return Ok(StreamMessage::StreamEnd);
}
if let Ok(text) = String::from_utf8(msg_data.to_vec()) {
// println!("JSON消息: {}", text);
if let Ok(error) = serde_json::from_str::<ChatError>(&text) {
return Err(StreamError::ChatError(error));
}
// 未预计
// messages.push(text);
}
}
// 其他类型暂不处理
t => eprintln!("收到未知消息类型: {},请尝试联系开发者以获取支持", t),
}
offset += 5 + msg_len;
}
if messages.is_empty() {
Err(StreamError::EmptyMessage)
} else {
Ok(StreamMessage::Content(messages))
}
}
#[test]
fn test_parse_stream_data() {
// 使用include_str!加载测试数据文件
let stream_data = include_str!("../../tests/data/stream_data.txt");
// 将整个字符串按每两个字符分割成字节
let bytes: Vec<u8> = stream_data
.as_bytes()
.chunks(2)
.map(|chunk| {
let hex_str = std::str::from_utf8(chunk).unwrap();
u8::from_str_radix(hex_str, 16).unwrap()
})
.collect();
// 辅助函数:找到下一个消息边界
fn find_next_message_boundary(bytes: &[u8]) -> usize {
if bytes.len() < 5 {
return bytes.len();
}
let msg_len = u32::from_be_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]) as usize;
5 + msg_len
}
// 辅助函数:将字节转换为hex字符串
fn bytes_to_hex(bytes: &[u8]) -> String {
bytes.iter()
.map(|b| format!("{:02X}", b))
.collect::<Vec<String>>()
.join("")
}
// 多次解析数据
let mut offset = 0;
while offset < bytes.len() {
let remaining_bytes = &bytes[offset..];
let msg_boundary = find_next_message_boundary(remaining_bytes);
let current_msg_bytes = &remaining_bytes[..msg_boundary];
let hex_str = bytes_to_hex(current_msg_bytes);
match parse_stream_data(current_msg_bytes) {
Ok(message) => {
match message {
StreamMessage::Content(messages) => {
print!("消息内容 [hex: {}]:", hex_str);
for msg in messages {
println!(" {}", msg);
}
offset += msg_boundary;
}
StreamMessage::Debug(_) => {
// println!("调试信息 [hex: {}]: {}", hex_str, prompt);
offset += msg_boundary;
}
StreamMessage::StreamEnd => {
println!("流结束 [hex: {}]", hex_str);
break;
}
StreamMessage::StreamStart => {
println!("流开始 [hex: {}]", hex_str);
offset += msg_boundary;
}
StreamMessage::Incomplete => {
println!("数据不完整 [hex: {}]", hex_str);
break;
}
}
}
Err(e) => {
println!("解析错误 [hex: {}]: {}", hex_str, e);
break;
}
}
}
}
|