amfafa commited on
Commit
a8a3da6
·
verified ·
1 Parent(s): 6a1ecee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +344 -157
app.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import json
3
  import math
4
  import time
 
5
  import torch
6
  import torch.nn as nn
7
  import torch.nn.functional as F
@@ -17,9 +18,7 @@ if not hasattr(torchaudio, 'list_audio_backends'):
17
 
18
  from transformers import AutoModel
19
 
20
-
21
- # CONFIGURATION
22
-
23
  CKPT_PATH = 'aam_best.pt'
24
  DB_PATH = 'voiceprint_db.json'
25
  MODEL_NAME = 'microsoft/unispeech-sat-base-sv'
@@ -35,10 +34,27 @@ LOCKOUT_MINUTES = 5
35
  COOLDOWN_SECONDS = 3
36
  ANTISPOOFING_THRESHOLD = 0.02
37
 
38
-
39
-
40
- # AAM-SOFTMAX MODEL
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  class AAMSoftmax(nn.Module):
43
  def __init__(self, in_features, num_classes, margin=0.2, scale=30.0):
44
  super().__init__()
@@ -80,9 +96,7 @@ class SpeakerClassifier(nn.Module):
80
  return self.relu(self.fc1(x))
81
 
82
 
83
-
84
- # LOAD MODELS
85
-
86
  print("Loading UniSpeech-SAT base model...")
87
  base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
88
  base_model.eval()
@@ -91,13 +105,10 @@ for param in base_model.parameters():
91
 
92
  print("Loading AAM-Softmax checkpoint...")
93
  ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
94
-
95
- # Auto-detect checkpoint format
96
  print(f"Checkpoint type: {type(ckpt)}")
97
  if isinstance(ckpt, dict):
98
  print(f"Checkpoint keys: {list(ckpt.keys())}")
99
 
100
- # Detect num_classes from checkpoint
101
  num_classes = 227
102
  if isinstance(ckpt, dict):
103
  if 'num_classes' in ckpt:
@@ -105,13 +116,10 @@ if isinstance(ckpt, dict):
105
  elif 'num_speakers' in ckpt:
106
  num_classes = ckpt['num_speakers']
107
 
108
- # Build classifier
109
  classifier = SpeakerClassifier(input_dim=768, hidden_dim=512, num_classes=num_classes).to(DEVICE)
110
 
111
- # Load weights - try every possible key format
112
  loaded = False
113
  if isinstance(ckpt, dict):
114
- # Try common key names for classifier state
115
  for key in ['classifier_state', 'classifier_state_dict', 'model_state_dict', 'state_dict', 'model']:
116
  if key in ckpt:
117
  try:
@@ -120,53 +128,42 @@ if isinstance(ckpt, dict):
120
  loaded = True
121
  break
122
  except Exception as e:
123
- print(f"Key '{key}' found but failed to load: {e}")
124
 
125
- # If no named key worked, try loading the dict directly (maybe ckpt IS the state_dict)
126
  if not loaded:
127
- # Check if the keys look like model parameters (contain dots like 'fc1.weight')
128
  sample_keys = list(ckpt.keys())[:5]
129
- looks_like_state_dict = any('.' in k for k in sample_keys)
130
- if looks_like_state_dict:
131
  try:
132
  classifier.load_state_dict(ckpt)
133
- print("Loaded classifier directly from checkpoint dict (it IS the state_dict)")
134
  loaded = True
135
- except Exception as e:
136
- print(f"Direct load failed: {e}")
137
- # Try with strict=False
138
  try:
139
  classifier.load_state_dict(ckpt, strict=False)
140
  print("Loaded classifier with strict=False")
141
  loaded = True
142
  except Exception as e2:
143
- print(f"Strict=False also failed: {e2}")
144
 
145
- # Try loading base_model state too if present
146
  if 'base_model_state' in ckpt:
147
  try:
148
  base_model.load_state_dict(ckpt['base_model_state'], strict=False)
149
- print("Also loaded fine-tuned base model weights")
150
- except Exception as e:
151
- print(f"Base model load skipped: {e}")
152
-
153
  elif isinstance(ckpt, nn.Module):
154
- # Checkpoint is the model itself
155
  classifier = ckpt.to(DEVICE)
156
- print("Loaded classifier directly (checkpoint is model object)")
157
  loaded = True
158
 
159
  if not loaded:
160
- print("WARNING: Could not load classifier weights. Using random initialization.")
161
- print("The system will still run but verification accuracy will be poor.")
162
 
163
  classifier.eval()
164
  print(f"Models ready. num_classes={num_classes}, loaded={loaded}")
165
 
166
 
167
-
168
- # DATABASE
169
-
170
  def load_db():
171
  if os.path.exists(DB_PATH):
172
  with open(DB_PATH, 'r') as f:
@@ -178,9 +175,7 @@ def save_db(db):
178
  json.dump(db, f, indent=2, default=str)
179
 
180
 
181
-
182
- # AUDIO PROCESSING
183
-
184
  def load_audio(audio_input):
185
  if isinstance(audio_input, tuple):
186
  sr, audio_np = audio_input
@@ -235,54 +230,49 @@ def add_noise(wav_tensor, noise_level=0.005):
235
  return wav_tensor + noise
236
 
237
 
238
- # LIVENESS DETECTION
239
-
240
  def check_liveness(wav_tensor):
241
  wav_np = wav_tensor.numpy()
242
  rms = np.sqrt(np.mean(wav_np ** 2))
243
  if rms < 0.001:
244
- return False, "Audio too quiet — possible silence or empty recording"
245
  std = np.std(wav_np)
246
  if std < 0.001:
247
- return False, "Audio lacks variation — possible synthetic tone"
248
  zero_crossings = np.sum(np.abs(np.diff(np.sign(wav_np)))) / (2 * len(wav_np))
249
  if zero_crossings < 0.01:
250
- return False, "Abnormal audio pattern — possible replay attack"
251
  non_silent = np.abs(wav_np) > 0.01
252
  speech_ratio = np.sum(non_silent) / len(wav_np)
253
  if speech_ratio < 0.1:
254
- return False, "Insufficient speech content detected"
255
  return True, "Liveness check passed"
256
 
257
 
258
-
259
- # ANTISPOOFING
260
-
261
  def check_antispoofing(wav_tensor):
