File size: 8,297 Bytes
bbb1195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
use axum::{
    Router,
    routing::{get, post},
    extract::DefaultBodyLimit,
    response::{IntoResponse, Response, Json},
};
use tracing::{debug, error};
use tower_http::trace::TraceLayer;
use std::sync::Arc;
use tokio::sync::oneshot;
use crate::proxy::TokenManager;


/// Axum 应用状态
#[derive(Clone)]
pub struct AppState {
    pub token_manager: Arc<TokenManager>,
    pub anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
    pub openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
    pub custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
    #[allow(dead_code)]
    pub request_timeout: u64,  // API 请求超时(秒)
    #[allow(dead_code)]
    pub thought_signature_map: Arc<tokio::sync::Mutex<std::collections::HashMap<String, String>>>, // 思维链签名映射 (ID -> Signature)
    #[allow(dead_code)]
    pub upstream_proxy: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>,
    pub upstream: Arc<crate::proxy::upstream::client::UpstreamClient>,
}

/// Axum 服务器实例
pub struct AxumServer {
    shutdown_tx: Option<oneshot::Sender<()>>,
    anthropic_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
    openai_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
    custom_mapping: Arc<tokio::sync::RwLock<std::collections::HashMap<String, String>>>,
    proxy_state: Arc<tokio::sync::RwLock<crate::proxy::config::UpstreamProxyConfig>>,
}

impl AxumServer {
    pub async fn update_mapping(&self, config: &crate::proxy::config::ProxyConfig) {
        {
            let mut m = self.anthropic_mapping.write().await;
            *m = config.anthropic_mapping.clone();
        }
        {
            let mut m = self.openai_mapping.write().await;
            *m = config.openai_mapping.clone();
        }
        {
            let mut m = self.custom_mapping.write().await;
            *m = config.custom_mapping.clone();
        }
        tracing::info!("模型映射 (Anthropic/OpenAI/Custom) 已全量热更新");
    }

    /// 更新代理配置
    pub async fn update_proxy(&self, new_config: crate::proxy::config::UpstreamProxyConfig) {
        let mut proxy = self.proxy_state.write().await;
        *proxy = new_config;
        tracing::info!("上游代理配置已热更新");
    }
    /// 启动 Axum 服务器
    pub async fn start(
        host: String,
        port: u16,
        token_manager: Arc<TokenManager>,
        anthropic_mapping: std::collections::HashMap<String, String>,
        openai_mapping: std::collections::HashMap<String, String>,
        custom_mapping: std::collections::HashMap<String, String>,
        _request_timeout: u64,
        upstream_proxy: crate::proxy::config::UpstreamProxyConfig,
    ) -> Result<(Self, tokio::task::JoinHandle<()>), String> {
        let mapping_state = Arc::new(tokio::sync::RwLock::new(anthropic_mapping));
        let openai_mapping_state = Arc::new(tokio::sync::RwLock::new(openai_mapping));
        let custom_mapping_state = Arc::new(tokio::sync::RwLock::new(custom_mapping));
        let proxy_state = Arc::new(tokio::sync::RwLock::new(upstream_proxy.clone()));

        let state = AppState {
            token_manager: token_manager.clone(),
            anthropic_mapping: mapping_state.clone(),
            openai_mapping: openai_mapping_state.clone(),
            custom_mapping: custom_mapping_state.clone(),
            request_timeout: 300, // 5分钟超时
            thought_signature_map: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
            upstream_proxy: proxy_state.clone(),
            upstream: Arc::new(crate::proxy::upstream::client::UpstreamClient::new(Some(upstream_proxy.clone()))),
        };
        
        // 构建路由 - 使用新架构的 handlers!
        use crate::proxy::handlers;
        // 构建路由
        let app = Router::new()
            // OpenAI Protocol
            .route("/v1/models", get(handlers::openai::handle_list_models))
            .route("/v1/chat/completions", post(handlers::openai::handle_chat_completions))
            .route("/v1/completions", post(handlers::openai::handle_completions))
            .route("/v1/responses", post(handlers::openai::handle_completions)) // 兼容 Codex CLI
            
            // Claude Protocol
            .route("/v1/messages", post(handlers::claude::handle_messages))
            .route("/v1/messages/count_tokens", post(handlers::claude::handle_count_tokens))
            .route("/v1/models/claude", get(handlers::claude::handle_list_models))
            
            // Gemini Protocol (Native)
            .route("/v1beta/models", get(handlers::gemini::handle_list_models))
            // Handle both GET (get info) and POST (generateContent with colon) at the same route
            .route("/v1beta/models/:model", get(handlers::gemini::handle_get_model).post(handlers::gemini::handle_generate))
            .route("/v1beta/models/:model/countTokens", post(handlers::gemini::handle_count_tokens)) // Specific route priority
            .route("/healthz", get(health_check_handler))
            .layer(DefaultBodyLimit::max(100 * 1024 * 1024))
            .layer(TraceLayer::new_for_http())
            .layer(axum::middleware::from_fn(crate::proxy::middleware::auth_middleware))
            .layer(crate::proxy::middleware::cors_layer())
            .with_state(state);
        
        // 绑定地址
        let addr = format!("{}:{}", host, port);
        let listener = tokio::net::TcpListener::bind(&addr)
            .await
            .map_err(|e| format!("地址 {} 绑定失败: {}", addr, e))?;
        
        tracing::info!("反代服务器启动在 http://{}", addr);
        
        // 创建关闭通道
        let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>();
        
        let server_instance = Self {
            shutdown_tx: Some(shutdown_tx),
            anthropic_mapping: mapping_state.clone(),
            openai_mapping: openai_mapping_state.clone(),
            custom_mapping: custom_mapping_state.clone(),
            proxy_state,
        };
        
        // 在新任务中启动服务器
        let handle = tokio::spawn(async move {
            use hyper_util::rt::TokioIo;
            use hyper::server::conn::http1;
            use hyper_util::service::TowerToHyperService;

            loop {
                tokio::select! {
                    res = listener.accept() => {
                        match res {
                            Ok((stream, _)) => {
                                let io = TokioIo::new(stream);
                                let service = TowerToHyperService::new(app.clone());
                                
                                tokio::task::spawn(async move {
                                    if let Err(err) = http1::Builder::new()
                                        .serve_connection(io, service)
                                        .with_upgrades() // 支持 WebSocket (如果以后需要)
                                        .await
                                    {
                                        debug!("连接处理结束或出错: {:?}", err);
                                    }
                                });
                            }
                            Err(e) => {
                                error!("接收连接失败: {:?}", e);
                            }
                        }
                    }
                    _ = &mut shutdown_rx => {
                        tracing::info!("反代服务器停止监听");
                        break;
                    }
                }
            }
        });
        
        Ok((
            server_instance,
            handle,
        ))
    }
    
    /// 停止服务器
    pub fn stop(mut self) {
        if let Some(tx) = self.shutdown_tx.take() {
            let _ = tx.send(());
        }
    }
}

// ===== API 处理器 (旧代码已移除,由 src/proxy/handlers/* 接管) =====

/// 健康检查处理器
async fn health_check_handler() -> Response {
    Json(serde_json::json!({
        "status": "ok"
    })).into_response()
}