Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import math | |
| import time | |
| import uuid | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchaudio | |
| import numpy as np | |
| import random | |
| import tempfile | |
| import gradio as gr | |
| from fastapi import FastAPI, UploadFile, File, Form | |
| from fastapi.responses import JSONResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from datetime import datetime, timedelta | |
| if not hasattr(torchaudio, 'list_audio_backends'): | |
| torchaudio.list_audio_backends = lambda: ["soundfile"] | |
| from transformers import AutoModel | |
| # Config | |
| CKPT_PATH = 'aam_best.pt' | |
| DB_PATH = 'voiceprint_db.json' | |
| MODEL_NAME = 'microsoft/unispeech-sat-base-sv' | |
| SAMPLE_RATE = 16000 | |
| MAX_SEC = 4 | |
| MAX_LEN = SAMPLE_RATE * MAX_SEC | |
| THRESHOLD = 0.3500 | |
| DEVICE = torch.device('cpu') | |
| NUM_CLEAN_SAMPLES = 6 | |
| NUM_NOISY_COPIES = 4 | |
| MAX_ATTEMPTS = 3 | |
| LOCKOUT_MINUTES = 5 | |
| COOLDOWN_SECONDS = 3 | |
| ANTISPOOFING_THRESHOLD = 0.02 | |
| # Challenge word pool (simple, short, easy to pronounce) | |
| CHALLENGE_WORDS = [ | |
| 'Red', 'Blue', 'Gold', 'Star', 'Water', | |
| 'Moon', 'Fire', 'Green', 'Black', 'White', | |
| 'Sun', 'Rain', 'Tree', 'Fish', 'Bird', | |
| 'Stone', 'Wind', 'Cloud', 'Light', 'Sound' | |
| ] | |
| # Session steps | |
| SESSION_STEPS = { | |
| 'STARTED': 'started', | |
| 'VERIFIED': 'verified', | |
| 'LIVENESS_PENDING': 'liveness_pending', | |
| 'AUTHENTICATED': 'authenticated', | |
| 'TRANSACTION_PENDING': 'transaction_pending', | |
| 'COMPLETE': 'complete', | |
| 'DENIED': 'denied' | |
| } | |
| # AAM-Softmax model | |
| class AAMSoftmax(nn.Module): | |
| def __init__(self, in_features, num_classes, margin=0.2, scale=30.0): | |
| super().__init__() | |
| self.margin = margin | |
| self.scale = scale | |
| self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features)) | |
| nn.init.xavier_uniform_(self.weight) | |
| self.cos_m = math.cos(margin) | |
| self.sin_m = math.sin(margin) | |
| self.threshold = math.cos(math.pi - margin) | |
| self.mm = math.sin(math.pi - margin) * margin | |
| def forward(self, embeddings, labels=None): | |
| embeddings = F.normalize(embeddings, p=2, dim=1) | |
| weight = F.normalize(self.weight, p=2, dim=1) | |
| cosine = F.linear(embeddings, weight) | |
| if labels is None: | |
| return cosine | |
| sine = torch.sqrt(1.0 - torch.clamp(cosine * cosine, 0, 1)) | |
| phi = cosine * self.cos_m - sine * self.sin_m | |
| phi = torch.where(cosine > self.threshold, phi, cosine - self.mm) | |
| one_hot = F.one_hot(labels, cosine.size(1)).float() | |
| output = (one_hot * phi) + ((1.0 - one_hot) * cosine) | |
| return output * self.scale | |
| class SpeakerClassifier(nn.Module): | |
| def __init__(self, input_dim=768, hidden_dim=512, num_classes=227): | |
| super().__init__() | |
| self.fc1 = nn.Linear(input_dim, hidden_dim) | |
| self.relu = nn.ReLU() | |
| self.aam = AAMSoftmax(hidden_dim, num_classes) | |
| def forward(self, x, labels=None): | |
| x = self.relu(self.fc1(x)) | |
| return self.aam(x, labels) | |
| def extract_embedding(self, x): | |
| return self.relu(self.fc1(x)) | |
| # Load models | |
| print("Loading UniSpeech-SAT base model...") | |
| base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) | |
| base_model.eval() | |
| for param in base_model.parameters(): | |
| param.requires_grad = False | |
| print("Loading AAM-Softmax checkpoint...") | |
| ckpt = torch.load(CKPT_PATH, map_location=DEVICE) | |
| print(f"Checkpoint type: {type(ckpt)}") | |
| if isinstance(ckpt, dict): | |
| print(f"Checkpoint keys: {list(ckpt.keys())}") | |
| num_classes = 227 | |
| if isinstance(ckpt, dict): | |
| if 'num_classes' in ckpt: | |
| num_classes = ckpt['num_classes'] | |
| elif 'num_speakers' in ckpt: | |
| num_classes = ckpt['num_speakers'] | |
| classifier = SpeakerClassifier(input_dim=768, hidden_dim=512, num_classes=num_classes).to(DEVICE) | |
| loaded = False | |
| if isinstance(ckpt, dict): | |
| for key in ['classifier_state', 'classifier_state_dict', 'model_state_dict', 'state_dict', 'model']: | |
| if key in ckpt: | |
| try: | |
| classifier.load_state_dict(ckpt[key]) | |
| print(f"Loaded classifier from key: '{key}'") | |
| loaded = True | |
| break | |
| except Exception as e: | |
| print(f"Key '{key}' found but failed: {e}") | |
| if not loaded: | |
| sample_keys = list(ckpt.keys())[:5] | |
| if any('.' in k for k in sample_keys): | |
| try: | |
| classifier.load_state_dict(ckpt) | |
| print("Loaded classifier directly from checkpoint dict") | |
| loaded = True | |
| except: | |
| try: | |
| classifier.load_state_dict(ckpt, strict=False) | |
| print("Loaded classifier with strict=False") | |
| loaded = True | |
| except Exception as e2: | |
| print(f"Direct load failed: {e2}") | |
| if 'base_model_state' in ckpt: | |
| try: | |
| base_model.load_state_dict(ckpt['base_model_state'], strict=False) | |
| print("Loaded fine-tuned base model weights") | |
| except: | |
| pass | |
| elif isinstance(ckpt, nn.Module): | |
| classifier = ckpt.to(DEVICE) | |
| print("Loaded classifier directly (model object)") | |
| loaded = True | |
| if not loaded: | |
| print("WARNING: Could not load classifier weights. Using random init.") | |
| classifier.eval() | |
| print(f"Models ready. num_classes={num_classes}, loaded={loaded}") | |
| # Database | |
| def load_db(): | |
| if os.path.exists(DB_PATH): | |
| with open(DB_PATH, 'r') as f: | |
| return json.load(f) | |
| return {} | |
| def save_db(db): | |
| with open(DB_PATH, 'w') as f: | |
| json.dump(db, f, indent=2, default=str) | |
| # Audio processing | |
| def load_audio(audio_input): | |
| if isinstance(audio_input, tuple): | |
| sr, audio_np = audio_input | |
| wav = torch.tensor(audio_np, dtype=torch.float32) | |
| if wav.dim() == 1: | |
| wav = wav.unsqueeze(0) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| wav = wav.squeeze(0) | |
| if wav.abs().max() > 1.0: | |
| wav = wav / 32768.0 | |
| if sr != SAMPLE_RATE: | |
| wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) | |
| elif isinstance(audio_input, str): | |
| wav, sr = torchaudio.load(audio_input) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| wav = wav.squeeze(0) | |
| if sr != SAMPLE_RATE: | |
| wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) | |
| elif isinstance(audio_input, bytes): | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| tmp.write(audio_input) | |
| tmp_path = tmp.name | |
| wav, sr = torchaudio.load(tmp_path) | |
| os.unlink(tmp_path) | |
| if wav.shape[0] > 1: | |
| wav = wav.mean(dim=0, keepdim=True) | |
| wav = wav.squeeze(0) | |
| if sr != SAMPLE_RATE: | |
| wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE) | |
| else: | |
| raise ValueError(f"Unsupported audio input type: {type(audio_input)}") | |
| if wav.shape[0] > MAX_LEN: | |
| wav = wav[:MAX_LEN] | |
| elif wav.shape[0] < MAX_LEN: | |
| wav = F.pad(wav, (0, MAX_LEN - wav.shape[0])) | |
| return wav | |
| def extract_embedding(wav_tensor): | |
| with torch.no_grad(): | |
| wav = wav_tensor.unsqueeze(0).to(DEVICE) | |
| outputs = base_model(wav) | |
| base_emb = outputs.last_hidden_state.mean(dim=1) | |
| embedding = classifier.extract_embedding(base_emb) | |
| embedding = F.normalize(embedding, p=2, dim=1) | |
| return embedding.squeeze(0).cpu().numpy() | |
| def add_noise(wav_tensor, noise_level=0.005): | |
| noise = torch.randn_like(wav_tensor) * noise_level | |
| return wav_tensor + noise | |
| # Liveness detection | |
| def check_liveness(wav_tensor): | |
| wav_np = wav_tensor.numpy() | |
| rms = np.sqrt(np.mean(wav_np ** 2)) | |
| if rms < 0.001: | |
| return False, "Audio too quiet" | |
| std = np.std(wav_np) | |
| if std < 0.001: | |
| return False, "Audio lacks variation" | |
| zero_crossings = np.sum(np.abs(np.diff(np.sign(wav_np)))) / (2 * len(wav_np)) | |
| if zero_crossings < 0.01: | |
| return False, "Abnormal audio pattern" | |
| non_silent = np.abs(wav_np) > 0.01 | |
| speech_ratio = np.sum(non_silent) / len(wav_np) | |
| if speech_ratio < 0.1: | |
| return False, "Insufficient speech content" | |
| return True, "Liveness check passed" | |
| # Antispoofing | |
| def check_antispoofing(wav_tensor): | |
| wav_np = wav_tensor.numpy() | |
| fft = np.fft.rfft(wav_np) | |
| magnitude = np.abs(fft) | |
| magnitude = magnitude[magnitude > 0] | |
| if len(magnitude) == 0: | |
| return False, "No frequency content" | |
| geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10))) | |
| arithmetic_mean = np.mean(magnitude) | |
| spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10) | |
| if spectral_flatness > (1.0 - ANTISPOOFING_THRESHOLD): | |
| return False, "Possible synthetic audio" | |
| frame_size = 1600 | |
| if len(wav_np) >= frame_size * 3: | |
| frames = [wav_np[i:i + frame_size] for i in range(0, len(wav_np) - frame_size, frame_size)] | |
| frame_energies = [np.sqrt(np.mean(f ** 2)) for f in frames] | |
| energy_std = np.std(frame_energies) | |
| if energy_std < 0.001: | |
| return False, "Unnaturally uniform energy" | |
| return True, "Antispoofing check passed" | |
| # Security: lockout and cooldown | |
| attempt_tracker = {} | |
| def check_security(user_id): | |
| now = datetime.now() | |
| if user_id not in attempt_tracker: | |
| return True, "OK" | |
| tracker = attempt_tracker[user_id] | |
| if "locked_until" in tracker and tracker["locked_until"]: | |
| locked_until = datetime.fromisoformat(tracker["locked_until"]) | |
| if now < locked_until: | |
| remaining = (locked_until - now).seconds | |
| return False, f"Account locked. Try again in {remaining} seconds." | |
| else: | |
| tracker["count"] = 0 | |
| tracker["locked_until"] = None | |
| if "last_attempt" in tracker and tracker["last_attempt"]: | |
| last = datetime.fromisoformat(tracker["last_attempt"]) | |
| elapsed = (now - last).total_seconds() | |
| if elapsed < COOLDOWN_SECONDS: | |
| return False, f"Please wait {COOLDOWN_SECONDS - int(elapsed)} seconds." | |
| return True, "OK" | |
| def record_attempt(user_id, success): | |
| now = datetime.now() | |
| if user_id not in attempt_tracker: | |
| attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None} | |
| tracker = attempt_tracker[user_id] | |
| tracker["last_attempt"] = now.isoformat() | |
| if success: | |
| tracker["count"] = 0 | |
| tracker["locked_until"] = None | |
| else: | |
| tracker["count"] += 1 | |
| if tracker["count"] >= MAX_ATTEMPTS: | |
| tracker["locked_until"] = (now + timedelta(minutes=LOCKOUT_MINUTES)).isoformat() | |
| # Generate random challenge (2 words from pool) | |
| def generate_challenge(): | |
| words = random.sample(CHALLENGE_WORDS, 2) | |
| return ' '.join(words) | |
| # Session storage (in-memory) | |
| sessions = {} | |
| def create_session(user_id): | |
| session_id = str(uuid.uuid4()) | |
| sessions[session_id] = { | |
| "session_id": session_id, | |
| "user_id": user_id.strip().upper(), | |
| "step": SESSION_STEPS['STARTED'], | |
| "challenge_phrase": None, | |
| "full_name": None, | |
| "similarity": None, | |
| "created_at": datetime.now().isoformat(), | |
| "expires_at": (datetime.now() + timedelta(minutes=5)).isoformat() | |
| } | |
| return sessions[session_id] | |
| def get_session(session_id): | |
| if session_id not in sessions: | |
| return None | |
| session = sessions[session_id] | |
| if datetime.now() > datetime.fromisoformat(session["expires_at"]): | |
| del sessions[session_id] | |
| return None | |
| return session | |
| # Enroll | |
| def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=NUM_CLEAN_SAMPLES): | |
| if not user_id or not user_id.strip(): | |
| return "Error: User ID is required." | |
| if not full_name or not full_name.strip(): | |
| return "Error: Full Name is required." | |
| if audio_input is None: | |
| return "Error: No audio recorded." | |
| user_id = user_id.strip().upper() | |
| full_name = full_name.strip() | |
| try: | |
| wav = load_audio(audio_input) | |
| is_live, live_msg = check_liveness(wav) | |
| if not is_live: | |
| return f"Enrollment failed: {live_msg}" | |
| is_real, spoof_msg = check_antispoofing(wav) | |
| if not is_real: | |
| return f"Enrollment failed: {spoof_msg}" | |
| clean_emb = extract_embedding(wav) | |
| noisy_embeddings = [] | |
| for i in range(NUM_NOISY_COPIES): | |
| noise_level = 0.003 + (i * 0.002) | |
| noisy_wav = add_noise(wav, noise_level) | |
| noisy_emb = extract_embedding(noisy_wav) | |
| noisy_embeddings.append(noisy_emb) | |
| db = load_db() | |
| if user_id not in db: | |
| db[user_id] = { | |
| "full_name": full_name, | |
| "enrolled_at": datetime.now().isoformat(), | |
| "sample_embeddings": [], | |
| "voiceprint": None, | |
| "status": "enrolling", | |
| "samples_collected": 0 | |
| } | |
| sample_data = { | |
| "clean": clean_emb.tolist(), | |
| "noisy": [e.tolist() for e in noisy_embeddings] | |
| } | |
| db[user_id]["sample_embeddings"].append(sample_data) | |
| db[user_id]["samples_collected"] = len(db[user_id]["sample_embeddings"]) | |
| db[user_id]["full_name"] = full_name | |
| samples_collected = db[user_id]["samples_collected"] | |
| if samples_collected >= total_samples: | |
| all_embeddings = [] | |
| for sample in db[user_id]["sample_embeddings"]: | |
| all_embeddings.append(np.array(sample["clean"])) | |
| for noisy in sample["noisy"]: | |
| all_embeddings.append(np.array(noisy)) | |
| avg_embedding = np.mean(all_embeddings, axis=0) | |
| avg_embedding = avg_embedding / (np.linalg.norm(avg_embedding) + 1e-10) | |
| db[user_id]["voiceprint"] = avg_embedding.tolist() | |
| db[user_id]["status"] = "enrolled" | |
| db[user_id]["completed_at"] = datetime.now().isoformat() | |
| db[user_id]["sample_embeddings"] = [] | |
| save_db(db) | |
| return f"Enrollment COMPLETE for {full_name} ({user_id}). Voiceprint created from {total_samples} samples ({total_samples * (1 + NUM_NOISY_COPIES)} embeddings averaged)." | |
| else: | |
| save_db(db) | |
| remaining = total_samples - samples_collected | |
| return f"Sample {samples_collected}/{total_samples} recorded for {full_name}. {remaining} more sample(s) needed." | |
| except Exception as e: | |
| return f"Enrollment error: {str(e)}" | |
| # Verify | |
| def verify_speaker(audio_input, user_id): | |
| if not user_id or not user_id.strip(): | |
| return "Error: User ID is required." | |
| if audio_input is None: | |
| return "Error: No audio recorded." | |
| user_id = user_id.strip().upper() | |
| allowed, sec_msg = check_security(user_id) | |
| if not allowed: | |
| return f"ACCESS DENIED: {sec_msg}" | |
| db = load_db() | |
| if user_id not in db: | |
| return f"Error: User '{user_id}' not found." | |
| if db[user_id].get("status") != "enrolled": | |
| samples = db[user_id].get("samples_collected", 0) | |
| remaining = NUM_CLEAN_SAMPLES - samples | |
| return f"Error: Enrollment incomplete. {remaining} more sample(s) needed." | |
| try: | |
| wav = load_audio(audio_input) | |
| is_live, live_msg = check_liveness(wav) | |
| if not is_live: | |
| record_attempt(user_id, False) | |
| return f"ACCESS DENIED: {live_msg}" | |
| is_real, spoof_msg = check_antispoofing(wav) | |
| if not is_real: | |
| record_attempt(user_id, False) | |
| return f"ACCESS DENIED: {spoof_msg}" | |
| test_emb = extract_embedding(wav) | |
| stored_emb = np.array(db[user_id]["voiceprint"]) | |
| similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10)) | |
| if similarity >= THRESHOLD: | |
| record_attempt(user_id, True) | |
| full_name = db[user_id].get("full_name", user_id) | |
| return (f"ACCESS GRANTED\nWelcome, {full_name}\n" | |
| f"Confidence: {similarity:.4f} (threshold: {THRESHOLD})\n" | |
| f"Liveness: Passed | Antispoofing: Passed") | |
| else: | |
| record_attempt(user_id, False) | |
| tracker = attempt_tracker.get(user_id, {}) | |
| attempts_left = MAX_ATTEMPTS - tracker.get("count", 0) | |
| msg = f"ACCESS DENIED\nVoice does not match.\nSimilarity: {similarity:.4f} (threshold: {THRESHOLD})\n" | |
| if attempts_left > 0: | |
| msg += f"Attempts remaining: {attempts_left}" | |
| else: | |
| msg += f"Account locked for {LOCKOUT_MINUTES} minutes." | |
| return msg | |
| except Exception as e: | |
| return f"Verification error: {str(e)}" | |
| # User management | |
| def list_users(): | |
| db = load_db() | |
| if not db: | |
| return "No users enrolled yet." | |
| lines = ["=== Enrolled Users ===\n"] | |
| for uid, data in db.items(): | |
| name = data.get("full_name", "Unknown") | |
| status = data.get("status", "unknown") | |
| enrolled = data.get("enrolled_at", "N/A") | |
| samples = data.get("samples_collected", 0) | |
| lines.append(f"ID: {uid} | Name: {name} | Status: {status} | Samples: {samples} | Enrolled: {enrolled}") | |
| return "\n".join(lines) | |
| def delete_user(user_id): | |
| if not user_id or not user_id.strip(): | |
| return "Error: User ID is required." | |
| user_id = user_id.strip().upper() | |
| db = load_db() | |
| if user_id not in db: | |
| return f"Error: User '{user_id}' not found." | |
| name = db[user_id].get("full_name", user_id) | |
| del db[user_id] | |
| save_db(db) | |
| if user_id in attempt_tracker: | |
| del attempt_tracker[user_id] | |
| return f"User '{name}' ({user_id}) deleted." | |
| def reset_lockout(user_id): | |
| if not user_id or not user_id.strip(): | |
| return "Error: User ID is required." | |
| user_id = user_id.strip().upper() | |
| if user_id in attempt_tracker: | |
| attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None} | |
| return f"Lockout reset for {user_id}." | |
| return f"No lockout record for {user_id}." | |
| # CREATE FASTAPI APP FIRST (before Gradio) | |
| app = FastAPI(title="ATM Voice Authentication API", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Health check | |
| async def health_check(): | |
| return {"status": "healthy", "model": "UniSpeech-SAT + AAM-Softmax", "threshold": THRESHOLD, "device": str(DEVICE), "timestamp": datetime.now().isoformat()} | |
| # Basic enroll endpoint | |
| async def api_enroll(audio: UploadFile = File(...), user_id: str = Form(...), full_name: str = Form(...)): | |
| try: | |
| audio_bytes = await audio.read() | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| tmp.write(audio_bytes) | |
| tmp_path = tmp.name | |
| result = enroll_sample(tmp_path, user_id, full_name, 1) | |
| os.unlink(tmp_path) | |
| db = load_db() | |
| uid = user_id.strip().upper() | |
| samples_collected = db.get(uid, {}).get("samples_collected", 0) | |
| is_complete = db.get(uid, {}).get("status") == "enrolled" | |
| return JSONResponse(content={"success": "error" not in result.lower() and "failed" not in result.lower(), "message": result, "user_id": uid, "samples_collected": samples_collected if not is_complete else NUM_CLEAN_SAMPLES, "samples_required": NUM_CLEAN_SAMPLES, "enrollment_complete": is_complete}) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"}) | |
| # Basic verify endpoint | |
| async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)): | |
| try: | |
| audio_bytes = await audio.read() | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| tmp.write(audio_bytes) | |
| tmp_path = tmp.name | |
| uid = user_id.strip().upper() | |
| allowed, sec_msg = check_security(uid) | |
| if not allowed: | |
| os.unlink(tmp_path) | |
| return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": sec_msg, "locked": True}) | |
| db = load_db() | |
| if uid not in db: | |
| os.unlink(tmp_path) | |
| return JSONResponse(content={"success": False, "message": f"User '{uid}' not found."}) | |
| if db[uid].get("status") != "enrolled": | |
| os.unlink(tmp_path) | |
| return JSONResponse(content={"success": False, "message": "Enrollment incomplete."}) | |
| wav = load_audio(tmp_path) | |
| os.unlink(tmp_path) | |
| is_live, live_msg = check_liveness(wav) | |
| if not is_live: | |
| record_attempt(uid, False) | |
| return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": live_msg, "liveness_passed": False}) | |
| is_real, spoof_msg = check_antispoofing(wav) | |
| if not is_real: | |
| record_attempt(uid, False) | |
| return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": spoof_msg, "antispoofing_passed": False}) | |
| test_emb = extract_embedding(wav) | |
| stored_emb = np.array(db[uid]["voiceprint"]) | |
| similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10)) | |
| granted = similarity >= THRESHOLD | |
| record_attempt(uid, granted) | |
| tracker = attempt_tracker.get(uid, {}) | |
| attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0)) | |
| response = {"success": True, "access_granted": granted, "user_id": uid, "full_name": db[uid].get("full_name", uid), "similarity": round(similarity, 4), "threshold": THRESHOLD, "liveness_passed": True, "antispoofing_passed": True, "attempts_remaining": attempts_remaining if not granted else MAX_ATTEMPTS, "locked": attempts_remaining == 0 and not granted} | |
| if granted: | |
| response["message"] = "Access granted. Voice verified." | |
| elif attempts_remaining > 0: | |
| response["message"] = f"Voice does not match. {attempts_remaining} attempt(s) remaining." | |
| else: | |
| response["message"] = f"Account locked for {LOCKOUT_MINUTES} minutes." | |
| return JSONResponse(content=response) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"}) | |
| # List users | |
| async def api_list_users(): | |
| db = load_db() | |
| users = [] | |
| for uid, data in db.items(): | |
| users.append({"user_id": uid, "full_name": data.get("full_name", "Unknown"), "status": data.get("status", "unknown"), "samples_collected": data.get("samples_collected", 0), "enrolled_at": data.get("enrolled_at", None), "completed_at": data.get("completed_at", None)}) | |
| return JSONResponse(content={"success": True, "users": users, "total": len(users)}) | |
| # Delete user | |
| async def api_delete_user(user_id: str): | |
| result = delete_user(user_id) | |
| success = "error" not in result.lower() | |
| return JSONResponse(content={"success": success, "message": result}) | |
| # Reset lockout | |
| async def api_reset_lockout(user_id: str = Form(...)): | |
| result = reset_lockout(user_id) | |
| return JSONResponse(content={"success": True, "message": result}) | |
| # Session: Start | |
| async def session_start(user_id: str = Form(...)): | |
| uid = user_id.strip().upper() | |
| db = load_db() | |
| if uid not in db: | |
| return JSONResponse(content={"success": False, "message": f"User '{uid}' not found. Please enroll first."}) | |
| if db[uid].get("status") != "enrolled": | |
| return JSONResponse(content={"success": False, "message": "Enrollment incomplete."}) | |
| allowed, sec_msg = check_security(uid) | |
| if not allowed: | |
| return JSONResponse(content={"success": False, "message": sec_msg, "locked": True}) | |
| session = create_session(uid) | |
| return JSONResponse(content={"success": True, "session_id": session["session_id"], "user_id": uid, "message": "Session started. Please provide a voice sample to verify your identity.", "next_step": "verify", "instruction": "Record your voice and send it to /api/session/verify"}) | |
| # Session: Verify identity | |
| async def session_verify(audio: UploadFile = File(...), session_id: str = Form(...)): | |
| session = get_session(session_id) | |
| if not session: | |
| return JSONResponse(content={"success": False, "message": "Session expired or not found."}) | |
| if session["step"] != SESSION_STEPS['STARTED']: | |
| return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"}) | |
| uid = session["user_id"] | |
| allowed, sec_msg = check_security(uid) | |
| if not allowed: | |
| session["step"] = SESSION_STEPS['DENIED'] | |
| return JSONResponse(content={"success": False, "message": sec_msg, "locked": True}) | |
| try: | |
| audio_bytes = await audio.read() | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| tmp.write(audio_bytes) | |
| tmp_path = tmp.name | |
| wav = load_audio(tmp_path) | |
| os.unlink(tmp_path) | |
| is_live, live_msg = check_liveness(wav) | |
| if not is_live: | |
| record_attempt(uid, False) | |
| return JSONResponse(content={"success": True, "verified": False, "message": live_msg}) | |
| is_real, spoof_msg = check_antispoofing(wav) | |
| if not is_real: | |
| record_attempt(uid, False) | |
| return JSONResponse(content={"success": True, "verified": False, "message": spoof_msg}) | |
| test_emb = extract_embedding(wav) | |
| db = load_db() | |
| stored_emb = np.array(db[uid]["voiceprint"]) | |
| similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10)) | |
| if similarity >= THRESHOLD: | |
| record_attempt(uid, True) | |
| full_name = db[uid].get("full_name", uid) | |
| challenge = generate_challenge() | |
| session["step"] = SESSION_STEPS['LIVENESS_PENDING'] | |
| session["full_name"] = full_name | |
| session["similarity"] = round(similarity, 4) | |
| session["challenge_phrase"] = challenge | |
| return JSONResponse(content={"success": True, "verified": True, "greeting": f"Welcome, {full_name}", "full_name": full_name, "similarity": round(similarity, 4), "next_step": "liveness", "challenge_phrase": challenge, "instruction": f"Say these words: {challenge}", "message": f"Voice verified. Welcome, {full_name}. For security, please say these words: {challenge}"}) | |
| else: | |
| record_attempt(uid, False) | |
| tracker = attempt_tracker.get(uid, {}) | |
| attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0)) | |
| locked = attempts_remaining == 0 | |
| if locked: | |
| session["step"] = SESSION_STEPS['DENIED'] | |
| return JSONResponse(content={"success": True, "verified": False, "similarity": round(similarity, 4), "attempts_remaining": attempts_remaining, "locked": locked, "message": f"Voice does not match. {attempts_remaining} attempt(s) remaining." if not locked else f"Account locked for {LOCKOUT_MINUTES} minutes."}) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"}) | |
| # Session: Liveness check | |
| async def session_liveness(audio: UploadFile = File(...), session_id: str = Form(...)): | |
| session = get_session(session_id) | |
| if not session: | |
| return JSONResponse(content={"success": False, "message": "Session expired or not found."}) | |
| if session["step"] != SESSION_STEPS['LIVENESS_PENDING']: | |
| return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"}) | |
| uid = session["user_id"] | |
| try: | |
| audio_bytes = await audio.read() | |
| with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp: | |
| tmp.write(audio_bytes) | |
| tmp_path = tmp.name | |
| wav = load_audio(tmp_path) | |
| os.unlink(tmp_path) | |
| is_live, live_msg = check_liveness(wav) | |
| if not is_live: | |
| return JSONResponse(content={"success": True, "liveness_passed": False, "message": live_msg}) | |
| is_real, spoof_msg = check_antispoofing(wav) | |
| if not is_real: | |
| return JSONResponse(content={"success": True, "liveness_passed": False, "message": spoof_msg}) | |
| test_emb = extract_embedding(wav) | |
| db = load_db() | |
| stored_emb = np.array(db[uid]["voiceprint"]) | |
| similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10)) | |
| if similarity >= THRESHOLD: | |
| session["step"] = SESSION_STEPS['AUTHENTICATED'] | |
| full_name = session["full_name"] | |
| return JSONResponse(content={"success": True, "liveness_passed": True, "authenticated": True, "full_name": full_name, "similarity": round(similarity, 4), "next_step": "transaction", "instruction": "How much would you like to withdraw?", "message": f"Liveness confirmed. You are fully authenticated, {full_name}. How much would you like to withdraw?"}) | |
| else: | |
| return JSONResponse(content={"success": True, "liveness_passed": False, "message": "Voice mismatch during liveness check. Please try again.", "challenge_phrase": session["challenge_phrase"], "instruction": f"Please say these words again: {session['challenge_phrase']}"}) | |
| except Exception as e: | |
| return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"}) | |
| # Session: Transaction | |
| async def session_transaction(session_id: str = Form(...), amount: str = Form(...)): | |
| session = get_session(session_id) | |
| if not session: | |
| return JSONResponse(content={"success": False, "message": "Session expired or not found."}) | |
| if session["step"] != SESSION_STEPS['AUTHENTICATED']: | |
| return JSONResponse(content={"success": False, "message": f"Not authenticated. Current step: {session['step']}"}) | |
| full_name = session["full_name"] | |
| session["step"] = SESSION_STEPS['COMPLETE'] | |
| return JSONResponse(content={"success": True, "transaction_approved": True, "full_name": full_name, "amount": amount, "message": f"Transaction approved. {full_name}, you are withdrawing {amount} cedis. Please collect your cash.", "instruction": "Transaction complete. Session ended.", "note": "In production, this step communicates with the banks core system to process the actual withdrawal."}) | |
| # Session status | |
| async def session_status(session_id: str): | |
| session = get_session(session_id) | |
| if not session: | |
| return JSONResponse(content={"success": False, "message": "Session expired or not found."}) | |
| return JSONResponse(content={"success": True, "session_id": session["session_id"], "user_id": session["user_id"], "step": session["step"], "full_name": session["full_name"], "challenge_phrase": session["challenge_phrase"], "created_at": session["created_at"], "expires_at": session["expires_at"]}) | |
| # Gradio interface | |
| with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # ATM Voice Authentication System | |
| ### Voice-Based Speaker Verification for Banking Security | |
| Voice biometric authentication system for secure ATM access | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("Enroll"): | |
| gr.Markdown("### Enroll New User\nRecord **6 voice samples** to create your voiceprint. Speak naturally for 3-4 seconds each time.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| enroll_audio = gr.Audio(label="Record Voice Sample", sources=["microphone", "upload"], type="numpy") | |
| enroll_user_id = gr.Textbox(label="User ID (e.g., ATM_001)", placeholder="ATM_001") | |
| enroll_name = gr.Textbox(label="Full Name", placeholder="Jochebed Fafa") | |
| enroll_sample_num = gr.Number(label="Sample Number (1-6)", value=1, minimum=1, maximum=6, step=1) | |
| enroll_btn = gr.Button("Enroll Sample", variant="primary") | |
| with gr.Column(): | |
| enroll_result = gr.Textbox(label="Result", lines=4, interactive=False) | |
| enroll_btn.click(fn=enroll_sample, inputs=[enroll_audio, enroll_user_id, enroll_name, enroll_sample_num], outputs=enroll_result) | |
| with gr.Tab("Verify"): | |
| gr.Markdown("### Verify Identity\nRecord your voice to verify against your enrolled voiceprint.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| verify_audio = gr.Audio(label="Record Voice", sources=["microphone", "upload"], type="numpy") | |
| verify_user_id = gr.Textbox(label="User ID", placeholder="ATM_001") | |
| verify_btn = gr.Button("Verify", variant="primary") | |
| with gr.Column(): | |
| verify_result = gr.Textbox(label="Result", lines=6, interactive=False) | |
| verify_btn.click(fn=verify_speaker, inputs=[verify_audio, verify_user_id], outputs=verify_result) | |
| with gr.Tab("Users"): | |
| gr.Markdown("### Manage Enrolled Users") | |
| list_btn = gr.Button("List All Users") | |
| users_output = gr.Textbox(label="Enrolled Users", lines=10, interactive=False) | |
| list_btn.click(fn=list_users, outputs=users_output) | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| with gr.Column(): | |
| del_user_id = gr.Textbox(label="User ID to Delete", placeholder="ATM_001") | |
| del_btn = gr.Button("Delete User", variant="stop") | |
| del_result = gr.Textbox(label="Result", interactive=False) | |
| del_btn.click(fn=delete_user, inputs=del_user_id, outputs=del_result) | |
| with gr.Column(): | |
| reset_user_id = gr.Textbox(label="User ID to Reset Lockout", placeholder="ATM_001") | |
| reset_btn = gr.Button("Reset Lockout", variant="secondary") | |
| reset_result = gr.Textbox(label="Result", interactive=False) | |
| reset_btn.click(fn=reset_lockout, inputs=reset_user_id, outputs=reset_result) | |
| with gr.Tab("API Docs"): | |
| gr.Markdown(""" | |
| ### REST API Endpoints | |
| **Base URL:** `https://amfafa-voice-authentication-sys.hf.space` | |
| --- | |
| #### Basic Endpoints | |
| - `POST /api/enroll` - Enroll a voice sample (audio, user_id, full_name) | |
| - `POST /api/verify` - Verify a voice (audio, user_id) | |
| - `GET /api/users` - List enrolled users | |
| - `DELETE /api/users/{user_id}` - Delete a user | |
| - `GET /api/health` - Health check | |
| --- | |
| #### Session-Based Voice Authentication Flow | |
| - `POST /api/session/start` - Start session (user_id) | |
| - `POST /api/session/verify` - Verify identity (audio, session_id) - Returns greeting + challenge words | |
| - `POST /api/session/liveness` - Liveness check (audio, session_id) - Returns authenticated or denied | |
| - `POST /api/session/transaction` - Confirm transaction (amount, session_id) | |
| - `GET /api/session/{session_id}` - Check session status | |
| """) | |
| # Mount Gradio ON the FastAPI app (not the other way around) | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| # Launch | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |