Spaces:
Build error
Build error
File size: 13,555 Bytes
1295969 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 | //! Zero Trust middleware: SPIFFE SVID + JWT on every request.
//! SECURITY FIX: Auth is now enforced by default. ZERO_TRUST_DISABLED requires
//! explicit opt-in AND is blocked in production (RETROSYNC_ENV=production).
use crate::AppState;
use axum::{
extract::{Request, State},
http::{HeaderValue, StatusCode},
middleware::Next,
response::Response,
};
use tracing::warn;
// ββ HTTP Security Headers middleware ββββββββββββββββββββββββββββββββββββββββββ
//
// Injected as the outermost layer so every response β including 4xx/5xx from
// inner middleware β carries the full set of defensive headers.
//
// Headers enforced:
// X-Content-Type-Options β prevents MIME-sniff attacks
// X-Frame-Options β blocks clickjacking / framing
// Referrer-Policy β restricts referrer leakage
// X-XSS-Protection β legacy XSS filter (belt+suspenders)
// Strict-Transport-Security β forces HTTPS (HSTS); also sent from Replit edge
// Content-Security-Policy β strict source allowlist; frame-ancestors 'none'
// Permissions-Policy β opt-out of unused browser APIs
// Cache-Control β API responses must not be cached by shared caches
pub async fn add_security_headers(request: Request, next: Next) -> Response {
use axum::http::header::{HeaderName, HeaderValue};
let mut response = next.run(request).await;
let headers = response.headers_mut();
// All values are ASCII string literals known to be valid header values;
// HeaderValue::from_static() panics only on non-ASCII, which none of these are.
let security_headers: &[(&str, &str)] = &[
("x-content-type-options", "nosniff"),
("x-frame-options", "DENY"),
("referrer-policy", "strict-origin-when-cross-origin"),
("x-xss-protection", "1; mode=block"),
(
"strict-transport-security",
"max-age=31536000; includeSubDomains; preload",
),
// CSP: this is an API server (JSON only) β no scripts, frames, or embedded
// content are ever served, so we use the most restrictive possible policy.
(
"content-security-policy",
"default-src 'none'; frame-ancestors 'none'; base-uri 'none'; form-action 'none'",
),
(
"permissions-policy",
"geolocation=(), camera=(), microphone=(), payment=(), usb=(), serial=()",
),
// API responses contain real-time financial/rights data β must not be cached.
(
"cache-control",
"no-store, no-cache, must-revalidate, private",
),
];
for (name, value) in security_headers {
if let (Ok(n), Ok(v)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
headers.insert(n, v);
}
}
response
}
pub async fn verify_zero_trust(
State(_state): State<AppState>,
request: Request,
next: Next,
) -> Result<Response, StatusCode> {
let env = std::env::var("RETROSYNC_ENV").unwrap_or_else(|_| "development".into());
let is_production = env == "production";
// SECURITY: Dev bypass is BLOCKED in production
if std::env::var("ZERO_TRUST_DISABLED").unwrap_or_default() == "1" {
if is_production {
warn!(
"SECURITY: ZERO_TRUST_DISABLED=1 is not allowed in production β blocking request"
);
return Err(StatusCode::FORBIDDEN);
}
warn!("ZERO_TRUST_DISABLED=1 β skipping auth (dev only, NOT for production)");
return Ok(next.run(request).await);
}
// SECURITY: Certain public endpoints are exempt from auth.
// /api/auth/* β wallet challenge issuance + verification (these PRODUCE auth tokens)
// /health, /metrics β infra health checks
let path = request.uri().path();
if path == "/health" || path == "/metrics" || path.starts_with("/api/auth/") {
return Ok(next.run(request).await);
}
// Extract Authorization header
let auth = request.headers().get("authorization");
let token = match auth {
None => {
warn!(path=%path, "Missing Authorization header β rejecting request");
return Err(StatusCode::UNAUTHORIZED);
}
Some(v) => v.to_str().map_err(|_| StatusCode::BAD_REQUEST)?,
};
// Validate Bearer token format
let jwt = token.strip_prefix("Bearer ").ok_or_else(|| {
warn!("Invalid Authorization header format β must be Bearer <token>");
StatusCode::UNAUTHORIZED
})?;
if jwt.is_empty() {
warn!("Empty Bearer token β rejecting");
return Err(StatusCode::UNAUTHORIZED);
}
// PRODUCTION: Full JWT validation with signature verification
// Development: Accept any non-empty token with warning
if is_production {
let secret = std::env::var("JWT_SECRET").map_err(|_| {
warn!("JWT_SECRET not configured in production");
StatusCode::INTERNAL_SERVER_ERROR
})?;
validate_jwt(jwt, &secret)?;
} else {
warn!(path=%path, "Dev mode: JWT signature not verified β non-empty token accepted");
}
Ok(next.run(request).await)
}
/// Validate JWT signature and claims (production enforcement).
/// In production, JWT_SECRET must be set and tokens must be properly signed.
fn validate_jwt(token: &str, secret: &str) -> Result<(), StatusCode> {
// Token structure: header.payload.signature (3 parts)
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
warn!("Malformed JWT: expected 3 parts, got {}", parts.len());
return Err(StatusCode::UNAUTHORIZED);
}
// Decode payload to check expiry
let payload_b64 = parts[1];
let payload_bytes = base64_decode_url(payload_b64).map_err(|_| {
warn!("JWT payload base64 decode failed");
StatusCode::UNAUTHORIZED
})?;
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).map_err(|_| {
warn!("JWT payload JSON parse failed");
StatusCode::UNAUTHORIZED
})?;
// Check expiry
if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
let now = chrono::Utc::now().timestamp();
if now > exp {
warn!("JWT expired at {} (now: {})", exp, now);
return Err(StatusCode::UNAUTHORIZED);
}
}
// HMAC-SHA256 signature verification
let signing_input = format!("{}.{}", parts[0], parts[1]);
let expected_sig = hmac_sha256(secret.as_bytes(), signing_input.as_bytes());
let expected_b64 = base64_encode_url(&expected_sig);
if !constant_time_eq(parts[2].as_bytes(), expected_b64.as_bytes()) {
warn!("JWT signature verification failed");
return Err(StatusCode::UNAUTHORIZED);
}
Ok(())
}
fn base64_decode_url(s: &str) -> Result<Vec<u8>, ()> {
// URL-safe base64 without padding β standard base64 with padding
let padded = match s.len() % 4 {
2 => format!("{s}=="),
3 => format!("{s}="),
_ => s.to_string(),
};
let standard = padded.replace('-', "+").replace('_', "/");
base64_simple_decode(&standard).map_err(|_| ())
}
fn base64_simple_decode(s: &str) -> Result<Vec<u8>, String> {
let mut chars: Vec<u8> = Vec::with_capacity(s.len());
for c in s.chars() {
let v = if c.is_ascii_uppercase() {
c as u8 - b'A'
} else if c.is_ascii_lowercase() {
c as u8 - b'a' + 26
} else if c.is_ascii_digit() {
c as u8 - b'0' + 52
} else if c == '+' || c == '-' {
62
} else if c == '/' || c == '_' {
63
} else if c == '=' {
continue; // standard padding β skip
} else {
return Err(format!("invalid base64 character: {c:?}"));
};
chars.push(v);
}
let mut out = Vec::new();
for chunk in chars.chunks(4) {
if chunk.len() < 2 {
break;
}
out.push((chunk[0] << 2) | (chunk[1] >> 4));
if chunk.len() >= 3 {
out.push((chunk[1] << 4) | (chunk[2] >> 2));
}
if chunk.len() >= 4 {
out.push((chunk[2] << 6) | chunk[3]);
}
}
Ok(out)
}
fn base64_encode_url(bytes: &[u8]) -> String {
let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
let mut out = String::new();
for chunk in bytes.chunks(3) {
let b0 = chunk[0];
let b1 = if chunk.len() > 1 { chunk[1] } else { 0 };
let b2 = if chunk.len() > 2 { chunk[2] } else { 0 };
out.push(chars[(b0 >> 2) as usize] as char);
out.push(chars[((b0 & 3) << 4 | b1 >> 4) as usize] as char);
if chunk.len() > 1 {
out.push(chars[((b1 & 0xf) << 2 | b2 >> 6) as usize] as char);
}
if chunk.len() > 2 {
out.push(chars[(b2 & 0x3f) as usize] as char);
}
}
out.replace('+', "-").replace('/', "_").replace('=', "")
}
fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
use sha2::{Digest, Sha256};
const BLOCK: usize = 64;
let mut k = if key.len() > BLOCK {
Sha256::digest(key).to_vec()
} else {
key.to_vec()
};
k.resize(BLOCK, 0);
let ipad: Vec<u8> = k.iter().map(|b| b ^ 0x36).collect();
let opad: Vec<u8> = k.iter().map(|b| b ^ 0x5c).collect();
let inner = Sha256::digest([ipad.as_slice(), msg].concat());
Sha256::digest([opad.as_slice(), inner.as_slice()].concat()).to_vec()
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0
}
/// Build CORS headers restricted to allowed origins.
/// Call this in main.rs instead of CorsLayer::new().allow_origin(Any).
pub fn allowed_origins() -> Vec<HeaderValue> {
let origins = std::env::var("ALLOWED_ORIGINS")
.unwrap_or_else(|_| "http://localhost:5173,http://localhost:3000".into());
origins
.split(',')
.filter_map(|o| o.trim().parse::<HeaderValue>().ok())
.collect()
}
/// Extract the authenticated caller's wallet address from the JWT in the
/// Authorization header. Returns the `sub` claim (normalised to lowercase).
///
/// Used by per-user auth guards in kyc.rs and privacy.rs to verify the
/// caller is accessing their own data only.
///
/// Always performs full HMAC-SHA256 signature verification when JWT_SECRET
/// is set. If JWT_SECRET is absent (dev mode), falls back to expiry-only
/// check with a warning β matching the behaviour of the outer middleware.
pub fn extract_caller(headers: &axum::http::HeaderMap) -> Result<String, axum::http::StatusCode> {
use axum::http::StatusCode;
let auth_header = headers
.get("authorization")
.ok_or_else(|| {
warn!("extract_caller: missing Authorization header");
StatusCode::UNAUTHORIZED
})?
.to_str()
.map_err(|_| StatusCode::BAD_REQUEST)?;
let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
warn!("extract_caller: invalid Authorization format");
StatusCode::UNAUTHORIZED
})?;
if token.is_empty() {
warn!("extract_caller: empty token");
return Err(StatusCode::UNAUTHORIZED);
}
// Full signature + claims verification when JWT_SECRET is configured.
// Falls back to expiry-only in dev (no secret set) with an explicit warn.
match std::env::var("JWT_SECRET") {
Ok(secret) => {
validate_jwt(token, &secret)?;
}
Err(_) => {
warn!("extract_caller: JWT_SECRET not set β signature not verified (dev mode only)");
// Expiry-only check so dev tokens still expire correctly.
let parts: Vec<&str> = token.split('.').collect();
if parts.len() == 3 {
if let Ok(payload_bytes) = base64_decode_url(parts[1]) {
if let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&payload_bytes)
{
if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
if chrono::Utc::now().timestamp() > exp {
warn!("extract_caller: JWT expired at {exp}");
return Err(StatusCode::UNAUTHORIZED);
}
}
}
}
}
}
}
// Decode payload to extract `sub` (sig already verified above).
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
warn!("extract_caller: malformed JWT ({} parts)", parts.len());
return Err(StatusCode::UNAUTHORIZED);
}
let payload_bytes = base64_decode_url(parts[1]).map_err(|_| {
warn!("extract_caller: base64 decode failed");
StatusCode::UNAUTHORIZED
})?;
let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).map_err(|_| {
warn!("extract_caller: JSON parse failed");
StatusCode::UNAUTHORIZED
})?;
let sub = payload
.get("sub")
.and_then(|v| v.as_str())
.ok_or_else(|| {
warn!("extract_caller: no `sub` claim in JWT");
StatusCode::UNAUTHORIZED
})?
.to_ascii_lowercase();
Ok(sub)
}
|