| |
|
|
| use axum::http::StatusCode; |
| use smg::{ |
| config::RouterConfig, |
| protocols::{ |
| common::{GenerationRequest, ToolChoice, ToolChoiceValue, UsageInfo}, |
| responses::{ |
| ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseTool, ResponseToolType, |
| ResponsesRequest, ServiceTier, Truncation, |
| }, |
| }, |
| routers::{conversations, RouterFactory}, |
| }; |
|
|
| use crate::common::{ |
| mock_mcp_server::MockMCPServer, |
| mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}, |
| }; |
|
|
| #[tokio::test] |
| async fn test_non_streaming_mcp_minimal_e2e_with_persistence() { |
| |
| let mut mcp = MockMCPServer::start().await.expect("start mcp"); |
|
|
| |
| let mcp_yaml = format!( |
| "servers:\n - name: mock\n protocol: streamable\n url: {}\n", |
| mcp.url() |
| ); |
| let dir = tempfile::tempdir().expect("tmpdir"); |
| let cfg_path = dir.path().join("mcp.yaml"); |
| std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg"); |
|
|
| |
| let mut worker = MockWorker::new(MockWorkerConfig { |
| port: 0, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }); |
| let worker_url = worker.start().await.expect("start worker"); |
|
|
| |
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec![worker_url]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(5) |
| .worker_startup_check_interval_secs(1) |
| .log_level("warn") |
| .max_concurrent_requests(32) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| |
| let ctx = |
| crate::common::create_test_context_with_mcp_config(router_cfg, cfg_path.to_str().unwrap()) |
| .await; |
| let router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| |
| let req = ResponsesRequest { |
| background: Some(false), |
| include: None, |
| input: ResponseInput::Text("search something".to_string()), |
| instructions: Some("Be brief".to_string()), |
| max_output_tokens: Some(64), |
| max_tool_calls: None, |
| metadata: None, |
| model: "mock-model".to_string(), |
| parallel_tool_calls: Some(true), |
| previous_response_id: None, |
| reasoning: None, |
| service_tier: Some(ServiceTier::Auto), |
| store: Some(true), |
| stream: Some(false), |
| temperature: Some(0.2), |
| tool_choice: Some(ToolChoice::default()), |
| tools: Some(vec![ResponseTool { |
| r#type: ResponseToolType::Mcp, |
| function: None, |
| server_url: Some(mcp.url()), |
| authorization: None, |
| server_label: Some("mock".to_string()), |
| server_description: None, |
| require_approval: None, |
| allowed_tools: None, |
| }]), |
| top_logprobs: Some(0), |
| top_p: None, |
| truncation: Some(Truncation::Disabled), |
| text: None, |
| user: None, |
| request_id: Some("resp_test_mcp_e2e".to_string()), |
| priority: 0, |
| frequency_penalty: Some(0.0), |
| presence_penalty: Some(0.0), |
| stop: None, |
| top_k: -1, |
| min_p: 0.0, |
| repetition_penalty: 1.0, |
| conversation: None, |
| }; |
|
|
| let resp = router |
| .route_responses(None, &req, Some(req.model.as_str())) |
| .await; |
|
|
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .expect("Failed to read response body"); |
| let body_json: serde_json::Value = |
| serde_json::from_slice(&body_bytes).expect("Failed to parse response JSON"); |
|
|
| let output = body_json |
| .get("output") |
| .and_then(|v| v.as_array()) |
| .expect("response output missing"); |
| assert!(!output.is_empty(), "expected at least one output item"); |
|
|
| |
| let list_tools_item = output |
| .iter() |
| .find(|entry| { |
| entry.get("type") == Some(&serde_json::Value::String("mcp_list_tools".into())) |
| }) |
| .expect("missing mcp_list_tools output item"); |
|
|
| assert_eq!( |
| list_tools_item.get("server_label").and_then(|v| v.as_str()), |
| Some("mock"), |
| "server_label should match" |
| ); |
| let tools_list = list_tools_item |
| .get("tools") |
| .and_then(|v| v.as_array()) |
| .expect("tools array missing in mcp_list_tools"); |
| assert!( |
| !tools_list.is_empty(), |
| "mcp_list_tools should contain at least one tool" |
| ); |
|
|
| |
| let mcp_call_item = output |
| .iter() |
| .find(|entry| entry.get("type") == Some(&serde_json::Value::String("mcp_call".into()))) |
| .expect("missing mcp_call output item"); |
|
|
| assert_eq!( |
| mcp_call_item.get("status").and_then(|v| v.as_str()), |
| Some("completed"), |
| "mcp_call status should be completed" |
| ); |
| assert_eq!( |
| mcp_call_item.get("server_label").and_then(|v| v.as_str()), |
| Some("mock"), |
| "server_label should match" |
| ); |
| assert!( |
| mcp_call_item.get("name").is_some(), |
| "mcp_call should have a tool name" |
| ); |
| assert!( |
| mcp_call_item.get("arguments").is_some(), |
| "mcp_call should have arguments" |
| ); |
| assert!( |
| mcp_call_item.get("output").is_some(), |
| "mcp_call should have output" |
| ); |
|
|
| let final_text = output |
| .iter() |
| .rev() |
| .filter_map(|entry| entry.get("content")) |
| .filter_map(|content| content.as_array()) |
| .flat_map(|parts| parts.iter()) |
| .filter_map(|part| part.get("text")) |
| .filter_map(|v| v.as_str()) |
| .next(); |
|
|
| if let Some(text) = final_text { |
| assert_eq!(text, "Tool result consumed; here is the final answer."); |
| } else { |
| let call_entry = output.iter().find(|entry| { |
| entry.get("type") == Some(&serde_json::Value::String("function_tool_call".into())) |
| }); |
| assert!(call_entry.is_some(), "missing function tool call entry"); |
| if let Some(entry) = call_entry { |
| assert_eq!( |
| entry.get("status").and_then(|v| v.as_str()), |
| Some("in_progress"), |
| "function call should be in progress when no content is returned" |
| ); |
| } |
| } |
|
|
| let tools = body_json |
| .get("tools") |
| .and_then(|v| v.as_array()) |
| .expect("tools array missing"); |
| assert_eq!(tools.len(), 1); |
| let tool = tools.first().unwrap(); |
| assert_eq!(tool.get("type").and_then(|v| v.as_str()), Some("mcp")); |
| assert_eq!( |
| tool.get("server_label").and_then(|v| v.as_str()), |
| Some("mock") |
| ); |
|
|
| |
| worker.stop().await; |
| mcp.stop().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_conversations_crud_basic() { |
| |
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec!["http://localhost".to_string()]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .log_level("warn") |
| .max_concurrent_requests(8) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = crate::common::create_test_context(router_cfg).await; |
| let _router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| |
| let create_body = serde_json::json!({ "metadata": { "project": "alpha" } }); |
| let create_resp = |
| conversations::create_conversation(&ctx.conversation_storage, create_body.clone()).await; |
| assert_eq!(create_resp.status(), StatusCode::OK); |
| let create_bytes = axum::body::to_bytes(create_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let create_json: serde_json::Value = serde_json::from_slice(&create_bytes).unwrap(); |
| let conv_id = create_json["id"].as_str().expect("id missing"); |
| assert!(conv_id.starts_with("conv_")); |
| assert_eq!(create_json["object"], "conversation"); |
|
|
| |
| let get_resp = conversations::get_conversation(&ctx.conversation_storage, conv_id).await; |
| assert_eq!(get_resp.status(), StatusCode::OK); |
| let get_bytes = axum::body::to_bytes(get_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let get_json: serde_json::Value = serde_json::from_slice(&get_bytes).unwrap(); |
| assert_eq!(get_json["metadata"]["project"], serde_json::json!("alpha")); |
|
|
| |
| let update_body = serde_json::json!({ "metadata": { "owner": "alice" } }); |
| let upd_resp = |
| conversations::update_conversation(&ctx.conversation_storage, conv_id, update_body.clone()) |
| .await; |
| assert_eq!(upd_resp.status(), StatusCode::OK); |
| let upd_bytes = axum::body::to_bytes(upd_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let upd_json: serde_json::Value = serde_json::from_slice(&upd_bytes).unwrap(); |
| assert_eq!(upd_json["metadata"]["project"], serde_json::json!("alpha")); |
| assert_eq!(upd_json["metadata"]["owner"], serde_json::json!("alice")); |
|
|
| |
| let del_resp = conversations::delete_conversation(&ctx.conversation_storage, conv_id).await; |
| assert_eq!(del_resp.status(), StatusCode::OK); |
| let del_bytes = axum::body::to_bytes(del_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let del_json: serde_json::Value = serde_json::from_slice(&del_bytes).unwrap(); |
| assert_eq!(del_json["deleted"], serde_json::json!(true)); |
|
|
| |
| let not_found = conversations::get_conversation(&ctx.conversation_storage, conv_id).await; |
| assert_eq!(not_found.status(), StatusCode::NOT_FOUND); |
| } |
|
|
| #[test] |
| fn test_responses_request_creation() { |
| let request = ResponsesRequest { |
| background: Some(false), |
| include: None, |
| input: ResponseInput::Text("Hello, world!".to_string()), |
| instructions: Some("Be helpful".to_string()), |
| max_output_tokens: Some(100), |
| max_tool_calls: None, |
| metadata: None, |
| model: "test-model".to_string(), |
| parallel_tool_calls: Some(true), |
| previous_response_id: None, |
| reasoning: Some(ResponseReasoningParam { |
| effort: Some(ReasoningEffort::Medium), |
| summary: None, |
| }), |
| service_tier: Some(ServiceTier::Auto), |
| store: Some(true), |
| stream: Some(false), |
| temperature: Some(0.7), |
| tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), |
| tools: Some(vec![ResponseTool { |
| r#type: ResponseToolType::WebSearchPreview, |
| ..Default::default() |
| }]), |
| top_logprobs: Some(5), |
| top_p: Some(0.9), |
| truncation: Some(Truncation::Disabled), |
| text: None, |
| user: Some("test-user".to_string()), |
| request_id: Some("resp_test123".to_string()), |
| priority: 0, |
| frequency_penalty: Some(0.0), |
| presence_penalty: Some(0.0), |
| stop: None, |
| top_k: -1, |
| min_p: 0.0, |
| repetition_penalty: 1.0, |
| conversation: None, |
| }; |
|
|
| assert!(!request.is_stream()); |
| assert_eq!(request.get_model(), Some("test-model")); |
| let routing_text = request.extract_text_for_routing(); |
| assert_eq!(routing_text, "Hello, world!"); |
| } |
|
|
| #[test] |
| fn test_responses_request_sglang_extensions() { |
| |
| let request = ResponsesRequest { |
| background: Some(false), |
| include: None, |
| input: ResponseInput::Text("Test".to_string()), |
| instructions: None, |
| max_output_tokens: Some(50), |
| max_tool_calls: None, |
| metadata: None, |
| model: "test-model".to_string(), |
| parallel_tool_calls: Some(true), |
| previous_response_id: None, |
| reasoning: None, |
| service_tier: Some(ServiceTier::Auto), |
| store: Some(true), |
| stream: Some(false), |
| temperature: Some(0.8), |
| tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), |
| tools: Some(vec![]), |
| top_logprobs: Some(0), |
| top_p: Some(0.95), |
| truncation: Some(Truncation::Auto), |
| text: None, |
| user: None, |
| request_id: Some("resp_test456".to_string()), |
| priority: 0, |
| frequency_penalty: Some(0.1), |
| presence_penalty: Some(0.2), |
| stop: None, |
| |
| top_k: 10, |
| min_p: 0.05, |
| repetition_penalty: 1.1, |
| conversation: None, |
| }; |
|
|
| |
| assert_eq!(request.top_k, 10); |
| assert_eq!(request.min_p, 0.05); |
| assert_eq!(request.repetition_penalty, 1.1); |
|
|
| |
| let json = serde_json::to_string(&request).expect("Serialization should work"); |
| let parsed: ResponsesRequest = |
| serde_json::from_str(&json).expect("Deserialization should work"); |
|
|
| assert_eq!(parsed.top_k, 10); |
| assert_eq!(parsed.min_p, 0.05); |
| assert_eq!(parsed.repetition_penalty, 1.1); |
| } |
|
|
| #[test] |
| fn test_usage_conversion() { |
| |
| let usage_info = UsageInfo { |
| prompt_tokens: 15, |
| completion_tokens: 25, |
| total_tokens: 40, |
| reasoning_tokens: Some(8), |
| prompt_tokens_details: Some(smg::protocols::common::PromptTokenUsageInfo { |
| cached_tokens: 3, |
| }), |
| }; |
| let response_usage = usage_info.to_response_usage(); |
|
|
| assert_eq!(response_usage.input_tokens, 15); |
| assert_eq!(response_usage.output_tokens, 25); |
| assert_eq!(response_usage.total_tokens, 40); |
|
|
| |
| assert!(response_usage.input_tokens_details.is_some()); |
| assert_eq!( |
| response_usage |
| .input_tokens_details |
| .as_ref() |
| .unwrap() |
| .cached_tokens, |
| 3 |
| ); |
|
|
| assert!(response_usage.output_tokens_details.is_some()); |
| assert_eq!( |
| response_usage |
| .output_tokens_details |
| .as_ref() |
| .unwrap() |
| .reasoning_tokens, |
| 8 |
| ); |
|
|
| let back_to_usage = response_usage.to_usage_info(); |
| assert_eq!(back_to_usage.prompt_tokens, 15); |
| assert_eq!(back_to_usage.completion_tokens, 25); |
| assert_eq!(back_to_usage.reasoning_tokens, Some(8)); |
| } |
|
|
| #[test] |
| fn test_reasoning_param_default() { |
| let param = ResponseReasoningParam { |
| effort: Some(ReasoningEffort::Medium), |
| summary: None, |
| }; |
|
|
| let json = serde_json::to_string(¶m).unwrap(); |
| let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap(); |
|
|
| assert!(matches!(parsed.effort, Some(ReasoningEffort::Medium))); |
| } |
|
|
| #[test] |
| fn test_json_serialization() { |
| let request = ResponsesRequest { |
| background: Some(true), |
| include: None, |
| input: ResponseInput::Text("Test input".to_string()), |
| instructions: Some("Test instructions".to_string()), |
| max_output_tokens: Some(200), |
| max_tool_calls: Some(5), |
| metadata: None, |
| model: "gpt-4".to_string(), |
| parallel_tool_calls: Some(false), |
| previous_response_id: None, |
| reasoning: Some(ResponseReasoningParam { |
| effort: Some(ReasoningEffort::High), |
| summary: None, |
| }), |
| service_tier: Some(ServiceTier::Priority), |
| store: Some(false), |
| stream: Some(true), |
| temperature: Some(0.9), |
| tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Required)), |
| tools: Some(vec![ResponseTool { |
| r#type: ResponseToolType::CodeInterpreter, |
| ..Default::default() |
| }]), |
| top_logprobs: Some(10), |
| top_p: Some(0.8), |
| truncation: Some(Truncation::Auto), |
| text: None, |
| user: Some("test_user".to_string()), |
| request_id: Some("resp_comprehensive_test".to_string()), |
| priority: 1, |
| frequency_penalty: Some(0.3), |
| presence_penalty: Some(0.4), |
| stop: None, |
| top_k: 50, |
| min_p: 0.1, |
| repetition_penalty: 1.2, |
| conversation: None, |
| }; |
|
|
| let json = serde_json::to_string(&request).expect("Serialization should work"); |
| let parsed: ResponsesRequest = |
| serde_json::from_str(&json).expect("Deserialization should work"); |
|
|
| assert_eq!( |
| parsed.request_id, |
| Some("resp_comprehensive_test".to_string()) |
| ); |
| assert_eq!(parsed.model, "gpt-4"); |
| assert_eq!(parsed.background, Some(true)); |
| assert_eq!(parsed.stream, Some(true)); |
| assert_eq!(parsed.tools.as_ref().map(|t| t.len()), Some(1)); |
| } |
|
|
| #[tokio::test] |
| async fn test_multi_turn_loop_with_mcp() { |
| |
| |
| |
| |
| |
| |
|
|
| |
| let mut mcp = MockMCPServer::start().await.expect("start mcp"); |
|
|
| |
| let mcp_yaml = format!( |
| "servers:\n - name: mock\n protocol: streamable\n url: {}\n", |
| mcp.url() |
| ); |
| let dir = tempfile::tempdir().expect("tmpdir"); |
| let cfg_path = dir.path().join("mcp.yaml"); |
| std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg"); |
| std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap()); |
|
|
| |
| let mut worker = MockWorker::new(MockWorkerConfig { |
| port: 0, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }); |
| let worker_url = worker.start().await.expect("start worker"); |
|
|
| |
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec![worker_url]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(5) |
| .worker_startup_check_interval_secs(1) |
| .log_level("info") |
| .max_concurrent_requests(32) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = crate::common::create_test_context(router_cfg).await; |
| let router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| |
| let req = ResponsesRequest { |
| background: Some(false), |
| include: None, |
| input: ResponseInput::Text("search for SGLang".to_string()), |
| instructions: Some("Be helpful".to_string()), |
| max_output_tokens: Some(128), |
| max_tool_calls: None, |
| metadata: None, |
| model: "mock-model".to_string(), |
| parallel_tool_calls: Some(true), |
| previous_response_id: None, |
| reasoning: None, |
| service_tier: Some(ServiceTier::Auto), |
| store: Some(true), |
| stream: Some(false), |
| temperature: Some(0.7), |
| tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), |
| tools: Some(vec![ResponseTool { |
| r#type: ResponseToolType::Mcp, |
| server_url: Some(mcp.url()), |
| server_label: Some("mock".to_string()), |
| server_description: Some("Mock MCP server for testing".to_string()), |
| require_approval: Some("never".to_string()), |
| ..Default::default() |
| }]), |
| top_logprobs: Some(0), |
| top_p: Some(1.0), |
| truncation: Some(Truncation::Disabled), |
| text: None, |
| user: None, |
| request_id: Some("resp_multi_turn_test".to_string()), |
| priority: 0, |
| frequency_penalty: Some(0.0), |
| presence_penalty: Some(0.0), |
| stop: None, |
| top_k: 50, |
| min_p: 0.0, |
| repetition_penalty: 1.0, |
| conversation: None, |
| }; |
|
|
| |
| let response = router.route_responses(None, &req, None).await; |
|
|
| |
| assert_eq!(response.status(), StatusCode::OK, "Request should succeed"); |
|
|
| |
| use axum::body::to_bytes; |
| let response_body = response.into_body(); |
| let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap(); |
| let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); |
|
|
| println!( |
| "Multi-turn response: {}", |
| serde_json::to_string_pretty(&response_json).unwrap() |
| ); |
|
|
| |
| assert_eq!(response_json["object"], "response"); |
| assert_eq!(response_json["status"], "completed"); |
| |
| assert!( |
| response_json["id"].is_string(), |
| "Response should have an id" |
| ); |
|
|
| |
| let output = response_json["output"] |
| .as_array() |
| .expect("output should be array"); |
| assert!(!output.is_empty(), "output should not be empty"); |
|
|
| |
| let has_final_text = output.iter().any(|item| { |
| item.get("type") |
| .and_then(|t| t.as_str()) |
| .map(|t| t == "message") |
| .unwrap_or(false) |
| && item |
| .get("content") |
| .and_then(|c| c.as_array()) |
| .map(|arr| { |
| arr.iter().any(|part| { |
| part.get("type") |
| .and_then(|t| t.as_str()) |
| .map(|t| t == "output_text") |
| .unwrap_or(false) |
| }) |
| }) |
| .unwrap_or(false) |
| }); |
|
|
| assert!(has_final_text, "Should have final text output"); |
|
|
| |
| let tools = response_json["tools"] |
| .as_array() |
| .expect("tools should be array"); |
| assert_eq!(tools.len(), 1); |
| assert_eq!(tools[0]["type"], "mcp"); |
| assert_eq!(tools[0]["server_label"], "mock"); |
|
|
| |
| std::env::remove_var("SGLANG_MCP_CONFIG"); |
| worker.stop().await; |
| mcp.stop().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_max_tool_calls_limit() { |
| |
| |
| |
|
|
| let mut mcp = MockMCPServer::start().await.expect("start mcp"); |
| let mcp_yaml = format!( |
| "servers:\n - name: mock\n protocol: streamable\n url: {}\n", |
| mcp.url() |
| ); |
| let dir = tempfile::tempdir().expect("tmpdir"); |
| let cfg_path = dir.path().join("mcp.yaml"); |
| std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg"); |
| std::env::set_var("SGLANG_MCP_CONFIG", cfg_path.to_str().unwrap()); |
|
|
| let mut worker = MockWorker::new(MockWorkerConfig { |
| port: 0, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }); |
| let worker_url = worker.start().await.expect("start worker"); |
|
|
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec![worker_url]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(5) |
| .worker_startup_check_interval_secs(1) |
| .log_level("info") |
| .max_concurrent_requests(32) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = crate::common::create_test_context(router_cfg).await; |
| let router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| let req = ResponsesRequest { |
| background: Some(false), |
| include: None, |
| input: ResponseInput::Text("test max calls".to_string()), |
| instructions: None, |
| max_output_tokens: Some(128), |
| max_tool_calls: Some(1), |
| metadata: None, |
| model: "mock-model".to_string(), |
| parallel_tool_calls: Some(true), |
| previous_response_id: None, |
| reasoning: None, |
| service_tier: Some(ServiceTier::Auto), |
| store: Some(false), |
| stream: Some(false), |
| temperature: Some(0.7), |
| tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), |
| tools: Some(vec![ResponseTool { |
| r#type: ResponseToolType::Mcp, |
| server_url: Some(mcp.url()), |
| server_label: Some("mock".to_string()), |
| ..Default::default() |
| }]), |
| top_logprobs: Some(0), |
| top_p: Some(1.0), |
| truncation: Some(Truncation::Disabled), |
| text: None, |
| user: None, |
| request_id: Some("resp_max_calls_test".to_string()), |
| priority: 0, |
| frequency_penalty: Some(0.0), |
| presence_penalty: Some(0.0), |
| stop: None, |
| top_k: 50, |
| min_p: 0.0, |
| repetition_penalty: 1.0, |
| conversation: None, |
| }; |
|
|
| let response = router.route_responses(None, &req, None).await; |
| assert_eq!(response.status(), StatusCode::OK); |
|
|
| use axum::body::to_bytes; |
| let response_body = response.into_body(); |
| let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap(); |
| let response_json: serde_json::Value = serde_json::from_slice(&body_bytes).unwrap(); |
|
|
| println!( |
| "Max calls response: {}", |
| serde_json::to_string_pretty(&response_json).unwrap() |
| ); |
|
|
| |
| |
| assert_eq!(response_json["status"], "completed"); |
|
|
| |
| assert!(response_json["id"].is_string()); |
| assert_eq!(response_json["object"], "response"); |
|
|
| |
| let tools = response_json["tools"] |
| .as_array() |
| .expect("tools should be array"); |
| assert_eq!(tools.len(), 1); |
| assert_eq!(tools[0]["type"], "mcp"); |
|
|
| |
| |
|
|
| std::env::remove_var("SGLANG_MCP_CONFIG"); |
| worker.stop().await; |
| mcp.stop().await; |
| } |
|
|
| |
| |
| async fn setup_streaming_mcp_test() -> ( |
| MockMCPServer, |
| MockWorker, |
| Box<dyn smg::routers::RouterTrait>, |
| tempfile::TempDir, |
| ) { |
| let mcp = MockMCPServer::start().await.expect("start mcp"); |
| let mcp_yaml = format!( |
| "servers:\n - name: mock\n protocol: streamable\n url: {}\n", |
| mcp.url() |
| ); |
| let dir = tempfile::tempdir().expect("tmpdir"); |
| let cfg_path = dir.path().join("mcp.yaml"); |
| std::fs::write(&cfg_path, mcp_yaml).expect("write mcp cfg"); |
|
|
| let mut worker = MockWorker::new(MockWorkerConfig { |
| port: 0, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }); |
| let worker_url = worker.start().await.expect("start worker"); |
|
|
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec![worker_url]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(5) |
| .worker_startup_check_interval_secs(1) |
| .log_level("info") |
| .max_concurrent_requests(32) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = |
| crate::common::create_test_context_with_mcp_config(router_cfg, cfg_path.to_str().unwrap()) |
| .await; |
| let router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| (mcp, worker, router, dir) |
| } |
|
|
| |
| fn parse_sse_events(body: &str) -> Vec<(Option<String>, serde_json::Value)> { |
| let mut events = Vec::new(); |
| let blocks: Vec<&str> = body |
| .split("\n\n") |
| .filter(|s| !s.trim().is_empty()) |
| .collect(); |
|
|
| for block in blocks { |
| let mut event_name: Option<String> = None; |
| let mut data_lines: Vec<String> = Vec::new(); |
|
|
| for line in block.lines() { |
| if let Some(rest) = line.strip_prefix("event:") { |
| event_name = Some(rest.trim().to_string()); |
| } else if let Some(rest) = line.strip_prefix("data:") { |
| let data = rest.trim_start(); |
| |
| if data != "[DONE]" { |
| data_lines.push(data.to_string()); |
| } |
| } |
| } |
|
|
| if !data_lines.is_empty() { |
| let data = data_lines.join("\n"); |
| if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(&data) { |
| events.push((event_name, parsed)); |
| } |
| } |
| } |
|
|
| events |
| } |
|
|
| #[tokio::test] |
| async fn test_streaming_with_mcp_tool_calls() { |
| |
| |
| |
| |
| |
| |
| |
|
|
| let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await; |
|
|
| |
| let req = ResponsesRequest { |
| background: Some(false), |
| include: None, |
| input: ResponseInput::Text("search for something interesting".to_string()), |
| instructions: Some("Use tools when needed".to_string()), |
| max_output_tokens: Some(256), |
| max_tool_calls: Some(3), |
| metadata: None, |
| model: "mock-model".to_string(), |
| parallel_tool_calls: Some(true), |
| previous_response_id: None, |
| reasoning: None, |
| service_tier: Some(ServiceTier::Auto), |
| store: Some(true), |
| stream: Some(true), |
| temperature: Some(0.7), |
| tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), |
| tools: Some(vec![ResponseTool { |
| r#type: ResponseToolType::Mcp, |
| server_url: Some(mcp.url()), |
| server_label: Some("mock".to_string()), |
| server_description: Some("Mock MCP for streaming test".to_string()), |
| require_approval: Some("never".to_string()), |
| ..Default::default() |
| }]), |
| top_logprobs: Some(0), |
| top_p: Some(1.0), |
| truncation: Some(Truncation::Disabled), |
| text: None, |
| user: None, |
| request_id: Some("resp_streaming_mcp_test".to_string()), |
| priority: 0, |
| frequency_penalty: Some(0.0), |
| presence_penalty: Some(0.0), |
| stop: None, |
| top_k: 50, |
| min_p: 0.0, |
| repetition_penalty: 1.0, |
| conversation: None, |
| }; |
|
|
| let response = router.route_responses(None, &req, None).await; |
|
|
| |
| assert_eq!( |
| response.status(), |
| StatusCode::OK, |
| "Streaming request should succeed" |
| ); |
|
|
| |
| let content_type = response |
| .headers() |
| .get("content-type") |
| .and_then(|v| v.to_str().ok()); |
| assert_eq!( |
| content_type, |
| Some("text/event-stream"), |
| "Should have SSE content type" |
| ); |
|
|
| |
| use axum::body::to_bytes; |
| let response_body = response.into_body(); |
| let body_bytes = to_bytes(response_body, usize::MAX).await.unwrap(); |
| let body_text = String::from_utf8_lossy(&body_bytes); |
|
|
| println!("Streaming SSE response:\n{}", body_text); |
|
|
| |
| let events = parse_sse_events(&body_text); |
|
|
| assert!(!events.is_empty(), "Should have at least one SSE event"); |
| println!("Total parsed SSE events: {}", events.len()); |
|
|
| |
| let has_done_marker = body_text.contains("data: [DONE]"); |
| assert!(has_done_marker, "Stream should end with [DONE] marker"); |
|
|
| |
| let mut found_mcp_list_tools = false; |
| let mut found_mcp_list_tools_in_progress = false; |
| let mut found_mcp_list_tools_completed = false; |
| let mut found_response_created = false; |
| let mut found_mcp_call_added = false; |
| let mut found_mcp_call_in_progress = false; |
| let mut found_mcp_call_arguments = false; |
| let mut found_mcp_call_arguments_done = false; |
| let mut found_mcp_call_done = false; |
| let mut found_response_completed = false; |
|
|
| for (event_name, data) in &events { |
| let event_type = data.get("type").and_then(|v| v.as_str()).unwrap_or(""); |
|
|
| match event_type { |
| "response.output_item.added" => { |
| |
| if let Some(item) = data.get("item") { |
| if item.get("type").and_then(|v| v.as_str()) == Some("mcp_list_tools") { |
| found_mcp_list_tools = true; |
| println!("✓ Found mcp_list_tools added event"); |
|
|
| |
| assert!( |
| item.get("tools").is_some(), |
| "mcp_list_tools should have tools array" |
| ); |
| } else if item.get("type").and_then(|v| v.as_str()) == Some("mcp_call") { |
| found_mcp_call_added = true; |
| println!("✓ Found mcp_call added event"); |
|
|
| |
| assert!(item.get("name").is_some(), "mcp_call should have name"); |
| assert_eq!( |
| item.get("server_label").and_then(|v| v.as_str()), |
| Some("mock"), |
| "mcp_call should have server_label" |
| ); |
| } |
| } |
| } |
| "response.mcp_list_tools.in_progress" => { |
| found_mcp_list_tools_in_progress = true; |
| println!("✓ Found mcp_list_tools.in_progress event"); |
|
|
| |
| assert!( |
| data.get("output_index").is_some(), |
| "mcp_list_tools.in_progress should have output_index" |
| ); |
| assert!( |
| data.get("item_id").is_some(), |
| "mcp_list_tools.in_progress should have item_id" |
| ); |
| } |
| "response.mcp_list_tools.completed" => { |
| found_mcp_list_tools_completed = true; |
| println!("✓ Found mcp_list_tools.completed event"); |
|
|
| |
| assert!( |
| data.get("output_index").is_some(), |
| "mcp_list_tools.completed should have output_index" |
| ); |
| assert!( |
| data.get("item_id").is_some(), |
| "mcp_list_tools.completed should have item_id" |
| ); |
| } |
| "response.mcp_call.in_progress" => { |
| found_mcp_call_in_progress = true; |
| println!("✓ Found mcp_call.in_progress event"); |
|
|
| |
| assert!( |
| data.get("output_index").is_some(), |
| "mcp_call.in_progress should have output_index" |
| ); |
| assert!( |
| data.get("item_id").is_some(), |
| "mcp_call.in_progress should have item_id" |
| ); |
| } |
| "response.mcp_call_arguments.delta" => { |
| found_mcp_call_arguments = true; |
| println!("✓ Found mcp_call_arguments.delta event"); |
|
|
| |
| assert!( |
| data.get("delta").is_some(), |
| "mcp_call_arguments.delta should include delta text" |
| ); |
| } |
| "response.mcp_call_arguments.done" => { |
| found_mcp_call_arguments_done = true; |
| println!("✓ Found mcp_call_arguments.done event"); |
|
|
| assert!( |
| data.get("arguments").is_some(), |
| "mcp_call_arguments.done should include full arguments" |
| ); |
| } |
| "response.output_item.done" => { |
| if let Some(item) = data.get("item") { |
| if item.get("type").and_then(|v| v.as_str()) == Some("mcp_call") { |
| found_mcp_call_done = true; |
| println!("✓ Found mcp_call done event"); |
|
|
| |
| assert!( |
| item.get("output").is_some(), |
| "mcp_call done should have output" |
| ); |
| } |
| } |
| } |
| "response.created" => { |
| found_response_created = true; |
| println!("✓ Found response.created event"); |
|
|
| |
| assert!( |
| data.get("response").is_some(), |
| "response.created should have response object" |
| ); |
| } |
| "response.completed" => { |
| found_response_completed = true; |
| println!("✓ Found response.completed event"); |
| } |
| _ => { |
| println!(" Other event: {}", event_type); |
| } |
| } |
|
|
| if let Some(name) = event_name { |
| println!(" Event name: {}", name); |
| } |
| } |
|
|
| |
| println!("\n=== Event Summary ==="); |
| println!("MCP list_tools added: {}", found_mcp_list_tools); |
| println!( |
| "MCP list_tools in_progress: {}", |
| found_mcp_list_tools_in_progress |
| ); |
| println!( |
| "MCP list_tools completed: {}", |
| found_mcp_list_tools_completed |
| ); |
| println!("Response created: {}", found_response_created); |
| println!("MCP call added: {}", found_mcp_call_added); |
| println!("MCP call in_progress: {}", found_mcp_call_in_progress); |
| println!("MCP call arguments delta: {}", found_mcp_call_arguments); |
| println!("MCP call arguments done: {}", found_mcp_call_arguments_done); |
| println!("MCP call done: {}", found_mcp_call_done); |
| println!("Response completed: {}", found_response_completed); |
|
|
| |
| assert!( |
| found_mcp_list_tools, |
| "Should send mcp_list_tools added event at the start" |
| ); |
| assert!( |
| found_mcp_list_tools_in_progress, |
| "Should send mcp_list_tools.in_progress event" |
| ); |
| assert!( |
| found_mcp_list_tools_completed, |
| "Should send mcp_list_tools.completed event" |
| ); |
| assert!(found_response_created, "Should send response.created event"); |
| assert!(found_mcp_call_added, "Should send mcp_call added event"); |
| assert!( |
| found_mcp_call_in_progress, |
| "Should send mcp_call.in_progress event" |
| ); |
| assert!(found_mcp_call_done, "Should send mcp_call done event"); |
|
|
| assert!( |
| found_mcp_call_arguments, |
| "Should send mcp_call_arguments.delta event" |
| ); |
| assert!( |
| found_mcp_call_arguments_done, |
| "Should send mcp_call_arguments.done event" |
| ); |
|
|
| |
| let has_error = body_text.contains("event: error"); |
| assert!(!has_error, "Should not have error events"); |
|
|
| worker.stop().await; |
| mcp.stop().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_streaming_multi_turn_with_mcp() { |
| |
| let (mut mcp, mut worker, router, _dir) = setup_streaming_mcp_test().await; |
|
|
| let req = ResponsesRequest { |
| background: Some(false), |
| include: None, |
| input: ResponseInput::Text("complex query requiring multiple tool calls".to_string()), |
| instructions: Some("Be thorough".to_string()), |
| max_output_tokens: Some(512), |
| max_tool_calls: Some(5), |
| metadata: None, |
| model: "mock-model".to_string(), |
| parallel_tool_calls: Some(true), |
| previous_response_id: None, |
| reasoning: None, |
| service_tier: Some(ServiceTier::Auto), |
| store: Some(true), |
| stream: Some(true), |
| temperature: Some(0.8), |
| tool_choice: Some(ToolChoice::Value(ToolChoiceValue::Auto)), |
| tools: Some(vec![ResponseTool { |
| r#type: ResponseToolType::Mcp, |
| server_url: Some(mcp.url()), |
| server_label: Some("mock".to_string()), |
| ..Default::default() |
| }]), |
| top_logprobs: Some(0), |
| top_p: Some(1.0), |
| truncation: Some(Truncation::Disabled), |
| text: None, |
| user: None, |
| request_id: Some("resp_streaming_multiturn_test".to_string()), |
| priority: 0, |
| frequency_penalty: Some(0.0), |
| presence_penalty: Some(0.0), |
| stop: None, |
| top_k: 50, |
| min_p: 0.0, |
| repetition_penalty: 1.0, |
| conversation: None, |
| }; |
|
|
| let response = router.route_responses(None, &req, None).await; |
| assert_eq!(response.status(), StatusCode::OK); |
|
|
| use axum::body::to_bytes; |
| let body_bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap(); |
| let body_text = String::from_utf8_lossy(&body_bytes); |
|
|
| println!("Multi-turn streaming response:\n{}", body_text); |
|
|
| |
| assert!(body_text.contains("data: [DONE]")); |
| assert!(!body_text.contains("event: error")); |
|
|
| |
| let event_count = body_text |
| .split("\n\n") |
| .filter(|s| !s.trim().is_empty()) |
| .count(); |
| println!("Total events in multi-turn stream: {}", event_count); |
|
|
| assert!(event_count > 0, "Should have received streaming events"); |
|
|
| worker.stop().await; |
| mcp.stop().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_conversation_items_create_and_get() { |
| |
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec!["http://localhost".to_string()]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .log_level("warn") |
| .max_concurrent_requests(8) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = crate::common::create_test_context(router_cfg).await; |
| let _router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| |
| let create_conv = serde_json::json!({}); |
| let conv_resp = |
| conversations::create_conversation(&ctx.conversation_storage, create_conv).await; |
| assert_eq!(conv_resp.status(), StatusCode::OK); |
| let conv_bytes = axum::body::to_bytes(conv_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let conv_json: serde_json::Value = serde_json::from_slice(&conv_bytes).unwrap(); |
| let conv_id = conv_json["id"].as_str().unwrap(); |
|
|
| |
| let create_items = serde_json::json!({ |
| "items": [ |
| { |
| "type": "message", |
| "role": "user", |
| "content": [{"type": "input_text", "text": "Hello"}] |
| }, |
| { |
| "type": "message", |
| "role": "assistant", |
| "content": [{"type": "output_text", "text": "Hi there!"}] |
| } |
| ] |
| }); |
|
|
| let items_resp = conversations::create_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| create_items, |
| ) |
| .await; |
| assert_eq!(items_resp.status(), StatusCode::OK); |
| let items_bytes = axum::body::to_bytes(items_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let items_json: serde_json::Value = serde_json::from_slice(&items_bytes).unwrap(); |
|
|
| |
| assert_eq!(items_json["object"], "list"); |
| assert!(items_json["data"].is_array()); |
|
|
| |
| let item_id = items_json["data"][0]["id"].as_str().unwrap(); |
| let get_resp = conversations::get_conversation_item( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| item_id, |
| None, |
| ) |
| .await; |
| assert_eq!(get_resp.status(), StatusCode::OK); |
| let get_bytes = axum::body::to_bytes(get_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let get_json: serde_json::Value = serde_json::from_slice(&get_bytes).unwrap(); |
|
|
| |
| assert_eq!(get_json["id"], item_id); |
| assert_eq!(get_json["type"], "message"); |
| assert_eq!(get_json["role"], "user"); |
| } |
|
|
| #[tokio::test] |
| async fn test_conversation_items_delete() { |
| |
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec!["http://localhost".to_string()]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .log_level("warn") |
| .max_concurrent_requests(8) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = crate::common::create_test_context(router_cfg).await; |
| let _router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| |
| let create_conv = serde_json::json!({}); |
| let conv_resp = |
| conversations::create_conversation(&ctx.conversation_storage, create_conv).await; |
| let conv_bytes = axum::body::to_bytes(conv_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let conv_json: serde_json::Value = serde_json::from_slice(&conv_bytes).unwrap(); |
| let conv_id = conv_json["id"].as_str().unwrap(); |
|
|
| |
| let create_items = serde_json::json!({ |
| "items": [ |
| { |
| "type": "message", |
| "role": "user", |
| "content": [{"type": "input_text", "text": "Test"}] |
| } |
| ] |
| }); |
|
|
| let items_resp = conversations::create_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| create_items, |
| ) |
| .await; |
| let items_bytes = axum::body::to_bytes(items_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let items_json: serde_json::Value = serde_json::from_slice(&items_bytes).unwrap(); |
| let item_id = items_json["data"][0]["id"].as_str().unwrap(); |
|
|
| |
| let list_resp = conversations::list_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| None, |
| None, |
| None, |
| ) |
| .await; |
| let list_bytes = axum::body::to_bytes(list_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let list_json: serde_json::Value = serde_json::from_slice(&list_bytes).unwrap(); |
| assert_eq!(list_json["data"].as_array().unwrap().len(), 1); |
|
|
| |
| let del_resp = conversations::delete_conversation_item( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| item_id, |
| ) |
| .await; |
| assert_eq!(del_resp.status(), StatusCode::OK); |
|
|
| |
| let list_resp2 = conversations::list_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| None, |
| None, |
| None, |
| ) |
| .await; |
| let list_bytes2 = axum::body::to_bytes(list_resp2.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let list_json2: serde_json::Value = serde_json::from_slice(&list_bytes2).unwrap(); |
| assert_eq!(list_json2["data"].as_array().unwrap().len(), 0); |
|
|
| |
| let get_resp = conversations::get_conversation_item( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| item_id, |
| None, |
| ) |
| .await; |
| assert_eq!(get_resp.status(), StatusCode::NOT_FOUND); |
| } |
|
|
| #[tokio::test] |
| async fn test_conversation_items_max_limit() { |
| |
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec!["http://localhost".to_string()]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .log_level("warn") |
| .max_concurrent_requests(8) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = crate::common::create_test_context(router_cfg).await; |
| let _router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| |
| let create_conv = serde_json::json!({}); |
| let conv_resp = |
| conversations::create_conversation(&ctx.conversation_storage, create_conv).await; |
| let conv_bytes = axum::body::to_bytes(conv_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let conv_json: serde_json::Value = serde_json::from_slice(&conv_bytes).unwrap(); |
| let conv_id = conv_json["id"].as_str().unwrap(); |
|
|
| |
| let mut items = Vec::new(); |
| for i in 0..21 { |
| items.push(serde_json::json!({ |
| "type": "message", |
| "role": "user", |
| "content": [{"type": "input_text", "text": format!("Message {}", i)}] |
| })); |
| } |
| let create_items = serde_json::json!({ "items": items }); |
|
|
| let items_resp = conversations::create_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| create_items, |
| ) |
| .await; |
| assert_eq!(items_resp.status(), StatusCode::BAD_REQUEST); |
|
|
| let items_bytes = axum::body::to_bytes(items_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let items_text = String::from_utf8_lossy(&items_bytes); |
| assert!(items_text.contains("Cannot add more than 20 items")); |
| } |
|
|
| #[tokio::test] |
| async fn test_conversation_items_unsupported_type() { |
| |
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec!["http://localhost".to_string()]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .log_level("warn") |
| .max_concurrent_requests(8) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = crate::common::create_test_context(router_cfg).await; |
| let _router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| |
| let create_conv = serde_json::json!({}); |
| let conv_resp = |
| conversations::create_conversation(&ctx.conversation_storage, create_conv).await; |
| let conv_bytes = axum::body::to_bytes(conv_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let conv_json: serde_json::Value = serde_json::from_slice(&conv_bytes).unwrap(); |
| let conv_id = conv_json["id"].as_str().unwrap(); |
|
|
| |
| let create_items = serde_json::json!({ |
| "items": [ |
| { |
| "type": "totally_invalid_type", |
| "content": [] |
| } |
| ] |
| }); |
|
|
| let items_resp = conversations::create_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_id, |
| create_items, |
| ) |
| .await; |
| assert_eq!(items_resp.status(), StatusCode::BAD_REQUEST); |
|
|
| let items_bytes = axum::body::to_bytes(items_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let items_text = String::from_utf8_lossy(&items_bytes); |
| assert!(items_text.contains("Unsupported item type")); |
| } |
|
|
| #[tokio::test] |
| async fn test_conversation_items_multi_conversation_sharing() { |
| |
| let router_cfg = RouterConfig::builder() |
| .openai_mode(vec!["http://localhost".to_string()]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(0) |
| .max_payload_size(8 * 1024 * 1024) |
| .request_timeout_secs(60) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .log_level("warn") |
| .max_concurrent_requests(8) |
| .queue_timeout_secs(5) |
| .build_unchecked(); |
|
|
| let ctx = crate::common::create_test_context(router_cfg).await; |
| let _router = RouterFactory::create_router(&ctx).await.expect("router"); |
|
|
| |
| let conv_a_resp = |
| conversations::create_conversation(&ctx.conversation_storage, serde_json::json!({})).await; |
| let conv_a_bytes = axum::body::to_bytes(conv_a_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let conv_a_json: serde_json::Value = serde_json::from_slice(&conv_a_bytes).unwrap(); |
| let conv_a_id = conv_a_json["id"].as_str().unwrap(); |
|
|
| let conv_b_resp = |
| conversations::create_conversation(&ctx.conversation_storage, serde_json::json!({})).await; |
| let conv_b_bytes = axum::body::to_bytes(conv_b_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let conv_b_json: serde_json::Value = serde_json::from_slice(&conv_b_bytes).unwrap(); |
| let conv_b_id = conv_b_json["id"].as_str().unwrap(); |
|
|
| |
| let create_items = serde_json::json!({ |
| "items": [ |
| { |
| "type": "message", |
| "role": "user", |
| "content": [{"type": "input_text", "text": "Shared message"}] |
| } |
| ] |
| }); |
|
|
| let items_a_resp = conversations::create_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_a_id, |
| create_items, |
| ) |
| .await; |
| let items_a_bytes = axum::body::to_bytes(items_a_resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let items_a_json: serde_json::Value = serde_json::from_slice(&items_a_bytes).unwrap(); |
| let item_id = items_a_json["data"][0]["id"].as_str().unwrap(); |
|
|
| |
| let reference_items = serde_json::json!({ |
| "items": [ |
| { |
| "type": "item_reference", |
| "id": item_id |
| } |
| ] |
| }); |
|
|
| let items_b_resp = conversations::create_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_b_id, |
| reference_items, |
| ) |
| .await; |
| assert_eq!(items_b_resp.status(), StatusCode::OK); |
|
|
| |
| let list_a = conversations::list_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_a_id, |
| None, |
| None, |
| None, |
| ) |
| .await; |
| let list_a_bytes = axum::body::to_bytes(list_a.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let list_a_json: serde_json::Value = serde_json::from_slice(&list_a_bytes).unwrap(); |
| assert_eq!(list_a_json["data"].as_array().unwrap().len(), 1); |
|
|
| let list_b = conversations::list_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_b_id, |
| None, |
| None, |
| None, |
| ) |
| .await; |
| let list_b_bytes = axum::body::to_bytes(list_b.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let list_b_json: serde_json::Value = serde_json::from_slice(&list_b_bytes).unwrap(); |
| assert_eq!(list_b_json["data"].as_array().unwrap().len(), 1); |
|
|
| |
| conversations::delete_conversation_item( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_a_id, |
| item_id, |
| ) |
| .await; |
|
|
| |
| let list_a2 = conversations::list_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_a_id, |
| None, |
| None, |
| None, |
| ) |
| .await; |
| let list_a2_bytes = axum::body::to_bytes(list_a2.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let list_a2_json: serde_json::Value = serde_json::from_slice(&list_a2_bytes).unwrap(); |
| assert_eq!(list_a2_json["data"].as_array().unwrap().len(), 0); |
|
|
| |
| let list_b2 = conversations::list_conversation_items( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_b_id, |
| None, |
| None, |
| None, |
| ) |
| .await; |
| let list_b2_bytes = axum::body::to_bytes(list_b2.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let list_b2_json: serde_json::Value = serde_json::from_slice(&list_b2_bytes).unwrap(); |
| assert_eq!(list_b2_json["data"].as_array().unwrap().len(), 1); |
|
|
| |
| let get_resp = conversations::get_conversation_item( |
| &ctx.conversation_storage, |
| &ctx.conversation_item_storage, |
| conv_b_id, |
| item_id, |
| None, |
| ) |
| .await; |
| assert_eq!(get_resp.status(), StatusCode::OK); |
| } |
|
|