MarisUK's picture
Maris AI model sync
f440f03 verified
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
Json,
};
use std::net::IpAddr;
use serde::{Deserialize, Serialize};
use crate::{
app_state::AppState,
AuthClaims,
jwt_auth::{self, AuthenticatedUser, TokenPair},
};
#[derive(Debug, Deserialize)]
pub struct LoginRequest {
pub login: String,
pub password: String,
}
#[derive(Debug, Deserialize)]
pub struct RefreshRequest {
pub refresh_token: String,
}
#[derive(Debug, Deserialize)]
pub struct BootstrapRequest {
pub email: String,
pub username: String,
pub display_name: String,
pub password: String,
}
#[derive(Debug, Serialize)]
pub struct AuthResponse {
pub user: AuthenticatedUser,
pub tokens: TokenPair,
}
fn extract_client_ip(headers: &HeaderMap) -> Option<String> {
let forwarded = headers
.get("x-forwarded-for")
.and_then(|value| value.to_str().ok())
.or_else(|| headers.get("x-real-ip").and_then(|value| value.to_str().ok()))?;
let first = forwarded.split(',').next()?.trim();
if first.is_empty() {
return None;
}
if let Ok(ip) = first.parse::<IpAddr>() {
return Some(ip.to_string());
}
if first.starts_with('[') {
if let Some(end) = first.find(']') {
let candidate = &first[1..end];
if let Ok(ip) = candidate.parse::<IpAddr>() {
return Some(ip.to_string());
}
}
return None;
}
if let Some((host, _port)) = first.rsplit_once(':') {
if let Ok(ip) = host.parse::<IpAddr>() {
return Some(ip.to_string());
}
}
None
}
pub async fn bootstrap(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<BootstrapRequest>,
) -> Result<Json<AuthResponse>, (StatusCode, Json<serde_json::Value>)> {
let pool = state
.postgres
.as_ref()
.ok_or_else(|| error(StatusCode::SERVICE_UNAVAILABLE, "PostgreSQL nav pieejams"))?;
let user = jwt_auth::bootstrap_admin(
pool,
&state.settings,
&req.email,
&req.username,
&req.display_name,
&req.password,
)
.await
.map_err(|e| error(StatusCode::BAD_REQUEST, &e.to_string()))?;
let client_ip = extract_client_ip(&headers);
let tokens = jwt_auth::issue_token_pair(
pool,
&state.settings,
&user,
client_ip.as_deref(),
headers.get("user-agent").and_then(|value| value.to_str().ok()),
)
.await
.map_err(|e| error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()))?;
Ok(Json(AuthResponse { user, tokens }))
}
pub async fn login(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<LoginRequest>,
) -> Result<Json<AuthResponse>, (StatusCode, Json<serde_json::Value>)> {
let pool = state
.postgres
.as_ref()
.ok_or_else(|| error(StatusCode::SERVICE_UNAVAILABLE, "PostgreSQL nav pieejams"))?;
let user = jwt_auth::authenticate_user(pool, &state.settings, &req.login, &req.password)
.await
.map_err(|e| error(StatusCode::UNAUTHORIZED, &e.to_string()))?;
let client_ip = extract_client_ip(&headers);
let tokens = jwt_auth::issue_token_pair(
pool,
&state.settings,
&user,
client_ip.as_deref(),
headers.get("user-agent").and_then(|value| value.to_str().ok()),
)
.await
.map_err(|e| error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()))?;
Ok(Json(AuthResponse { user, tokens }))
}
pub async fn refresh(
State(state): State<AppState>,
headers: HeaderMap,
Json(req): Json<RefreshRequest>,
) -> Result<Json<AuthResponse>, (StatusCode, Json<serde_json::Value>)> {
let pool = state
.postgres
.as_ref()
.ok_or_else(|| error(StatusCode::SERVICE_UNAVAILABLE, "PostgreSQL nav pieejams"))?;
let client_ip = extract_client_ip(&headers);
let tokens = jwt_auth::refresh_token_pair(
pool,
&state.settings,
&req.refresh_token,
client_ip.as_deref(),
headers.get("user-agent").and_then(|value| value.to_str().ok()),
)
.await
.map_err(|e| error(StatusCode::UNAUTHORIZED, &e.to_string()))?;
let claims = jwt_auth::verify_access_token(&state.settings, &tokens.access_token)
.map_err(|e| error(StatusCode::UNAUTHORIZED, &e.to_string()))?;
let user = jwt_auth::load_user_from_claims(pool, &state.settings, &claims)
.await
.map_err(|e| error(StatusCode::UNAUTHORIZED, &e.to_string()))?;
Ok(Json(AuthResponse { user, tokens }))
}
pub async fn me(
State(state): State<AppState>,
claims: axum::extract::Extension<AuthClaims>,
) -> Result<Json<AuthenticatedUser>, (StatusCode, Json<serde_json::Value>)> {
let pool = state
.postgres
.as_ref()
.ok_or_else(|| error(StatusCode::SERVICE_UNAVAILABLE, "PostgreSQL nav pieejams"))?;
let user = jwt_auth::load_user_from_claims(pool, &state.settings, &claims.0)
.await
.map_err(|e| error(StatusCode::UNAUTHORIZED, &e.to_string()))?;
Ok(Json(user))
}
pub async fn logout(
State(state): State<AppState>,
claims: axum::extract::Extension<AuthClaims>,
) -> Result<StatusCode, (StatusCode, Json<serde_json::Value>)> {
let pool = state
.postgres
.as_ref()
.ok_or_else(|| error(StatusCode::SERVICE_UNAVAILABLE, "PostgreSQL nav pieejams"))?;
jwt_auth::revoke_session(pool, &claims.sid)
.await
.map_err(|e| error(StatusCode::INTERNAL_SERVER_ERROR, &e.to_string()))?;
Ok(StatusCode::NO_CONTENT)
}
fn error(status: StatusCode, message: &str) -> (StatusCode, Json<serde_json::Value>) {
(status, Json(serde_json::json!({ "error": message })))
}