262
  wav_np = wav_tensor.numpy()
263
  fft = np.fft.rfft(wav_np)
264
  magnitude = np.abs(fft)
265
  magnitude = magnitude[magnitude > 0]
266
  if len(magnitude) == 0:
267
- return False, "No frequency content detected"
268
  geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10)))
269
  arithmetic_mean = np.mean(magnitude)
270
  spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10)
271
  if spectral_flatness > (1.0 - ANTISPOOFING_THRESHOLD):
272
- return False, f"Spectral flatness too high ({spectral_flatness:.4f}) — possible synthetic audio"
273
  frame_size = 1600
274
  if len(wav_np) >= frame_size * 3:
275
  frames = [wav_np[i:i + frame_size] for i in range(0, len(wav_np) - frame_size, frame_size)]
276
  frame_energies = [np.sqrt(np.mean(f ** 2)) for f in frames]
277
  energy_std = np.std(frame_energies)
278
  if energy_std < 0.001:
279
- return False, "Unnaturally uniform energy — possible synthetic audio"
280
  return True, "Antispoofing check passed"
281
 
282
 
283
-
284
- # SECURITY: LOCKOUT & COOLDOWN
285
-
286
  attempt_tracker = {}
287
 
288
  def check_security(user_id):
@@ -302,7 +292,7 @@ def check_security(user_id):
302
  last = datetime.fromisoformat(tracker["last_attempt"])
303
  elapsed = (now - last).total_seconds()
304
  if elapsed < COOLDOWN_SECONDS:
305
- return False, f"Please wait {COOLDOWN_SECONDS - int(elapsed)} seconds before trying again."
306
  return True, "OK"
307
 
308
  def record_attempt(user_id, success):
@@ -320,33 +310,61 @@ def record_attempt(user_id, success):
320
  tracker["locked_until"] = (now + timedelta(minutes=LOCKOUT_MINUTES)).isoformat()
321
 
322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
- # ENROLL
325
 
 
326
  def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=NUM_CLEAN_SAMPLES):
327
  if not user_id or not user_id.strip():
328
  return "Error: User ID is required."
329
  if not full_name or not full_name.strip():
330
  return "Error: Full Name is required."
331
  if audio_input is None:
332
- return "Error: No audio recorded. Please record your voice."
333
 
334
  user_id = user_id.strip().upper()
335
  full_name = full_name.strip()
336
 
337
  try:
338
  wav = load_audio(audio_input)
339
-
340
  is_live, live_msg = check_liveness(wav)
341
  if not is_live:
342
  return f"Enrollment failed: {live_msg}"
343
-
344
  is_real, spoof_msg = check_antispoofing(wav)
345
  if not is_real:
346
  return f"Enrollment failed: {spoof_msg}"
347
 
348
  clean_emb = extract_embedding(wav)
349
-
350
  noisy_embeddings = []
351
  for i in range(NUM_NOISY_COPIES):
352
  noise_level = 0.003 + (i * 0.002)
@@ -355,7 +373,6 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
355
  noisy_embeddings.append(noisy_emb)
356
 
357
  db = load_db()
358
-
359
  if user_id not in db:
360
  db[user_id] = {
361
  "full_name": full_name,
@@ -373,7 +390,6 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
373
  db[user_id]["sample_embeddings"].append(sample_data)
374
  db[user_id]["samples_collected"] = len(db[user_id]["sample_embeddings"])
375
  db[user_id]["full_name"] = full_name
376
-
377
  samples_collected = db[user_id]["samples_collected"]
378
 
379
  if samples_collected >= total_samples:
@@ -382,44 +398,37 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
382
  all_embeddings.append(np.array(sample["clean"]))
383
  for noisy in sample["noisy"]:
384
  all_embeddings.append(np.array(noisy))
385
-
386
  avg_embedding = np.mean(all_embeddings, axis=0)
387
  avg_embedding = avg_embedding / (np.linalg.norm(avg_embedding) + 1e-10)
388
-
389
  db[user_id]["voiceprint"] = avg_embedding.tolist()
390
  db[user_id]["status"] = "enrolled"
391
  db[user_id]["completed_at"] = datetime.now().isoformat()
392
  db[user_id]["sample_embeddings"] = []
393
-
394
  save_db(db)
395
  return f"Enrollment COMPLETE for {full_name} ({user_id}). Voiceprint created from {total_samples} samples ({total_samples * (1 + NUM_NOISY_COPIES)} embeddings averaged)."
396
  else:
397
  save_db(db)
398
  remaining = total_samples - samples_collected
399
  return f"Sample {samples_collected}/{total_samples} recorded for {full_name}. {remaining} more sample(s) needed."
400
-
401
  except Exception as e:
402
  return f"Enrollment error: {str(e)}"
403
 
404
 
405
- # VERIFY
406
-
407
  def verify_speaker(audio_input, user_id):
408
  if not user_id or not user_id.strip():
409
  return "Error: User ID is required."
410
  if audio_input is None:
411
- return "Error: No audio recorded. Please speak into the microphone."
412
 
413
  user_id = user_id.strip().upper()
414
-
415
  allowed, sec_msg = check_security(user_id)
416
  if not allowed:
417
  return f"ACCESS DENIED: {sec_msg}"
418
 
419
  db = load_db()
420
  if user_id not in db:
421
- return f"Error: User '{user_id}' not found. Please enroll first."
422
-
423
  if db[user_id].get("status") != "enrolled":
424
  samples = db[user_id].get("samples_collected", 0)
425
  remaining = NUM_CLEAN_SAMPLES - samples
@@ -427,12 +436,10 @@ def verify_speaker(audio_input, user_id):
427
 
428
  try:
429
  wav = load_audio(audio_input)
430
-
431
  is_live, live_msg = check_liveness(wav)
432
  if not is_live:
433
  record_attempt(user_id, False)
434
  return f"ACCESS DENIED: {live_msg}"
435
-
436
  is_real, spoof_msg = check_antispoofing(wav)
437
  if not is_real:
438
  record_attempt(user_id, False)
@@ -440,42 +447,29 @@ def verify_speaker(audio_input, user_id):
440
 
441
  test_emb = extract_embedding(wav)
442
  stored_emb = np.array(db[user_id]["voiceprint"])
443
-
444
- similarity = float(np.dot(test_emb, stored_emb) / (
445
- np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10
446
- ))
447
 
448
  if similarity >= THRESHOLD:
449
  record_attempt(user_id, True)
450
  full_name = db[user_id].get("full_name", user_id)
451
- return (
452
- f"ACCESS GRANTED\n"
453
- f"Welcome, {full_name}\n"
454
- f"Confidence: {similarity:.4f} (threshold: {THRESHOLD})\n"
455
- f"Liveness: Passed | Antispoofing: Passed"
456
- )
457
  else:
458
  record_attempt(user_id, False)
459
  tracker = attempt_tracker.get(user_id, {})
460
  attempts_left = MAX_ATTEMPTS - tracker.get("count", 0)
461
- msg = (
462
- f"ACCESS DENIED\n"
463
- f"Voice does not match registered voiceprint.\n"
464
- f"Similarity: {similarity:.4f} (threshold: {THRESHOLD})\n"
465
- )
466
  if attempts_left > 0:
467
  msg += f"Attempts remaining: {attempts_left}"
468
  else:
469
  msg += f"Account locked for {LOCKOUT_MINUTES} minutes."
470
  return msg
471
-
472
  except Exception as e:
473
  return f"Verification error: {str(e)}"
474
 
475
 
476
-
477
- # USER MANAGEMENT
478
-
479
  def list_users():
480
  db = load_db()
481
  if not db:
@@ -501,7 +495,7 @@ def delete_user(user_id):
501
  save_db(db)
502
  if user_id in attempt_tracker:
503
  del attempt_tracker[user_id]
504
- return f"User '{name}' ({user_id}) deleted successfully."
505
 
506
  def reset_lockout(user_id):
507
  if not user_id or not user_id.strip():
@@ -510,16 +504,16 @@ def reset_lockout(user_id):
510
  if user_id in attempt_tracker:
511
  attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
512
  return f"Lockout reset for {user_id}."
513
- return f"No lockout record found for {user_id}."
514
-
515
 
516
 
517
- # GRADIO INTERFACE
518
-
519
  with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft()) as demo:
