amfafa's picture
Update app.py
2dab870 verified
import os
import json
import math
import time
import uuid
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import random
import tempfile
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime, timedelta
if not hasattr(torchaudio, 'list_audio_backends'):
torchaudio.list_audio_backends = lambda: ["soundfile"]
from transformers import AutoModel
# Config
CKPT_PATH = 'aam_best.pt'
DB_PATH = 'voiceprint_db.json'
MODEL_NAME = 'microsoft/unispeech-sat-base-sv'
SAMPLE_RATE = 16000
MAX_SEC = 4
MAX_LEN = SAMPLE_RATE * MAX_SEC
THRESHOLD = 0.3500
DEVICE = torch.device('cpu')
NUM_CLEAN_SAMPLES = 6
NUM_NOISY_COPIES = 4
MAX_ATTEMPTS = 3
LOCKOUT_MINUTES = 5
COOLDOWN_SECONDS = 3
ANTISPOOFING_THRESHOLD = 0.02
# Challenge word pool (simple, short, easy to pronounce)
CHALLENGE_WORDS = [
'Red', 'Blue', 'Gold', 'Star', 'Water',
'Moon', 'Fire', 'Green', 'Black', 'White',
'Sun', 'Rain', 'Tree', 'Fish', 'Bird',
'Stone', 'Wind', 'Cloud', 'Light', 'Sound'
]
# Session steps
SESSION_STEPS = {
'STARTED': 'started',
'VERIFIED': 'verified',
'LIVENESS_PENDING': 'liveness_pending',
'AUTHENTICATED': 'authenticated',
'TRANSACTION_PENDING': 'transaction_pending',
'COMPLETE': 'complete',
'DENIED': 'denied'
}
# AAM-Softmax model
class AAMSoftmax(nn.Module):
def __init__(self, in_features, num_classes, margin=0.2, scale=30.0):
super().__init__()
self.margin = margin
self.scale = scale
self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
nn.init.xavier_uniform_(self.weight)
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.threshold = math.cos(math.pi - margin)
self.mm = math.sin(math.pi - margin) * margin
def forward(self, embeddings, labels=None):
embeddings = F.normalize(embeddings, p=2, dim=1)
weight = F.normalize(self.weight, p=2, dim=1)
cosine = F.linear(embeddings, weight)
if labels is None:
return cosine
sine = torch.sqrt(1.0 - torch.clamp(cosine * cosine, 0, 1))
phi = cosine * self.cos_m - sine * self.sin_m
phi = torch.where(cosine > self.threshold, phi, cosine - self.mm)
one_hot = F.one_hot(labels, cosine.size(1)).float()
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
return output * self.scale
class SpeakerClassifier(nn.Module):
def __init__(self, input_dim=768, hidden_dim=512, num_classes=227):
super().__init__()
self.fc1 = nn.Linear(input_dim, hidden_dim)
self.relu = nn.ReLU()
self.aam = AAMSoftmax(hidden_dim, num_classes)
def forward(self, x, labels=None):
x = self.relu(self.fc1(x))
return self.aam(x, labels)
def extract_embedding(self, x):
return self.relu(self.fc1(x))
# Load models
print("Loading UniSpeech-SAT base model...")
base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
base_model.eval()
for param in base_model.parameters():
param.requires_grad = False
print("Loading AAM-Softmax checkpoint...")
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
print(f"Checkpoint type: {type(ckpt)}")
if isinstance(ckpt, dict):
print(f"Checkpoint keys: {list(ckpt.keys())}")
num_classes = 227
if isinstance(ckpt, dict):
if 'num_classes' in ckpt:
num_classes = ckpt['num_classes']
elif 'num_speakers' in ckpt:
num_classes = ckpt['num_speakers']
classifier = SpeakerClassifier(input_dim=768, hidden_dim=512, num_classes=num_classes).to(DEVICE)
loaded = False
if isinstance(ckpt, dict):
for key in ['classifier_state', 'classifier_state_dict', 'model_state_dict', 'state_dict', 'model']:
if key in ckpt:
try:
classifier.load_state_dict(ckpt[key])
print(f"Loaded classifier from key: '{key}'")
loaded = True
break
except Exception as e:
print(f"Key '{key}' found but failed: {e}")
if not loaded:
sample_keys = list(ckpt.keys())[:5]
if any('.' in k for k in sample_keys):
try:
classifier.load_state_dict(ckpt)
print("Loaded classifier directly from checkpoint dict")
loaded = True
except:
try:
classifier.load_state_dict(ckpt, strict=False)
print("Loaded classifier with strict=False")
loaded = True
except Exception as e2:
print(f"Direct load failed: {e2}")
if 'base_model_state' in ckpt:
try:
base_model.load_state_dict(ckpt['base_model_state'], strict=False)
print("Loaded fine-tuned base model weights")
except:
pass
elif isinstance(ckpt, nn.Module):
classifier = ckpt.to(DEVICE)
print("Loaded classifier directly (model object)")
loaded = True
if not loaded:
print("WARNING: Could not load classifier weights. Using random init.")
classifier.eval()
print(f"Models ready. num_classes={num_classes}, loaded={loaded}")
# Database
def load_db():
if os.path.exists(DB_PATH):
with open(DB_PATH, 'r') as f:
return json.load(f)
return {}
def save_db(db):
with open(DB_PATH, 'w') as f:
json.dump(db, f, indent=2, default=str)
# Audio processing
def load_audio(audio_input):
if isinstance(audio_input, tuple):
sr, audio_np = audio_input
wav = torch.tensor(audio_np, dtype=torch.float32)
if wav.dim() == 1:
wav = wav.unsqueeze(0)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = wav.squeeze(0)
if wav.abs().max() > 1.0:
wav = wav / 32768.0
if sr != SAMPLE_RATE:
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
elif isinstance(audio_input, str):
wav, sr = torchaudio.load(audio_input)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = wav.squeeze(0)
if sr != SAMPLE_RATE:
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
elif isinstance(audio_input, bytes):
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tmp.write(audio_input)
tmp_path = tmp.name
wav, sr = torchaudio.load(tmp_path)
os.unlink(tmp_path)
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = wav.squeeze(0)
if sr != SAMPLE_RATE:
wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
else:
raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
if wav.shape[0] > MAX_LEN:
wav = wav[:MAX_LEN]
elif wav.shape[0] < MAX_LEN:
wav = F.pad(wav, (0, MAX_LEN - wav.shape[0]))
return wav
def extract_embedding(wav_tensor):
with torch.no_grad():
wav = wav_tensor.unsqueeze(0).to(DEVICE)
outputs = base_model(wav)
base_emb = outputs.last_hidden_state.mean(dim=1)
embedding = classifier.extract_embedding(base_emb)
embedding = F.normalize(embedding, p=2, dim=1)
return embedding.squeeze(0).cpu().numpy()
def add_noise(wav_tensor, noise_level=0.005):
noise = torch.randn_like(wav_tensor) * noise_level
return wav_tensor + noise
# Liveness detection
def check_liveness(wav_tensor):
wav_np = wav_tensor.numpy()
rms = np.sqrt(np.mean(wav_np ** 2))
if rms < 0.001:
return False, "Audio too quiet"
std = np.std(wav_np)
if std < 0.001:
return False, "Audio lacks variation"
zero_crossings = np.sum(np.abs(np.diff(np.sign(wav_np)))) / (2 * len(wav_np))
if zero_crossings < 0.01:
return False, "Abnormal audio pattern"
non_silent = np.abs(wav_np) > 0.01
speech_ratio = np.sum(non_silent) / len(wav_np)
if speech_ratio < 0.1:
return False, "Insufficient speech content"
return True, "Liveness check passed"
# Antispoofing
def check_antispoofing(wav_tensor):
wav_np = wav_tensor.numpy()
fft = np.fft.rfft(wav_np)
magnitude = np.abs(fft)
magnitude = magnitude[magnitude > 0]
if len(magnitude) == 0:
return False, "No frequency content"
geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10)))
arithmetic_mean = np.mean(magnitude)
spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10)
if spectral_flatness > (1.0 - ANTISPOOFING_THRESHOLD):
return False, "Possible synthetic audio"
frame_size = 1600
if len(wav_np) >= frame_size * 3:
frames = [wav_np[i:i + frame_size] for i in range(0, len(wav_np) - frame_size, frame_size)]
frame_energies = [np.sqrt(np.mean(f ** 2)) for f in frames]
energy_std = np.std(frame_energies)
if energy_std < 0.001:
return False, "Unnaturally uniform energy"
return True, "Antispoofing check passed"
# Security: lockout and cooldown
attempt_tracker = {}
def check_security(user_id):
now = datetime.now()
if user_id not in attempt_tracker:
return True, "OK"
tracker = attempt_tracker[user_id]
if "locked_until" in tracker and tracker["locked_until"]:
locked_until = datetime.fromisoformat(tracker["locked_until"])
if now < locked_until:
remaining = (locked_until - now).seconds
return False, f"Account locked. Try again in {remaining} seconds."
else:
tracker["count"] = 0
tracker["locked_until"] = None
if "last_attempt" in tracker and tracker["last_attempt"]:
last = datetime.fromisoformat(tracker["last_attempt"])
elapsed = (now - last).total_seconds()
if elapsed < COOLDOWN_SECONDS:
return False, f"Please wait {COOLDOWN_SECONDS - int(elapsed)} seconds."
return True, "OK"
def record_attempt(user_id, success):
now = datetime.now()
if user_id not in attempt_tracker:
attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
tracker = attempt_tracker[user_id]
tracker["last_attempt"] = now.isoformat()
if success:
tracker["count"] = 0
tracker["locked_until"] = None
else:
tracker["count"] += 1
if tracker["count"] >= MAX_ATTEMPTS:
tracker["locked_until"] = (now + timedelta(minutes=LOCKOUT_MINUTES)).isoformat()
# Generate random challenge (2 words from pool)
def generate_challenge():
words = random.sample(CHALLENGE_WORDS, 2)
return ' '.join(words)
# Session storage (in-memory)
sessions = {}
def create_session(user_id):
session_id = str(uuid.uuid4())
sessions[session_id] = {
"session_id": session_id,
"user_id": user_id.strip().upper(),
"step": SESSION_STEPS['STARTED'],
"challenge_phrase": None,
"full_name": None,
"similarity": None,
"created_at": datetime.now().isoformat(),
"expires_at": (datetime.now() + timedelta(minutes=5)).isoformat()
}
return sessions[session_id]
def get_session(session_id):
if session_id not in sessions:
return None
session = sessions[session_id]
if datetime.now() > datetime.fromisoformat(session["expires_at"]):
del sessions[session_id]
return None
return session
# Enroll
def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=NUM_CLEAN_SAMPLES):
if not user_id or not user_id.strip():
return "Error: User ID is required."
if not full_name or not full_name.strip():
return "Error: Full Name is required."
if audio_input is None:
return "Error: No audio recorded."
user_id = user_id.strip().upper()
full_name = full_name.strip()
try:
wav = load_audio(audio_input)
is_live, live_msg = check_liveness(wav)
if not is_live:
return f"Enrollment failed: {live_msg}"
is_real, spoof_msg = check_antispoofing(wav)
if not is_real:
return f"Enrollment failed: {spoof_msg}"
clean_emb = extract_embedding(wav)
noisy_embeddings = []
for i in range(NUM_NOISY_COPIES):
noise_level = 0.003 + (i * 0.002)
noisy_wav = add_noise(wav, noise_level)
noisy_emb = extract_embedding(noisy_wav)
noisy_embeddings.append(noisy_emb)
db = load_db()
if user_id not in db:
db[user_id] = {
"full_name": full_name,
"enrolled_at": datetime.now().isoformat(),
"sample_embeddings": [],
"voiceprint": None,
"status": "enrolling",
"samples_collected": 0
}
sample_data = {
"clean": clean_emb.tolist(),
"noisy": [e.tolist() for e in noisy_embeddings]
}
db[user_id]["sample_embeddings"].append(sample_data)
db[user_id]["samples_collected"] = len(db[user_id]["sample_embeddings"])
db[user_id]["full_name"] = full_name
samples_collected = db[user_id]["samples_collected"]
if samples_collected >= total_samples:
all_embeddings = []
for sample in db[user_id]["sample_embeddings"]:
all_embeddings.append(np.array(sample["clean"]))
for noisy in sample["noisy"]:
all_embeddings.append(np.array(noisy))
avg_embedding = np.mean(all_embeddings, axis=0)
avg_embedding = avg_embedding / (np.linalg.norm(avg_embedding) + 1e-10)
db[user_id]["voiceprint"] = avg_embedding.tolist()
db[user_id]["status"] = "enrolled"
db[user_id]["completed_at"] = datetime.now().isoformat()
db[user_id]["sample_embeddings"] = []
save_db(db)
return f"Enrollment COMPLETE for {full_name} ({user_id}). Voiceprint created from {total_samples} samples ({total_samples * (1 + NUM_NOISY_COPIES)} embeddings averaged)."
else:
save_db(db)
remaining = total_samples - samples_collected
return f"Sample {samples_collected}/{total_samples} recorded for {full_name}. {remaining} more sample(s) needed."
except Exception as e:
return f"Enrollment error: {str(e)}"
# Verify
def verify_speaker(audio_input, user_id):
if not user_id or not user_id.strip():
return "Error: User ID is required."
if audio_input is None:
return "Error: No audio recorded."
user_id = user_id.strip().upper()
allowed, sec_msg = check_security(user_id)
if not allowed:
return f"ACCESS DENIED: {sec_msg}"
db = load_db()
if user_id not in db:
return f"Error: User '{user_id}' not found."
if db[user_id].get("status") != "enrolled":
samples = db[user_id].get("samples_collected", 0)
remaining = NUM_CLEAN_SAMPLES - samples
return f"Error: Enrollment incomplete. {remaining} more sample(s) needed."
try:
wav = load_audio(audio_input)
is_live, live_msg = check_liveness(wav)
if not is_live:
record_attempt(user_id, False)
return f"ACCESS DENIED: {live_msg}"
is_real, spoof_msg = check_antispoofing(wav)
if not is_real:
record_attempt(user_id, False)
return f"ACCESS DENIED: {spoof_msg}"
test_emb = extract_embedding(wav)
stored_emb = np.array(db[user_id]["voiceprint"])
similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
if similarity >= THRESHOLD:
record_attempt(user_id, True)
full_name = db[user_id].get("full_name", user_id)
return (f"ACCESS GRANTED\nWelcome, {full_name}\n"
f"Confidence: {similarity:.4f} (threshold: {THRESHOLD})\n"
f"Liveness: Passed | Antispoofing: Passed")
else:
record_attempt(user_id, False)
tracker = attempt_tracker.get(user_id, {})
attempts_left = MAX_ATTEMPTS - tracker.get("count", 0)
msg = f"ACCESS DENIED\nVoice does not match.\nSimilarity: {similarity:.4f} (threshold: {THRESHOLD})\n"
if attempts_left > 0:
msg += f"Attempts remaining: {attempts_left}"
else:
msg += f"Account locked for {LOCKOUT_MINUTES} minutes."
return msg
except Exception as e:
return f"Verification error: {str(e)}"
# User management
def list_users():
db = load_db()
if not db:
return "No users enrolled yet."
lines = ["=== Enrolled Users ===\n"]
for uid, data in db.items():
name = data.get("full_name", "Unknown")
status = data.get("status", "unknown")
enrolled = data.get("enrolled_at", "N/A")
samples = data.get("samples_collected", 0)
lines.append(f"ID: {uid} | Name: {name} | Status: {status} | Samples: {samples} | Enrolled: {enrolled}")
return "\n".join(lines)
def delete_user(user_id):
if not user_id or not user_id.strip():
return "Error: User ID is required."
user_id = user_id.strip().upper()
db = load_db()
if user_id not in db:
return f"Error: User '{user_id}' not found."
name = db[user_id].get("full_name", user_id)
del db[user_id]
save_db(db)
if user_id in attempt_tracker:
del attempt_tracker[user_id]
return f"User '{name}' ({user_id}) deleted."
def reset_lockout(user_id):
if not user_id or not user_id.strip():
return "Error: User ID is required."
user_id = user_id.strip().upper()
if user_id in attempt_tracker:
attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
return f"Lockout reset for {user_id}."
return f"No lockout record for {user_id}."
# CREATE FASTAPI APP FIRST (before Gradio)
app = FastAPI(title="ATM Voice Authentication API", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health check
@app.get("/api/health")
async def health_check():
return {"status": "healthy", "model": "UniSpeech-SAT + AAM-Softmax", "threshold": THRESHOLD, "device": str(DEVICE), "timestamp": datetime.now().isoformat()}
# Basic enroll endpoint
@app.post("/api/enroll")
async def api_enroll(audio: UploadFile = File(...), user_id: str = Form(...), full_name: str = Form(...)):
try:
audio_bytes = await audio.read()
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
result = enroll_sample(tmp_path, user_id, full_name, 1)
os.unlink(tmp_path)
db = load_db()
uid = user_id.strip().upper()
samples_collected = db.get(uid, {}).get("samples_collected", 0)
is_complete = db.get(uid, {}).get("status") == "enrolled"
return JSONResponse(content={"success": "error" not in result.lower() and "failed" not in result.lower(), "message": result, "user_id": uid, "samples_collected": samples_collected if not is_complete else NUM_CLEAN_SAMPLES, "samples_required": NUM_CLEAN_SAMPLES, "enrollment_complete": is_complete})
except Exception as e:
return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
# Basic verify endpoint
@app.post("/api/verify")
async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
try:
audio_bytes = await audio.read()
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
uid = user_id.strip().upper()
allowed, sec_msg = check_security(uid)
if not allowed:
os.unlink(tmp_path)
return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": sec_msg, "locked": True})
db = load_db()
if uid not in db:
os.unlink(tmp_path)
return JSONResponse(content={"success": False, "message": f"User '{uid}' not found."})
if db[uid].get("status") != "enrolled":
os.unlink(tmp_path)
return JSONResponse(content={"success": False, "message": "Enrollment incomplete."})
wav = load_audio(tmp_path)
os.unlink(tmp_path)
is_live, live_msg = check_liveness(wav)
if not is_live:
record_attempt(uid, False)
return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": live_msg, "liveness_passed": False})
is_real, spoof_msg = check_antispoofing(wav)
if not is_real:
record_attempt(uid, False)
return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": spoof_msg, "antispoofing_passed": False})
test_emb = extract_embedding(wav)
stored_emb = np.array(db[uid]["voiceprint"])
similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
granted = similarity >= THRESHOLD
record_attempt(uid, granted)
tracker = attempt_tracker.get(uid, {})
attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0))
response = {"success": True, "access_granted": granted, "user_id": uid, "full_name": db[uid].get("full_name", uid), "similarity": round(similarity, 4), "threshold": THRESHOLD, "liveness_passed": True, "antispoofing_passed": True, "attempts_remaining": attempts_remaining if not granted else MAX_ATTEMPTS, "locked": attempts_remaining == 0 and not granted}
if granted:
response["message"] = "Access granted. Voice verified."
elif attempts_remaining > 0:
response["message"] = f"Voice does not match. {attempts_remaining} attempt(s) remaining."
else:
response["message"] = f"Account locked for {LOCKOUT_MINUTES} minutes."
return JSONResponse(content=response)
except Exception as e:
return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
# List users
@app.get("/api/users")
async def api_list_users():
db = load_db()
users = []
for uid, data in db.items():
users.append({"user_id": uid, "full_name": data.get("full_name", "Unknown"), "status": data.get("status", "unknown"), "samples_collected": data.get("samples_collected", 0), "enrolled_at": data.get("enrolled_at", None), "completed_at": data.get("completed_at", None)})
return JSONResponse(content={"success": True, "users": users, "total": len(users)})
# Delete user
@app.delete("/api/users/{user_id}")
async def api_delete_user(user_id: str):
result = delete_user(user_id)
success = "error" not in result.lower()
return JSONResponse(content={"success": success, "message": result})
# Reset lockout
@app.post("/api/reset-lockout")
async def api_reset_lockout(user_id: str = Form(...)):
result = reset_lockout(user_id)
return JSONResponse(content={"success": True, "message": result})
# Session: Start
@app.post("/api/session/start")
async def session_start(user_id: str = Form(...)):
uid = user_id.strip().upper()
db = load_db()
if uid not in db:
return JSONResponse(content={"success": False, "message": f"User '{uid}' not found. Please enroll first."})
if db[uid].get("status") != "enrolled":
return JSONResponse(content={"success": False, "message": "Enrollment incomplete."})
allowed, sec_msg = check_security(uid)
if not allowed:
return JSONResponse(content={"success": False, "message": sec_msg, "locked": True})
session = create_session(uid)
return JSONResponse(content={"success": True, "session_id": session["session_id"], "user_id": uid, "message": "Session started. Please provide a voice sample to verify your identity.", "next_step": "verify", "instruction": "Record your voice and send it to /api/session/verify"})
# Session: Verify identity
@app.post("/api/session/verify")
async def session_verify(audio: UploadFile = File(...), session_id: str = Form(...)):
session = get_session(session_id)
if not session:
return JSONResponse(content={"success": False, "message": "Session expired or not found."})
if session["step"] != SESSION_STEPS['STARTED']:
return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"})
uid = session["user_id"]
allowed, sec_msg = check_security(uid)
if not allowed:
session["step"] = SESSION_STEPS['DENIED']
return JSONResponse(content={"success": False, "message": sec_msg, "locked": True})
try:
audio_bytes = await audio.read()
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
wav = load_audio(tmp_path)
os.unlink(tmp_path)
is_live, live_msg = check_liveness(wav)
if not is_live:
record_attempt(uid, False)
return JSONResponse(content={"success": True, "verified": False, "message": live_msg})
is_real, spoof_msg = check_antispoofing(wav)
if not is_real:
record_attempt(uid, False)
return JSONResponse(content={"success": True, "verified": False, "message": spoof_msg})
test_emb = extract_embedding(wav)
db = load_db()
stored_emb = np.array(db[uid]["voiceprint"])
similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
if similarity >= THRESHOLD:
record_attempt(uid, True)
full_name = db[uid].get("full_name", uid)
challenge = generate_challenge()
session["step"] = SESSION_STEPS['LIVENESS_PENDING']
session["full_name"] = full_name
session["similarity"] = round(similarity, 4)
session["challenge_phrase"] = challenge
return JSONResponse(content={"success": True, "verified": True, "greeting": f"Welcome, {full_name}", "full_name": full_name, "similarity": round(similarity, 4), "next_step": "liveness", "challenge_phrase": challenge, "instruction": f"Say these words: {challenge}", "message": f"Voice verified. Welcome, {full_name}. For security, please say these words: {challenge}"})
else:
record_attempt(uid, False)
tracker = attempt_tracker.get(uid, {})
attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0))
locked = attempts_remaining == 0
if locked:
session["step"] = SESSION_STEPS['DENIED']
return JSONResponse(content={"success": True, "verified": False, "similarity": round(similarity, 4), "attempts_remaining": attempts_remaining, "locked": locked, "message": f"Voice does not match. {attempts_remaining} attempt(s) remaining." if not locked else f"Account locked for {LOCKOUT_MINUTES} minutes."})
except Exception as e:
return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
# Session: Liveness check
@app.post("/api/session/liveness")
async def session_liveness(audio: UploadFile = File(...), session_id: str = Form(...)):
session = get_session(session_id)
if not session:
return JSONResponse(content={"success": False, "message": "Session expired or not found."})
if session["step"] != SESSION_STEPS['LIVENESS_PENDING']:
return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"})
uid = session["user_id"]
try:
audio_bytes = await audio.read()
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
wav = load_audio(tmp_path)
os.unlink(tmp_path)
is_live, live_msg = check_liveness(wav)
if not is_live:
return JSONResponse(content={"success": True, "liveness_passed": False, "message": live_msg})
is_real, spoof_msg = check_antispoofing(wav)
if not is_real:
return JSONResponse(content={"success": True, "liveness_passed": False, "message": spoof_msg})
test_emb = extract_embedding(wav)
db = load_db()
stored_emb = np.array(db[uid]["voiceprint"])
similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
if similarity >= THRESHOLD:
session["step"] = SESSION_STEPS['AUTHENTICATED']
full_name = session["full_name"]
return JSONResponse(content={"success": True, "liveness_passed": True, "authenticated": True, "full_name": full_name, "similarity": round(similarity, 4), "next_step": "transaction", "instruction": "How much would you like to withdraw?", "message": f"Liveness confirmed. You are fully authenticated, {full_name}. How much would you like to withdraw?"})
else:
return JSONResponse(content={"success": True, "liveness_passed": False, "message": "Voice mismatch during liveness check. Please try again.", "challenge_phrase": session["challenge_phrase"], "instruction": f"Please say these words again: {session['challenge_phrase']}"})
except Exception as e:
return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
# Session: Transaction
@app.post("/api/session/transaction")
async def session_transaction(session_id: str = Form(...), amount: str = Form(...)):
session = get_session(session_id)
if not session:
return JSONResponse(content={"success": False, "message": "Session expired or not found."})
if session["step"] != SESSION_STEPS['AUTHENTICATED']:
return JSONResponse(content={"success": False, "message": f"Not authenticated. Current step: {session['step']}"})
full_name = session["full_name"]
session["step"] = SESSION_STEPS['COMPLETE']
return JSONResponse(content={"success": True, "transaction_approved": True, "full_name": full_name, "amount": amount, "message": f"Transaction approved. {full_name}, you are withdrawing {amount} cedis. Please collect your cash.", "instruction": "Transaction complete. Session ended.", "note": "In production, this step communicates with the banks core system to process the actual withdrawal."})
# Session status
@app.get("/api/session/{session_id}")
async def session_status(session_id: str):
session = get_session(session_id)
if not session:
return JSONResponse(content={"success": False, "message": "Session expired or not found."})
return JSONResponse(content={"success": True, "session_id": session["session_id"], "user_id": session["user_id"], "step": session["step"], "full_name": session["full_name"], "challenge_phrase": session["challenge_phrase"], "created_at": session["created_at"], "expires_at": session["expires_at"]})
# Gradio interface
with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# ATM Voice Authentication System
### Voice-Based Speaker Verification for Banking Security
Voice biometric authentication system for secure ATM access
""")
with gr.Tabs():
with gr.Tab("Enroll"):
gr.Markdown("### Enroll New User\nRecord **6 voice samples** to create your voiceprint. Speak naturally for 3-4 seconds each time.")
with gr.Row():
with gr.Column():
enroll_audio = gr.Audio(label="Record Voice Sample", sources=["microphone", "upload"], type="numpy")
enroll_user_id = gr.Textbox(label="User ID (e.g., ATM_001)", placeholder="ATM_001")
enroll_name = gr.Textbox(label="Full Name", placeholder="Jochebed Fafa")
enroll_sample_num = gr.Number(label="Sample Number (1-6)", value=1, minimum=1, maximum=6, step=1)
enroll_btn = gr.Button("Enroll Sample", variant="primary")
with gr.Column():
enroll_result = gr.Textbox(label="Result", lines=4, interactive=False)
enroll_btn.click(fn=enroll_sample, inputs=[enroll_audio, enroll_user_id, enroll_name, enroll_sample_num], outputs=enroll_result)
with gr.Tab("Verify"):
gr.Markdown("### Verify Identity\nRecord your voice to verify against your enrolled voiceprint.")
with gr.Row():
with gr.Column():
verify_audio = gr.Audio(label="Record Voice", sources=["microphone", "upload"], type="numpy")
verify_user_id = gr.Textbox(label="User ID", placeholder="ATM_001")
verify_btn = gr.Button("Verify", variant="primary")
with gr.Column():
verify_result = gr.Textbox(label="Result", lines=6, interactive=False)
verify_btn.click(fn=verify_speaker, inputs=[verify_audio, verify_user_id], outputs=verify_result)
with gr.Tab("Users"):
gr.Markdown("### Manage Enrolled Users")
list_btn = gr.Button("List All Users")
users_output = gr.Textbox(label="Enrolled Users", lines=10, interactive=False)
list_btn.click(fn=list_users, outputs=users_output)
gr.Markdown("---")
with gr.Row():
with gr.Column():
del_user_id = gr.Textbox(label="User ID to Delete", placeholder="ATM_001")
del_btn = gr.Button("Delete User", variant="stop")
del_result = gr.Textbox(label="Result", interactive=False)
del_btn.click(fn=delete_user, inputs=del_user_id, outputs=del_result)
with gr.Column():
reset_user_id = gr.Textbox(label="User ID to Reset Lockout", placeholder="ATM_001")
reset_btn = gr.Button("Reset Lockout", variant="secondary")
reset_result = gr.Textbox(label="Result", interactive=False)
reset_btn.click(fn=reset_lockout, inputs=reset_user_id, outputs=reset_result)
with gr.Tab("API Docs"):
gr.Markdown("""
### REST API Endpoints
**Base URL:** `https://amfafa-voice-authentication-sys.hf.space`
---
#### Basic Endpoints
- `POST /api/enroll` - Enroll a voice sample (audio, user_id, full_name)
- `POST /api/verify` - Verify a voice (audio, user_id)
- `GET /api/users` - List enrolled users
- `DELETE /api/users/{user_id}` - Delete a user
- `GET /api/health` - Health check
---
#### Session-Based Voice Authentication Flow
- `POST /api/session/start` - Start session (user_id)
- `POST /api/session/verify` - Verify identity (audio, session_id) - Returns greeting + challenge words
- `POST /api/session/liveness` - Liveness check (audio, session_id) - Returns authenticated or denied
- `POST /api/session/transaction` - Confirm transaction (amount, session_id)
- `GET /api/session/{session_id}` - Check session status
""")
# Mount Gradio ON the FastAPI app (not the other way around)
app = gr.mount_gradio_app(app, demo, path="/")
# Launch
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)