File size: 13,555 Bytes
1295969
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
//! Zero Trust middleware: SPIFFE SVID + JWT on every request.
//! SECURITY FIX: Auth is now enforced by default. ZERO_TRUST_DISABLED requires
//! explicit opt-in AND is blocked in production (RETROSYNC_ENV=production).
use crate::AppState;
use axum::{
    extract::{Request, State},
    http::{HeaderValue, StatusCode},
    middleware::Next,
    response::Response,
};
use tracing::warn;

// ── HTTP Security Headers middleware ──────────────────────────────────────────
//
// Injected as the outermost layer so every response β€” including 4xx/5xx from
// inner middleware β€” carries the full set of defensive headers.
//
// Headers enforced:
//   X-Content-Type-Options    β€” prevents MIME-sniff attacks
//   X-Frame-Options           β€” blocks clickjacking / framing
//   Referrer-Policy           β€” restricts referrer leakage
//   X-XSS-Protection          β€” legacy XSS filter (belt+suspenders)
//   Strict-Transport-Security β€” forces HTTPS (HSTS); also sent from Replit edge
//   Content-Security-Policy   β€” strict source allowlist; frame-ancestors 'none'
//   Permissions-Policy        β€” opt-out of unused browser APIs
//   Cache-Control             β€” API responses must not be cached by shared caches

pub async fn add_security_headers(request: Request, next: Next) -> Response {
    use axum::http::header::{HeaderName, HeaderValue};

    let mut response = next.run(request).await;
    let headers = response.headers_mut();

    // All values are ASCII string literals known to be valid header values;
    // HeaderValue::from_static() panics only on non-ASCII, which none of these are.
    let security_headers: &[(&str, &str)] = &[
        ("x-content-type-options", "nosniff"),
        ("x-frame-options", "DENY"),
        ("referrer-policy", "strict-origin-when-cross-origin"),
        ("x-xss-protection", "1; mode=block"),
        (
            "strict-transport-security",
            "max-age=31536000; includeSubDomains; preload",
        ),
        // CSP: this is an API server (JSON only) β€” no scripts, frames, or embedded
        // content are ever served, so we use the most restrictive possible policy.
        (
            "content-security-policy",
            "default-src 'none'; frame-ancestors 'none'; base-uri 'none'; form-action 'none'",
        ),
        (
            "permissions-policy",
            "geolocation=(), camera=(), microphone=(), payment=(), usb=(), serial=()",
        ),
        // API responses contain real-time financial/rights data β€” must not be cached.
        (
            "cache-control",
            "no-store, no-cache, must-revalidate, private",
        ),
    ];

    for (name, value) in security_headers {
        if let (Ok(n), Ok(v)) = (
            HeaderName::from_bytes(name.as_bytes()),
            HeaderValue::from_str(value),
        ) {
            headers.insert(n, v);
        }
    }

    response
}

pub async fn verify_zero_trust(
    State(_state): State<AppState>,
    request: Request,
    next: Next,
) -> Result<Response, StatusCode> {
    let env = std::env::var("RETROSYNC_ENV").unwrap_or_else(|_| "development".into());
    let is_production = env == "production";

    // SECURITY: Dev bypass is BLOCKED in production
    if std::env::var("ZERO_TRUST_DISABLED").unwrap_or_default() == "1" {
        if is_production {
            warn!(
                "SECURITY: ZERO_TRUST_DISABLED=1 is not allowed in production β€” blocking request"
            );
            return Err(StatusCode::FORBIDDEN);
        }
        warn!("ZERO_TRUST_DISABLED=1 β€” skipping auth (dev only, NOT for production)");
        return Ok(next.run(request).await);
    }

    // SECURITY: Certain public endpoints are exempt from auth.
    // /api/auth/* β€” wallet challenge issuance + verification (these PRODUCE auth tokens)
    // /health, /metrics β€” infra health checks
    let path = request.uri().path();
    if path == "/health" || path == "/metrics" || path.starts_with("/api/auth/") {
        return Ok(next.run(request).await);
    }

    // Extract Authorization header
    let auth = request.headers().get("authorization");
    let token = match auth {
        None => {
            warn!(path=%path, "Missing Authorization header β€” rejecting request");
            return Err(StatusCode::UNAUTHORIZED);
        }
        Some(v) => v.to_str().map_err(|_| StatusCode::BAD_REQUEST)?,
    };

    // Validate Bearer token format
    let jwt = token.strip_prefix("Bearer ").ok_or_else(|| {
        warn!("Invalid Authorization header format β€” must be Bearer <token>");
        StatusCode::UNAUTHORIZED
    })?;

    if jwt.is_empty() {
        warn!("Empty Bearer token β€” rejecting");
        return Err(StatusCode::UNAUTHORIZED);
    }

    // PRODUCTION: Full JWT validation with signature verification
    // Development: Accept any non-empty token with warning
    if is_production {
        let secret = std::env::var("JWT_SECRET").map_err(|_| {
            warn!("JWT_SECRET not configured in production");
            StatusCode::INTERNAL_SERVER_ERROR
        })?;
        validate_jwt(jwt, &secret)?;
    } else {
        warn!(path=%path, "Dev mode: JWT signature not verified β€” non-empty token accepted");
    }

    Ok(next.run(request).await)
}