520
 
521
  gr.Markdown("""
522
- # Voice Authentication System
 
 
523
  """)
524
 
525
  with gr.Tabs():
@@ -528,7 +522,6 @@ with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft())
528
  gr.Markdown("""
529
  ### Enroll New User
530
  Record **6 voice samples** to create your voiceprint. Speak naturally for 3-4 seconds each time.
531
- The system adds noise augmentation automatically (6 clean + 24 noisy = 30 embeddings averaged).
532
  """)
533
  with gr.Row():
534
  with gr.Column():
@@ -545,7 +538,6 @@ with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft())
545
  gr.Markdown("""
546
  ### Verify Identity
547
  Record your voice to verify against your enrolled voiceprint.
548
- Security: 3 failed attempts = 5-minute lockout. 3-second cooldown between attempts.
549
  """)
550
  with gr.Row():
551
  with gr.Column():
@@ -576,51 +568,42 @@ with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft())
576
 
577
  with gr.Tab("API Docs"):
578
  gr.Markdown("""
579
- ### REST API Endpoints for Banking Systems
580
 
581
  **Base URL:** `https://amfafa-voice-authentication-sys.hf.space`
582
 
583
  ---
584
 
585
- #### 1. Enroll a Voice Sample
586
- ```
587
- POST /api/enroll
588
- Content-Type: multipart/form-data
589
- Fields: audio (WAV file), user_id (string), full_name (string)
590
- ```
591
-
592
- #### 2. Verify a Speaker
593
- ```
594
- POST /api/verify
595
- Content-Type: multipart/form-data
596
- Fields: audio (WAV file), user_id (string)
597
- ```
598
-
599
- #### 3. List Enrolled Users
600
- ```
601
- GET /api/users
602
- ```
603
-
604
- #### 4. Delete a User
605
- ```
606
- DELETE /api/users/{user_id}
607
- ```
608
-
609
- #### 5. Health Check
610
- ```
611
- GET /api/health
612
- ```
613
-
614
- #### 6. Reset Lockout
615
- ```
616
- POST /api/reset-lockout
617
- Field: user_id (string)
618
- ```
619
- """)
620
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
 
622
- # REST API ENDPOINTS
623
 
 
624
  from fastapi import UploadFile, File, Form
625
  from fastapi.responses import JSONResponse
626
  from fastapi.middleware.cors import CORSMiddleware
@@ -635,17 +618,18 @@ fastapi_app.add_middleware(
635
  allow_headers=["*"],
636
  )
637
 
 
638
  @fastapi_app.get("/api/health")
639
  async def health_check():
640
  return {
641
  "status": "healthy",
642
  "model": "UniSpeech-SAT + AAM-Softmax",
643
- "eer": "3.94%",
644
  "threshold": THRESHOLD,
645
  "device": str(DEVICE),
646
  "timestamp": datetime.now().isoformat()
647
  }
648
 
 
649
  @fastapi_app.post("/api/enroll")
650
  async def api_enroll(audio: UploadFile = File(...), user_id: str = Form(...), full_name: str = Form(...)):
651
  try:
