import os import json import math import time import uuid import torch import torch.nn as nn import torch.nn.functional as F import torchaudio import numpy as np import random import tempfile import gradio as gr from fastapi import FastAPI, UploadFile, File, Form from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware from datetime import datetime, timedelta if not hasattr(torchaudio, 'list_audio_backends'): torchaudio.list_audio_backends = lambda: ["soundfile"] from transformers import AutoModel # Config CKPT_PATH = 'aam_best.pt' DB_PATH = 'voiceprint_db.json' MODEL_NAME = 'microsoft/unispeech-sat-base-sv' SAMPLE_RATE = 16000 MAX_SEC = 4 MAX_LEN = SAMPLE_RATE * MAX_SEC THRESHOLD = 0.3500 DEVICE = torch.device('cpu') NUM_CLEAN_SAMPLES = 6 NUM_NOISY_COPIES = 4 MAX_ATTEMPTS = 3 LOCKOUT_MINUTES = 5 COOLDOWN_SECONDS = 3 ANTISPOOFING_THRESHOLD = 0.02 # Challenge word pool (simple, short, easy to pronounce) CHALLENGE_WORDS = [ 'Red', 'Blue', 'Gold', 'Star', 'Water', 'Moon', 'Fire', 'Green', 'Black', 'White', 'Sun', 'Rain', 'Tree', 'Fish', 'Bird', 'Stone', 'Wind', 'Cloud', 'Light', 'Sound' ] # Session steps SESSION_STEPS = { 'STARTED': 'started', 'VERIFIED': 'verified', 'LIVENESS_PENDING': 'liveness_pending', 'AUTHENTICATED': 'authenticated', 'TRANSACTION_PENDING': 'transaction_pending', 'COMPLETE': 'complete', 'DENIED': 'denied' } # AAM-Softmax model class AAMSoftmax(nn.Module): def __init__(self, in_features, num_classes, margin=0.2, scale=30.0): super().__init__() self.margin = margin self.scale = scale self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features)) nn.init.xavier_uniform_(self.weight) self.cos_m = math.cos(margin) self.sin_m = math.sin(margin) self.threshold = math.cos(math.pi - margin) self.mm = math.sin(math.pi - margin) * margin def forward(self, embeddings, labels=None): embeddings = F.normalize(embeddings, p=2, dim=1) weight = F.normalize(self.weight, p=2, dim=1) cosine = F.linear(embeddings, weight) if labels is None: return cosine sine = torch.sqrt(1.0 - torch.clamp(cosine * cosine, 0, 1)) phi = cosine * self.cos_m - sine * self.sin_m phi = torch.where(cosine > self.threshold, phi, cosine - self.mm) one_hot = F.one_hot(labels, cosine.size(1)).float() output = (one_hot * phi) + ((1.0 - one_hot) * cosine) return output * self.scale class SpeakerClassifier(nn.Module): def __init__(self, input_dim=768, hidden_dim=512, num_classes=227): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.relu = nn.ReLU() self.aam = AAMSoftmax(hidden_dim, num_classes) def forward(self, x, labels=None): x = self.relu(self.fc1(x)) return self.aam(x, labels) def extract_embedding(self, x): return self.relu(self.fc1(x)) # Load models print("Loading UniSpeech-SAT base model...") base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) base_model.eval() for param in base_model.parameters(): param.requires_grad = False print("Loading AAM-Softmax checkpoint...") ckpt = torch.load(CKPT_PATH, map_location=DEVICE) print(f"Checkpoint type: {type(ckpt)}") if isinstance(ckpt, dict): print(f"Checkpoint keys: {list(ckpt.keys())}") num_classes = 227 if isinstance(ckpt, dict): if 'num_classes' in ckpt: num_classes = ckpt['num_classes'] elif 'num_speakers' in ckpt: num_classes = ckpt['num_speakers'] classifier = SpeakerClassifier(input_dim=768, hidden_dim=512, num_classes=num_classes).to(DEVICE) loaded = False if isinstance(ckpt, dict): for key in ['classifier_state', 'classifier_state_dict', 'model_state_dict', 'state_dict', 'model']: if key in ckpt: try: classifier.load_state_dict(ckpt[key]) print(f"Loaded classifier from key: '{key}'") loaded = True break except Exception as e: print(f"Key '{key}' found but failed: {e}") if not loaded: sample_keys = list(ckpt.keys())[:5] if any('.' in k for k in sample_keys): try: classifier.load_state_dict(ckpt) print("Loaded classifier directly from checkpoint dict") loaded = True except: try: classifier.load_state_dict(ckpt, strict=False) print("Loaded classifier with strict=False") loaded = True except Exception as e2: print(f"Direct load failed: {e2}") if 'base_model_state' in ckpt: try: base_model.load_state_dict(ckpt['base_model_state'], strict=False) print("Loaded fine-tuned base model weights") except: pass elif isinstance(ckpt, nn.Module): classifier = ckpt.to(DEVICE) print("Loaded classifier directly (model object)") loaded = True if not loaded: print("WARNING: Could not load classifier weights. Using random init.") classifier.eval() print(f"Models ready. num_classes={num_classes}, loaded={loaded}") # Database def load_db(): if os.path.exists(DB_PATH): with open(DB_PATH, 'r') as f: return json.load(f) return {} def save_db(db): with open(DB_PATH, 'w') as f: json.dump(db, f, indent=2, default=str) # Audio processing def load_audio(audio_input): if isinstance(audio_input, tuple): sr, audio_np = audio_input wav = torch.tensor(audio_np, dtype=torch.float32) if wav.dim() == 1: wav = wav.unsqueeze(0) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) wav = wav.squeeze(0) if wav.abs().max() > 1.0: wav = wav / 32768.0 if sr != SAMPLE_RATE: wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) elif isinstance(audio_input, str): wav, sr = torchaudio.load(audio_input) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) wav = wav.squeeze(0) if sr != SAMPLE_RATE: wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) elif isinstance(audio_input, bytes): with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: tmp.write(audio_input) tmp_path = tmp.name wav, sr = torchaudio.load(tmp_path) os.unlink(tmp_path) if wav.shape[0] > 1: wav = wav.mean(dim=0, keepdim=True) wav = wav.squeeze(0) if sr != SAMPLE_RATE: wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) else: raise ValueError(f"Unsupported audio input type: {type(audio_input)}") if wav.shape[0] > MAX_LEN: wav = wav[:MAX_LEN] elif wav.shape[0] < MAX_LEN: wav = F.pad(wav, (0, MAX_LEN - wav.shape[0])) return wav def extract_embedding(wav_tensor): with torch.no_grad(): wav = wav_tensor.unsqueeze(0).to(DEVICE) outputs = base_model(wav) base_emb = outputs.last_hidden_state.mean(dim=1) embedding = classifier.extract_embedding(base_emb) embedding = F.normalize(embedding, p=2, dim=1) return embedding.squeeze(0).cpu().numpy() def add_noise(wav_tensor, noise_level=0.005): noise = torch.randn_like(wav_tensor) * noise_level return wav_tensor + noise # Liveness detection def check_liveness(wav_tensor): wav_np = wav_tensor.numpy() rms = np.sqrt(np.mean(wav_np ** 2)) if rms < 0.001: return False, "Audio too quiet" std = np.std(wav_np) if std < 0.001: return False, "Audio lacks variation" zero_crossings = np.sum(np.abs(np.diff(np.sign(wav_np)))) / (2 * len(wav_np)) if zero_crossings < 0.01: return False, "Abnormal audio pattern" non_silent = np.abs(wav_np) > 0.01 speech_ratio = np.sum(non_silent) / len(wav_np) if speech_ratio < 0.1: return False, "Insufficient speech content" return True, "Liveness check passed" # Antispoofing def check_antispoofing(wav_tensor): wav_np = wav_tensor.numpy() fft = np.fft.rfft(wav_np) magnitude = np.abs(fft) magnitude = magnitude[magnitude > 0] if len(magnitude) == 0: return False, "No frequency content" geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10))) arithmetic_mean = np.mean(magnitude) spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10) if spectral_flatness > (1.0 - ANTISPOOFING_THRESHOLD): return False, "Possible synthetic audio" frame_size = 1600 if len(wav_np) >= frame_size * 3: frames = [wav_np[i:i + frame_size] for i in range(0, len(wav_np) - frame_size, frame_size)] frame_energies = [np.sqrt(np.mean(f ** 2)) for f in frames] energy_std = np.std(frame_energies) if energy_std < 0.001: return False, "Unnaturally uniform energy" return True, "Antispoofing check passed" # Security: lockout and cooldown attempt_tracker = {} def check_security(user_id): now = datetime.now() if user_id not in attempt_tracker: return True, "OK" tracker = attempt_tracker[user_id] if "locked_until" in tracker and tracker["locked_until"]: locked_until = datetime.fromisoformat(tracker["locked_until"]) if now < locked_until: remaining = (locked_until - now).seconds return False, f"Account locked. Try again in {remaining} seconds." else: tracker["count"] = 0 tracker["locked_until"] = None if "last_attempt" in tracker and tracker["last_attempt"]: last = datetime.fromisoformat(tracker["last_attempt"]) elapsed = (now - last).total_seconds() if elapsed < COOLDOWN_SECONDS: return False, f"Please wait {COOLDOWN_SECONDS - int(elapsed)} seconds." return True, "OK" def record_attempt(user_id, success): now = datetime.now() if user_id not in attempt_tracker: attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None} tracker = attempt_tracker[user_id] tracker["last_attempt"] = now.isoformat() if success: tracker["count"] = 0 tracker["locked_until"] = None else: tracker["count"] += 1 if tracker["count"] >= MAX_ATTEMPTS: tracker["locked_until"] = (now + timedelta(minutes=LOCKOUT_MINUTES)).isoformat() # Generate random challenge (2 words from pool) def generate_challenge(): words = random.sample(CHALLENGE_WORDS, 2) return ' '.join(words) # Session storage (in-memory) sessions = {} def create_session(user_id): session_id = str(uuid.uuid4()) sessions[session_id] = { "session_id": session_id, "user_id": user_id.strip().upper(), "step": SESSION_STEPS['STARTED'], "challenge_phrase": None, "full_name": None, "similarity": None, "created_at": datetime.now().isoformat(), "expires_at": (datetime.now() + timedelta(minutes=5)).isoformat() } return sessions[session_id] def get_session(session_id): if session_id not in sessions: return None session = sessions[session_id] if datetime.now() > datetime.fromisoformat(session["expires_at"]): del sessions[session_id] return None return session # Enroll def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=NUM_CLEAN_SAMPLES): if not user_id or not user_id.strip(): return "Error: User ID is required." if not full_name or not full_name.strip(): return "Error: Full Name is required." if audio_input is None: return "Error: No audio recorded." user_id = user_id.strip().upper() full_name = full_name.strip() try: wav = load_audio(audio_input) is_live, live_msg = check_liveness(wav) if not is_live: return f"Enrollment failed: {live_msg}" is_real, spoof_msg = check_antispoofing(wav) if not is_real: return f"Enrollment failed: {spoof_msg}" clean_emb = extract_embedding(wav) noisy_embeddings = [] for i in range(NUM_NOISY_COPIES): noise_level = 0.003 + (i * 0.002) noisy_wav = add_noise(wav, noise_level) noisy_emb = extract_embedding(noisy_wav) noisy_embeddings.append(noisy_emb) db = load_db() if user_id not in db: db[user_id] = { "full_name": full_name, "enrolled_at": datetime.now().isoformat(), "sample_embeddings": [], "voiceprint": None, "status": "enrolling", "samples_collected": 0 } sample_data = { "clean": clean_emb.tolist(), "noisy": [e.tolist() for e in noisy_embeddings] } db[user_id]["sample_embeddings"].append(sample_data) db[user_id]["samples_collected"] = len(db[user_id]["sample_embeddings"]) db[user_id]["full_name"] = full_name samples_collected = db[user_id]["samples_collected"] if samples_collected >= total_samples: all_embeddings = [] for sample in db[user_id]["sample_embeddings"]: all_embeddings.append(np.array(sample["clean"])) for noisy in sample["noisy"]: all_embeddings.append(np.array(noisy)) avg_embedding = np.mean(all_embeddings, axis=0) avg_embedding = avg_embedding / (np.linalg.norm(avg_embedding) + 1e-10) db[user_id]["voiceprint"] = avg_embedding.tolist() db[user_id]["status"] = "enrolled" db[user_id]["completed_at"] = datetime.now().isoformat() db[user_id]["sample_embeddings"] = [] save_db(db) return f"Enrollment COMPLETE for {full_name} ({user_id}). Voiceprint created from {total_samples} samples ({total_samples * (1 + NUM_NOISY_COPIES)} embeddings averaged)." else: save_db(db) remaining = total_samples - samples_collected return f"Sample {samples_collected}/{total_samples} recorded for {full_name}. {remaining} more sample(s) needed." except Exception as e: return f"Enrollment error: {str(e)}" # Verify def verify_speaker(audio_input, user_id): if not user_id or not user_id.strip(): return "Error: User ID is required." if audio_input is None: return "Error: No audio recorded." user_id = user_id.strip().upper() allowed, sec_msg = check_security(user_id) if not allowed: return f"ACCESS DENIED: {sec_msg}" db = load_db() if user_id not in db: return f"Error: User '{user_id}' not found." if db[user_id].get("status") != "enrolled": samples = db[user_id].get("samples_collected", 0) remaining = NUM_CLEAN_SAMPLES - samples return f"Error: Enrollment incomplete. {remaining} more sample(s) needed." try: wav = load_audio(audio_input) is_live, live_msg = check_liveness(wav) if not is_live: record_attempt(user_id, False) return f"ACCESS DENIED: {live_msg}" is_real, spoof_msg = check_antispoofing(wav) if not is_real: record_attempt(user_id, False) return f"ACCESS DENIED: {spoof_msg}" test_emb = extract_embedding(wav) stored_emb = np.array(db[user_id]["voiceprint"]) similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10)) if similarity >= THRESHOLD: record_attempt(user_id, True) full_name = db[user_id].get("full_name", user_id) return (f"ACCESS GRANTED\nWelcome, {full_name}\n" f"Confidence: {similarity:.4f} (threshold: {THRESHOLD})\n" f"Liveness: Passed | Antispoofing: Passed") else: record_attempt(user_id, False) tracker = attempt_tracker.get(user_id, {}) attempts_left = MAX_ATTEMPTS - tracker.get("count", 0) msg = f"ACCESS DENIED\nVoice does not match.\nSimilarity: {similarity:.4f} (threshold: {THRESHOLD})\n" if attempts_left > 0: msg += f"Attempts remaining: {attempts_left}" else: msg += f"Account locked for {LOCKOUT_MINUTES} minutes." return msg except Exception as e: return f"Verification error: {str(e)}" # User management def list_users(): db = load_db() if not db: return "No users enrolled yet." lines = ["=== Enrolled Users ===\n"] for uid, data in db.items(): name = data.get("full_name", "Unknown") status = data.get("status", "unknown") enrolled = data.get("enrolled_at", "N/A") samples = data.get("samples_collected", 0) lines.append(f"ID: {uid} | Name: {name} | Status: {status} | Samples: {samples} | Enrolled: {enrolled}") return "\n".join(lines) def delete_user(user_id): if not user_id or not user_id.strip(): return "Error: User ID is required." user_id = user_id.strip().upper() db = load_db() if user_id not in db: return f"Error: User '{user_id}' not found." name = db[user_id].get("full_name", user_id) del db[user_id] save_db(db) if user_id in attempt_tracker: del attempt_tracker[user_id] return f"User '{name}' ({user_id}) deleted." def reset_lockout(user_id): if not user_id or not user_id.strip(): return "Error: User ID is required." user_id = user_id.strip().upper() if user_id in attempt_tracker: attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None} return f"Lockout reset for {user_id}." return f"No lockout record for {user_id}." # CREATE FASTAPI APP FIRST (before Gradio) app = FastAPI(title="ATM Voice Authentication API", version="1.0.0") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Health check @app.get("/api/health") async def health_check(): return {"status": "healthy", "model": "UniSpeech-SAT + AAM-Softmax", "threshold": THRESHOLD, "device": str(DEVICE), "timestamp": datetime.now().isoformat()} # Basic enroll endpoint @app.post("/api/enroll") async def api_enroll(audio: UploadFile = File(...), user_id: str = Form(...), full_name: str = Form(...)): try: audio_bytes = await audio.read() with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name result = enroll_sample(tmp_path, user_id, full_name, 1) os.unlink(tmp_path) db = load_db() uid = user_id.strip().upper() samples_collected = db.get(uid, {}).get("samples_collected", 0) is_complete = db.get(uid, {}).get("status") == "enrolled" return JSONResponse(content={"success": "error" not in result.lower() and "failed" not in result.lower(), "message": result, "user_id": uid, "samples_collected": samples_collected if not is_complete else NUM_CLEAN_SAMPLES, "samples_required": NUM_CLEAN_SAMPLES, "enrollment_complete": is_complete}) except Exception as e: return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"}) # Basic verify endpoint @app.post("/api/verify") async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)): try: audio_bytes = await audio.read() with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name uid = user_id.strip().upper() allowed, sec_msg = check_security(uid) if not allowed: os.unlink(tmp_path) return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": sec_msg, "locked": True}) db = load_db() if uid not in db: os.unlink(tmp_path) return JSONResponse(content={"success": False, "message": f"User '{uid}' not found."}) if db[uid].get("status") != "enrolled": os.unlink(tmp_path) return JSONResponse(content={"success": False, "message": "Enrollment incomplete."}) wav = load_audio(tmp_path) os.unlink(tmp_path) is_live, live_msg = check_liveness(wav) if not is_live: record_attempt(uid, False) return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": live_msg, "liveness_passed": False}) is_real, spoof_msg = check_antispoofing(wav) if not is_real: record_attempt(uid, False) return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": spoof_msg, "antispoofing_passed": False}) test_emb = extract_embedding(wav) stored_emb = np.array(db[uid]["voiceprint"]) similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10)) granted = similarity >= THRESHOLD record_attempt(uid, granted) tracker = attempt_tracker.get(uid, {}) attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0)) response = {"success": True, "access_granted": granted, "user_id": uid, "full_name": db[uid].get("full_name", uid), "similarity": round(similarity, 4), "threshold": THRESHOLD, "liveness_passed": True, "antispoofing_passed": True, "attempts_remaining": attempts_remaining if not granted else MAX_ATTEMPTS, "locked": attempts_remaining == 0 and not granted} if granted: response["message"] = "Access granted. Voice verified." elif attempts_remaining > 0: response["message"] = f"Voice does not match. {attempts_remaining} attempt(s) remaining." else: response["message"] = f"Account locked for {LOCKOUT_MINUTES} minutes." return JSONResponse(content=response) except Exception as e: return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"}) # List users @app.get("/api/users") async def api_list_users(): db = load_db() users = [] for uid, data in db.items(): users.append({"user_id": uid, "full_name": data.get("full_name", "Unknown"), "status": data.get("status", "unknown"), "samples_collected": data.get("samples_collected", 0), "enrolled_at": data.get("enrolled_at", None), "completed_at": data.get("completed_at", None)}) return JSONResponse(content={"success": True, "users": users, "total": len(users)}) # Delete user @app.delete("/api/users/{user_id}") async def api_delete_user(user_id: str): result = delete_user(user_id) success = "error" not in result.lower() return JSONResponse(content={"success": success, "message": result}) # Reset lockout @app.post("/api/reset-lockout") async def api_reset_lockout(user_id: str = Form(...)): result = reset_lockout(user_id) return JSONResponse(content={"success": True, "message": result}) # Session: Start @app.post("/api/session/start") async def session_start(user_id: str = Form(...)): uid = user_id.strip().upper() db = load_db() if uid not in db: return JSONResponse(content={"success": False, "message": f"User '{uid}' not found. Please enroll first."}) if db[uid].get("status") != "enrolled": return JSONResponse(content={"success": False, "message": "Enrollment incomplete."}) allowed, sec_msg = check_security(uid) if not allowed: return JSONResponse(content={"success": False, "message": sec_msg, "locked": True}) session = create_session(uid) return JSONResponse(content={"success": True, "session_id": session["session_id"], "user_id": uid, "message": "Session started. Please provide a voice sample to verify your identity.", "next_step": "verify", "instruction": "Record your voice and send it to /api/session/verify"}) # Session: Verify identity @app.post("/api/session/verify") async def session_verify(audio: UploadFile = File(...), session_id: str = Form(...)): session = get_session(session_id) if not session: return JSONResponse(content={"success": False, "message": "Session expired or not found."}) if session["step"] != SESSION_STEPS['STARTED']: return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"}) uid = session["user_id"] allowed, sec_msg = check_security(uid) if not allowed: session["step"] = SESSION_STEPS['DENIED'] return JSONResponse(content={"success": False, "message": sec_msg, "locked": True}) try: audio_bytes = await audio.read() with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name wav = load_audio(tmp_path) os.unlink(tmp_path) is_live, live_msg = check_liveness(wav) if not is_live: record_attempt(uid, False) return JSONResponse(content={"success": True, "verified": False, "message": live_msg}) is_real, spoof_msg = check_antispoofing(wav) if not is_real: record_attempt(uid, False) return JSONResponse(content={"success": True, "verified": False, "message": spoof_msg}) test_emb = extract_embedding(wav) db = load_db() stored_emb = np.array(db[uid]["voiceprint"]) similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10)) if similarity >= THRESHOLD: record_attempt(uid, True) full_name = db[uid].get("full_name", uid) challenge = generate_challenge() session["step"] = SESSION_STEPS['LIVENESS_PENDING'] session["full_name"] = full_name session["similarity"] = round(similarity, 4) session["challenge_phrase"] = challenge return JSONResponse(content={"success": True, "verified": True, "greeting": f"Welcome, {full_name}", "full_name": full_name, "similarity": round(similarity, 4), "next_step": "liveness", "challenge_phrase": challenge, "instruction": f"Say these words: {challenge}", "message": f"Voice verified. Welcome, {full_name}. For security, please say these words: {challenge}"}) else: record_attempt(uid, False) tracker = attempt_tracker.get(uid, {}) attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0)) locked = attempts_remaining == 0 if locked: session["step"] = SESSION_STEPS['DENIED'] return JSONResponse(content={"success": True, "verified": False, "similarity": round(similarity, 4), "attempts_remaining": attempts_remaining, "locked": locked, "message": f"Voice does not match. {attempts_remaining} attempt(s) remaining." if not locked else f"Account locked for {LOCKOUT_MINUTES} minutes."}) except Exception as e: return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"}) # Session: Liveness check @app.post("/api/session/liveness") async def session_liveness(audio: UploadFile = File(...), session_id: str = Form(...)): session = get_session(session_id) if not session: return JSONResponse(content={"success": False, "message": "Session expired or not found."}) if session["step"] != SESSION_STEPS['LIVENESS_PENDING']: return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"}) uid = session["user_id"] try: audio_bytes = await audio.read() with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name wav = load_audio(tmp_path) os.unlink(tmp_path) is_live, live_msg = check_liveness(wav) if not is_live: return JSONResponse(content={"success": True, "liveness_passed": False, "message": live_msg}) is_real, spoof_msg = check_antispoofing(wav) if not is_real: return JSONResponse(content={"success": True, "liveness_passed": False, "message": spoof_msg}) test_emb = extract_embedding(wav) db = load_db() stored_emb = np.array(db[uid]["voiceprint"]) similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10)) if similarity >= THRESHOLD: session["step"] = SESSION_STEPS['AUTHENTICATED'] full_name = session["full_name"] return JSONResponse(content={"success": True, "liveness_passed": True, "authenticated": True, "full_name": full_name, "similarity": round(similarity, 4), "next_step": "transaction", "instruction": "How much would you like to withdraw?", "message": f"Liveness confirmed. You are fully authenticated, {full_name}. How much would you like to withdraw?"}) else: return JSONResponse(content={"success": True, "liveness_passed": False, "message": "Voice mismatch during liveness check. Please try again.", "challenge_phrase": session["challenge_phrase"], "instruction": f"Please say these words again: {session['challenge_phrase']}"}) except Exception as e: return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"}) # Session: Transaction @app.post("/api/session/transaction") async def session_transaction(session_id: str = Form(...), amount: str = Form(...)): session = get_session(session_id) if not session: return JSONResponse(content={"success": False, "message": "Session expired or not found."}) if session["step"] != SESSION_STEPS['AUTHENTICATED']: return JSONResponse(content={"success": False, "message": f"Not authenticated. Current step: {session['step']}"}) full_name = session["full_name"] session["step"] = SESSION_STEPS['COMPLETE'] return JSONResponse(content={"success": True, "transaction_approved": True, "full_name": full_name, "amount": amount, "message": f"Transaction approved. {full_name}, you are withdrawing {amount} cedis. Please collect your cash.", "instruction": "Transaction complete. Session ended.", "note": "In production, this step communicates with the banks core system to process the actual withdrawal."}) # Session status @app.get("/api/session/{session_id}") async def session_status(session_id: str): session = get_session(session_id) if not session: return JSONResponse(content={"success": False, "message": "Session expired or not found."}) return JSONResponse(content={"success": True, "session_id": session["session_id"], "user_id": session["user_id"], "step": session["step"], "full_name": session["full_name"], "challenge_phrase": session["challenge_phrase"], "created_at": session["created_at"], "expires_at": session["expires_at"]}) # Gradio interface with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ATM Voice Authentication System ### Voice-Based Speaker Verification for Banking Security Voice biometric authentication system for secure ATM access """) with gr.Tabs(): with gr.Tab("Enroll"): gr.Markdown("### Enroll New User\nRecord **6 voice samples** to create your voiceprint. Speak naturally for 3-4 seconds each time.") with gr.Row(): with gr.Column(): enroll_audio = gr.Audio(label="Record Voice Sample", sources=["microphone", "upload"], type="numpy") enroll_user_id = gr.Textbox(label="User ID (e.g., ATM_001)", placeholder="ATM_001") enroll_name = gr.Textbox(label="Full Name", placeholder="Jochebed Fafa") enroll_sample_num = gr.Number(label="Sample Number (1-6)", value=1, minimum=1, maximum=6, step=1) enroll_btn = gr.Button("Enroll Sample", variant="primary") with gr.Column(): enroll_result = gr.Textbox(label="Result", lines=4, interactive=False) enroll_btn.click(fn=enroll_sample, inputs=[enroll_audio, enroll_user_id, enroll_name, enroll_sample_num], outputs=enroll_result) with gr.Tab("Verify"): gr.Markdown("### Verify Identity\nRecord your voice to verify against your enrolled voiceprint.") with gr.Row(): with gr.Column(): verify_audio = gr.Audio(label="Record Voice", sources=["microphone", "upload"], type="numpy") verify_user_id = gr.Textbox(label="User ID", placeholder="ATM_001") verify_btn = gr.Button("Verify", variant="primary") with gr.Column(): verify_result = gr.Textbox(label="Result", lines=6, interactive=False) verify_btn.click(fn=verify_speaker, inputs=[verify_audio, verify_user_id], outputs=verify_result) with gr.Tab("Users"): gr.Markdown("### Manage Enrolled Users") list_btn = gr.Button("List All Users") users_output = gr.Textbox(label="Enrolled Users", lines=10, interactive=False) list_btn.click(fn=list_users, outputs=users_output) gr.Markdown("---") with gr.Row(): with gr.Column(): del_user_id = gr.Textbox(label="User ID to Delete", placeholder="ATM_001") del_btn = gr.Button("Delete User", variant="stop") del_result = gr.Textbox(label="Result", interactive=False) del_btn.click(fn=delete_user, inputs=del_user_id, outputs=del_result) with gr.Column(): reset_user_id = gr.Textbox(label="User ID to Reset Lockout", placeholder="ATM_001") reset_btn = gr.Button("Reset Lockout", variant="secondary") reset_result = gr.Textbox(label="Result", interactive=False) reset_btn.click(fn=reset_lockout, inputs=reset_user_id, outputs=reset_result) with gr.Tab("API Docs"): gr.Markdown(""" ### REST API Endpoints **Base URL:** `https://amfafa-voice-authentication-sys.hf.space` --- #### Basic Endpoints - `POST /api/enroll` - Enroll a voice sample (audio, user_id, full_name) - `POST /api/verify` - Verify a voice (audio, user_id) - `GET /api/users` - List enrolled users - `DELETE /api/users/{user_id}` - Delete a user - `GET /api/health` - Health check --- #### Session-Based Voice Authentication Flow - `POST /api/session/start` - Start session (user_id) - `POST /api/session/verify` - Verify identity (audio, session_id) - Returns greeting + challenge words - `POST /api/session/liveness` - Liveness check (audio, session_id) - Returns authenticated or denied - `POST /api/session/transaction` - Confirm transaction (amount, session_id) - `GET /api/session/{session_id}` - Check session status """) # Mount Gradio ON the FastAPI app (not the other way around) app = gr.mount_gradio_app(app, demo, path="/") # Launch if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)