/// Validate JWT signature and claims (production enforcement).
/// In production, JWT_SECRET must be set and tokens must be properly signed.
fn validate_jwt(token: &str, secret: &str) -> Result<(), StatusCode> {
    // Token structure: header.payload.signature (3 parts)
    let parts: Vec<&str> = token.split('.').collect();
    if parts.len() != 3 {
        warn!("Malformed JWT: expected 3 parts, got {}", parts.len());
        return Err(StatusCode::UNAUTHORIZED);
    }

    // Decode payload to check expiry
    let payload_b64 = parts[1];
    let payload_bytes = base64_decode_url(payload_b64).map_err(|_| {
        warn!("JWT payload base64 decode failed");
        StatusCode::UNAUTHORIZED
    })?;

    let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).map_err(|_| {
        warn!("JWT payload JSON parse failed");
        StatusCode::UNAUTHORIZED
    })?;

    // Check expiry
    if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
        let now = chrono::Utc::now().timestamp();
        if now > exp {
            warn!("JWT expired at {} (now: {})", exp, now);
            return Err(StatusCode::UNAUTHORIZED);
        }
    }

    // HMAC-SHA256 signature verification
    let signing_input = format!("{}.{}", parts[0], parts[1]);
    let expected_sig = hmac_sha256(secret.as_bytes(), signing_input.as_bytes());
    let expected_b64 = base64_encode_url(&expected_sig);

    if !constant_time_eq(parts[2].as_bytes(), expected_b64.as_bytes()) {
        warn!("JWT signature verification failed");
        return Err(StatusCode::UNAUTHORIZED);
    }

    Ok(())
}

fn base64_decode_url(s: &str) -> Result<Vec<u8>, ()> {
    // URL-safe base64 without padding β†’ standard base64 with padding
    let padded = match s.len() % 4 {
        2 => format!("{s}=="),
        3 => format!("{s}="),
        _ => s.to_string(),
    };
    let standard = padded.replace('-', "+").replace('_', "/");
    base64_simple_decode(&standard).map_err(|_| ())
}

fn base64_simple_decode(s: &str) -> Result<Vec<u8>, String> {
    let mut chars: Vec<u8> = Vec::with_capacity(s.len());
    for c in s.chars() {
        let v = if c.is_ascii_uppercase() {
            c as u8 - b'A'
        } else if c.is_ascii_lowercase() {
            c as u8 - b'a' + 26
        } else if c.is_ascii_digit() {
            c as u8 - b'0' + 52
        } else if c == '+' || c == '-' {
            62
        } else if c == '/' || c == '_' {
            63
        } else if c == '=' {
            continue; // standard padding β€” skip
        } else {
            return Err(format!("invalid base64 character: {c:?}"));
        };
        chars.push(v);
    }

    let mut out = Vec::new();
    for chunk in chars.chunks(4) {
        if chunk.len() < 2 {
            break;
        }
        out.push((chunk[0] << 2) | (chunk[1] >> 4));
        if chunk.len() >= 3 {
            out.push((chunk[1] << 4) | (chunk[2] >> 2));
        }
        if chunk.len() >= 4 {
            out.push((chunk[2] << 6) | chunk[3]);
        }
    }
    Ok(out)
}

fn base64_encode_url(bytes: &[u8]) -> String {
    let chars = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
    let mut out = String::new();
    for chunk in bytes.chunks(3) {
        let b0 = chunk[0];
        let b1 = if chunk.len() > 1 { chunk[1] } else { 0 };
        let b2 = if chunk.len() > 2 { chunk[2] } else { 0 };
        out.push(chars[(b0 >> 2) as usize] as char);
        out.push(chars[((b0 & 3) << 4 | b1 >> 4) as usize] as char);
        if chunk.len() > 1 {
            out.push(chars[((b1 & 0xf) << 2 | b2 >> 6) as usize] as char);
        }
        if chunk.len() > 2 {
            out.push(chars[(b2 & 0x3f) as usize] as char);
        }
    }
    out.replace('+', "-").replace('/', "_").replace('=', "")
}

fn hmac_sha256(key: &[u8], msg: &[u8]) -> Vec<u8> {
    use sha2::{Digest, Sha256};
    const BLOCK: usize = 64;
    let mut k = if key.len() > BLOCK {
        Sha256::digest(key).to_vec()
    } else {
        key.to_vec()
    };
    k.resize(BLOCK, 0);
    let ipad: Vec<u8> = k.iter().map(|b| b ^ 0x36).collect();
    let opad: Vec<u8> = k.iter().map(|b| b ^ 0x5c).collect();
    let inner = Sha256::digest([ipad.as_slice(), msg].concat());
    Sha256::digest([opad.as_slice(), inner.as_slice()].concat()).to_vec()
}

fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
    if a.len() != b.len() {
        return false;
    }
    a.iter().zip(b).fold(0u8, |acc, (x, y)| acc | (x ^ y)) == 0
}

/// Build CORS headers restricted to allowed origins.
/// Call this in main.rs instead of CorsLayer::new().allow_origin(Any).
pub fn allowed_origins() -> Vec<HeaderValue> {
    let origins = std::env::var("ALLOWED_ORIGINS")
        .unwrap_or_else(|_| "http://localhost:5173,http://localhost:3000".into());
    origins
        .split(',')
        .filter_map(|o| o.trim().parse::<HeaderValue>().ok())
        .collect()
}

/// Extract the authenticated caller's wallet address from the JWT in the
/// Authorization header.  Returns the `sub` claim (normalised to lowercase).
///
/// Used by per-user auth guards in kyc.rs and privacy.rs to verify the
/// caller is accessing their own data only.
///
/// Always performs full HMAC-SHA256 signature verification when JWT_SECRET
/// is set.  If JWT_SECRET is absent (dev mode), falls back to expiry-only
/// check with a warning β€” matching the behaviour of the outer middleware.
pub fn extract_caller(headers: &axum::http::HeaderMap) -> Result<String, axum::http::StatusCode> {
    use axum::http::StatusCode;

    let auth_header = headers
        .get("authorization")
        .ok_or_else(|| {
            warn!("extract_caller: missing Authorization header");
            StatusCode::UNAUTHORIZED
        })?
        .to_str()
        .map_err(|_| StatusCode::BAD_REQUEST)?;

    let token = auth_header.strip_prefix("Bearer ").ok_or_else(|| {
        warn!("extract_caller: invalid Authorization format");
        StatusCode::UNAUTHORIZED
    })?;

    if token.is_empty() {
        warn!("extract_caller: empty token");
        return Err(StatusCode::UNAUTHORIZED);
    }

    // Full signature + claims verification when JWT_SECRET is configured.
    // Falls back to expiry-only in dev (no secret set) with an explicit warn.
    match std::env::var("JWT_SECRET") {
        Ok(secret) => {
            validate_jwt(token, &secret)?;
        }
        Err(_) => {
            warn!("extract_caller: JWT_SECRET not set β€” signature not verified (dev mode only)");
            // Expiry-only check so dev tokens still expire correctly.
            let parts: Vec<&str> = token.split('.').collect();
            if parts.len() == 3 {
                if let Ok(payload_bytes) = base64_decode_url(parts[1]) {
                    if let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&payload_bytes)
                    {
                        if let Some(exp) = payload.get("exp").and_then(|v| v.as_i64()) {
                            if chrono::Utc::now().timestamp() > exp {
                                warn!("extract_caller: JWT expired at {exp}");
                                return Err(StatusCode::UNAUTHORIZED);
                            }
                        }
                    }
                }
            }
        }
    }

    // Decode payload to extract `sub` (sig already verified above).
    let parts: Vec<&str> = token.split('.').collect();
    if parts.len() != 3 {
        warn!("extract_caller: malformed JWT ({} parts)", parts.len());
        return Err(StatusCode::UNAUTHORIZED);
    }

    let payload_bytes = base64_decode_url(parts[1]).map_err(|_| {
        warn!("extract_caller: base64 decode failed");
        StatusCode::UNAUTHORIZED
    })?;

    let payload: serde_json::Value = serde_json::from_slice(&payload_bytes).map_err(|_| {
        warn!("extract_caller: JSON parse failed");
        StatusCode::UNAUTHORIZED
    })?;

    let sub = payload
        .get("sub")
        .and_then(|v| v.as_str())
        .ok_or_else(|| {
            warn!("extract_caller: no `sub` claim in JWT");
            StatusCode::UNAUTHORIZED
        })?
        .to_ascii_lowercase();

    Ok(sub)
}