@@ -670,6 +654,7 @@ async def api_enroll(audio: UploadFile = File(...), user_id: str = Form(...), fu
670
  except Exception as e:
671
  return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
672
 
 
673
  @fastapi_app.post("/api/verify")
674
  async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
675
  try:
@@ -687,12 +672,10 @@ async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
687
  db = load_db()
688
  if uid not in db:
689
  os.unlink(tmp_path)
690
- return JSONResponse(content={"success": False, "message": f"User '{uid}' not found. Please enroll first."})
691
-
692
  if db[uid].get("status") != "enrolled":
693
  os.unlink(tmp_path)
694
- samples = db[uid].get("samples_collected", 0)
695
- return JSONResponse(content={"success": False, "message": f"Enrollment incomplete. {NUM_CLEAN_SAMPLES - samples} more sample(s) needed."})
696
 
697
  wav = load_audio(tmp_path)
698
  os.unlink(tmp_path)
@@ -700,12 +683,12 @@ async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
700
  is_live, live_msg = check_liveness(wav)
701
  if not is_live:
702
  record_attempt(uid, False)
703
- return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": live_msg, "liveness_passed": False, "antispoofing_passed": None})
704
 
705
  is_real, spoof_msg = check_antispoofing(wav)
706
  if not is_real:
707
  record_attempt(uid, False)
708
- return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": spoof_msg, "liveness_passed": True, "antispoofing_passed": False})
709
 
710
  test_emb = extract_embedding(wav)
711
  stored_emb = np.array(db[uid]["voiceprint"])
@@ -713,10 +696,8 @@ async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
713
 
714
  granted = similarity >= THRESHOLD
715
  record_attempt(uid, granted)
716
-
717
  tracker = attempt_tracker.get(uid, {})
718
- attempts_used = tracker.get("count", 0)
719
- attempts_remaining = max(0, MAX_ATTEMPTS - attempts_used)
720
 
721
  response = {
722
  "success": True,
@@ -730,19 +711,17 @@ async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
730
  "attempts_remaining": attempts_remaining if not granted else MAX_ATTEMPTS,
731
  "locked": attempts_remaining == 0 and not granted
732
  }
733
-
734
  if granted:
735
  response["message"] = "Access granted. Voice verified."
 
 
736
  else:
737
- if attempts_remaining > 0:
738
- response["message"] = f"Voice does not match. {attempts_remaining} attempt(s) remaining."
739
- else:
740
- response["message"] = f"Account locked for {LOCKOUT_MINUTES} minutes."
741
-
742
  return JSONResponse(content=response)
743
  except Exception as e:
744
  return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
745
 
 
746
  @fastapi_app.get("/api/users")
747
  async def api_list_users():
748
  db = load_db()
@@ -758,20 +737,228 @@ async def api_list_users():
758
  })
759
  return JSONResponse(content={"success": True, "users": users, "total": len(users)})
760
 
 
761
  @fastapi_app.delete("/api/users/{user_id}")
762
  async def api_delete_user(user_id: str):
763
  result = delete_user(user_id)
764
  success = "error" not in result.lower()
765
  return JSONResponse(content={"success": success, "message": result})
766
 
 
767
  @fastapi_app.post("/api/reset-lockout")
768
  async def api_reset_lockout(user_id: str = Form(...)):
769
  result = reset_lockout(user_id)
770
  return JSONResponse(content={"success": True, "message": result})
771
 
772
 
 
 
 
 
 
 
 
 
 
 
 
773
 
774
- # LAUNCH
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  if __name__ == "__main__":
777
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  import json
3
  import math
4
  import time
5
+ import uuid
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
18
 
19
  from transformers import AutoModel
20
 
21
+ # Config
 
 
22
  CKPT_PATH = 'aam_best.pt'
23
  DB_PATH = 'voiceprint_db.json'
24
  MODEL_NAME = 'microsoft/unispeech-sat-base-sv'
 
34
  COOLDOWN_SECONDS = 3
35
  ANTISPOOFING_THRESHOLD = 0.02
36
 
37
+ # Challenge word pool (simple, short, easy to pronounce)
38
+ CHALLENGE_WORDS = [
39
+ 'Red', 'Blue', 'Gold', 'Star', 'Water',
40
+ 'Moon', 'Fire', 'Green', 'Black', 'White',
41
+ 'Sun', 'Rain', 'Tree', 'Fish', 'Bird',
42
+ 'Stone', 'Wind', 'Cloud', 'Light', 'Sound'
43
+ ]
44
+
45
+ # Session steps
46
+ SESSION_STEPS = {
47
+ 'STARTED': 'started',
48
+ 'VERIFIED': 'verified',
49
+ 'LIVENESS_PENDING': 'liveness_pending',
50
+ 'AUTHENTICATED': 'authenticated',
51
+ 'TRANSACTION_PENDING': 'transaction_pending',
52
+ 'COMPLETE': 'complete',
53
+ 'DENIED': 'denied'
54
+ }
55
+
56
+
57
+ # AAM-Softmax model
58
  class AAMSoftmax(nn.Module):
59
  def __init__(self, in_features, num_classes, margin=0.2, scale=30.0):
60
  super().__init__()
 
96
  return self.relu(self.fc1(x))
97
 
98
 
99
+ # Load models
 
 
100
  print("Loading UniSpeech-SAT base model...")
101
  base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
102
  base_model.eval()
 
105
 
106
  print("Loading AAM-Softmax checkpoint...")
107
  ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
 
 
108
  print(f"Checkpoint type: {type(ckpt)}")
109
  if isinstance(ckpt, dict):
110
  print(f"Checkpoint keys: {list(ckpt.keys())}")
111
 
 
112
  num_classes = 227
113
  if isinstance(ckpt, dict):
114
  if 'num_classes' in ckpt:
 
116
  elif 'num_speakers' in ckpt:
117
  num_classes = ckpt['num_speakers']
118
 
 
119
  classifier = SpeakerClassifier(input_dim=768, hidden_dim=512, num_classes=num_classes).to(DEVICE)
120
 
 
121
  loaded = False
122
  if isinstance(ckpt, dict):
 
123
  for key in ['classifier_state', 'classifier_state_dict', 'model_state_dict', 'state_dict', 'model']:
124
  if key in ckpt:
125
  try:
 
128
  loaded = True
129
  break
130
  except Exception as e:
131
+ print(f"Key '{key}' found but failed: {e}")
132
 
 
133
  if not loaded:
 
134
  sample_keys = list(ckpt.keys())[:5]
135
+ if any('.' in k for k in sample_keys):
 
136
  try:
137
  classifier.load_state_dict(ckpt)
138
+ print("Loaded classifier directly from checkpoint dict")
139
  loaded = True
140
+ except:
 
 
141
  try:
142
  classifier.load_state_dict(ckpt, strict=False)
143
  print("Loaded classifier with strict=False")
144
  loaded = True
145
  except Exception as e2:
146
+ print(f"Direct load failed: {e2}")
147
 
 
148
  if 'base_model_state' in ckpt:
149
  try:
150
  base_model.load_state_dict(ckpt['base_model_state'], strict=False)
151
+ print("Loaded fine-tuned base model weights")
152
+ except:
153
+ pass
 
154
  elif isinstance(ckpt, nn.Module):
 
155
  classifier = ckpt.to(DEVICE)
156
+ print("Loaded classifier directly (model object)")
157
  loaded = True
158
 
159
  if not loaded:
160
+ print("WARNING: Could not load classifier weights. Using random init.")
 
161
 
162
  classifier.eval()
163
  print(f"Models ready. num_classes={num_classes}, loaded={loaded}")
164
 
165
 
166
+ # Database
 
 
167
  def load_db():
168
  if os.path.exists(DB_PATH):
169
  with open(DB_PATH, 'r') as f:
 
175
  json.dump(db, f, indent=2, default=str)
176
 
177
 
178
+ # Audio processing
 
 
179
  def load_audio(audio_input):
180
  if isinstance(audio_input, tuple):
181
  sr, audio_np = audio_input
 
230
  return wav_tensor + noise
231
 
232
 
233
+ # Liveness detection
 
234
  def check_liveness(wav_tensor):
235
  wav_np = wav_tensor.numpy()
236
  rms = np.sqrt(np.mean(wav_np ** 2))
237
  if rms < 0.001:
238
+ return False, "Audio too quiet"
239
  std = np.std(wav_np)
240
  if std < 0.001:
241
+ return False, "Audio lacks variation"
242
  zero_crossings = np.sum(np.abs(np.diff(np.sign(wav_np)))) / (2 * len(wav_np))
243
  if zero_crossings < 0.01:
244
+ return False, "Abnormal audio pattern"
245
  non_silent = np.abs(wav_np) > 0.01
246
  speech_ratio = np.sum(non_silent) / len(wav_np)
247
  if speech_ratio < 0.1:
248
+ return False, "Insufficient speech content"
249
  return True, "Liveness check passed"
250
 
251
 
252
+ # Antispoofing
 
 
253
  def check_antispoofing(wav_tensor):
254
  wav_np = wav_tensor.numpy()
255
  fft = np.fft.rfft(wav_np)
256
  magnitude = np.abs(fft)
257
  magnitude = magnitude[magnitude > 0]
258
  if len(magnitude) == 0:
259
+ return False, "No frequency content"
260
  geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10)))
261
  arithmetic_mean = np.mean(magnitude)
262
  spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10)
263
  if spectral_flatness > (1.0 - ANTISPOOFING_THRESHOLD):
264
+ return False, "Possible synthetic audio"
265
  frame_size = 1600
266
  if len(wav_np) >= frame_size * 3:
267
  frames = [wav_np[i:i + frame_size] for i in range(0, len(wav_np) - frame_size, frame_size)]
268
  frame_energies = [np.sqrt(np.mean(f ** 2)) for f in frames]
269
  energy_std = np.std(frame_energies)
270
  if energy_std < 0.001:
271
+ return False, "Unnaturally uniform energy"
272
  return True, "Antispoofing check passed"
273
 
274
 
275
+ # Security: lockout and cooldown
 
 
276
  attempt_tracker = {}
277
 
278
  def check_security(user_id):
 
292
  last = datetime.fromisoformat(tracker["last_attempt"])
293
  elapsed = (now - last).total_seconds()
294
  if elapsed < COOLDOWN_SECONDS:
295
+ return False, f"Please wait {COOLDOWN_SECONDS - int(elapsed)} seconds."
296
  return True, "OK"
297
 
298
  def record_attempt(user_id, success):
 
310
  tracker["locked_until"] = (now + timedelta(minutes=LOCKOUT_MINUTES)).isoformat()
311
 
312
 
313
+ # Generate random challenge (2 words from pool)
314
+ def generate_challenge():
315
+ words = random.sample(CHALLENGE_WORDS, 2)
316
+ return ' '.join(words)
317
+
318
+
319
+ # Session storage (in-memory)
320
+ sessions = {}
321
+
322
+ def create_session(user_id):
323
+ session_id = str(uuid.uuid4())
324
+ sessions[session_id] = {
325
+ "session_id": session_id,
326
+ "user_id": user_id.strip().upper(),
327
+ "step": SESSION_STEPS['STARTED'],
328
+ "challenge_phrase": None,
329
+ "full_name": None,
330
+ "similarity": None,
331
+ "created_at": datetime.now().isoformat(),
332
+ "expires_at": (datetime.now() + timedelta(minutes=5)).isoformat()
333
+ }
334
+ return sessions[session_id]
335
+
336
+ def get_session(session_id):
337
+ if session_id not in sessions:
338
+ return None
339
+ session = sessions[session_id]
340
+ if datetime.now() > datetime.fromisoformat(session["expires_at"]):
341
+ del sessions[session_id]
342
+ return None
343
+ return session
344
 
 
345
 
346
+ # Enroll
347
  def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=NUM_CLEAN_SAMPLES):
348
  if not user_id or not user_id.strip():
349
  return "Error: User ID is required."
350
  if not full_name or not full_name.strip():
351
  return "Error: Full Name is required."
352
  if audio_input is None:
353
+ return "Error: No audio recorded."
354
 
355
  user_id = user_id.strip().upper()
356
  full_name = full_name.strip()
357
 
358
  try:
359
  wav = load_audio(audio_input)
 
360
  is_live, live_msg = check_liveness(wav)
361
  if not is_live:
362
  return f"Enrollment failed: {live_msg}"
 
363
  is_real, spoof_msg = check_antispoofing(wav)
364
  if not is_real:
365
  return f"Enrollment failed: {spoof_msg}"
366
 
367
  clean_emb = extract_embedding(wav)
 
368
  noisy_embeddings = []
369
  for i in range(NUM_NOISY_COPIES):
370
  noise_level = 0.003 + (i * 0.002)
 
373
  noisy_embeddings.append(noisy_emb)
374
 
375
  db = load_db()
 
376
  if user_id not in db:
