| use axum::{ |
| body::Body, |
| extract::Request, |
| http::{header::CONTENT_TYPE, StatusCode}, |
| }; |
| use serde_json::json; |
| use smg::{config::RouterConfig, routers::RouterFactory}; |
| use tower::ServiceExt; |
|
|
| use crate::common::{ |
| mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType}, |
| AppTestContext, |
| }; |
|
|
| #[cfg(test)] |
| mod health_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_liveness_endpoint() { |
| let ctx = AppTestContext::new(vec![]).await; |
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/liveness") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_readiness_with_healthy_workers() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18001, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/readiness") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_readiness_with_unhealthy_workers() { |
| let ctx = AppTestContext::new(vec![]).await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/readiness") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| |
| assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); |
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_health_endpoint_details() { |
| let ctx = AppTestContext::new(vec![ |
| MockWorkerConfig { |
| port: 18003, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| MockWorkerConfig { |
| port: 18004, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| ]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/health") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_health_generate_endpoint() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18005, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/health_generate") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert!(body_json.is_object()); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod generation_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_generate_success() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18101, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "text": "Hello, world!", |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert!(body_json.get("text").is_some()); |
| assert!(body_json.get("meta_info").is_some()); |
| let meta_info = &body_json["meta_info"]; |
| assert!(meta_info.get("finish_reason").is_some()); |
| assert_eq!(meta_info["finish_reason"]["type"], "stop"); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_generate_streaming() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18102, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "text": "Stream test", |
| "stream": true |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| |
| |
| |
| |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_generate_with_worker_failure() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18103, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 1.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "text": "This should fail", |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_chat_completions_success() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18104, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "model": "test-model", |
| "messages": [ |
| {"role": "user", "content": "Hello!"} |
| ], |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/chat/completions") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert!(body_json.get("choices").is_some()); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod model_info_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_get_server_info() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18201, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/get_server_info") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert!(body_json.is_object()); |
| |
| assert!(body_json.get("version").is_some()); |
| assert!(body_json.get("model_path").is_some()); |
| assert!(body_json.get("tokenizer_path").is_some()); |
| assert!(body_json.get("port").is_some()); |
| assert!(body_json.get("max_num_batched_tokens").is_some()); |
| assert!(body_json.get("schedule_policy").is_some()); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_get_model_info() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18202, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/get_model_info") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert!(body_json.is_object()); |
| |
| assert_eq!( |
| body_json.get("model_path").and_then(|v| v.as_str()), |
| Some("mock-model-path") |
| ); |
| assert_eq!( |
| body_json.get("tokenizer_path").and_then(|v| v.as_str()), |
| Some("mock-tokenizer-path") |
| ); |
| assert_eq!( |
| body_json.get("is_generation").and_then(|v| v.as_bool()), |
| Some(true) |
| ); |
| assert!(body_json.get("preferred_sampling_params").is_some()); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_models() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18203, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/v1/models") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert!(body_json.get("object").is_some()); |
| assert_eq!( |
| body_json.get("object").and_then(|v| v.as_str()), |
| Some("list") |
| ); |
|
|
| let data = body_json.get("data").and_then(|v| v.as_array()); |
| assert!(data.is_some()); |
|
|
| let models = data.unwrap(); |
| assert!(!models.is_empty()); |
|
|
| let first_model = &models[0]; |
| assert_eq!( |
| first_model.get("id").and_then(|v| v.as_str()), |
| Some("mock-model") |
| ); |
| assert_eq!( |
| first_model.get("object").and_then(|v| v.as_str()), |
| Some("model") |
| ); |
| assert!(first_model.get("created").is_some()); |
| assert_eq!( |
| first_model.get("owned_by").and_then(|v| v.as_str()), |
| Some("organization-owner") |
| ); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_model_info_with_no_workers() { |
| let ctx = AppTestContext::new(vec![]).await; |
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/get_server_info") |
| .body(Body::empty()) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| |
| assert!( |
| resp.status() == StatusCode::OK |
| || resp.status() == StatusCode::SERVICE_UNAVAILABLE |
| || resp.status() == StatusCode::NOT_FOUND |
| || resp.status() == StatusCode::INTERNAL_SERVER_ERROR, |
| "Unexpected status code: {:?}", |
| resp.status() |
| ); |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/get_model_info") |
| .body(Body::empty()) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| |
| assert!( |
| resp.status() == StatusCode::OK |
| || resp.status() == StatusCode::SERVICE_UNAVAILABLE |
| || resp.status() == StatusCode::NOT_FOUND |
| || resp.status() == StatusCode::INTERNAL_SERVER_ERROR, |
| "Unexpected status code: {:?}", |
| resp.status() |
| ); |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/v1/models") |
| .body(Body::empty()) |
| .unwrap(); |
| let resp = app.oneshot(req).await.unwrap(); |
| |
| assert!( |
| resp.status() == StatusCode::OK |
| || resp.status() == StatusCode::SERVICE_UNAVAILABLE |
| || resp.status() == StatusCode::NOT_FOUND |
| || resp.status() == StatusCode::INTERNAL_SERVER_ERROR, |
| "Unexpected status code: {:?}", |
| resp.status() |
| ); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_model_info_with_multiple_workers() { |
| let ctx = AppTestContext::new(vec![ |
| MockWorkerConfig { |
| port: 18204, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| MockWorkerConfig { |
| port: 18205, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| ]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| for _ in 0..5 { |
| let req = Request::builder() |
| .method("GET") |
| .uri("/get_model_info") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert_eq!( |
| body_json.get("model_path").and_then(|v| v.as_str()), |
| Some("mock-model-path") |
| ); |
| } |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_model_info_with_unhealthy_worker() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18206, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 1.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/get_model_info") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| |
| assert!( |
| resp.status() == StatusCode::INTERNAL_SERVER_ERROR |
| || resp.status() == StatusCode::SERVICE_UNAVAILABLE, |
| "Expected error status for always-failing worker, got: {:?}", |
| resp.status() |
| ); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod router_policy_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_random_policy() { |
| let ctx = AppTestContext::new(vec![ |
| MockWorkerConfig { |
| port: 18801, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| MockWorkerConfig { |
| port: 18802, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| ]) |
| .await; |
|
|
| |
| let app = ctx.create_app().await; |
|
|
| for i in 0..10 { |
| let payload = json!({ |
| "text": format!("Request {}", i), |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
| } |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_worker_selection() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18207, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let _payload = json!({ |
| "text": "Test selection", |
| "stream": false |
| }); |
|
|
| |
| |
| |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod responses_endpoint_tests { |
| use reqwest::Client as HttpClient; |
|
|
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_v1_responses_non_streaming() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18950, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "input": "Hello Responses API", |
| "model": "mock-model", |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/responses") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert_eq!(body_json["object"], "response"); |
| assert_eq!(body_json["status"], "completed"); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_responses_streaming() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18951, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "input": "Hello Responses API", |
| "model": "mock-model", |
| "stream": true |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/responses") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| |
| let headers = resp.headers().clone(); |
| let ct = headers |
| .get("content-type") |
| .and_then(|v| v.to_str().ok()) |
| .unwrap_or(""); |
| assert!(ct.contains("text/event-stream")); |
|
|
| |
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_responses_get() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18952, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| |
| let resp_id = "test-get-resp-id-123"; |
| let payload = json!({ |
| "input": "Hello Responses API", |
| "model": "mock-model", |
| "stream": false, |
| "store": true, |
| "background": true, |
| "request_id": resp_id |
| }); |
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/responses") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| |
| let req = Request::builder() |
| .method("GET") |
| .uri(format!("/v1/responses/{}", resp_id)) |
| .body(Body::empty()) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let get_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert_eq!(get_json["object"], "response"); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_responses_cancel() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18953, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| |
| let resp_id = "test-cancel-resp-id-456"; |
| let payload = json!({ |
| "input": "Hello Responses API", |
| "model": "mock-model", |
| "stream": false, |
| "store": true, |
| "background": true, |
| "request_id": resp_id |
| }); |
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/responses") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| |
| let req = Request::builder() |
| .method("POST") |
| .uri(format!("/v1/responses/{}/cancel", resp_id)) |
| .body(Body::empty()) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let cancel_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
| assert_eq!(cancel_json["status"], "cancelled"); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_responses_delete_not_implemented() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18954, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| |
| let resp_id = "resp-test-123"; |
|
|
| let req = Request::builder() |
| .method("DELETE") |
| .uri(format!("/v1/responses/{}", resp_id)) |
| .body(Body::empty()) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::NOT_IMPLEMENTED); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_responses_input_items() { |
| |
| |
| let config = RouterConfig::builder() |
| .openai_mode(vec!["http://dummy.local".to_string()]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(3002) |
| .max_payload_size(256 * 1024 * 1024) |
| .request_timeout_secs(600) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .max_concurrent_requests(64) |
| .queue_size(0) |
| .queue_timeout_secs(60) |
| .build_unchecked(); |
|
|
| let ctx = AppTestContext::new_with_config( |
| config, |
| vec![], |
| ) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| |
| use data_connector::{ResponseId, StoredResponse}; |
| let mut stored_response = StoredResponse::new(None); |
| stored_response.id = ResponseId::from("resp_test_input_items"); |
| stored_response.input = json!([ |
| {"id": "item_1", "content": "hello", "role": "user"}, |
| {"id": "item_2", "content": "hi there", "role": "assistant"} |
| ]); |
| stored_response.output = json!([ |
| {"type": "message", "role": "assistant", "content": [{"type": "output_text", "text": "test response"}]} |
| ]); |
|
|
| ctx.app_context |
| .response_storage |
| .store_response(stored_response) |
| .await |
| .expect("Failed to store response"); |
|
|
| |
| let req = Request::builder() |
| .method("GET") |
| .uri("/v1/responses/resp_test_input_items/input_items") |
| .body(Body::empty()) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let items_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
|
|
| |
| assert_eq!(items_json["object"], "list"); |
| assert!(items_json["data"].is_array()); |
|
|
| |
| let items = items_json["data"].as_array().unwrap(); |
| assert_eq!(items.len(), 2); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_responses_get_multi_worker_fanout() { |
| |
| let ctx = AppTestContext::new(vec![ |
| MockWorkerConfig { |
| port: 18960, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| MockWorkerConfig { |
| port: 18961, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| ]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| |
| let rid = format!("resp_{}", 18960); |
| let payload = json!({ |
| "input": "Hello Responses API", |
| "model": "mock-model", |
| "background": true, |
| "store": true, |
| "request_id": rid, |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/responses") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| |
| let req = Request::builder() |
| .method("GET") |
| .uri(format!("/v1/responses/{}", rid)) |
| .body(Body::empty()) |
| .unwrap(); |
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| |
| let client = HttpClient::new(); |
| let mut ok_count = 0usize; |
| |
| let worker_urls: Vec<String> = vec![ |
| "http://127.0.0.1:18960".to_string(), |
| "http://127.0.0.1:18961".to_string(), |
| ]; |
| for url in worker_urls { |
| let get_url = format!("{}/v1/responses/{}", url, rid); |
| let res = client.get(get_url).send().await.unwrap(); |
| if res.status() == StatusCode::OK { |
| ok_count += 1; |
| } |
| } |
| assert_eq!(ok_count, 1, "exactly one worker should store the response"); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod error_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_404_not_found() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18401, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/unknown_endpoint") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::NOT_FOUND); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/api/v2/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from( |
| serde_json::to_string(&json!({"text": "test"})).unwrap(), |
| )) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::NOT_FOUND); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_method_not_allowed() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18402, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| |
| let req = Request::builder() |
| .method("GET") |
| .uri("/generate") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| |
| assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); |
|
|
| |
| let req = Request::builder() |
| .method("POST") |
| .uri("/health") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from("{}")) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_payload_too_large() { |
| |
| let config = RouterConfig::builder() |
| .regular_mode(vec![]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(3010) |
| .max_payload_size(1024) |
| .request_timeout_secs(600) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .max_concurrent_requests(64) |
| .queue_timeout_secs(60) |
| .build_unchecked(); |
|
|
| let ctx = AppTestContext::new_with_config( |
| config, |
| vec![MockWorkerConfig { |
| port: 18403, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }], |
| ) |
| .await; |
|
|
| |
| |
| |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_invalid_json_payload() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18404, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| |
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from("{invalid json}")) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::BAD_REQUEST); |
|
|
| |
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::BAD_REQUEST); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_invalid_model() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18406, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "model": "invalid-model-name-that-does-not-exist", |
| "messages": [{"role": "user", "content": "Hello"}], |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/chat/completions") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| |
| assert!(resp.status().is_success() || resp.status() == StatusCode::BAD_REQUEST); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod cache_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_flush_cache() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18501, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/flush_cache") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| |
| let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| if !body_bytes.is_empty() { |
| if let Ok(body) = serde_json::from_slice::<serde_json::Value>(&body_bytes) { |
| |
| assert!(body.is_object()); |
| assert!(body.get("message").is_some() || body.get("status").is_some()); |
| } |
| } |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_get_loads() { |
| let ctx = AppTestContext::new(vec![ |
| MockWorkerConfig { |
| port: 18502, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| MockWorkerConfig { |
| port: 18503, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| ]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("GET") |
| .uri("/get_loads") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
|
|
| assert!(body_json.is_object()); |
| |
| |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_flush_cache_no_workers() { |
| let ctx = AppTestContext::new(vec![]).await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/flush_cache") |
| .body(Body::empty()) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| |
| assert!( |
| resp.status() == StatusCode::OK || resp.status() == StatusCode::SERVICE_UNAVAILABLE |
| ); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod load_balancing_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_request_distribution() { |
| |
| let ctx = AppTestContext::new(vec![ |
| MockWorkerConfig { |
| port: 18601, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| MockWorkerConfig { |
| port: 18602, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }, |
| ]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| |
| let mut request_count = 0; |
| for i in 0..10 { |
| let payload = json!({ |
| "text": format!("Request {}", i), |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| if resp.status() == StatusCode::OK { |
| request_count += 1; |
| } |
| } |
|
|
| |
| assert_eq!(request_count, 10); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod pd_mode_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_pd_mode_routing() { |
| |
| let mut prefill_worker = MockWorker::new(MockWorkerConfig { |
| port: 18701, |
| worker_type: WorkerType::Prefill, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }); |
|
|
| let mut decode_worker = MockWorker::new(MockWorkerConfig { |
| port: 18702, |
| worker_type: WorkerType::Decode, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }); |
|
|
| let prefill_url = prefill_worker.start().await.unwrap(); |
| let decode_url = decode_worker.start().await.unwrap(); |
|
|
| tokio::time::sleep(tokio::time::Duration::from_millis(200)).await; |
|
|
| |
| let prefill_port = prefill_url |
| .split(':') |
| .next_back() |
| .and_then(|p| p.trim_end_matches('/').parse::<u16>().ok()) |
| .unwrap_or(9000); |
|
|
| let config = RouterConfig::builder() |
| .prefill_decode_mode(vec![(prefill_url, Some(prefill_port))], vec![decode_url]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(3011) |
| .max_payload_size(256 * 1024 * 1024) |
| .request_timeout_secs(600) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .max_concurrent_requests(64) |
| .queue_timeout_secs(60) |
| .build_unchecked(); |
|
|
| |
| let app_context = crate::common::create_test_context(config).await; |
|
|
| |
| let router_result = RouterFactory::create_router(&app_context).await; |
|
|
| |
| prefill_worker.stop().await; |
| decode_worker.stop().await; |
|
|
| |
| assert!(router_result.is_err() || router_result.is_ok()); |
| } |
| } |
|
|
| #[cfg(test)] |
| mod request_id_tests { |
| use super::*; |
|
|
| #[tokio::test] |
| async fn test_request_id_generation() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18901, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "text": "Test request", |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| |
| let request_id = resp.headers().get("x-request-id"); |
| assert!( |
| request_id.is_some(), |
| "Response should have x-request-id header" |
| ); |
|
|
| let id_value = request_id.unwrap().to_str().unwrap(); |
| assert!( |
| id_value.starts_with("gnt-"), |
| "Generate endpoint should have gnt- prefix" |
| ); |
| assert!( |
| id_value.len() > 4, |
| "Request ID should have content after prefix" |
| ); |
|
|
| let custom_id = "custom-request-id-123"; |
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .header("x-request-id", custom_id) |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let response_id = resp.headers().get("x-request-id"); |
| assert!(response_id.is_some()); |
| assert_eq!(response_id.unwrap(), custom_id); |
|
|
| let chat_payload = json!({ |
| "messages": [{"role": "user", "content": "Hello"}], |
| "model": "test-model" |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/chat/completions") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&chat_payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let request_id = resp.headers().get("x-request-id"); |
| assert!(request_id.is_some()); |
| assert!(request_id |
| .unwrap() |
| .to_str() |
| .unwrap() |
| .starts_with("chatcmpl-")); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .header("x-correlation-id", "correlation-123") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let response_id = resp.headers().get("x-request-id"); |
| assert!(response_id.is_some()); |
| assert_eq!(response_id.unwrap(), "correlation-123"); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_request_id_with_custom_headers() { |
| |
| let config = RouterConfig::builder() |
| .regular_mode(vec![]) |
| .random_policy() |
| .host("127.0.0.1") |
| .port(3002) |
| .max_payload_size(256 * 1024 * 1024) |
| .request_timeout_secs(600) |
| .worker_startup_timeout_secs(1) |
| .worker_startup_check_interval_secs(1) |
| .request_id_headers(vec!["custom-id".to_string(), "trace-id".to_string()]) |
| .max_concurrent_requests(64) |
| .queue_timeout_secs(60) |
| .build_unchecked(); |
|
|
| let ctx = AppTestContext::new_with_config( |
| config, |
| vec![MockWorkerConfig { |
| port: 18902, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }], |
| ) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "text": "Test request", |
| "stream": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/generate") |
| .header(CONTENT_TYPE, "application/json") |
| .header("custom-id", "my-custom-id") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let response_id = resp.headers().get("x-request-id"); |
| assert!(response_id.is_some()); |
| assert_eq!(response_id.unwrap(), "my-custom-id"); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|
| #[cfg(test)] |
| mod rerank_tests { |
| use super::*; |
| |
|
|
| #[tokio::test] |
| async fn test_rerank_success() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18105, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "query": "machine learning algorithms", |
| "documents": [ |
| "Introduction to machine learning concepts", |
| "Deep learning neural networks tutorial" |
| ], |
| "model": "test-rerank-model", |
| "top_k": 2, |
| "return_documents": true, |
| "rid": "test-request-123" |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
|
|
| assert!(body_json.get("results").is_some()); |
| assert!(body_json.get("model").is_some()); |
| assert_eq!(body_json["model"], "test-rerank-model"); |
|
|
| let results = body_json["results"].as_array().unwrap(); |
| assert_eq!(results.len(), 2); |
|
|
| assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap()); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_rerank_with_top_k() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18106, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "query": "test query", |
| "documents": [ |
| "Document 1", |
| "Document 2", |
| "Document 3" |
| ], |
| "model": "test-model", |
| "top_k": 1, |
| "return_documents": true |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
|
|
| |
| let results = body_json["results"].as_array().unwrap(); |
| assert_eq!(results.len(), 1); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_rerank_without_documents() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18107, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "query": "test query", |
| "documents": ["Document 1", "Document 2"], |
| "model": "test-model", |
| "return_documents": false |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
|
|
| |
| let results = body_json["results"].as_array().unwrap(); |
| for result in results { |
| assert!(result.get("document").is_none()); |
| } |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_rerank_worker_failure() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18108, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 1.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "query": "test query", |
| "documents": ["Document 1"], |
| "model": "test-model" |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| |
| assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR); |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_v1_rerank_compatibility() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18110, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "query": "machine learning algorithms", |
| "documents": [ |
| "Introduction to machine learning concepts", |
| "Deep learning neural networks tutorial", |
| "Statistical learning theory basics" |
| ] |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/v1/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::OK); |
|
|
| let body = axum::body::to_bytes(resp.into_body(), usize::MAX) |
| .await |
| .unwrap(); |
| let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap(); |
|
|
| assert!(body_json.get("results").is_some()); |
| assert!(body_json.get("model").is_some()); |
|
|
| |
| assert_eq!(body_json["model"], "unknown"); |
|
|
| let results = body_json["results"].as_array().unwrap(); |
| assert_eq!(results.len(), 3); |
|
|
| assert!(results[0]["score"].as_f64().unwrap() >= results[1]["score"].as_f64().unwrap()); |
| assert!(results[1]["score"].as_f64().unwrap() >= results[2]["score"].as_f64().unwrap()); |
|
|
| |
| for result in results { |
| assert!(result.get("document").is_some()); |
| } |
|
|
| ctx.shutdown().await; |
| } |
|
|
| #[tokio::test] |
| async fn test_rerank_invalid_request() { |
| let ctx = AppTestContext::new(vec![MockWorkerConfig { |
| port: 18111, |
| worker_type: WorkerType::Regular, |
| health_status: HealthStatus::Healthy, |
| response_delay_ms: 0, |
| fail_rate: 0.0, |
| }]) |
| .await; |
|
|
| let app = ctx.create_app().await; |
|
|
| let payload = json!({ |
| "query": "", |
| "documents": ["Document 1", "Document 2"], |
| "model": "test-model" |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::BAD_REQUEST); |
|
|
| let payload = json!({ |
| "query": " ", |
| "documents": ["Document 1", "Document 2"], |
| "model": "test-model" |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::BAD_REQUEST); |
|
|
| let payload = json!({ |
| "query": "test query", |
| "documents": [], |
| "model": "test-model" |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.clone().oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::BAD_REQUEST); |
|
|
| let payload = json!({ |
| "query": "test query", |
| "documents": ["Document 1", "Document 2"], |
| "model": "test-model", |
| "top_k": 0 |
| }); |
|
|
| let req = Request::builder() |
| .method("POST") |
| .uri("/rerank") |
| .header(CONTENT_TYPE, "application/json") |
| .body(Body::from(serde_json::to_string(&payload).unwrap())) |
| .unwrap(); |
|
|
| let resp = app.oneshot(req).await.unwrap(); |
| assert_eq!(resp.status(), StatusCode::BAD_REQUEST); |
|
|
| ctx.shutdown().await; |
| } |
| } |
|
|