|
|
use std::collections::HashMap; |
|
|
use std::convert::Infallible; |
|
|
use std::net::SocketAddr; |
|
|
use std::sync::Arc; |
|
|
|
|
|
use axum::{ |
|
|
Json, Router, |
|
|
extract::State, |
|
|
http::{Method, StatusCode, header::CONTENT_TYPE}, |
|
|
response::{IntoResponse, Response, Sse, sse::Event}, |
|
|
routing::{get, post}, |
|
|
}; |
|
|
use serde::{Deserialize, Serialize}; |
|
|
use tokio::sync::RwLock; |
|
|
use tokio_stream::wrappers::ReceiverStream; |
|
|
use tower_http::cors::{Any, CorsLayer}; |
|
|
use tracing::{error, info}; |
|
|
|
|
|
use crate::{ |
|
|
client::ChatGptClient, |
|
|
utils::{ChatGptError, Result as ChatGptResult}, |
|
|
}; |
|
|
|
|
|
|
|
|
#[derive(Debug, Deserialize)] |
|
|
pub struct ChatCompletionRequest { |
|
|
#[serde(default)] |
|
|
pub model: Option<String>, |
|
|
pub messages: Vec<ChatMessage>, |
|
|
#[serde(default)] |
|
|
pub stream: bool, |
|
|
#[serde(default)] |
|
|
pub temperature: Option<f32>, |
|
|
#[serde(default)] |
|
|
pub max_tokens: Option<u32>, |
|
|
#[serde(default)] |
|
|
pub conversation_id: Option<String>, |
|
|
#[serde(default)] |
|
|
pub proxy: Option<String>, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Deserialize, Serialize, Clone)] |
|
|
pub struct ChatMessage { |
|
|
pub role: String, |
|
|
pub content: String, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
pub struct ChatCompletionResponse { |
|
|
pub id: String, |
|
|
pub object: String, |
|
|
pub created: u64, |
|
|
pub model: String, |
|
|
pub choices: Vec<ChatChoice>, |
|
|
pub usage: Usage, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
pub struct ChatChoice { |
|
|
pub index: u32, |
|
|
pub message: ChatMessage, |
|
|
pub finish_reason: String, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
pub struct ChatCompletionChunk { |
|
|
pub id: String, |
|
|
pub object: String, |
|
|
pub created: u64, |
|
|
pub model: String, |
|
|
pub choices: Vec<ChatChunkChoice>, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
pub struct ChatChunkChoice { |
|
|
pub index: u32, |
|
|
pub delta: Delta, |
|
|
pub finish_reason: Option<String>, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
pub struct Delta { |
|
|
#[serde(skip_serializing_if = "Option::is_none")] |
|
|
pub role: Option<String>, |
|
|
#[serde(skip_serializing_if = "Option::is_none")] |
|
|
pub content: Option<String>, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
pub struct Usage { |
|
|
pub prompt_tokens: u32, |
|
|
pub completion_tokens: u32, |
|
|
pub total_tokens: u32, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
pub struct Model { |
|
|
pub id: String, |
|
|
pub name: String, |
|
|
pub object: String, |
|
|
pub created: u64, |
|
|
pub owned_by: String, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
pub struct ModelsResponse { |
|
|
pub object: String, |
|
|
pub data: Vec<Model>, |
|
|
} |
|
|
|
|
|
#[derive(Debug, Serialize)] |
|
|
struct ErrorResponse { |
|
|
status: &'static str, |
|
|
detail: String, |
|
|
} |
|
|
|
|
|
#[derive(Debug)] |
|
|
struct ApiError { |
|
|
status: StatusCode, |
|
|
message: String, |
|
|
} |
|
|
|
|
|
impl ApiError { |
|
|
fn new(status: StatusCode, message: impl Into<String>) -> Self { |
|
|
Self { |
|
|
status, |
|
|
message: message.into(), |
|
|
} |
|
|
} |
|
|
|
|
|
fn bad_request(message: impl Into<String>) -> Self { |
|
|
Self::new(StatusCode::BAD_REQUEST, message) |
|
|
} |
|
|
} |
|
|
|
|
|
impl From<ChatGptError> for ApiError { |
|
|
fn from(err: ChatGptError) -> Self { |
|
|
let status = match err { |
|
|
ChatGptError::InvalidProxy(_) => StatusCode::BAD_REQUEST, |
|
|
ChatGptError::Authentication(_) => StatusCode::UNAUTHORIZED, |
|
|
ChatGptError::IpFlagged => StatusCode::FORBIDDEN, |
|
|
ChatGptError::Network(_) => StatusCode::BAD_GATEWAY, |
|
|
ChatGptError::Configuration(_) => StatusCode::INTERNAL_SERVER_ERROR, |
|
|
ChatGptError::ChallengeSolve(_) |
|
|
| ChatGptError::VmExecution(_) |
|
|
| ChatGptError::InvalidResponse(_) |
|
|
| ChatGptError::Unknown(_) => StatusCode::INTERNAL_SERVER_ERROR, |
|
|
ChatGptError::Json(_) |
|
|
| ChatGptError::Base64Decode(_) |
|
|
| ChatGptError::Io(_) |
|
|
| ChatGptError::Image(_) => StatusCode::UNPROCESSABLE_ENTITY, |
|
|
}; |
|
|
|
|
|
Self::new(status, err.to_string()) |
|
|
} |
|
|
} |
|
|
|
|
|
impl IntoResponse for ApiError { |
|
|
fn into_response(self) -> Response { |
|
|
let payload = Json(ErrorResponse { |
|
|
status: "error", |
|
|
detail: self.message, |
|
|
}); |
|
|
|
|
|
(self.status, payload).into_response() |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
#[derive(Clone)] |
|
|
pub struct AppState { |
|
|
clients: Arc<RwLock<HashMap<String, Arc<RwLock<ChatGptClient>>>>>, |
|
|
default_proxy: Option<String>, |
|
|
} |
|
|
|
|
|
impl AppState { |
|
|
pub fn new(default_proxy: Option<String>) -> Self { |
|
|
Self { |
|
|
clients: Arc::new(RwLock::new(HashMap::new())), |
|
|
default_proxy, |
|
|
} |
|
|
} |
|
|
|
|
|
async fn get_or_create_client( |
|
|
&self, |
|
|
conversation_id: Option<&str>, |
|
|
proxy: Option<&str>, |
|
|
) -> Result<(String, Arc<RwLock<ChatGptClient>>), ApiError> { |
|
|
let proxy_to_use = proxy.or(self.default_proxy.as_deref()); |
|
|
|
|
|
|
|
|
let client_key = conversation_id |
|
|
.map(|id| id.to_string()) |
|
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()); |
|
|
|
|
|
let mut clients = self.clients.write().await; |
|
|
|
|
|
if let Some(client) = clients.get(&client_key) { |
|
|
|
|
|
Ok((client_key.clone(), Arc::clone(client))) |
|
|
} else { |
|
|
|
|
|
let client = ChatGptClient::new(proxy_to_use).await.map_err(|err| { |
|
|
error!("Failed to create ChatGPT client: {}", err); |
|
|
ApiError::from(err) |
|
|
})?; |
|
|
|
|
|
let client_arc = Arc::new(RwLock::new(client)); |
|
|
clients.insert(client_key.clone(), Arc::clone(&client_arc)); |
|
|
Ok((client_key, client_arc)) |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
pub fn router(state: AppState) -> Router { |
|
|
let cors = CorsLayer::new() |
|
|
.allow_origin(Any) |
|
|
.allow_methods([Method::POST, Method::GET]) |
|
|
.allow_headers([CONTENT_TYPE]); |
|
|
|
|
|
Router::new() |
|
|
.route("/v1/chat/completions", post(chat_completions)) |
|
|
.route("/api/v1/chat/completions", post(chat_completions)) |
|
|
.route("/hf/v1/chat/completions", post(chat_completions)) |
|
|
.route("/v1/models", get(list_models)) |
|
|
.route("/api/v1/models", get(list_models)) |
|
|
.route("/hf/v1/models", get(list_models)) |
|
|
.with_state(state) |
|
|
.layer(cors) |
|
|
} |
|
|
|
|
|
|
|
|
pub async fn run(host: &str, port: u16, default_proxy: Option<String>) -> ChatGptResult<()> { |
|
|
let addr: SocketAddr = format!("{}:{}", host, port) |
|
|
.parse() |
|
|
.map_err(|err| ChatGptError::configuration(format!("invalid address: {}", err)))?; |
|
|
|
|
|
let state = AppState::new(default_proxy); |
|
|
let app = router(state); |
|
|
|
|
|
let listener = tokio::net::TcpListener::bind(addr).await?; |
|
|
let local_addr = listener.local_addr()?; |
|
|
|
|
|
info!("API server listening on http://{}", local_addr); |
|
|
|
|
|
axum::serve(listener, app).await?; |
|
|
|
|
|
Ok(()) |
|
|
} |
|
|
|
|
|
|
|
|
async fn chat_completions( |
|
|
State(state): State<AppState>, |
|
|
Json(payload): Json<ChatCompletionRequest>, |
|
|
) -> std::result::Result<Response, ApiError> { |
|
|
let ChatCompletionRequest { |
|
|
messages, |
|
|
stream, |
|
|
conversation_id, |
|
|
proxy, |
|
|
.. |
|
|
} = payload; |
|
|
|
|
|
if messages.is_empty() { |
|
|
return Err(ApiError::bad_request("Messages cannot be empty")); |
|
|
} |
|
|
|
|
|
|
|
|
let last_message = messages |
|
|
.iter() |
|
|
.rev() |
|
|
.find(|m| m.role == "user") |
|
|
.ok_or_else(|| ApiError::bad_request("No user message found"))?; |
|
|
|
|
|
let message_content = last_message.content.trim(); |
|
|
if message_content.is_empty() { |
|
|
return Err(ApiError::bad_request("Message content cannot be empty")); |
|
|
} |
|
|
|
|
|
info!( |
|
|
"Handling chat completion request (stream: {}, conversation_id: {:?})", |
|
|
stream, conversation_id |
|
|
); |
|
|
|
|
|
|
|
|
let (conv_id, client_arc) = state |
|
|
.get_or_create_client(conversation_id.as_deref(), proxy.as_deref()) |
|
|
.await?; |
|
|
|
|
|
let is_new_conversation = conversation_id.is_none() || messages.len() <= 1; |
|
|
|
|
|
if stream { |
|
|
|
|
|
handle_stream_response(client_arc, message_content, is_new_conversation, conv_id).await |
|
|
} else { |
|
|
|
|
|
handle_non_stream_response(client_arc, message_content, is_new_conversation, conv_id).await |
|
|
} |
|
|
} |
|
|
|
|
|
async fn handle_non_stream_response( |
|
|
client_arc: Arc<RwLock<ChatGptClient>>, |
|
|
message: &str, |
|
|
is_new: bool, |
|
|
conversation_id: String, |
|
|
) -> std::result::Result<Response, ApiError> { |
|
|
let mut client = client_arc.write().await; |
|
|
|
|
|
let answer = if is_new { |
|
|
client.start_conversation(message).await.map_err(|err| { |
|
|
error!("ChatGPT conversation failed: {}", err); |
|
|
ApiError::from(err) |
|
|
})? |
|
|
} else { |
|
|
client.hold_conversation(message, false).await.map_err(|err| { |
|
|
error!("ChatGPT hold_conversation failed: {}", err); |
|
|
ApiError::from(err) |
|
|
})? |
|
|
}; |
|
|
|
|
|
let response = ChatCompletionResponse { |
|
|
id: conversation_id, |
|
|
object: "chat.completion".to_string(), |
|
|
created: std::time::SystemTime::now() |
|
|
.duration_since(std::time::UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_secs(), |
|
|
model: "gpt-4".to_string(), |
|
|
choices: vec![ChatChoice { |
|
|
index: 0, |
|
|
message: ChatMessage { |
|
|
role: "assistant".to_string(), |
|
|
content: answer, |
|
|
}, |
|
|
finish_reason: "stop".to_string(), |
|
|
}], |
|
|
usage: Usage { |
|
|
prompt_tokens: 0, |
|
|
completion_tokens: 0, |
|
|
total_tokens: 0, |
|
|
}, |
|
|
}; |
|
|
|
|
|
Ok(Json(response).into_response()) |
|
|
} |
|
|
|
|
|
async fn handle_stream_response( |
|
|
client_arc: Arc<RwLock<ChatGptClient>>, |
|
|
message: &str, |
|
|
is_new: bool, |
|
|
conversation_id: String, |
|
|
) -> std::result::Result<Response, ApiError> { |
|
|
let mut client = client_arc.write().await; |
|
|
|
|
|
let answer = if is_new { |
|
|
client.start_conversation(message).await.map_err(|err| { |
|
|
error!("ChatGPT conversation failed: {}", err); |
|
|
ApiError::from(err) |
|
|
})? |
|
|
} else { |
|
|
client.hold_conversation(message, false).await.map_err(|err| { |
|
|
error!("ChatGPT hold_conversation failed: {}", err); |
|
|
ApiError::from(err) |
|
|
})? |
|
|
}; |
|
|
|
|
|
|
|
|
let chunks: Vec<String> = answer |
|
|
.chars() |
|
|
.collect::<Vec<char>>() |
|
|
.chunks(10) |
|
|
.map(|chunk| chunk.iter().collect::<String>()) |
|
|
.collect(); |
|
|
|
|
|
let created = std::time::SystemTime::now() |
|
|
.duration_since(std::time::UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_secs(); |
|
|
|
|
|
let (tx, rx) = tokio::sync::mpsc::channel(100); |
|
|
|
|
|
|
|
|
tokio::spawn(async move { |
|
|
|
|
|
for (i, chunk) in chunks.into_iter().enumerate() { |
|
|
let chunk_data = ChatCompletionChunk { |
|
|
id: conversation_id.clone(), |
|
|
object: "chat.completion.chunk".to_string(), |
|
|
created, |
|
|
model: "gpt-4".to_string(), |
|
|
choices: vec![ChatChunkChoice { |
|
|
index: 0, |
|
|
delta: Delta { |
|
|
role: if i == 0 { Some("assistant".to_string()) } else { None }, |
|
|
content: Some(chunk), |
|
|
}, |
|
|
finish_reason: None, |
|
|
}], |
|
|
}; |
|
|
|
|
|
if tx.send(Ok::<_, Infallible>(Event::default().json_data(chunk_data).unwrap())).await.is_err() { |
|
|
break; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
let final_chunk = ChatCompletionChunk { |
|
|
id: conversation_id.clone(), |
|
|
object: "chat.completion.chunk".to_string(), |
|
|
created, |
|
|
model: "gpt-4".to_string(), |
|
|
choices: vec![ChatChunkChoice { |
|
|
index: 0, |
|
|
delta: Delta { |
|
|
role: None, |
|
|
content: None, |
|
|
}, |
|
|
finish_reason: Some("stop".to_string()), |
|
|
}], |
|
|
}; |
|
|
|
|
|
let _ = tx.send(Ok(Event::default().json_data(final_chunk).unwrap())).await; |
|
|
}); |
|
|
|
|
|
let stream = ReceiverStream::new(rx); |
|
|
Ok(Sse::new(stream).into_response()) |
|
|
} |
|
|
|
|
|
|
|
|
async fn list_models() -> std::result::Result<Response, ApiError> { |
|
|
let models = vec![ |
|
|
Model { |
|
|
id: "gpt-4".to_string(), |
|
|
name: "GPT 4".to_string(), |
|
|
object: "model".to_string(), |
|
|
created: std::time::SystemTime::now() |
|
|
.duration_since(std::time::UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_secs(), |
|
|
owned_by: "openai".to_string(), |
|
|
}, |
|
|
Model { |
|
|
id: "gpt-3.5-turbo".to_string(), |
|
|
name: "GPT 3.5".to_string(), |
|
|
object: "model".to_string(), |
|
|
created: std::time::SystemTime::now() |
|
|
.duration_since(std::time::UNIX_EPOCH) |
|
|
.unwrap() |
|
|
.as_secs(), |
|
|
owned_by: "openai".to_string(), |
|
|
}, |
|
|
]; |
|
|
|
|
|
let response = ModelsResponse { |
|
|
object: "list".to_string(), |
|
|
data: models, |
|
|
}; |
|
|
|
|
|
Ok(Json(response).into_response()) |
|
|
} |
|
|
|