377
  db[user_id] = {
378
  "full_name": full_name,
 
390
  db[user_id]["sample_embeddings"].append(sample_data)
391
  db[user_id]["samples_collected"] = len(db[user_id]["sample_embeddings"])
392
  db[user_id]["full_name"] = full_name
 
393
  samples_collected = db[user_id]["samples_collected"]
394
 
395
  if samples_collected >= total_samples:
 
398
  all_embeddings.append(np.array(sample["clean"]))
399
  for noisy in sample["noisy"]:
400
  all_embeddings.append(np.array(noisy))
 
401
  avg_embedding = np.mean(all_embeddings, axis=0)
402
  avg_embedding = avg_embedding / (np.linalg.norm(avg_embedding) + 1e-10)
 
403
  db[user_id]["voiceprint"] = avg_embedding.tolist()
404
  db[user_id]["status"] = "enrolled"
405
  db[user_id]["completed_at"] = datetime.now().isoformat()
406
  db[user_id]["sample_embeddings"] = []
 
407
  save_db(db)
408
  return f"Enrollment COMPLETE for {full_name} ({user_id}). Voiceprint created from {total_samples} samples ({total_samples * (1 + NUM_NOISY_COPIES)} embeddings averaged)."
409
  else:
410
  save_db(db)
411
  remaining = total_samples - samples_collected
412
  return f"Sample {samples_collected}/{total_samples} recorded for {full_name}. {remaining} more sample(s) needed."
 
413
  except Exception as e:
414
  return f"Enrollment error: {str(e)}"
415
 
416
 
417
+ # Verify
 
418
  def verify_speaker(audio_input, user_id):
419
  if not user_id or not user_id.strip():
420
  return "Error: User ID is required."
421
  if audio_input is None:
422
+ return "Error: No audio recorded."
423
 
424
  user_id = user_id.strip().upper()
 
425
  allowed, sec_msg = check_security(user_id)
426
  if not allowed:
427
  return f"ACCESS DENIED: {sec_msg}"
428
 
429
  db = load_db()
430
  if user_id not in db:
431
+ return f"Error: User '{user_id}' not found."
 
432
  if db[user_id].get("status") != "enrolled":
433
  samples = db[user_id].get("samples_collected", 0)
434
  remaining = NUM_CLEAN_SAMPLES - samples
 
436
 
437
  try:
438
  wav = load_audio(audio_input)
 
439
  is_live, live_msg = check_liveness(wav)
440
  if not is_live:
441
  record_attempt(user_id, False)
442
  return f"ACCESS DENIED: {live_msg}"
 
443
  is_real, spoof_msg = check_antispoofing(wav)
444
  if not is_real:
445
  record_attempt(user_id, False)
 
447
 
448
  test_emb = extract_embedding(wav)
449
  stored_emb = np.array(db[user_id]["voiceprint"])
450
+ similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
 
 
 
451
 
452
  if similarity >= THRESHOLD:
453
  record_attempt(user_id, True)
454
  full_name = db[user_id].get("full_name", user_id)
455
+ return (f"ACCESS GRANTED\nWelcome, {full_name}\n"
456
+ f"Confidence: {similarity:.4f} (threshold: {THRESHOLD})\n"
457
+ f"Liveness: Passed | Antispoofing: Passed")
 
 
 
458
  else:
459
  record_attempt(user_id, False)
460
  tracker = attempt_tracker.get(user_id, {})
461
  attempts_left = MAX_ATTEMPTS - tracker.get("count", 0)
462
+ msg = f"ACCESS DENIED\nVoice does not match.\nSimilarity: {similarity:.4f} (threshold: {THRESHOLD})\n"
 
 
 
 
463
  if attempts_left > 0:
464
  msg += f"Attempts remaining: {attempts_left}"
465
  else:
466
  msg += f"Account locked for {LOCKOUT_MINUTES} minutes."
467
  return msg
 
468
  except Exception as e:
469
  return f"Verification error: {str(e)}"
470
 
471
 
472
+ # User management
 
 
473
  def list_users():
474
  db = load_db()
475
  if not db:
 
495
  save_db(db)
496
  if user_id in attempt_tracker:
497
  del attempt_tracker[user_id]
498
+ return f"User '{name}' ({user_id}) deleted."
499
 
500
  def reset_lockout(user_id):
501
  if not user_id or not user_id.strip():
 
504
  if user_id in attempt_tracker:
505
  attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
506
  return f"Lockout reset for {user_id}."
507
+ return f"No lockout record for {user_id}."
 
508
 
509
 
510
+ # Gradio interface
 
511
  with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft()) as demo:
512
 
513
  gr.Markdown("""
514
+ # ATM Voice Authentication System
515
+ ### Voice-Based Speaker Verification for Banking Security
516
+ Voice biometric authentication system for secure ATM access
517
  """)
518
 
519
  with gr.Tabs():
 
522
  gr.Markdown("""
523
  ### Enroll New User
524
  Record **6 voice samples** to create your voiceprint. Speak naturally for 3-4 seconds each time.
 
525
  """)
526
  with gr.Row():
527
  with gr.Column():
 
538
  gr.Markdown("""
539
  ### Verify Identity
540
  Record your voice to verify against your enrolled voiceprint.
 
541
  """)
542
  with gr.Row():
543
  with gr.Column():
 
568
 
569
  with gr.Tab("API Docs"):
570
  gr.Markdown("""
571
+ ### REST API Endpoints
572
 
573
  **Base URL:** `https://amfafa-voice-authentication-sys.hf.space`
574
 
575
  ---
576
 
577
+ #### Basic Endpoints
578
+ - `POST /api/enroll` — Enroll a voice sample (audio, user_id, full_name)
579
+ - `POST /api/verify` — Verify a voice (audio, user_id)
580
+ - `GET /api/users` — List enrolled users
581
+ - `DELETE /api/users/{user_id}` Delete a user
582
+ - `GET /api/health` — Health check
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
583
 
584
+ ---
585
+
586
+ #### Session-Based Voice Authentication Flow
587
+ These endpoints power the full conversational ATM experience.
588
+
589
+ **Step 1: Start session**
590
+ `POST /api/session/start` — Send `user_id` → Returns session_id
591
+
592
+ **Step 2: Verify identity**
593
+ `POST /api/session/verify` — Send audio + session_id → Returns greeting with user's name + challenge words
594
+
595
+ **Step 3: Liveness check**
596
+ `POST /api/session/liveness` — Send audio of challenge words + session_id → Returns authenticated or denied
597
+
598
+ **Step 4: Confirm transaction (simulated)**
599
+ `POST /api/session/transaction` — Send amount + session_id → Returns confirmation
600
+
601
+ **Check session**
602
+ `GET /api/session/{session_id}` — Returns current session state
603
+ """)
604
 
 
605
 
