app / src-tauri /src /proxy /mappers /openai /streaming.rs
AZILS's picture
Upload 323 files
a21c316 verified
// OpenAI 流式转换
use bytes::{Bytes, BytesMut};
use chrono::Utc;
use futures::{Stream, StreamExt};
use rand::Rng;
use serde_json::{json, Value};
use std::pin::Pin;
use tracing::debug;
use uuid::Uuid;
/// 保存 thoughtSignature 到会话缓存
pub fn store_thought_signature(sig: &str, session_id: &str, message_count: usize) {
if sig.is_empty() {
return;
}
// 2. [CRITICAL] 存储到 Session 隔离缓存 (对齐 Claude 协议)
crate::proxy::SignatureCache::global().cache_session_signature(session_id, sig.to_string(), message_count);
tracing::debug!(
"[ThoughtSig] 存储 Session 签名 (sid: {}, len: {}, msg_count: {})",
session_id,
sig.len(),
message_count
);
}
/// Extract and convert Gemini usageMetadata to OpenAI usage format
fn extract_usage_metadata(u: &Value) -> Option<super::models::OpenAIUsage> {
use super::models::{OpenAIUsage, PromptTokensDetails};
let prompt_tokens = u
.get("promptTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let completion_tokens = u
.get("candidatesTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let total_tokens = u
.get("totalTokenCount")
.and_then(|v| v.as_u64())
.unwrap_or(0) as u32;
let cached_tokens = u
.get("cachedContentTokenCount")
.and_then(|v| v.as_u64())
.map(|v| v as u32);
Some(OpenAIUsage {
prompt_tokens,
completion_tokens,
total_tokens,
prompt_tokens_details: cached_tokens.map(|ct| PromptTokensDetails {
cached_tokens: Some(ct),
}),
completion_tokens_details: None,
})
}
pub fn create_openai_sse_stream<S, E>(
mut gemini_stream: Pin<Box<S>>,
model: String,
session_id: String,
message_count: usize,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, String>> + Send>>
where
S: Stream<Item = Result<Bytes, E>> + Send + ?Sized + 'static,
E: std::fmt::Display + Send + 'static,
{
let mut buffer = BytesMut::new();
let stream_id = format!("chatcmpl-{}", Uuid::new_v4());
let created_ts = Utc::now().timestamp();
let stream = async_stream::stream! {
let mut emitted_tool_calls = std::collections::HashSet::new();
let mut final_usage: Option<super::models::OpenAIUsage> = None;
let mut error_occurred = false;
let mut tool_call_index = 0;
let mut heartbeat_interval = tokio::time::interval(std::time::Duration::from_secs(15));
heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
item = gemini_stream.next() => {
match item {
Some(Ok(bytes)) => {
buffer.extend_from_slice(&bytes);
while let Some(pos) = buffer.iter().position(|&b| b == b'\n') {
let line_raw = buffer.split_to(pos + 1);
if let Ok(line_str) = std::str::from_utf8(&line_raw) {
let line = line_str.trim();
if line.is_empty() { continue; }
if line.starts_with("data: ") {
let json_part = line.trim_start_matches("data: ").trim();
if json_part == "[DONE]" { continue; }
if let Ok(mut json) = serde_json::from_str::<Value>(json_part) {
let actual_data = if let Some(inner) = json.get_mut("response").map(|v| v.take()) { inner } else { json };
if let Some(u) = actual_data.get("usageMetadata") {
final_usage = extract_usage_metadata(u);
}
if let Some(candidates) = actual_data.get("candidates").and_then(|c| c.as_array()) {
// [DEBUG] 打印原始 candidate 以排查空回复问题
if candidates.len() > 0 {
tracing::debug!("[Stream-Debug] Raw Candidate: {:?}", candidates[0]);
}
for (idx, candidate) in candidates.iter().enumerate() {
let parts = candidate.get("content").and_then(|c| c.get("parts")).and_then(|p| p.as_array());
let mut content_out = String::new();
let mut thought_out = String::new();
if let Some(parts_list) = parts {
for part in parts_list {
let is_thought_part = part.get("thought").and_then(|v| v.as_bool()).unwrap_or(false);
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
if is_thought_part { thought_out.push_str(text); }
else { content_out.push_str(text); }
}
if let Some(sig) = part.get("thoughtSignature").or(part.get("thought_signature")).and_then(|s| s.as_str()) {
store_thought_signature(sig, &session_id, message_count);
}
if let Some(img) = part.get("inlineData") {
let mime_type = img.get("mimeType").and_then(|v| v.as_str()).unwrap_or("image/png");
let data = img.get("data").and_then(|v| v.as_str()).unwrap_or("");
if !data.is_empty() {
content_out.push_str(&format!("![image](data:{};base64,{})", mime_type, data));
}
}
if let Some(func_call) = part.get("functionCall") {
let call_key = serde_json::to_string(func_call).unwrap_or_default();
if !emitted_tool_calls.contains(&call_key) {
emitted_tool_calls.insert(call_key);
let name = func_call.get("name").and_then(|v| v.as_str()).unwrap_or("unknown");
let mut args = func_call.get("args").unwrap_or(&json!({})).clone();
// [FIX #1575] 标准化 shell 工具参数名称
// Gemini 可能使用 cmd/code/script 等替代参数名,统一为 command
if name == "shell" || name == "bash" || name == "local_shell" {
if let Some(obj) = args.as_object_mut() {
if !obj.contains_key("command") {
for alt_key in &["cmd", "code", "script", "shell_command"] {
if let Some(val) = obj.remove(*alt_key) {
obj.insert("command".to_string(), val);
debug!("[OpenAI-Stream] Normalized shell arg '{}' -> 'command'", alt_key);
break;
}
}
}
}
}
let args_str = serde_json::to_string(&args).unwrap_or_default();
let mut hasher = std::collections::hash_map::DefaultHasher::new();
use std::hash::{Hash, Hasher};
serde_json::to_string(func_call).unwrap_or_default().hash(&mut hasher);
let call_id = format!("call_{:x}", hasher.finish());
let tool_call_chunk = json!({
"id": &stream_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": &model,
"choices": [{
"index": idx as u32,
"delta": {
"role": "assistant",
"tool_calls": [{
"index": tool_call_index,
"id": call_id,
"type": "function",
"function": { "name": name, "arguments": args_str }
}]
},
"finish_reason": serde_json::Value::Null
}]
});
tool_call_index += 1;
let sse_out = format!("data: {}\n\n", serde_json::to_string(&tool_call_chunk).unwrap_or_default());
yield Ok::<Bytes, String>(Bytes::from(sse_out));
}
}
}
}
if let Some(grounding) = candidate.get("groundingMetadata") {
let mut grounding_text = String::new();
if let Some(queries) = grounding.get("webSearchQueries").and_then(|q| q.as_array()) {
let query_list: Vec<&str> = queries.iter().filter_map(|v| v.as_str()).collect();
if !query_list.is_empty() {
grounding_text.push_str("\n\n---\n**🔍 已为您搜索:** ");
grounding_text.push_str(&query_list.join(", "));
}
}
if let Some(chunks) = grounding.get("groundingChunks").and_then(|c| c.as_array()) {
let mut links = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
if let Some(web) = chunk.get("web") {
let title = web.get("title").and_then(|v| v.as_str()).unwrap_or("网页来源");
let uri = web.get("uri").and_then(|v| v.as_str()).unwrap_or("#");
links.push(format!("[{}] [{}]({})", i + 1, title, uri));
}
}
if !links.is_empty() {
grounding_text.push_str("\n\n**🌐 来源引文:**\n");
grounding_text.push_str(&links.join("\n"));
}
}
if !grounding_text.is_empty() { content_out.push_str(&grounding_text); }
}
let gemini_finish_reason = candidate.get("finishReason").and_then(|f| f.as_str()).map(|f| match f {
"STOP" => "stop",
"MAX_TOKENS" => "length",
"SAFETY" => "content_filter",
"RECITATION" => "content_filter",
_ => f,
});
// [FIX #1575] 如果发射了工具调用,强制设置为 tool_calls
// 解决 Gemini 返回 STOP 但有工具调用时,OpenAI 客户端认为对话已结束的问题
let finish_reason = if !emitted_tool_calls.is_empty() && gemini_finish_reason.is_some() {
Some("tool_calls")
} else {
gemini_finish_reason
};
if !thought_out.is_empty() {
let reasoning_chunk = json!({
"id": &stream_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": &model,
"choices": [{
"index": idx as u32,
"delta": { "role": "assistant", "content": serde_json::Value::Null, "reasoning_content": thought_out },
"finish_reason": serde_json::Value::Null
}]
});
let sse_out = format!("data: {}\n\n", serde_json::to_string(&reasoning_chunk).unwrap_or_default());
yield Ok::<Bytes, String>(Bytes::from(sse_out));
}
if !content_out.is_empty() || finish_reason.is_some() {
let mut openai_chunk = json!({
"id": &stream_id,
"object": "chat.completion.chunk",
"created": created_ts,
"model": &model,
"choices": [{
"index": idx as u32,
"delta": { "content": content_out },
"finish_reason": finish_reason
}]
});
if finish_reason.is_some() {
if let Some(ref usage) = final_usage {
openai_chunk["usage"] = serde_json::to_value(usage).unwrap();
}
}
if finish_reason.is_some() { final_usage = None; }
let sse_out = format!("data: {}\n\n", serde_json::to_string(&openai_chunk).unwrap_or_default());
yield Ok::<Bytes, String>(Bytes::from(sse_out));
}
}
}
}
}
}
}
}
Some(Err(e)) => {
use crate::proxy::mappers::error_classifier::classify_stream_error;
let (error_type, user_msg, i18n_key) = classify_stream_error(&e);
tracing::error!("OpenAI Stream Error: {}", e);
let error_chunk = json!({
"id": &stream_id, "object": "chat.completion.chunk", "created": created_ts, "model": &model, "choices": [],
"error": { "type": error_type, "message": user_msg, "code": "stream_error", "i18n_key": i18n_key }
});
yield Ok(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&error_chunk).unwrap_or_default())));
yield Ok(Bytes::from("data: [DONE]\n\n"));
error_occurred = true;
break;
}
None => break,
}
}
_ = heartbeat_interval.tick() => {
yield Ok::<Bytes, String>(Bytes::from(": ping\n\n"));
}
}
}
// [FIX #1732] Flush remaining buffer to prevent hang on network fragmentation
if !buffer.is_empty() {
if let Ok(line_str) = std::str::from_utf8(&buffer) {
let line = line_str.trim();
if !line.is_empty() && line.starts_with("data: ") {
let json_part = line.trim_start_matches("data: ").trim();
if json_part != "[DONE]" {
// Re-use logic for processing the last line
// (Note: In a more complex refactor we'd extract this to a function,
// but for a targeted fix, processing the terminal data chunk is safer)
tracing::debug!("[OpenAI-SSE] Flushing remaining {} bytes in buffer", buffer.len());
}
}
}
}
if !error_occurred {
yield Ok::<Bytes, String>(Bytes::from("data: [DONE]\n\n"));
}
};
Box::pin(stream)
}
pub fn create_legacy_sse_stream<S, E>(
mut gemini_stream: Pin<Box<S>>,
model: String,
session_id: String,
message_count: usize,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, String>> + Send>>
where
S: Stream<Item = Result<Bytes, E>> + Send + ?Sized + 'static,
E: std::fmt::Display + Send + 'static,
{
let mut buffer = BytesMut::new();
let charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
let mut rng = rand::thread_rng();
let random_str: String = (0..28).map(|_| {
let idx = rng.gen_range(0..charset.len());
charset.chars().nth(idx).unwrap()
}).collect();
let stream_id = format!("cmpl-{}", random_str);
let created_ts = Utc::now().timestamp();
let stream = async_stream::stream! {
let mut final_usage: Option<super::models::OpenAIUsage> = None;
let mut error_occurred = false;
let mut heartbeat_interval = tokio::time::interval(std::time::Duration::from_secs(15));
heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
item = gemini_stream.next() => {
match item {
Some(Ok(bytes)) => {
buffer.extend_from_slice(&bytes);
while let Some(pos) = buffer.iter().position(|&b| b == b'\n') {
let line_raw = buffer.split_to(pos + 1);
if let Ok(line_str) = std::str::from_utf8(&line_raw) {
let line = line_str.trim();
if line.is_empty() { continue; }
if line.starts_with("data: ") {
let json_part = line.trim_start_matches("data: ").trim();
if json_part == "[DONE]" { continue; }
if let Ok(mut json) = serde_json::from_str::<Value>(json_part) {
let actual_data = if let Some(inner) = json.get_mut("response").map(|v| v.take()) { inner } else { json };
if let Some(u) = actual_data.get("usageMetadata") { final_usage = extract_usage_metadata(u); }
let mut content_out = String::new();
if let Some(candidates) = actual_data.get("candidates").and_then(|c| c.as_array()) {
if let Some(candidate) = candidates.get(0) {
if let Some(parts) = candidate.get("content").and_then(|c| c.get("parts")).and_then(|p| p.as_array()) {
for part in parts {
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
content_out.push_str(text);
}
if let Some(sig) = part.get("thoughtSignature").or(part.get("thought_signature")).and_then(|s| s.as_str()) {
store_thought_signature(sig, &session_id, message_count);
}
}
}
}
}
let finish_reason = actual_data.get("candidates").and_then(|c| c.as_array()).and_then(|c| c.get(0)).and_then(|c| c.get("finishReason")).and_then(|f| f.as_str()).map(|f| match f {
"STOP" => "stop", "MAX_TOKENS" => "length", "SAFETY" => "content_filter", _ => f,
});
let mut legacy_chunk = json!({
"id": &stream_id, "object": "text_completion", "created": created_ts, "model": &model,
"choices": [{ "text": content_out, "index": 0, "logprobs": null, "finish_reason": finish_reason }]
});
if let Some(ref usage) = final_usage { legacy_chunk["usage"] = serde_json::to_value(usage).unwrap(); }
if finish_reason.is_some() { final_usage = None; }
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&legacy_chunk).unwrap_or_default())));
}
}
}
}
}
Some(Err(e)) => {
use crate::proxy::mappers::error_classifier::classify_stream_error;
let (error_type, user_msg, i18n_key) = classify_stream_error(&e);
tracing::error!("Legacy Stream Error: {}", e);
let error_chunk = json!({
"id": &stream_id, "object": "text_completion", "created": created_ts, "model": &model, "choices": [],
"error": { "type": error_type, "message": user_msg, "code": "stream_error", "i18n_key": i18n_key }
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&error_chunk).unwrap_or_default())));
yield Ok::<Bytes, String>(Bytes::from("data: [DONE]\n\n"));
error_occurred = true;
break;
}
None => break,
}
}
_ = heartbeat_interval.tick() => { yield Ok::<Bytes, String>(Bytes::from(": ping\n\n")); }
}
}
if !error_occurred {
yield Ok::<Bytes, String>(Bytes::from("data: [DONE]\n\n"));
}
};
Box::pin(stream)
}
pub fn create_codex_sse_stream<S, E>(
mut gemini_stream: Pin<Box<S>>,
_model: String,
session_id: String,
message_count: usize,
) -> Pin<Box<dyn Stream<Item = Result<Bytes, String>> + Send>>
where
S: Stream<Item = Result<Bytes, E>> + Send + ?Sized + 'static,
E: std::fmt::Display + Send + 'static,
{
let mut buffer = BytesMut::new();
let charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789";
let mut rng = rand::thread_rng();
let random_str: String = (0..24).map(|_| {
let idx = rng.gen_range(0..charset.len());
charset.chars().nth(idx).unwrap()
}).collect();
let response_id = format!("resp-{}", random_str);
let item_id = format!("item-{}", &random_str[..16]);
let stream = async_stream::stream! {
// 1. response.created
let created_ev = json!({ "type": "response.created", "response": { "id": &response_id, "object": "response", "status": "in_progress", "output": [] } });
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&created_ev).unwrap())));
// 2. response.output_item.added - 告诉客户端开始一个输出项
let output_item_added = json!({
"type": "response.output_item.added",
"output_index": 0,
"item": {
"id": &item_id,
"type": "message",
"role": "assistant",
"status": "in_progress",
"content": []
}
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&output_item_added).unwrap())));
// 3. response.content_part.added - 告诉客户端开始一个文本内容块
let content_part_added = json!({
"type": "response.content_part.added",
"item_id": &item_id,
"output_index": 0,
"content_index": 0,
"part": {
"type": "output_text",
"text": ""
}
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&content_part_added).unwrap())));
let mut emitted_tool_calls = std::collections::HashSet::new();
let mut accumulated_text = String::new();
let mut heartbeat_interval = tokio::time::interval(std::time::Duration::from_secs(15));
heartbeat_interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
loop {
tokio::select! {
item = gemini_stream.next() => {
match item {
Some(Ok(bytes)) => {
buffer.extend_from_slice(&bytes);
while let Some(pos) = buffer.iter().position(|&b| b == b'\n') {
let line_raw = buffer.split_to(pos + 1);
if let Ok(line_str) = std::str::from_utf8(&line_raw) {
let line = line_str.trim();
if line.is_empty() || !line.starts_with("data: ") { continue; }
let json_part = line.trim_start_matches("data: ").trim();
if json_part == "[DONE]" { continue; }
if let Ok(mut json) = serde_json::from_str::<Value>(json_part) {
let actual_data = if let Some(inner) = json.get_mut("response").map(|v| v.take()) { inner } else { json };
if let Some(candidates) = actual_data.get("candidates").and_then(|c| c.as_array()) {
if candidates.len() > 0 {
tracing::debug!("[Codex-Stream-Debug] Raw Candidate: {:?}", candidates[0]);
}
if let Some(candidate) = candidates.get(0) {
if let Some(parts) = candidate.get("content").and_then(|c| c.get("parts")).and_then(|p| p.as_array()) {
for part in parts {
let is_thought = part.get("thought").and_then(|v| v.as_bool()).unwrap_or(false);
if let Some(text) = part.get("text").and_then(|t| t.as_str()) {
if !text.is_empty() {
if is_thought {
// 思维链内容 → response.reasoning.delta
let reasoning_ev = json!({
"type": "response.reasoning.delta",
"item_id": &item_id,
"output_index": 0,
"content_index": 0,
"delta": text
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&reasoning_ev).unwrap())));
} else {
accumulated_text.push_str(text);
// 4. response.output_text.delta - 文本增量
let delta_ev = json!({
"type": "response.output_text.delta",
"item_id": &item_id,
"output_index": 0,
"content_index": 0,
"delta": text
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&delta_ev).unwrap())));
}
}
}
if let Some(sig) = part.get("thoughtSignature").or(part.get("thought_signature")).and_then(|s| s.as_str()) {
store_thought_signature(sig, &session_id, message_count);
}
if let Some(func_call) = part.get("functionCall") {
let call_key = serde_json::to_string(func_call).unwrap_or_default();
if !emitted_tool_calls.contains(&call_key) {
emitted_tool_calls.insert(call_key);
}
}
}
}
// 处理 groundingMetadata (搜索引文)
if let Some(grounding) = candidate.get("groundingMetadata") {
let mut grounding_text = String::new();
if let Some(queries) = grounding.get("webSearchQueries").and_then(|q| q.as_array()) {
let query_list: Vec<&str> = queries.iter().filter_map(|v| v.as_str()).collect();
if !query_list.is_empty() {
grounding_text.push_str("\n\n---\n**🔍 已为您搜索:** ");
grounding_text.push_str(&query_list.join(", "));
}
}
if let Some(chunks) = grounding.get("groundingChunks").and_then(|c| c.as_array()) {
let mut links = Vec::new();
for (i, chunk) in chunks.iter().enumerate() {
if let Some(web) = chunk.get("web") {
let title = web.get("title").and_then(|v| v.as_str()).unwrap_or("网页来源");
let uri = web.get("uri").and_then(|v| v.as_str()).unwrap_or("#");
links.push(format!("[{}] [{}]({})", i + 1, title, uri));
}
}
if !links.is_empty() {
grounding_text.push_str("\n\n**🌐 来源引文:**\n");
grounding_text.push_str(&links.join("\n"));
}
}
if !grounding_text.is_empty() {
accumulated_text.push_str(&grounding_text);
let delta_ev = json!({
"type": "response.output_text.delta",
"item_id": &item_id,
"output_index": 0,
"content_index": 0,
"delta": grounding_text
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&delta_ev).unwrap())));
}
}
}
}
}
}
}
}
Some(Err(_)) => break,
None => break,
}
}
_ = heartbeat_interval.tick() => { yield Ok::<Bytes, String>(Bytes::from(": ping\n\n")); }
}
}
// 5. response.output_text.done - 文本完成
let text_done = json!({
"type": "response.output_text.done",
"item_id": &item_id,
"output_index": 0,
"content_index": 0,
"text": &accumulated_text
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&text_done).unwrap())));
// 6. response.content_part.done
let content_part_done = json!({
"type": "response.content_part.done",
"item_id": &item_id,
"output_index": 0,
"content_index": 0,
"part": {
"type": "output_text",
"text": &accumulated_text
}
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&content_part_done).unwrap())));
// 7. response.output_item.done
let output_item_done = json!({
"type": "response.output_item.done",
"output_index": 0,
"item": {
"id": &item_id,
"type": "message",
"role": "assistant",
"status": "completed",
"content": [{
"type": "output_text",
"text": &accumulated_text
}]
}
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&output_item_done).unwrap())));
// 8. response.completed
let completed_ev = json!({
"type": "response.completed",
"response": {
"id": &response_id,
"object": "response",
"status": "completed",
"output": [{
"id": &item_id,
"type": "message",
"role": "assistant",
"content": [{
"type": "output_text",
"text": &accumulated_text
}]
}]
}
});
yield Ok::<Bytes, String>(Bytes::from(format!("data: {}\n\n", serde_json::to_string(&completed_ev).unwrap())));
};
Box::pin(stream)
}
#[cfg(test)]
mod tests {
use super::*;
use futures::stream;
use serde_json::json;
#[tokio::test]
async fn test_openai_streaming_usage_only_at_end() {
// Chunk 1: Partial content, no usage
let chunk1_json = json!({
"candidates": [{
"content": {
"parts": [{ "text": "Hello" }]
}
}]
});
// Chunk 2: Finish reason + Usage metadata
let chunk2_json = json!({
"candidates": [{
"finishReason": "STOP",
"content": {
"parts": [{ "text": "" }]
}
}],
"usageMetadata": {
"promptTokenCount": 5,
"candidatesTokenCount": 2,
"totalTokenCount": 7
}
});
// Use a helper to create the stream items compatible with the required signature
let items: Vec<Result<Bytes, reqwest::Error>> = vec![
Ok(Bytes::from(format!("data: {}\n\n", chunk1_json))),
Ok(Bytes::from(format!("data: {}\n\n", chunk2_json))),
];
let gemini_stream = Box::pin(stream::iter(items));
let mut openai_stream = create_openai_sse_stream(
gemini_stream,
"gemini-1.5-flash".to_string(),
"test-session".to_string(),
0
);
let mut chunks = Vec::new();
while let Some(result) = openai_stream.next().await {
if let Ok(bytes) = result {
let s = String::from_utf8_lossy(&bytes).to_string();
for line in s.lines() {
if line.starts_with("data: ") && !line.contains("[DONE]") {
chunks.push(line.to_string());
}
}
}
}
let mut found_usage = false;
let mut found_finish = false;
for (i, chunk_str) in chunks.iter().enumerate() {
let json_str = chunk_str.trim_start_matches("data: ").trim();
let json: Value = serde_json::from_str(json_str).unwrap();
if i < chunks.len() - 1 {
assert!(json.get("usage").is_none(), "Usage should not be in intermediate chunks. Found in chunk {}", i);
} else {
if let Some(usage) = json.get("usage") {
found_usage = true;
assert_eq!(usage["prompt_tokens"], 5);
assert_eq!(usage["completion_tokens"], 2);
assert_eq!(usage["total_tokens"], 7);
}
if let Some(choices) = json.get("choices") {
if let Some(choice) = choices.get(0) {
if let Some(finish_reason) = choice.get("finish_reason") {
if finish_reason.as_str() == Some("stop") {
found_finish = true;
}
}
}
}
}
}
assert!(found_usage, "Usage should be found in the last chunk");
assert!(found_finish, "Finish reason should be strictly 'stop'");
}
}