| |
| |
|
|
| use super::models::*; |
| use super::utils::to_claude_usage; |
| use crate::proxy::mappers::signature_store::store_thought_signature; |
| use bytes::Bytes; |
| use serde_json::json; |
|
|
| |
| #[derive(Debug, Clone, Copy, PartialEq, Eq)] |
| pub enum BlockType { |
| None, |
| Text, |
| Thinking, |
| Function, |
| } |
|
|
| |
| pub struct SignatureManager { |
| pending: Option<String>, |
| } |
|
|
| impl SignatureManager { |
| pub fn new() -> Self { |
| Self { pending: None } |
| } |
|
|
| pub fn store(&mut self, signature: Option<String>) { |
| if signature.is_some() { |
| self.pending = signature; |
| } |
| } |
|
|
| pub fn consume(&mut self) -> Option<String> { |
| self.pending.take() |
| } |
|
|
| pub fn has_pending(&self) -> bool { |
| self.pending.is_some() |
| } |
| } |
|
|
| |
| pub struct StreamingState { |
| block_type: BlockType, |
| pub block_index: usize, |
| pub message_start_sent: bool, |
| pub message_stop_sent: bool, |
| used_tool: bool, |
| signatures: SignatureManager, |
| trailing_signature: Option<String>, |
| pub web_search_query: Option<String>, |
| pub grounding_chunks: Option<Vec<serde_json::Value>>, |
| } |
|
|
| impl StreamingState { |
| pub fn new() -> Self { |
| Self { |
| block_type: BlockType::None, |
| block_index: 0, |
| message_start_sent: false, |
| message_stop_sent: false, |
| used_tool: false, |
| signatures: SignatureManager::new(), |
| trailing_signature: None, |
| web_search_query: None, |
| grounding_chunks: None, |
| } |
| } |
|
|
| |
| pub fn emit(&self, event_type: &str, data: serde_json::Value) -> Bytes { |
| let sse = format!( |
| "event: {}\ndata: {}\n\n", |
| event_type, |
| serde_json::to_string(&data).unwrap_or_default() |
| ); |
| Bytes::from(sse) |
| } |
|
|
| |
| pub fn emit_message_start(&mut self, raw_json: &serde_json::Value) -> Bytes { |
| if self.message_start_sent { |
| return Bytes::new(); |
| } |
|
|
| let usage = raw_json |
| .get("usageMetadata") |
| .and_then(|u| serde_json::from_value::<UsageMetadata>(u.clone()).ok()) |
| .map(|u| to_claude_usage(&u)); |
|
|
| let mut message = json!({ |
| "id": raw_json.get("responseId") |
| .and_then(|v| v.as_str()) |
| .unwrap_or_else(|| "msg_unknown"), |
| "type": "message", |
| "role": "assistant", |
| "content": [], |
| "model": raw_json.get("modelVersion") |
| .and_then(|v| v.as_str()) |
| .unwrap_or(""), |
| "stop_reason": null, |
| "stop_sequence": null, |
| }); |
|
|
| if let Some(u) = usage { |
| message["usage"] = json!(u); |
| } |
|
|
| let result = self.emit( |
| "message_start", |
| json!({ |
| "type": "message_start", |
| "message": message |
| }), |
| ); |
|
|
| self.message_start_sent = true; |
| result |
| } |
|
|
| |
| pub fn start_block( |
| &mut self, |
| block_type: BlockType, |
| content_block: serde_json::Value, |
| ) -> Vec<Bytes> { |
| let mut chunks = Vec::new(); |
| if self.block_type != BlockType::None { |
| chunks.extend(self.end_block()); |
| } |
|
|
| chunks.push(self.emit( |
| "content_block_start", |
| json!({ |
| "type": "content_block_start", |
| "index": self.block_index, |
| "content_block": content_block |
| }), |
| )); |
|
|
| self.block_type = block_type; |
| chunks |
| } |
|
|
| |
| pub fn end_block(&mut self) -> Vec<Bytes> { |
| if self.block_type == BlockType::None { |
| return vec![]; |
| } |
|
|
| let mut chunks = Vec::new(); |
|
|
| |
| if self.block_type == BlockType::Thinking && self.signatures.has_pending() { |
| if let Some(signature) = self.signatures.consume() { |
| chunks.push(self.emit_delta("signature_delta", json!({ "signature": signature }))); |
| } |
| } |
|
|
| chunks.push(self.emit( |
| "content_block_stop", |
| json!({ |
| "type": "content_block_stop", |
| "index": self.block_index |
| }), |
| )); |
|
|
| self.block_index += 1; |
| self.block_type = BlockType::None; |
|
|
| chunks |
| } |
|
|
| |
| pub fn emit_delta(&self, delta_type: &str, delta_content: serde_json::Value) -> Bytes { |
| let mut delta = json!({ "type": delta_type }); |
| if let serde_json::Value::Object(map) = delta_content { |
| for (k, v) in map { |
| delta[k] = v; |
| } |
| } |
|
|
| self.emit( |
| "content_block_delta", |
| json!({ |
| "type": "content_block_delta", |
| "index": self.block_index, |
| "delta": delta |
| }), |
| ) |
| } |
|
|
| |
| pub fn emit_finish( |
| &mut self, |
| finish_reason: Option<&str>, |
| usage_metadata: Option<&UsageMetadata>, |
| ) -> Vec<Bytes> { |
| let mut chunks = Vec::new(); |
|
|
| |
| chunks.extend(self.end_block()); |
|
|
| |
| if let Some(signature) = self.trailing_signature.take() { |
| chunks.push(self.emit( |
| "content_block_start", |
| json!({ |
| "type": "content_block_start", |
| "index": self.block_index, |
| "content_block": { "type": "thinking", "thinking": "" } |
| }), |
| )); |
| chunks.push(self.emit_delta("thinking_delta", json!({ "thinking": "" }))); |
| chunks.push(self.emit_delta("signature_delta", json!({ "signature": signature }))); |
| chunks.push(self.emit( |
| "content_block_stop", |
| json!({ |
| "type": "content_block_stop", |
| "index": self.block_index |
| }), |
| )); |
| self.block_index += 1; |
| } |
|
|
| |
| if self.web_search_query.is_some() || self.grounding_chunks.is_some() { |
| let mut grounding_text = String::new(); |
| |
| |
| if let Some(query) = &self.web_search_query { |
| if !query.is_empty() { |
| grounding_text.push_str("\n\n---\n**🔍 已为您搜索:** "); |
| grounding_text.push_str(query); |
| } |
| } |
|
|
| |
| if let Some(chunks) = &self.grounding_chunks { |
| let mut links = Vec::new(); |
| for (i, chunk) in chunks.iter().enumerate() { |
| if let Some(web) = chunk.get("web") { |
| let title = web.get("title").and_then(|v| v.as_str()).unwrap_or("网页来源"); |
| let uri = web.get("uri").and_then(|v| v.as_str()).unwrap_or("#"); |
| links.push(format!("[{}] [{}]({})", i + 1, title, uri)); |
| } |
| } |
| |
| if !links.is_empty() { |
| grounding_text.push_str("\n\n**🌐 来源引文:**\n"); |
| grounding_text.push_str(&links.join("\n")); |
| } |
| } |
|
|
| if !grounding_text.is_empty() { |
| |
| chunks.push(self.emit("content_block_start", json!({ |
| "type": "content_block_start", |
| "index": self.block_index, |
| "content_block": { "type": "text", "text": "" } |
| }))); |
| chunks.push(self.emit_delta("text_delta", json!({ "text": grounding_text }))); |
| chunks.push(self.emit("content_block_stop", json!({ "type": "content_block_stop", "index": self.block_index }))); |
| self.block_index += 1; |
| } |
| } |
|
|
| |
| let stop_reason = if self.used_tool { |
| "tool_use" |
| } else if finish_reason == Some("MAX_TOKENS") { |
| "max_tokens" |
| } else { |
| "end_turn" |
| }; |
|
|
| let usage = usage_metadata |
| .map(|u| to_claude_usage(u)) |
| .unwrap_or(Usage { |
| input_tokens: 0, |
| output_tokens: 0, |
| server_tool_use: None, |
| }); |
|
|
| chunks.push(self.emit( |
| "message_delta", |
| json!({ |
| "type": "message_delta", |
| "delta": { "stop_reason": stop_reason, "stop_sequence": null }, |
| "usage": usage |
| }), |
| )); |
|
|
| if !self.message_stop_sent { |
| chunks.push(Bytes::from( |
| "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n", |
| )); |
| self.message_stop_sent = true; |
| } |
|
|
| chunks |
| } |
|
|
| |
| pub fn mark_tool_used(&mut self) { |
| self.used_tool = true; |
| } |
|
|
| |
| pub fn current_block_type(&self) -> BlockType { |
| self.block_type |
| } |
|
|
| |
| pub fn current_block_index(&self) -> usize { |
| self.block_index |
| } |
|
|
| |
| pub fn store_signature(&mut self, signature: Option<String>) { |
| self.signatures.store(signature); |
| } |
|
|
| |
| pub fn set_trailing_signature(&mut self, signature: Option<String>) { |
| self.trailing_signature = signature; |
| } |
|
|
| |
| pub fn has_trailing_signature(&self) -> bool { |
| self.trailing_signature.is_some() |
| } |
| } |
|
|
| |
| pub struct PartProcessor<'a> { |
| state: &'a mut StreamingState, |
| } |
|
|
| impl<'a> PartProcessor<'a> { |
| pub fn new(state: &'a mut StreamingState) -> Self { |
| Self { state } |
| } |
|
|
| |
| pub fn process(&mut self, part: &GeminiPart) -> Vec<Bytes> { |
| let mut chunks = Vec::new(); |
| let signature = part.thought_signature.clone(); |
|
|
| |
| if let Some(fc) = &part.function_call { |
| |
| if self.state.has_trailing_signature() { |
| chunks.extend(self.state.end_block()); |
| if let Some(trailing_sig) = self.state.trailing_signature.take() { |
| chunks.push(self.state.emit( |
| "content_block_start", |
| json!({ |
| "type": "content_block_start", |
| "index": self.state.current_block_index(), |
| "content_block": { "type": "thinking", "thinking": "" } |
| }), |
| )); |
| chunks.push( |
| self.state |
| .emit_delta("thinking_delta", json!({ "thinking": "" })), |
| ); |
| chunks.push( |
| self.state |
| .emit_delta("signature_delta", json!({ "signature": trailing_sig })), |
| ); |
| chunks.extend(self.state.end_block()); |
| } |
| } |
|
|
| chunks.extend(self.process_function_call(fc, signature)); |
| return chunks; |
| } |
|
|
| |
| if let Some(text) = &part.text { |
| if part.thought.unwrap_or(false) { |
| |
| chunks.extend(self.process_thinking(text, signature)); |
| } else { |
| |
| chunks.extend(self.process_text(text, signature)); |
| } |
| } |
|
|
| |
| if let Some(img) = &part.inline_data { |
| let mime_type = &img.mime_type; |
| let data = &img.data; |
| if !data.is_empty() { |
| let markdown_img = format!("", mime_type, data); |
| chunks.extend(self.process_text(&markdown_img, None)); |
| } |
| } |
|
|
| chunks |
| } |
|
|
| |
| fn process_thinking(&mut self, text: &str, signature: Option<String>) -> Vec<Bytes> { |
| let mut chunks = Vec::new(); |
|
|
| |
| if self.state.has_trailing_signature() { |
| chunks.extend(self.state.end_block()); |
| if let Some(trailing_sig) = self.state.trailing_signature.take() { |
| chunks.push(self.state.emit( |
| "content_block_start", |
| json!({ |
| "type": "content_block_start", |
| "index": self.state.current_block_index(), |
| "content_block": { "type": "thinking", "thinking": "" } |
| }), |
| )); |
| chunks.push( |
| self.state |
| .emit_delta("thinking_delta", json!({ "thinking": "" })), |
| ); |
| chunks.push( |
| self.state |
| .emit_delta("signature_delta", json!({ "signature": trailing_sig })), |
| ); |
| chunks.extend(self.state.end_block()); |
| } |
| } |
|
|
| |
| if self.state.current_block_type() != BlockType::Thinking { |
| chunks.extend(self.state.start_block( |
| BlockType::Thinking, |
| json!({ "type": "thinking", "thinking": "" }), |
| )); |
| } |
|
|
| if !text.is_empty() { |
| chunks.push( |
| self.state |
| .emit_delta("thinking_delta", json!({ "thinking": text })), |
| ); |
| } |
|
|
| |
| self.state.store_signature(signature); |
|
|
| chunks |
| } |
|
|
| |
| fn process_text(&mut self, text: &str, signature: Option<String>) -> Vec<Bytes> { |
| let mut chunks = Vec::new(); |
|
|
| |
| if text.is_empty() { |
| if signature.is_some() { |
| self.state.set_trailing_signature(signature); |
| } |
| return chunks; |
| } |
|
|
| |
| if self.state.has_trailing_signature() { |
| chunks.extend(self.state.end_block()); |
| if let Some(trailing_sig) = self.state.trailing_signature.take() { |
| chunks.push(self.state.emit( |
| "content_block_start", |
| json!({ |
| "type": "content_block_start", |
| "index": self.state.current_block_index(), |
| "content_block": { "type": "thinking", "thinking": "" } |
| }), |
| )); |
| chunks.push( |
| self.state |
| .emit_delta("thinking_delta", json!({ "thinking": "" })), |
| ); |
| chunks.push( |
| self.state |
| .emit_delta("signature_delta", json!({ "signature": trailing_sig })), |
| ); |
| chunks.extend(self.state.end_block()); |
| } |
| } |
|
|
| |
| if signature.is_some() { |
| |
| chunks.extend( |
| self.state |
| .start_block(BlockType::Text, json!({ "type": "text", "text": "" })), |
| ); |
| chunks.push(self.state.emit_delta("text_delta", json!({ "text": text }))); |
| chunks.extend(self.state.end_block()); |
|
|
| |
| chunks.push(self.state.emit( |
| "content_block_start", |
| json!({ |
| "type": "content_block_start", |
| "index": self.state.current_block_index(), |
| "content_block": { "type": "thinking", "thinking": "" } |
| }), |
| )); |
| chunks.push( |
| self.state |
| .emit_delta("thinking_delta", json!({ "thinking": "" })), |
| ); |
| chunks.push(self.state.emit_delta( |
| "signature_delta", |
| json!({ "signature": signature.unwrap() }), |
| )); |
| chunks.extend(self.state.end_block()); |
|
|
| return chunks; |
| } |
|
|
| |
| if self.state.current_block_type() != BlockType::Text { |
| chunks.extend( |
| self.state |
| .start_block(BlockType::Text, json!({ "type": "text", "text": "" })), |
| ); |
| } |
|
|
| chunks.push(self.state.emit_delta("text_delta", json!({ "text": text }))); |
|
|
| chunks |
| } |
|
|
| |
| fn process_function_call( |
| &mut self, |
| fc: &FunctionCall, |
| signature: Option<String>, |
| ) -> Vec<Bytes> { |
| let mut chunks = Vec::new(); |
|
|
| self.state.mark_tool_used(); |
|
|
| let tool_id = fc.id.clone().unwrap_or_else(|| { |
| format!( |
| "{}-{}", |
| fc.name, |
| crate::proxy::common::utils::generate_random_id() |
| ) |
| }); |
|
|
| |
| let mut tool_use = json!({ |
| "type": "tool_use", |
| "id": tool_id, |
| "name": fc.name, |
| "input": {} |
| }); |
|
|
| if let Some(ref sig) = signature { |
| tool_use["signature"] = json!(sig); |
| |
| store_thought_signature(sig); |
| tracing::info!( |
| "[Claude-SSE] Captured thought_signature for function call (length: {})", |
| sig.len() |
| ); |
| } |
|
|
| chunks.extend(self.state.start_block(BlockType::Function, tool_use)); |
|
|
| |
| if let Some(args) = &fc.args { |
| let json_str = serde_json::to_string(args).unwrap_or_else(|_| "{}".to_string()); |
| chunks.push( |
| self.state |
| .emit_delta("input_json_delta", json!({ "partial_json": json_str })), |
| ); |
| } |
|
|
| |
| chunks.extend(self.state.end_block()); |
|
|
| chunks |
| } |
| } |
|
|
| #[cfg(test)] |
| mod tests { |
| use super::*; |
|
|
| #[test] |
| fn test_signature_manager() { |
| let mut mgr = SignatureManager::new(); |
| assert!(!mgr.has_pending()); |
|
|
| mgr.store(Some("sig123".to_string())); |
| assert!(mgr.has_pending()); |
|
|
| let sig = mgr.consume(); |
| assert_eq!(sig, Some("sig123".to_string())); |
| assert!(!mgr.has_pending()); |
| } |
|
|
| #[test] |
| fn test_streaming_state_emit() { |
| let state = StreamingState::new(); |
| let chunk = state.emit("test_event", json!({"foo": "bar"})); |
|
|
| let s = String::from_utf8(chunk.to_vec()).unwrap(); |
| assert!(s.contains("event: test_event")); |
| assert!(s.contains("\"foo\":\"bar\"")); |
| } |
|
|
| #[test] |
| fn test_process_function_call_deltas() { |
| let mut state = StreamingState::new(); |
| let mut processor = PartProcessor::new(&mut state); |
|
|
| let fc = FunctionCall { |
| name: "test_tool".to_string(), |
| args: Some(json!({"arg": "value"})), |
| id: Some("call_123".to_string()), |
| }; |
|
|
| |
| let part = GeminiPart { |
| text: None, |
| function_call: Some(fc), |
| inline_data: None, |
| thought: None, |
| thought_signature: None, |
| function_response: None, |
| }; |
|
|
| let chunks = processor.process(&part); |
| let output = chunks |
| .iter() |
| .map(|b| String::from_utf8(b.to_vec()).unwrap()) |
| .collect::<Vec<_>>() |
| .join(""); |
|
|
| |
| |
| assert!(output.contains(r#""type":"content_block_start""#)); |
| assert!(output.contains(r#""name":"test_tool""#)); |
| assert!(output.contains(r#""input":{}"#)); |
|
|
| |
| assert!(output.contains(r#""type":"content_block_delta""#)); |
| assert!(output.contains(r#""type":"input_json_delta""#)); |
| |
| assert!(output.contains(r#"partial_json":"{\"arg\":\"value\"}"#)); |
|
|
| |
| assert!(output.contains(r#""type":"content_block_stop""#)); |
| } |
| } |
|
|