606
+ # REST API endpoints
607
  from fastapi import UploadFile, File, Form
608
  from fastapi.responses import JSONResponse
609
  from fastapi.middleware.cors import CORSMiddleware
 
618
  allow_headers=["*"],
619
  )
620
 
621
+ # Health check
622
  @fastapi_app.get("/api/health")
623
  async def health_check():
624
  return {
625
  "status": "healthy",
626
  "model": "UniSpeech-SAT + AAM-Softmax",
 
627
  "threshold": THRESHOLD,
628
  "device": str(DEVICE),
629
  "timestamp": datetime.now().isoformat()
630
  }
631
 
632
+ # Basic enroll endpoint
633
  @fastapi_app.post("/api/enroll")
634
  async def api_enroll(audio: UploadFile = File(...), user_id: str = Form(...), full_name: str = Form(...)):
635
  try:
 
654
  except Exception as e:
655
  return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
656
 
657
+ # Basic verify endpoint
658
  @fastapi_app.post("/api/verify")
659
  async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
660
  try:
 
672
  db = load_db()
673
  if uid not in db:
674
  os.unlink(tmp_path)
675
+ return JSONResponse(content={"success": False, "message": f"User '{uid}' not found."})
 
676
  if db[uid].get("status") != "enrolled":
677
  os.unlink(tmp_path)
678
+ return JSONResponse(content={"success": False, "message": "Enrollment incomplete."})
 
679
 
680
  wav = load_audio(tmp_path)
681
  os.unlink(tmp_path)
 
683
  is_live, live_msg = check_liveness(wav)
684
  if not is_live:
685
  record_attempt(uid, False)
686
+ return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": live_msg, "liveness_passed": False})
687
 
688
  is_real, spoof_msg = check_antispoofing(wav)
689
  if not is_real:
690
  record_attempt(uid, False)
691
+ return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": spoof_msg, "antispoofing_passed": False})
692
 
693
  test_emb = extract_embedding(wav)
694
  stored_emb = np.array(db[uid]["voiceprint"])
 
696
 
697
  granted = similarity >= THRESHOLD
698
  record_attempt(uid, granted)
 
699
  tracker = attempt_tracker.get(uid, {})
700
+ attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0))
 
701
 
702
  response = {
703
  "success": True,
 
711
  "attempts_remaining": attempts_remaining if not granted else MAX_ATTEMPTS,
712
  "locked": attempts_remaining == 0 and not granted
713
  }
 
714
  if granted:
715
  response["message"] = "Access granted. Voice verified."
716
+ elif attempts_remaining > 0:
717
+ response["message"] = f"Voice does not match. {attempts_remaining} attempt(s) remaining."
718
  else:
719
+ response["message"] = f"Account locked for {LOCKOUT_MINUTES} minutes."
 
 
 
 
720
  return JSONResponse(content=response)
721
  except Exception as e:
722
  return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
723
 
724
+ # List users
725
  @fastapi_app.get("/api/users")
726
  async def api_list_users():
727
  db = load_db()
 
737
  })
738
  return JSONResponse(content={"success": True, "users": users, "total": len(users)})
739
 
740
+ # Delete user
741
  @fastapi_app.delete("/api/users/{user_id}")
742
  async def api_delete_user(user_id: str):
743
  result = delete_user(user_id)
744
  success = "error" not in result.lower()
745
  return JSONResponse(content={"success": success, "message": result})
746
 
747
+ # Reset lockout
748
  @fastapi_app.post("/api/reset-lockout")
749
  async def api_reset_lockout(user_id: str = Form(...)):
750
  result = reset_lockout(user_id)
751
  return JSONResponse(content={"success": True, "message": result})
752
 
753
 
754
+ # SESSION-BASED ENDPOINTS (conversational ATM flow)
755
+
756
+ # Step 1: Start a session
757
+ @fastapi_app.post("/api/session/start")
758
+ async def session_start(user_id: str = Form(...)):
759
+ uid = user_id.strip().upper()
760
+ db = load_db()
761
+ if uid not in db:
762
+ return JSONResponse(content={"success": False, "message": f"User '{uid}' not found. Please enroll first."})
763
+ if db[uid].get("status") != "enrolled":
764
+ return JSONResponse(content={"success": False, "message": "Enrollment incomplete."})
765
 
