| use tokio::io::{AsyncReadExt, AsyncWriteExt}; |
| use tokio::net::TcpListener; |
| use tokio::sync::mpsc; |
| use tokio::sync::watch; |
| use std::sync::{Mutex, OnceLock}; |
| use tauri::Url; |
| use crate::modules::oauth; |
|
|
| struct OAuthFlowState { |
| auth_url: String, |
| #[allow(dead_code)] |
| redirect_uri: String, |
| state: String, |
| client_key: String, |
| cancel_tx: watch::Sender<bool>, |
| code_tx: mpsc::Sender<Result<String, String>>, |
| code_rx: Option<mpsc::Receiver<Result<String, String>>>, |
| } |
|
|
| static OAUTH_FLOW_STATE: OnceLock<Mutex<Option<OAuthFlowState>>> = OnceLock::new(); |
|
|
| fn get_oauth_flow_state() -> &'static Mutex<Option<OAuthFlowState>> { |
| OAUTH_FLOW_STATE.get_or_init(|| Mutex::new(None)) |
| } |
|
|
| fn oauth_success_html() -> &'static str { |
| "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\n\r\n\ |
| <html>\ |
| <body style='font-family: sans-serif; text-align: center; padding: 50px;'>\ |
| <h1 style='color: green;'>✅ Authorization Successful!</h1>\ |
| <p>You can close this window and return to the application.</p>\ |
| <script>setTimeout(function() { window.close(); }, 2000);</script>\ |
| </body>\ |
| </html>" |
| } |
|
|
| fn oauth_fail_html() -> &'static str { |
| "HTTP/1.1 400 Bad Request\r\nContent-Type: text/html; charset=utf-8\r\n\r\n\ |
| <html>\ |
| <body style='font-family: sans-serif; text-align: center; padding: 50px;'>\ |
| <h1 style='color: red;'>❌ Authorization Failed</h1>\ |
| <p>Failed to obtain Authorization Code. Please return to the app and try again.</p>\ |
| </body>\ |
| </html>" |
| } |
|
|
| async fn ensure_oauth_flow_prepared(app_handle: Option<tauri::AppHandle>, requested_client_key: Option<String>) -> Result<String, String> { |
| if let Ok(mut state) = get_oauth_flow_state().lock() { |
| if let Some(s) = state.as_mut() { |
| if let Some(requested_key) = requested_client_key.as_ref() { |
| if s.client_key != requested_key.to_ascii_lowercase() { |
| let _ = s.cancel_tx.send(true); |
| *state = None; |
| } |
| } |
| } |
| } |
|
|
| |
| if let Ok(mut state) = get_oauth_flow_state().lock() { |
| if let Some(s) = state.as_mut() { |
| if s.code_rx.is_some() { |
| return Ok(s.auth_url.clone()); |
| } else { |
| |
| |
| let _ = s.cancel_tx.send(true); |
| *state = None; |
| } |
| } |
| } |
|
|
| |
| |
| |
| let mut ipv4_listener: Option<TcpListener> = None; |
| let mut ipv6_listener: Option<TcpListener> = None; |
|
|
| |
| |
| |
| let port: u16; |
| match TcpListener::bind("[::1]:0").await { |
| Ok(l6) => { |
| port = l6 |
| .local_addr() |
| .map_err(|e| format!("failed_to_get_local_port: {}", e))? |
| .port(); |
| ipv6_listener = Some(l6); |
|
|
| match TcpListener::bind(format!("127.0.0.1:{}", port)).await { |
| Ok(l4) => ipv4_listener = Some(l4), |
| Err(e) => { |
| crate::modules::logger::log_warn(&format!( |
| "failed_to_bind_ipv4_callback_port_127_0_0_1:{} (will only listen on IPv6): {}", |
| port, e |
| )); |
| } |
| } |
| } |
| Err(_) => { |
| let l4 = TcpListener::bind("127.0.0.1:0") |
| .await |
| .map_err(|e| format!("failed_to_bind_local_port: {}", e))?; |
| port = l4 |
| .local_addr() |
| .map_err(|e| format!("failed_to_get_local_port: {}", e))? |
| .port(); |
| ipv4_listener = Some(l4); |
|
|
| match TcpListener::bind(format!("[::1]:{}", port)).await { |
| Ok(l6) => ipv6_listener = Some(l6), |
| Err(e) => { |
| crate::modules::logger::log_warn(&format!( |
| "failed_to_bind_ipv6_callback_port_::1:{} (will only listen on IPv4): {}", |
| port, e |
| )); |
| } |
| } |
| } |
| } |
|
|
| let has_ipv4 = ipv4_listener.is_some(); |
| let has_ipv6 = ipv6_listener.is_some(); |
|
|
| let redirect_uri = if has_ipv4 && has_ipv6 { |
| format!("http://localhost:{}/oauth-callback", port) |
| } else if has_ipv4 { |
| format!("http://127.0.0.1:{}/oauth-callback", port) |
| } else { |
| format!("http://[::1]:{}/oauth-callback", port) |
| }; |
|
|
| let state_str = uuid::Uuid::new_v4().to_string(); |
| let (auth_url, resolved_client_key) = oauth::get_auth_url_with_client( |
| &redirect_uri, |
| &state_str, |
| requested_client_key.as_deref(), |
| )?; |
|
|
| |
| let (cancel_tx, cancel_rx) = watch::channel(false); |
| |
| let (code_tx, code_rx) = mpsc::channel::<Result<String, String>>(1); |
|
|
| |
| |
| let app_handle_for_tasks = app_handle.clone(); |
|
|
| if let Some(l4) = ipv4_listener { |
| let tx = code_tx.clone(); |
| let mut rx = cancel_rx.clone(); |
| let app_handle = app_handle_for_tasks.clone(); |
| tokio::spawn(async move { |
| if let Ok((mut stream, _)) = tokio::select! { |
| res = l4.accept() => res.map_err(|e| format!("failed_to_accept_connection: {}", e)), |
| _ = rx.changed() => Err("OAuth cancelled".to_string()), |
| } { |
| |
| |
| let mut buffer = [0u8; 4096]; |
| let bytes_read = stream.read(&mut buffer).await.unwrap_or(0); |
| let request = String::from_utf8_lossy(&buffer[..bytes_read]); |
| |
| |
| let query_params = request |
| .lines() |
| .next() |
| .and_then(|line| { |
| let parts: Vec<&str> = line.split_whitespace().collect(); |
| if parts.len() >= 2 { Some(parts[1]) } else { None } |
| }) |
| .and_then(|path| { |
| |
| Url::parse(&format!("http://localhost{}", path)).ok() |
| }) |
| .map(|url| { |
| let mut code = None; |
| let mut state = None; |
| for (k, v) in url.query_pairs() { |
| if k == "code" { code = Some(v.to_string()); } |
| else if k == "state" { state = Some(v.to_string()); } |
| } |
| (code, state) |
| }); |
|
|
| let (code, received_state) = match query_params { |
| Some((c, s)) => (c, s), |
| None => (None, None), |
| }; |
|
|
| if code.is_none() && bytes_read > 0 { |
| crate::modules::logger::log_error(&format!( |
| "OAuth callback failed to parse code. Raw request (first 512 bytes): {}", |
| &request.chars().take(512).collect::<String>() |
| )); |
| } |
|
|
| |
| let state_valid = { |
| if let Ok(lock) = get_oauth_flow_state().lock() { |
| if let Some(s) = lock.as_ref() { |
| received_state.as_ref() == Some(&s.state) |
| } else { |
| false |
| } |
| } else { |
| false |
| } |
| }; |
|
|
| let (result, response_html) = match (code, state_valid) { |
| (Some(code), true) => { |
| crate::modules::logger::log_info("Successfully captured OAuth code from IPv4 listener"); |
| (Ok(code), oauth_success_html()) |
| }, |
| (Some(_), false) => { |
| crate::modules::logger::log_error("OAuth callback state mismatch (CSRF protection)"); |
| (Err("OAuth state mismatch".to_string()), oauth_fail_html()) |
| }, |
| (None, _) => (Err("Failed to get Authorization Code in callback".to_string()), oauth_fail_html()), |
| }; |
| |
| let _ = stream.write_all(response_html.as_bytes()).await; |
| let _ = stream.flush().await; |
|
|
| if let Some(h) = app_handle { |
| use tauri::Emitter; |
| let _ = h.emit("oauth-callback-received", ()); |
| } |
| let _ = tx.send(result).await; |
| } |
| }); |
| } |
|
|
| if let Some(l6) = ipv6_listener { |
| let tx = code_tx.clone(); |
| let mut rx = cancel_rx; |
| let app_handle = app_handle_for_tasks; |
| tokio::spawn(async move { |
| if let Ok((mut stream, _)) = tokio::select! { |
| res = l6.accept() => res.map_err(|e| format!("failed_to_accept_connection: {}", e)), |
| _ = rx.changed() => Err("OAuth cancelled".to_string()), |
| } { |
| let mut buffer = [0u8; 4096]; |
| let bytes_read = stream.read(&mut buffer).await.unwrap_or(0); |
| let request = String::from_utf8_lossy(&buffer[..bytes_read]); |
| |
| let query_params = request |
| .lines() |
| .next() |
| .and_then(|line| { |
| let parts: Vec<&str> = line.split_whitespace().collect(); |
| if parts.len() >= 2 { Some(parts[1]) } else { None } |
| }) |
| .and_then(|path| { |
| Url::parse(&format!("http://localhost{}", path)).ok() |
| }) |
| .map(|url| { |
| let mut code = None; |
| let mut state = None; |
| for (k, v) in url.query_pairs() { |
| if k == "code" { code = Some(v.to_string()); } |
| else if k == "state" { state = Some(v.to_string()); } |
| } |
| (code, state) |
| }); |
|
|
| let (code, received_state) = match query_params { |
| Some((c, s)) => (c, s), |
| None => (None, None), |
| }; |
|
|
| if code.is_none() && bytes_read > 0 { |
| crate::modules::logger::log_error(&format!( |
| "OAuth callback failed to parse code (IPv6). Raw request: {}", |
| &request.chars().take(512).collect::<String>() |
| )); |
| } |
|
|
| |
| let state_valid = { |
| if let Ok(lock) = get_oauth_flow_state().lock() { |
| if let Some(s) = lock.as_ref() { |
| received_state.as_ref() == Some(&s.state) |
| } else { |
| false |
| } |
| } else { |
| false |
| } |
| }; |
|
|
| let (result, response_html) = match (code, state_valid) { |
| (Some(code), true) => { |
| crate::modules::logger::log_info("Successfully captured OAuth code from IPv6 listener"); |
| (Ok(code), oauth_success_html()) |
| }, |
| (Some(_), false) => { |
| crate::modules::logger::log_error("OAuth callback state mismatch (IPv6 CSRF protection)"); |
| (Err("OAuth state mismatch".to_string()), oauth_fail_html()) |
| }, |
| (None, _) => (Err("Failed to get Authorization Code in callback".to_string()), oauth_fail_html()), |
| }; |
| |
| let _ = stream.write_all(response_html.as_bytes()).await; |
| let _ = stream.flush().await; |
|
|
| if let Some(h) = app_handle { |
| use tauri::Emitter; |
| let _ = h.emit("oauth-callback-received", ()); |
| } |
| let _ = tx.send(result).await; |
| } |
| }); |
| } |
|
|
| |
| if let Ok(mut state) = get_oauth_flow_state().lock() { |
| *state = Some(OAuthFlowState { |
| auth_url: auth_url.clone(), |
| redirect_uri, |
| state: state_str, |
| client_key: resolved_client_key, |
| cancel_tx, |
| code_tx, |
| code_rx: Some(code_rx), |
| }); |
| } |
|
|
| |
| if let Some(h) = app_handle { |
| use tauri::Emitter; |
| let _ = h.emit("oauth-url-generated", &auth_url); |
| } |
|
|
| Ok(auth_url) |
| } |
|
|
| |
| pub async fn prepare_oauth_url(app_handle: Option<tauri::AppHandle>, oauth_client_key: Option<String>) -> Result<String, String> { |
| ensure_oauth_flow_prepared(app_handle, oauth_client_key).await |
| } |
|
|
| |
| pub fn cancel_oauth_flow() { |
| if let Ok(mut state) = get_oauth_flow_state().lock() { |
| if let Some(s) = state.take() { |
| let _ = s.cancel_tx.send(true); |
| crate::modules::logger::log_info("Sent OAuth cancellation signal"); |
| } |
| } |
| } |
|
|
| |
| pub async fn start_oauth_flow(app_handle: Option<tauri::AppHandle>, oauth_client_key: Option<String>) -> Result<oauth::TokenResponse, String> { |
| |
| let auth_url = ensure_oauth_flow_prepared(app_handle.clone(), oauth_client_key).await?; |
|
|
| if let Some(h) = app_handle { |
| |
| use tauri_plugin_opener::OpenerExt; |
| h.opener() |
| .open_url(&auth_url, None::<String>) |
| .map_err(|e| format!("failed_to_open_browser: {}", e))?; |
| } |
|
|
| |
| let (mut code_rx, redirect_uri, client_key) = { |
| let mut lock = get_oauth_flow_state() |
| .lock() |
| .map_err(|_| "OAuth state lock corrupted".to_string())?; |
| let Some(state) = lock.as_mut() else { |
| return Err("OAuth state does not exist".to_string()); |
| }; |
| let rx = state |
| .code_rx |
| .take() |
| .ok_or_else(|| "OAuth authorization already in progress".to_string())?; |
| (rx, state.redirect_uri.clone(), state.client_key.clone()) |
| }; |
|
|
| |
| |
| let code = match code_rx.recv().await { |
| Some(Ok(code)) => code, |
| Some(Err(e)) => return Err(e), |
| None => return Err("OAuth flow channel closed unexpectedly".to_string()), |
| }; |
|
|
| |
| if let Ok(mut lock) = get_oauth_flow_state().lock() { |
| *lock = None; |
| } |
|
|
| oauth::exchange_code_with_client(&code, &redirect_uri, Some(&client_key)).await |
| } |
|
|
| |
| |
| |
| pub async fn complete_oauth_flow(app_handle: Option<tauri::AppHandle>) -> Result<oauth::TokenResponse, String> { |
| |
| let _ = ensure_oauth_flow_prepared(app_handle, None).await?; |
|
|
| |
| let (mut code_rx, redirect_uri, client_key) = { |
| let mut lock = get_oauth_flow_state() |
| .lock() |
| .map_err(|_| "OAuth state lock corrupted".to_string())?; |
| let Some(state) = lock.as_mut() else { |
| return Err("OAuth state does not exist".to_string()); |
| }; |
| let rx = state |
| .code_rx |
| .take() |
| .ok_or_else(|| "OAuth authorization already in progress".to_string())?; |
| (rx, state.redirect_uri.clone(), state.client_key.clone()) |
| }; |
|
|
| let code = match code_rx.recv().await { |
| Some(Ok(code)) => code, |
| Some(Err(e)) => return Err(e), |
| None => return Err("OAuth flow channel closed unexpectedly".to_string()), |
| }; |
|
|
| if let Ok(mut lock) = get_oauth_flow_state().lock() { |
| *lock = None; |
| } |
|
|
| oauth::exchange_code_with_client(&code, &redirect_uri, Some(&client_key)).await |
| } |
|
|
| |
| |
| |
| pub async fn submit_oauth_code(code_input: String, state_input: Option<String>) -> Result<(), String> { |
| let tx = { |
| let lock = get_oauth_flow_state().lock().map_err(|e| e.to_string())?; |
| if let Some(state) = lock.as_ref() { |
| |
| if let Some(provided_state) = state_input { |
| if provided_state != state.state { |
| return Err("OAuth state mismatch (CSRF protection)".to_string()); |
| } |
| } |
| state.code_tx.clone() |
| } else { |
| return Err("No active OAuth flow found".to_string()); |
| } |
| }; |
|
|
| |
| let code = if code_input.starts_with("http") { |
| if let Ok(url) = Url::parse(&code_input) { |
| url.query_pairs() |
| .find(|(k, _)| k == "code") |
| .map(|(_, v)| v.to_string()) |
| .unwrap_or(code_input) |
| } else { |
| code_input |
| } |
| } else { |
| code_input |
| }; |
|
|
| crate::modules::logger::log_info("Received manual OAuth code submission"); |
| |
| |
| tx.send(Ok(code)).await.map_err(|_| "Failed to send code to OAuth flow (receiver dropped)".to_string())?; |
| |
| Ok(()) |
| } |
| |
| |
| pub fn prepare_oauth_flow_manually(redirect_uri: String, state_str: String, oauth_client_key: Option<String>) -> Result<(String, mpsc::Receiver<Result<String, String>>), String> { |
| let (auth_url, resolved_client_key) = |
| oauth::get_auth_url_with_client(&redirect_uri, &state_str, oauth_client_key.as_deref())?; |
| |
| |
| if let Ok(mut lock) = get_oauth_flow_state().lock() { |
| if let Some(s) = lock.as_mut() { |
| |
| |
| |
| let _ = s.cancel_tx.send(true); |
| *lock = None; |
| } |
| } |
|
|
| let (cancel_tx, _cancel_rx) = watch::channel(false); |
| let (code_tx, code_rx) = mpsc::channel(1); |
|
|
| if let Ok(mut state) = get_oauth_flow_state().lock() { |
| *state = Some(OAuthFlowState { |
| auth_url: auth_url.clone(), |
| redirect_uri: redirect_uri.clone(), |
| state: state_str, |
| client_key: resolved_client_key, |
| cancel_tx, |
| code_tx, |
| code_rx: None, |
| }); |
| } |
|
|
| Ok((auth_url, code_rx)) |
| } |
|
|