766
+ allowed, sec_msg = check_security(uid)
767
+ if not allowed:
768
+ return JSONResponse(content={"success": False, "message": sec_msg, "locked": True})
769
+
770
+ session = create_session(uid)
771
+ return JSONResponse(content={
772
+ "success": True,
773
+ "session_id": session["session_id"],
774
+ "user_id": uid,
775
+ "message": "Session started. Please provide a voice sample to verify your identity.",
776
+ "next_step": "verify",
777
+ "instruction": "Record your voice and send it to /api/session/verify"
778
+ })
779
+
780
+ # Step 2: Verify identity (returns greeting + challenge)
781
+ @fastapi_app.post("/api/session/verify")
782
+ async def session_verify(audio: UploadFile = File(...), session_id: str = Form(...)):
783
+ session = get_session(session_id)
784
+ if not session:
785
+ return JSONResponse(content={"success": False, "message": "Session expired or not found. Start a new session."})
786
+
787
+ if session["step"] != SESSION_STEPS['STARTED']:
788
+ return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"})
789
+
790
+ uid = session["user_id"]
791
+ allowed, sec_msg = check_security(uid)
792
+ if not allowed:
793
+ session["step"] = SESSION_STEPS['DENIED']
794
+ return JSONResponse(content={"success": False, "message": sec_msg, "locked": True})
795
+
796
+ try:
797
+ audio_bytes = await audio.read()
798
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
799
+ tmp.write(audio_bytes)
800
+ tmp_path = tmp.name
801
+
802
+ wav = load_audio(tmp_path)
803
+ os.unlink(tmp_path)
804
+
805
+ is_live, live_msg = check_liveness(wav)
806
+ if not is_live:
807
+ record_attempt(uid, False)
808
+ return JSONResponse(content={"success": True, "verified": False, "message": live_msg})
809
+
810
+ is_real, spoof_msg = check_antispoofing(wav)
811
+ if not is_real:
812
+ record_attempt(uid, False)
813
+ return JSONResponse(content={"success": True, "verified": False, "message": spoof_msg})
814
+
815
+ test_emb = extract_embedding(wav)
816
+ db = load_db()
817
+ stored_emb = np.array(db[uid]["voiceprint"])
818
+ similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
819
+
820
+ if similarity >= THRESHOLD:
821
+ record_attempt(uid, True)
822
+ full_name = db[uid].get("full_name", uid)
823
+ challenge = generate_challenge()
824
+
825
+ session["step"] = SESSION_STEPS['LIVENESS_PENDING']
826
+ session["full_name"] = full_name
827
+ session["similarity"] = round(similarity, 4)
828
+ session["challenge_phrase"] = challenge
829
+
830
+ return JSONResponse(content={
831
+ "success": True,
832
+ "verified": True,
833
+ "greeting": f"Welcome, {full_name}",
834
+ "full_name": full_name,
835
+ "similarity": round(similarity, 4),
836
+ "next_step": "liveness",
837
+ "challenge_phrase": challenge,
838
+ "instruction": f"Say these words: {challenge}",
839
+ "message": f"Voice verified. Welcome, {full_name}. For security, please say these words: {challenge}"
840
+ })
841
+ else:
842
+ record_attempt(uid, False)
843
+ tracker = attempt_tracker.get(uid, {})
844
+ attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0))
845
+ locked = attempts_remaining == 0
846
+
847
+ if locked:
848
+ session["step"] = SESSION_STEPS['DENIED']
849
+
850
+ return JSONResponse(content={
851
+ "success": True,
852
+ "verified": False,
853
+ "similarity": round(similarity, 4),
854
+ "attempts_remaining": attempts_remaining,
855
+ "locked": locked,
856
+ "message": f"Voice does not match. {attempts_remaining} attempt(s) remaining." if not locked else f"Account locked for {LOCKOUT_MINUTES} minutes."
857
+ })
858
+ except Exception as e:
859
+ return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
860
+
861
+ # Step 3: Liveness check (verify challenge phrase voice)
862
+ @fastapi_app.post("/api/session/liveness")
863
+ async def session_liveness(audio: UploadFile = File(...), session_id: str = Form(...)):
864
+ session = get_session(session_id)
865
+ if not session:
866
+ return JSONResponse(content={"success": False, "message": "Session expired or not found."})
867
+
868
+ if session["step"] != SESSION_STEPS['LIVENESS_PENDING']:
869
+ return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"})
870
+
871
+ uid = session["user_id"]
872
+
873
+ try:
874
+ audio_bytes = await audio.read()
875
+ with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
876
+ tmp.write(audio_bytes)
877
+ tmp_path = tmp.name
878
+
879
+ wav = load_audio(tmp_path)
880
+ os.unlink(tmp_path)
881
+
882
+ is_live, live_msg = check_liveness(wav)
883
+ if not is_live:
884
+ return JSONResponse(content={"success": True, "liveness_passed": False, "message": live_msg})
885
+
886
+ is_real, spoof_msg = check_antispoofing(wav)
887
+ if not is_real:
888
+ return JSONResponse(content={"success": True, "liveness_passed": False, "message": spoof_msg})
889
+
890
+ # Verify it's still the same person speaking
891
+ test_emb = extract_embedding(wav)
892
+ db = load_db()
893
+ stored_emb = np.array(db[uid]["voiceprint"])
894
+ similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
895
+
896
+ if similarity >= THRESHOLD:
897
+ session["step"] = SESSION_STEPS['AUTHENTICATED']
898
+ full_name = session["full_name"]
899
+
900
+ return JSONResponse(content={
901
+ "success": True,
902
+ "liveness_passed": True,
903
+ "authenticated": True,
904
+ "full_name": full_name,
905
+ "similarity": round(similarity, 4),
906
+ "next_step": "transaction",
907
+ "instruction": "How much would you like to withdraw?",
908
+ "message": f"Liveness confirmed. You are fully authenticated, {full_name}. How much would you like to withdraw?"
909
+ })
910
+ else:
911
+ return JSONResponse(content={
912
+ "success": True,
913
+ "liveness_passed": False,
914
+ "message": "Voice mismatch during liveness check. Please try again.",
915
+ "challenge_phrase": session["challenge_phrase"],
916
+ "instruction": f"Please say these words again: {session['challenge_phrase']}"
917
+ })
918
+ except Exception as e:
919
+ return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
920
 
921
+ # Step 4: Transaction (simulated)
922
+ @fastapi_app.post("/api/session/transaction")
923
+ async def session_transaction(session_id: str = Form(...), amount: str = Form(...)):
924
+ session = get_session(session_id)
925
+ if not session:
926
+ return JSONResponse(content={"success": False, "message": "Session expired or not found."})
927
+
928
+ if session["step"] != SESSION_STEPS['AUTHENTICATED']:
929
+ return JSONResponse(content={"success": False, "message": f"Not authenticated. Current step: {session['step']}"})
930
+
931
+ full_name = session["full_name"]
932
+ session["step"] = SESSION_STEPS['COMPLETE']
933
+
934
+ return JSONResponse(content={
935
+ "success": True,
936
+ "transaction_approved": True,
937
+ "full_name": full_name,
938
+ "amount": amount,
939
+ "message": f"Transaction approved. {full_name}, you are withdrawing {amount} cedis. Please collect your cash.",
940
+ "instruction": "Transaction complete. Session ended.",
941
+ "note": "In production, this step communicates with the bank's core system to process the actual withdrawal."
942
+ })
943
+
944
+ # Get session status
945
+ @fastapi_app.get("/api/session/{session_id}")
946
+ async def session_status(session_id: str):
947
+ session = get_session(session_id)
948
+ if not session:
949
+ return JSONResponse(content={"success": False, "message": "Session expired or not found."})
950
+ return JSONResponse(content={
951
+ "success": True,
952
+ "session_id": session["session_id"],
953
+ "user_id": session["user_id"],
954
+ "step": session["step"],
955
+ "full_name": session["full_name"],
956
+ "challenge_phrase": session["challenge_phrase"],
957
+ "created_at": session["created_at"],
958
+ "expires_at": session["expires_at"]
959
+ })
960
+
961
+
962
+ # Launch
963
  if __name__ == "__main__":
964
  demo.launch(server_name="0.0.0.0", server_port=7860)