amfafa commited on
Commit
7bd4461
·
verified ·
1 Parent(s): b659e09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -327
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import io
3
  import json
4
  import math
5
  import time
@@ -10,13 +9,9 @@ import torchaudio
10
  import numpy as np
11
  import random
12
  import tempfile
13
- import base64
14
  import gradio as gr
15
  from datetime import datetime, timedelta
16
 
17
-
18
- # TORCHAUDIO COMPATIBILITY FIX
19
-
20
  if not hasattr(torchaudio, 'list_audio_backends'):
21
  torchaudio.list_audio_backends = lambda: ["soundfile"]
22
 
@@ -39,7 +34,6 @@ MAX_ATTEMPTS = 3
39
  LOCKOUT_MINUTES = 5
40
  COOLDOWN_SECONDS = 3
41
  ANTISPOOFING_THRESHOLD = 0.02
42
- LIVE_AUDIO_THRESHOLD = 0.5
43
 
44
 
45
 
@@ -97,15 +91,81 @@ for param in base_model.parameters():
97
 
98
  print("Loading AAM-Softmax checkpoint...")
99
  ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
100
- num_classes = ckpt.get('num_classes', 227)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  classifier = SpeakerClassifier(input_dim=768, hidden_dim=512, num_classes=num_classes).to(DEVICE)
102
- classifier.load_state_dict(ckpt['classifier_state'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  classifier.eval()
104
- print(f"Models loaded. Speakers trained on: {num_classes}")
105
 
106
 
107
 
108
- # DATABASE MANAGEMENT
109
 
110
  def load_db():
111
  if os.path.exists(DB_PATH):
@@ -113,7 +173,6 @@ def load_db():
113
  return json.load(f)
114
  return {}
115
 
116
-
117
  def save_db(db):
118
  with open(DB_PATH, 'w') as f:
119
  json.dump(db, f, indent=2, default=str)
@@ -123,7 +182,6 @@ def save_db(db):
123
  # AUDIO PROCESSING
124
 
125
  def load_audio(audio_input):
126
- """Load audio from file path, tuple (sr, numpy), or bytes."""
127
  if isinstance(audio_input, tuple):
128
  sr, audio_np = audio_input
129
  wav = torch.tensor(audio_np, dtype=torch.float32)
@@ -132,7 +190,6 @@ def load_audio(audio_input):
132
  if wav.shape[0] > 1:
133
  wav = wav.mean(dim=0, keepdim=True)
134
  wav = wav.squeeze(0)
135
- # Normalize int audio to float
136
  if wav.abs().max() > 1.0:
137
  wav = wav / 32768.0
138
  if sr != SAMPLE_RATE:
@@ -158,17 +215,13 @@ def load_audio(audio_input):
158
  else:
159
  raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
160
 
161
- # Pad or trim to MAX_LEN
162
  if wav.shape[0] > MAX_LEN:
163
  wav = wav[:MAX_LEN]
164
  elif wav.shape[0] < MAX_LEN:
165
  wav = F.pad(wav, (0, MAX_LEN - wav.shape[0]))
166
-
167
  return wav
168
 
169
-
170
  def extract_embedding(wav_tensor):
171
- """Extract 512-dim speaker embedding from audio tensor."""
172
  with torch.no_grad():
173
  wav = wav_tensor.unsqueeze(0).to(DEVICE)
174
  outputs = base_model(wav)
@@ -177,123 +230,87 @@ def extract_embedding(wav_tensor):
177
  embedding = F.normalize(embedding, p=2, dim=1)
178
  return embedding.squeeze(0).cpu().numpy()
179
 
180
-
181
  def add_noise(wav_tensor, noise_level=0.005):
182
- """Add Gaussian noise for data augmentation."""
183
  noise = torch.randn_like(wav_tensor) * noise_level
184
  return wav_tensor + noise
185
 
186
 
187
-
188
  # LIVENESS DETECTION
189
 
190
  def check_liveness(wav_tensor):
191
- """Basic liveness check — detects silence or suspicious patterns."""
192
  wav_np = wav_tensor.numpy()
193
-
194
- # Check if audio has enough energy (not silent)
195
  rms = np.sqrt(np.mean(wav_np ** 2))
196
  if rms < 0.001:
197
  return False, "Audio too quiet — possible silence or empty recording"
198
-
199
- # Check for sufficient variation (not a constant tone)
200
  std = np.std(wav_np)
201
  if std < 0.001:
202
  return False, "Audio lacks variation — possible synthetic tone"
203
-
204
- # Check zero-crossing rate (natural speech has moderate ZCR)
205
  zero_crossings = np.sum(np.abs(np.diff(np.sign(wav_np)))) / (2 * len(wav_np))
206
  if zero_crossings < 0.01:
207
  return False, "Abnormal audio pattern — possible replay attack"
208
-
209
- # Check audio duration has content
210
  non_silent = np.abs(wav_np) > 0.01
211
  speech_ratio = np.sum(non_silent) / len(wav_np)
212
  if speech_ratio < 0.1:
213
  return False, "Insufficient speech content detected"
214
-
215
  return True, "Liveness check passed"
216
 
217
 
218
 
219
- # ANTISPOOFING CHECK
220
 
221
  def check_antispoofing(wav_tensor):
222
- """Basic antispoofing — checks spectral characteristics."""
223
  wav_np = wav_tensor.numpy()
224
-
225
- # Check spectral flatness (natural speech vs synthetic)
226
  fft = np.fft.rfft(wav_np)
227
  magnitude = np.abs(fft)
228
  magnitude = magnitude[magnitude > 0]
229
-
230
  if len(magnitude) == 0:
231
  return False, "No frequency content detected"
232
-
233
  geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10)))
234
  arithmetic_mean = np.mean(magnitude)
235
  spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10)
236
-
237
  if spectral_flatness > (1.0 - ANTISPOOFING_THRESHOLD):
238
  return False, f"Spectral flatness too high ({spectral_flatness:.4f}) — possible synthetic audio"
239
-
240
- # Check for unnaturally uniform amplitude
241
- frame_size = 1600 # 100ms frames
242
  if len(wav_np) >= frame_size * 3:
243
  frames = [wav_np[i:i + frame_size] for i in range(0, len(wav_np) - frame_size, frame_size)]
244
  frame_energies = [np.sqrt(np.mean(f ** 2)) for f in frames]
245
  energy_std = np.std(frame_energies)
246
  if energy_std < 0.001:
247
  return False, "Unnaturally uniform energy — possible synthetic audio"
248
-
249
  return True, "Antispoofing check passed"
250
 
251
 
252
- # SECURITY: LOCKOUT & COOLDOWN
253
 
254
- attempt_tracker = {} # {user_id: {"count": int, "last_attempt": datetime, "locked_until": datetime}}
255
 
 
256
 
257
  def check_security(user_id):
258
- """Check if user is locked out or in cooldown."""
259
  now = datetime.now()
260
-
261
  if user_id not in attempt_tracker:
262
  return True, "OK"
263
-
264
  tracker = attempt_tracker[user_id]
265
-
266
- # Check lockout
267
  if "locked_until" in tracker and tracker["locked_until"]:
268
  locked_until = datetime.fromisoformat(tracker["locked_until"])
269
  if now < locked_until:
270
  remaining = (locked_until - now).seconds
271
  return False, f"Account locked. Try again in {remaining} seconds."
272
  else:
273
- # Lockout expired — reset
274
  tracker["count"] = 0
275
  tracker["locked_until"] = None
276
-
277
- # Check cooldown
278
  if "last_attempt" in tracker and tracker["last_attempt"]:
279
  last = datetime.fromisoformat(tracker["last_attempt"])
280
  elapsed = (now - last).total_seconds()
281
  if elapsed < COOLDOWN_SECONDS:
282
  return False, f"Please wait {COOLDOWN_SECONDS - int(elapsed)} seconds before trying again."
283
-
284
  return True, "OK"
285
 
286
-
287
  def record_attempt(user_id, success):
288
- """Record a verification attempt."""
289
  now = datetime.now()
290
-
291
  if user_id not in attempt_tracker:
292
  attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
293
-
294
  tracker = attempt_tracker[user_id]
295
  tracker["last_attempt"] = now.isoformat()
296
-
297
  if success:
298
  tracker["count"] = 0
299
  tracker["locked_until"] = None
@@ -304,10 +321,9 @@ def record_attempt(user_id, success):
304
 
305
 
306
 
307
- # CORE FUNCTIONS: ENROLL & VERIFY
308
 
309
  def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=NUM_CLEAN_SAMPLES):
310
- """Process a single enrollment sample. Collects NUM_CLEAN_SAMPLES then finalizes."""
311
  if not user_id or not user_id.strip():
312
  return "Error: User ID is required."
313
  if not full_name or not full_name.strip():
@@ -321,20 +337,16 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
321
  try:
322
  wav = load_audio(audio_input)
323
 
324
- # Liveness check
325
  is_live, live_msg = check_liveness(wav)
326
  if not is_live:
327
  return f"Enrollment failed: {live_msg}"
328
 
329
- # Antispoofing check
330
  is_real, spoof_msg = check_antispoofing(wav)
331
  if not is_real:
332
  return f"Enrollment failed: {spoof_msg}"
333
 
334
- # Extract clean embedding
335
  clean_emb = extract_embedding(wav)
336
 
337
- # Generate noisy augmented embeddings
338
  noisy_embeddings = []
339
  for i in range(NUM_NOISY_COPIES):
340
  noise_level = 0.003 + (i * 0.002)
@@ -342,7 +354,6 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
342
  noisy_emb = extract_embedding(noisy_wav)
343
  noisy_embeddings.append(noisy_emb)
344
 
345
- # Load DB and accumulate samples
346
  db = load_db()
347
 
348
  if user_id not in db:
@@ -355,7 +366,6 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
355
  "samples_collected": 0
356
  }
357
 
358
- # Store this sample's embeddings (1 clean + 4 noisy = 5 per sample)
359
  sample_data = {
360
  "clean": clean_emb.tolist(),
361
  "noisy": [e.tolist() for e in noisy_embeddings]
@@ -367,7 +377,6 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
367
  samples_collected = db[user_id]["samples_collected"]
368
 
369
  if samples_collected >= total_samples:
370
- # Finalize: average all embeddings (6 clean + 24 noisy = 30 total)
371
  all_embeddings = []
372
  for sample in db[user_id]["sample_embeddings"]:
373
  all_embeddings.append(np.array(sample["clean"]))
@@ -380,7 +389,6 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
380
  db[user_id]["voiceprint"] = avg_embedding.tolist()
381
  db[user_id]["status"] = "enrolled"
382
  db[user_id]["completed_at"] = datetime.now().isoformat()
383
- # Remove raw samples to save space
384
  db[user_id]["sample_embeddings"] = []
385
 
386
  save_db(db)
@@ -394,8 +402,9 @@ def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=
394
  return f"Enrollment error: {str(e)}"
395
 
396
 
 
 
397
  def verify_speaker(audio_input, user_id):
398
- """Verify a speaker against their stored voiceprint."""
399
  if not user_id or not user_id.strip():
400
  return "Error: User ID is required."
401
  if audio_input is None:
@@ -403,12 +412,10 @@ def verify_speaker(audio_input, user_id):
403
 
404
  user_id = user_id.strip().upper()
405
 
406
- # Security check
407
  allowed, sec_msg = check_security(user_id)
408
  if not allowed:
409
  return f"ACCESS DENIED: {sec_msg}"
410
 
411
- # Check user exists
412
  db = load_db()
413
  if user_id not in db:
414
  return f"Error: User '{user_id}' not found. Please enroll first."
@@ -421,19 +428,16 @@ def verify_speaker(audio_input, user_id):
421
  try:
422
  wav = load_audio(audio_input)
423
 
424
- # Liveness check
425
  is_live, live_msg = check_liveness(wav)
426
  if not is_live:
427
  record_attempt(user_id, False)
428
  return f"ACCESS DENIED: {live_msg}"
429
 
430
- # Antispoofing check
431
  is_real, spoof_msg = check_antispoofing(wav)
432
  if not is_real:
433
  record_attempt(user_id, False)
434
  return f"ACCESS DENIED: {spoof_msg}"
435
 
436
- # Extract embedding and compare
437
  test_emb = extract_embedding(wav)
438
  stored_emb = np.array(db[user_id]["voiceprint"])
439
 
@@ -469,12 +473,13 @@ def verify_speaker(audio_input, user_id):
469
  return f"Verification error: {str(e)}"
470
 
471
 
 
 
 
472
  def list_users():
473
- """List all enrolled users."""
474
  db = load_db()
475
  if not db:
476
  return "No users enrolled yet."
477
-
478
  lines = ["=== Enrolled Users ===\n"]
479
  for uid, data in db.items():
480
  name = data.get("full_name", "Unknown")
@@ -484,34 +489,23 @@ def list_users():
484
  lines.append(f"ID: {uid} | Name: {name} | Status: {status} | Samples: {samples} | Enrolled: {enrolled}")
485
  return "\n".join(lines)
486
 
487
-
488
  def delete_user(user_id):
489
- """Delete a user's voiceprint."""
490
  if not user_id or not user_id.strip():
491
  return "Error: User ID is required."
492
-
493
  user_id = user_id.strip().upper()
494
  db = load_db()
495
-
496
  if user_id not in db:
497
  return f"Error: User '{user_id}' not found."
498
-
499
  name = db[user_id].get("full_name", user_id)
500
  del db[user_id]
501
  save_db(db)
502
-
503
- # Clear attempt tracker too
504
  if user_id in attempt_tracker:
505
  del attempt_tracker[user_id]
506
-
507
  return f"User '{name}' ({user_id}) deleted successfully."
508
 
509
-
510
  def reset_lockout(user_id):
511
- """Reset lockout for a user."""
512
  if not user_id or not user_id.strip():
513
  return "Error: User ID is required."
514
-
515
  user_id = user_id.strip().upper()
516
  if user_id in attempt_tracker:
517
  attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
@@ -522,84 +516,53 @@ def reset_lockout(user_id):
522
 
523
  # GRADIO INTERFACE
524
 
525
- with gr.Blocks(
526
- title="ATM Voice Authentication System",
527
- theme=gr.themes.Soft()
528
- ) as demo:
529
 
530
- gr.Markdown(
531
- """
532
- # ATM Voice Authentication System
533
- ### Voice-Based Speaker Verification for Banking Security
534
- **Model:** UniSpeech-SAT + AAM-Softmax | **EER:** 3.94% | **Speakers Trained:** 227 Akan speakers
535
- """
536
- )
537
 
538
  with gr.Tabs():
539
 
540
- # ---- ENROLL TAB ----
541
  with gr.Tab("Enroll"):
542
- gr.Markdown(
543
- """
544
- ### Enroll New User
545
- Record **6 voice samples** to create your voiceprint. Speak naturally for 3-4 seconds each time.
546
- The system adds noise augmentation automatically (6 clean + 24 noisy = 30 embeddings averaged).
547
- """
548
- )
549
  with gr.Row():
550
  with gr.Column():
551
- enroll_audio = gr.Audio(
552
- label="Record Voice Sample",
553
- sources=["microphone"],
554
- type="numpy"
555
- )
556
  enroll_user_id = gr.Textbox(label="User ID (e.g., ATM_001)", placeholder="ATM_001")
557
  enroll_name = gr.Textbox(label="Full Name", placeholder="Jochebed Fafa")
558
  enroll_sample_num = gr.Number(label="Sample Number (1-6)", value=1, minimum=1, maximum=6, step=1)
559
  enroll_btn = gr.Button("Enroll Sample", variant="primary")
560
  with gr.Column():
561
  enroll_result = gr.Textbox(label="Result", lines=4, interactive=False)
 
562
 
563
- enroll_btn.click(
564
- fn=enroll_sample,
565
- inputs=[enroll_audio, enroll_user_id, enroll_name, enroll_sample_num],
566
- outputs=enroll_result
567
- )
568
-
569
- # ---- VERIFY TAB ----
570
  with gr.Tab("Verify"):
571
- gr.Markdown(
572
- """
573
- ### Verify Identity
574
- Record your voice to verify against your enrolled voiceprint.
575
- Security: 3 failed attempts = 5-minute lockout. 3-second cooldown between attempts.
576
- """
577
- )
578
  with gr.Row():
579
  with gr.Column():
580
- verify_audio = gr.Audio(
581
- label="Record Voice",
582
- sources=["microphone"],
583
- type="numpy"
584
- )
585
  verify_user_id = gr.Textbox(label="User ID", placeholder="ATM_001")
586
  verify_btn = gr.Button("Verify", variant="primary")
587
  with gr.Column():
588
  verify_result = gr.Textbox(label="Result", lines=6, interactive=False)
 
589
 
590
- verify_btn.click(
591
- fn=verify_speaker,
592
- inputs=[verify_audio, verify_user_id],
593
- outputs=verify_result
594
- )
595
-
596
- # ---- MANAGE USERS TAB ----
597
  with gr.Tab("Users"):
598
  gr.Markdown("### Manage Enrolled Users")
599
  list_btn = gr.Button("List All Users")
600
  users_output = gr.Textbox(label="Enrolled Users", lines=10, interactive=False)
601
  list_btn.click(fn=list_users, outputs=users_output)
602
-
603
  gr.Markdown("---")
604
  with gr.Row():
605
  with gr.Column():
@@ -607,125 +570,66 @@ with gr.Blocks(
607
  del_btn = gr.Button("Delete User", variant="stop")
608
  del_result = gr.Textbox(label="Result", interactive=False)
609
  del_btn.click(fn=delete_user, inputs=del_user_id, outputs=del_result)
610
-
611
  with gr.Column():
612
  reset_user_id = gr.Textbox(label="User ID to Reset Lockout", placeholder="ATM_001")
613
  reset_btn = gr.Button("Reset Lockout", variant="secondary")
614
  reset_result = gr.Textbox(label="Result", interactive=False)
615
  reset_btn.click(fn=reset_lockout, inputs=reset_user_id, outputs=reset_result)
616
 
617
- # ---- API DOCS TAB ----
618
  with gr.Tab("API Docs"):
619
- gr.Markdown(
620
- """
621
- ### REST API Endpoints for Banking Systems
622
-
623
- **Base URL:** `https://amfafa-voice-authentication-sys.hf.space`
624
-
625
- ---
626
-
627
- #### 1. Enroll a Voice Sample
628
- ```
629
- POST /api/enroll
630
- Content-Type: multipart/form-data
631
-
632
- Fields:
633
- - audio: WAV file (required)
634
- - user_id: string (required)
635
- - full_name: string (required)
636
- ```
637
-
638
- **Response:**
639
- ```json
640
- {
641
- "success": true,
642
- "message": "Sample 1/6 recorded...",
643
- "user_id": "ATM_001",
644
- "samples_collected": 1,
645
- "samples_required": 6,
646
- "enrollment_complete": false
647
- }
648
- ```
649
-
650
- ---
651
-
652
- #### 2. Verify a Speaker
653
- ```
654
- POST /api/verify
655
- Content-Type: multipart/form-data
656
-
657
- Fields:
658
- - audio: WAV file (required)
659
- - user_id: string (required)
660
- ```
661
-
662
- **Response:**
663
- ```json
664
- {
665
- "success": true,
666
- "access_granted": true,
667
- "user_id": "ATM_001",
668
- "full_name": "Jochebed Fafa",
669
- "similarity": 0.4521,
670
- "threshold": 0.35,
671
- "liveness_passed": true,
672
- "antispoofing_passed": true
673
- }
674
- ```
675
-
676
- ---
677
-
678
- #### 3. List Enrolled Users
679
- ```
680
- GET /api/users
681
- ```
682
-
683
- ---
684
-
685
- #### 4. Delete a User
686
- ```
687
- DELETE /api/users/{user_id}
688
- ```
689
-
690
- ---
691
-
692
- #### 5. Health Check
693
- ```
694
- GET /api/health
695
- ```
696
- """
697
- )
698
 
 
699
 
 
700
 
701
- # FASTAPI REST API ENDPOINTS
 
 
 
 
 
702
 
703
- app = gr.mount_gradio_app(gr.routes.App(), demo, path="/")
 
 
 
 
 
704
 
 
 
 
 
705
 
706
- # We use Gradio's underlying FastAPI app to add REST endpoints
707
- from fastapi import FastAPI, UploadFile, File, Form, HTTPException
708
- from fastapi.responses import JSONResponse
709
- from fastapi.middleware.cors import CORSMiddleware
710
 
711
- # Get the FastAPI app from Gradio
712
- app = demo.app if hasattr(demo, 'app') else None
 
 
713
 
714
- # Since Gradio 4.x manages its own FastAPI app, we create custom endpoints
715
- # by using Gradio's built-in Blocks.launch() with app_kwargs or by
716
- # adding routes after launch. For Hugging Face Spaces, we use a different approach:
717
- # We create a FastAPI app, mount Gradio on it, and add our REST routes.
 
 
718
 
719
- from fastapi import FastAPI
720
 
721
- api_app = FastAPI(
722
- title="ATM Voice Authentication API",
723
- description="Voice-Based Speaker Verification System for Banking",
724
- version="1.0.0"
725
- )
726
 
727
- # CORS allow mobile app and banking systems to connect
728
- api_app.add_middleware(
 
 
 
 
 
729
  CORSMiddleware,
730
  allow_origins=["*"],
731
  allow_credentials=True,
@@ -733,8 +637,7 @@ api_app.add_middleware(
733
  allow_headers=["*"],
734
  )
735
 
736
-
737
- @api_app.get("/api/health")
738
  async def health_check():
739
  return {
740
  "status": "healthy",
@@ -745,29 +648,19 @@ async def health_check():
745
  "timestamp": datetime.now().isoformat()
746
  }
747
 
748
-
749
- @api_app.post("/api/enroll")
750
- async def api_enroll(
751
- audio: UploadFile = File(...),
752
- user_id: str = Form(...),
753
- full_name: str = Form(...)
754
- ):
755
  try:
756
  audio_bytes = await audio.read()
757
-
758
- # Save to temp file
759
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
760
  tmp.write(audio_bytes)
761
  tmp_path = tmp.name
762
-
763
  result = enroll_sample(tmp_path, user_id, full_name, 1)
764
  os.unlink(tmp_path)
765
-
766
  db = load_db()
767
  uid = user_id.strip().upper()
768
  samples_collected = db.get(uid, {}).get("samples_collected", 0)
769
  is_complete = db.get(uid, {}).get("status") == "enrolled"
770
-
771
  return JSONResponse(content={
772
  "success": "error" not in result.lower() and "failed" not in result.lower(),
773
  "message": result,
@@ -776,93 +669,49 @@ async def api_enroll(
776
  "samples_required": NUM_CLEAN_SAMPLES,
777
  "enrollment_complete": is_complete
778
  })
779
-
780
  except Exception as e:
781
- return JSONResponse(
782
- status_code=500,
783
- content={"success": False, "message": f"Server error: {str(e)}"}
784
- )
785
-
786
 
787
- @api_app.post("/api/verify")
788
- async def api_verify(
789
- audio: UploadFile = File(...),
790
- user_id: str = Form(...)
791
- ):
792
  try:
793
  audio_bytes = await audio.read()
794
-
795
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
796
  tmp.write(audio_bytes)
797
  tmp_path = tmp.name
798
-
799
- # Run verification
800
  uid = user_id.strip().upper()
801
 
802
- # Security check first
803
  allowed, sec_msg = check_security(uid)
804
  if not allowed:
805
  os.unlink(tmp_path)
806
- return JSONResponse(content={
807
- "success": True,
808
- "access_granted": False,
809
- "user_id": uid,
810
- "message": sec_msg,
811
- "locked": True
812
- })
813
-
814
- # Check user exists
815
  db = load_db()
816
  if uid not in db:
817
  os.unlink(tmp_path)
818
- return JSONResponse(content={
819
- "success": False,
820
- "message": f"User '{uid}' not found. Please enroll first."
821
- })
822
 
823
  if db[uid].get("status") != "enrolled":
824
  os.unlink(tmp_path)
825
  samples = db[uid].get("samples_collected", 0)
826
- return JSONResponse(content={
827
- "success": False,
828
- "message": f"Enrollment incomplete. {NUM_CLEAN_SAMPLES - samples} more sample(s) needed."
829
- })
830
 
831
  wav = load_audio(tmp_path)
832
  os.unlink(tmp_path)
833
 
834
- # Liveness
835
  is_live, live_msg = check_liveness(wav)
836
  if not is_live:
837
  record_attempt(uid, False)
838
- return JSONResponse(content={
839
- "success": True,
840
- "access_granted": False,
841
- "user_id": uid,
842
- "message": live_msg,
843
- "liveness_passed": False,
844
- "antispoofing_passed": None
845
- })
846
-
847
- # Antispoofing
848
  is_real, spoof_msg = check_antispoofing(wav)
849
  if not is_real:
850
  record_attempt(uid, False)
851
- return JSONResponse(content={
852
- "success": True,
853
- "access_granted": False,
854
- "user_id": uid,
855
- "message": spoof_msg,
856
- "liveness_passed": True,
857
- "antispoofing_passed": False
858
- })
859
-
860
- # Embedding comparison
861
  test_emb = extract_embedding(wav)
862
  stored_emb = np.array(db[uid]["voiceprint"])
863
- similarity = float(np.dot(test_emb, stored_emb) / (
864
- np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10
865
- ))
866
 
867
  granted = similarity >= THRESHOLD
868
  record_attempt(uid, granted)
@@ -893,15 +742,10 @@ async def api_verify(
893
  response["message"] = f"Account locked for {LOCKOUT_MINUTES} minutes."
894
 
895
  return JSONResponse(content=response)
896
-
897
  except Exception as e:
898
- return JSONResponse(
899
- status_code=500,
900
- content={"success": False, "message": f"Server error: {str(e)}"}
901
- )
902
 
903
-
904
- @api_app.get("/api/users")
905
  async def api_list_users():
906
  db = load_db()
907
  users = []
@@ -916,29 +760,20 @@ async def api_list_users():
916
  })
917
  return JSONResponse(content={"success": True, "users": users, "total": len(users)})
918
 
919
-
920
- @api_app.delete("/api/users/{user_id}")
921
  async def api_delete_user(user_id: str):
922
  result = delete_user(user_id)
923
  success = "error" not in result.lower()
924
  return JSONResponse(content={"success": success, "message": result})
925
 
926
-
927
- @api_app.post("/api/reset-lockout")
928
  async def api_reset_lockout(user_id: str = Form(...)):
929
  result = reset_lockout(user_id)
930
  return JSONResponse(content={"success": True, "message": result})
931
 
932
 
933
 
934
- # MOUNT GRADIO ON FASTAPI
935
-
936
- app = gr.mount_gradio_app(api_app, demo, path="/")
937
-
938
-
939
-
940
  # LAUNCH
941
 
942
  if __name__ == "__main__":
943
- import uvicorn
944
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
 
2
  import json
3
  import math
4
  import time
 
9
  import numpy as np
10
  import random
11
  import tempfile
 
12
  import gradio as gr
13
  from datetime import datetime, timedelta
14
 
 
 
 
15
  if not hasattr(torchaudio, 'list_audio_backends'):
16
  torchaudio.list_audio_backends = lambda: ["soundfile"]
17
 
 
34
  LOCKOUT_MINUTES = 5
35
  COOLDOWN_SECONDS = 3
36
  ANTISPOOFING_THRESHOLD = 0.02
 
37
 
38
 
39
 
 
91
 
92
  print("Loading AAM-Softmax checkpoint...")
93
  ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
94
+
95
+ # Auto-detect checkpoint format
96
+ print(f"Checkpoint type: {type(ckpt)}")
97
+ if isinstance(ckpt, dict):
98
+ print(f"Checkpoint keys: {list(ckpt.keys())}")
99
+
100
+ # Detect num_classes from checkpoint
101
+ num_classes = 227
102
+ if isinstance(ckpt, dict):
103
+ if 'num_classes' in ckpt:
104
+ num_classes = ckpt['num_classes']
105
+ elif 'num_speakers' in ckpt:
106
+ num_classes = ckpt['num_speakers']
107
+
108
+ # Build classifier
109
  classifier = SpeakerClassifier(input_dim=768, hidden_dim=512, num_classes=num_classes).to(DEVICE)
110
+
111
+ # Load weights - try every possible key format
112
+ loaded = False
113
+ if isinstance(ckpt, dict):
114
+ # Try common key names for classifier state
115
+ for key in ['classifier_state', 'classifier_state_dict', 'model_state_dict', 'state_dict', 'model']:
116
+ if key in ckpt:
117
+ try:
118
+ classifier.load_state_dict(ckpt[key])
119
+ print(f"Loaded classifier from key: '{key}'")
120
+ loaded = True
121
+ break
122
+ except Exception as e:
123
+ print(f"Key '{key}' found but failed to load: {e}")
124
+
125
+ # If no named key worked, try loading the dict directly (maybe ckpt IS the state_dict)
126
+ if not loaded:
127
+ # Check if the keys look like model parameters (contain dots like 'fc1.weight')
128
+ sample_keys = list(ckpt.keys())[:5]
129
+ looks_like_state_dict = any('.' in k for k in sample_keys)
130
+ if looks_like_state_dict:
131
+ try:
132
+ classifier.load_state_dict(ckpt)
133
+ print("Loaded classifier directly from checkpoint dict (it IS the state_dict)")
134
+ loaded = True
135
+ except Exception as e:
136
+ print(f"Direct load failed: {e}")
137
+ # Try with strict=False
138
+ try:
139
+ classifier.load_state_dict(ckpt, strict=False)
140
+ print("Loaded classifier with strict=False")
141
+ loaded = True
142
+ except Exception as e2:
143
+ print(f"Strict=False also failed: {e2}")
144
+
145
+ # Try loading base_model state too if present
146
+ if 'base_model_state' in ckpt:
147
+ try:
148
+ base_model.load_state_dict(ckpt['base_model_state'], strict=False)
149
+ print("Also loaded fine-tuned base model weights")
150
+ except Exception as e:
151
+ print(f"Base model load skipped: {e}")
152
+
153
+ elif isinstance(ckpt, nn.Module):
154
+ # Checkpoint is the model itself
155
+ classifier = ckpt.to(DEVICE)
156
+ print("Loaded classifier directly (checkpoint is model object)")
157
+ loaded = True
158
+
159
+ if not loaded:
160
+ print("WARNING: Could not load classifier weights. Using random initialization.")
161
+ print("The system will still run but verification accuracy will be poor.")
162
+
163
  classifier.eval()
164
+ print(f"Models ready. num_classes={num_classes}, loaded={loaded}")
165
 
166
 
167
 
168
+ # DATABASE
169
 
170
  def load_db():
171
  if os.path.exists(DB_PATH):
 
173
  return json.load(f)
174
  return {}
175
 
 
176
  def save_db(db):
177
  with open(DB_PATH, 'w') as f:
178
  json.dump(db, f, indent=2, default=str)
 
182
  # AUDIO PROCESSING
183
 
184
  def load_audio(audio_input):
 
185
  if isinstance(audio_input, tuple):
186
  sr, audio_np = audio_input
187
  wav = torch.tensor(audio_np, dtype=torch.float32)
 
190
  if wav.shape[0] > 1:
191
  wav = wav.mean(dim=0, keepdim=True)
192
  wav = wav.squeeze(0)
 
193
  if wav.abs().max() > 1.0:
194
  wav = wav / 32768.0
195
  if sr != SAMPLE_RATE:
 
215
  else:
216
  raise ValueError(f"Unsupported audio input type: {type(audio_input)}")
217
 
 
218
  if wav.shape[0] > MAX_LEN:
219
  wav = wav[:MAX_LEN]
220
  elif wav.shape[0] < MAX_LEN:
221
  wav = F.pad(wav, (0, MAX_LEN - wav.shape[0]))
 
222
  return wav
223
 
 
224
  def extract_embedding(wav_tensor):
 
225
  with torch.no_grad():
226
  wav = wav_tensor.unsqueeze(0).to(DEVICE)
227
  outputs = base_model(wav)
 
230
  embedding = F.normalize(embedding, p=2, dim=1)
231
  return embedding.squeeze(0).cpu().numpy()
232
 
 
233
  def add_noise(wav_tensor, noise_level=0.005):
 
234
  noise = torch.randn_like(wav_tensor) * noise_level
235
  return wav_tensor + noise
236
 
237
 
 
238
  # LIVENESS DETECTION
239
 
240
  def check_liveness(wav_tensor):
 
241
  wav_np = wav_tensor.numpy()
 
 
242
  rms = np.sqrt(np.mean(wav_np ** 2))
243
  if rms < 0.001:
244
  return False, "Audio too quiet — possible silence or empty recording"
 
 
245
  std = np.std(wav_np)
246
  if std < 0.001:
247
  return False, "Audio lacks variation — possible synthetic tone"
 
 
248
  zero_crossings = np.sum(np.abs(np.diff(np.sign(wav_np)))) / (2 * len(wav_np))
249
  if zero_crossings < 0.01:
250
  return False, "Abnormal audio pattern — possible replay attack"
 
 
251
  non_silent = np.abs(wav_np) > 0.01
252
  speech_ratio = np.sum(non_silent) / len(wav_np)
253
  if speech_ratio < 0.1:
254
  return False, "Insufficient speech content detected"
 
255
  return True, "Liveness check passed"
256
 
257
 
258
 
259
+ # ANTISPOOFING
260
 
261
  def check_antispoofing(wav_tensor):
 
262
  wav_np = wav_tensor.numpy()
 
 
263
  fft = np.fft.rfft(wav_np)
264
  magnitude = np.abs(fft)
265
  magnitude = magnitude[magnitude > 0]
 
266
  if len(magnitude) == 0:
267
  return False, "No frequency content detected"
 
268
  geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10)))
269
  arithmetic_mean = np.mean(magnitude)
270
  spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10)
 
271
  if spectral_flatness > (1.0 - ANTISPOOFING_THRESHOLD):
272
  return False, f"Spectral flatness too high ({spectral_flatness:.4f}) — possible synthetic audio"
273
+ frame_size = 1600
 
 
274
  if len(wav_np) >= frame_size * 3:
275
  frames = [wav_np[i:i + frame_size] for i in range(0, len(wav_np) - frame_size, frame_size)]
276
  frame_energies = [np.sqrt(np.mean(f ** 2)) for f in frames]
277
  energy_std = np.std(frame_energies)
278
  if energy_std < 0.001:
279
  return False, "Unnaturally uniform energy — possible synthetic audio"
 
280
  return True, "Antispoofing check passed"
281
 
282
 
 
283
 
284
+ # SECURITY: LOCKOUT & COOLDOWN
285
 
286
+ attempt_tracker = {}
287
 
288
  def check_security(user_id):
 
289
  now = datetime.now()
 
290
  if user_id not in attempt_tracker:
291
  return True, "OK"
 
292
  tracker = attempt_tracker[user_id]
 
 
293
  if "locked_until" in tracker and tracker["locked_until"]:
294
  locked_until = datetime.fromisoformat(tracker["locked_until"])
295
  if now < locked_until:
296
  remaining = (locked_until - now).seconds
297
  return False, f"Account locked. Try again in {remaining} seconds."
298
  else:
 
299
  tracker["count"] = 0
300
  tracker["locked_until"] = None
 
 
301
  if "last_attempt" in tracker and tracker["last_attempt"]:
302
  last = datetime.fromisoformat(tracker["last_attempt"])
303
  elapsed = (now - last).total_seconds()
304
  if elapsed < COOLDOWN_SECONDS:
305
  return False, f"Please wait {COOLDOWN_SECONDS - int(elapsed)} seconds before trying again."
 
306
  return True, "OK"
307
 
 
308
  def record_attempt(user_id, success):
 
309
  now = datetime.now()
 
310
  if user_id not in attempt_tracker:
311
  attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
 
312
  tracker = attempt_tracker[user_id]
313
  tracker["last_attempt"] = now.isoformat()
 
314
  if success:
315
  tracker["count"] = 0
316
  tracker["locked_until"] = None
 
321
 
322
 
323
 
324
+ # ENROLL
325
 
326
  def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=NUM_CLEAN_SAMPLES):
 
327
  if not user_id or not user_id.strip():
328
  return "Error: User ID is required."
329
  if not full_name or not full_name.strip():
 
337
  try:
338
  wav = load_audio(audio_input)
339
 
 
340
  is_live, live_msg = check_liveness(wav)
341
  if not is_live:
342
  return f"Enrollment failed: {live_msg}"
343
 
 
344
  is_real, spoof_msg = check_antispoofing(wav)
345
  if not is_real:
346
  return f"Enrollment failed: {spoof_msg}"
347
 
 
348
  clean_emb = extract_embedding(wav)
349
 
 
350
  noisy_embeddings = []
351
  for i in range(NUM_NOISY_COPIES):
352
  noise_level = 0.003 + (i * 0.002)
 
354
  noisy_emb = extract_embedding(noisy_wav)
355
  noisy_embeddings.append(noisy_emb)
356
 
 
357
  db = load_db()
358
 
359
  if user_id not in db:
 
366
  "samples_collected": 0
367
  }
368
 
 
369
  sample_data = {
370
  "clean": clean_emb.tolist(),
371
  "noisy": [e.tolist() for e in noisy_embeddings]
 
377
  samples_collected = db[user_id]["samples_collected"]
378
 
379
  if samples_collected >= total_samples:
 
380
  all_embeddings = []
381
  for sample in db[user_id]["sample_embeddings"]:
382
  all_embeddings.append(np.array(sample["clean"]))
 
389
  db[user_id]["voiceprint"] = avg_embedding.tolist()
390
  db[user_id]["status"] = "enrolled"
391
  db[user_id]["completed_at"] = datetime.now().isoformat()
 
392
  db[user_id]["sample_embeddings"] = []
393
 
394
  save_db(db)
 
402
  return f"Enrollment error: {str(e)}"
403
 
404
 
405
+ # VERIFY
406
+
407
  def verify_speaker(audio_input, user_id):
 
408
  if not user_id or not user_id.strip():
409
  return "Error: User ID is required."
410
  if audio_input is None:
 
412
 
413
  user_id = user_id.strip().upper()
414
 
 
415
  allowed, sec_msg = check_security(user_id)
416
  if not allowed:
417
  return f"ACCESS DENIED: {sec_msg}"
418
 
 
419
  db = load_db()
420
  if user_id not in db:
421
  return f"Error: User '{user_id}' not found. Please enroll first."
 
428
  try:
429
  wav = load_audio(audio_input)
430
 
 
431
  is_live, live_msg = check_liveness(wav)
432
  if not is_live:
433
  record_attempt(user_id, False)
434
  return f"ACCESS DENIED: {live_msg}"
435
 
 
436
  is_real, spoof_msg = check_antispoofing(wav)
437
  if not is_real:
438
  record_attempt(user_id, False)
439
  return f"ACCESS DENIED: {spoof_msg}"
440
 
 
441
  test_emb = extract_embedding(wav)
442
  stored_emb = np.array(db[user_id]["voiceprint"])
443
 
 
473
  return f"Verification error: {str(e)}"
474
 
475
 
476
+
477
+ # USER MANAGEMENT
478
+
479
  def list_users():
 
480
  db = load_db()
481
  if not db:
482
  return "No users enrolled yet."
 
483
  lines = ["=== Enrolled Users ===\n"]
484
  for uid, data in db.items():
485
  name = data.get("full_name", "Unknown")
 
489
  lines.append(f"ID: {uid} | Name: {name} | Status: {status} | Samples: {samples} | Enrolled: {enrolled}")
490
  return "\n".join(lines)
491
 
 
492
  def delete_user(user_id):
 
493
  if not user_id or not user_id.strip():
494
  return "Error: User ID is required."
 
495
  user_id = user_id.strip().upper()
496
  db = load_db()
 
497
  if user_id not in db:
498
  return f"Error: User '{user_id}' not found."
 
499
  name = db[user_id].get("full_name", user_id)
500
  del db[user_id]
501
  save_db(db)
 
 
502
  if user_id in attempt_tracker:
503
  del attempt_tracker[user_id]
 
504
  return f"User '{name}' ({user_id}) deleted successfully."
505
 
 
506
  def reset_lockout(user_id):
 
507
  if not user_id or not user_id.strip():
508
  return "Error: User ID is required."
 
509
  user_id = user_id.strip().upper()
510
  if user_id in attempt_tracker:
511
  attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
 
516
 
517
  # GRADIO INTERFACE
518
 
519
+ with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft()) as demo:
 
 
 
520
 
521
+ gr.Markdown("""
522
+ # ATM Voice Authentication System
523
+ ### Voice-Based Speaker Verification for Banking Security
524
+ **Model:** UniSpeech-SAT + AAM-Softmax | **EER:** 3.94% | **Speakers Trained:** 227 Akan speakers
525
+ """)
 
 
526
 
527
  with gr.Tabs():
528
 
 
529
  with gr.Tab("Enroll"):
530
+ gr.Markdown("""
531
+ ### Enroll New User
532
+ Record **6 voice samples** to create your voiceprint. Speak naturally for 3-4 seconds each time.
533
+ The system adds noise augmentation automatically (6 clean + 24 noisy = 30 embeddings averaged).
534
+ """)
 
 
535
  with gr.Row():
536
  with gr.Column():
537
+ enroll_audio = gr.Audio(label="Record Voice Sample", sources=["microphone", "upload"], type="numpy")
 
 
 
 
538
  enroll_user_id = gr.Textbox(label="User ID (e.g., ATM_001)", placeholder="ATM_001")
539
  enroll_name = gr.Textbox(label="Full Name", placeholder="Jochebed Fafa")
540
  enroll_sample_num = gr.Number(label="Sample Number (1-6)", value=1, minimum=1, maximum=6, step=1)
541
  enroll_btn = gr.Button("Enroll Sample", variant="primary")
542
  with gr.Column():
543
  enroll_result = gr.Textbox(label="Result", lines=4, interactive=False)
544
+ enroll_btn.click(fn=enroll_sample, inputs=[enroll_audio, enroll_user_id, enroll_name, enroll_sample_num], outputs=enroll_result)
545
 
 
 
 
 
 
 
 
546
  with gr.Tab("Verify"):
547
+ gr.Markdown("""
548
+ ### Verify Identity
549
+ Record your voice to verify against your enrolled voiceprint.
550
+ Security: 3 failed attempts = 5-minute lockout. 3-second cooldown between attempts.
551
+ """)
 
 
552
  with gr.Row():
553
  with gr.Column():
554
+ verify_audio = gr.Audio(label="Record Voice", sources=["microphone", "upload"], type="numpy")
 
 
 
 
555
  verify_user_id = gr.Textbox(label="User ID", placeholder="ATM_001")
556
  verify_btn = gr.Button("Verify", variant="primary")
557
  with gr.Column():
558
  verify_result = gr.Textbox(label="Result", lines=6, interactive=False)
559
+ verify_btn.click(fn=verify_speaker, inputs=[verify_audio, verify_user_id], outputs=verify_result)
560
 
 
 
 
 
 
 
 
561
  with gr.Tab("Users"):
562
  gr.Markdown("### Manage Enrolled Users")
563
  list_btn = gr.Button("List All Users")
564
  users_output = gr.Textbox(label="Enrolled Users", lines=10, interactive=False)
565
  list_btn.click(fn=list_users, outputs=users_output)
 
566
  gr.Markdown("---")
567
  with gr.Row():
568
  with gr.Column():
 
570
  del_btn = gr.Button("Delete User", variant="stop")
571
  del_result = gr.Textbox(label="Result", interactive=False)
572
  del_btn.click(fn=delete_user, inputs=del_user_id, outputs=del_result)
 
573
  with gr.Column():
574
  reset_user_id = gr.Textbox(label="User ID to Reset Lockout", placeholder="ATM_001")
575
  reset_btn = gr.Button("Reset Lockout", variant="secondary")
576
  reset_result = gr.Textbox(label="Result", interactive=False)
577
  reset_btn.click(fn=reset_lockout, inputs=reset_user_id, outputs=reset_result)
578
 
 
579
  with gr.Tab("API Docs"):
580
+ gr.Markdown("""
581
+ ### REST API Endpoints for Banking Systems
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
 
583
+ **Base URL:** `https://amfafa-voice-authentication-sys.hf.space`
584
 
585
+ ---
586
 
587
+ #### 1. Enroll a Voice Sample
588
+ ```
589
+ POST /api/enroll
590
+ Content-Type: multipart/form-data
591
+ Fields: audio (WAV file), user_id (string), full_name (string)
592
+ ```
593
 
594
+ #### 2. Verify a Speaker
595
+ ```
596
+ POST /api/verify
597
+ Content-Type: multipart/form-data
598
+ Fields: audio (WAV file), user_id (string)
599
+ ```
600
 
601
+ #### 3. List Enrolled Users
602
+ ```
603
+ GET /api/users
604
+ ```
605
 
606
+ #### 4. Delete a User
607
+ ```
608
+ DELETE /api/users/{user_id}
609
+ ```
610
 
611
+ #### 5. Health Check
612
+ ```
613
+ GET /api/health
614
+ ```
615
 
616
+ #### 6. Reset Lockout
617
+ ```
618
+ POST /api/reset-lockout
619
+ Field: user_id (string)
620
+ ```
621
+ """)
622
 
 
623
 
624
+ # REST API ENDPOINTS
 
 
 
 
625
 
626
+ from fastapi import UploadFile, File, Form
627
+ from fastapi.responses import JSONResponse
628
+ from fastapi.middleware.cors import CORSMiddleware
629
+
630
+ fastapi_app = demo.app
631
+
632
+ fastapi_app.add_middleware(
633
  CORSMiddleware,
634
  allow_origins=["*"],
635
  allow_credentials=True,
 
637
  allow_headers=["*"],
638
  )
639
 
640
+ @fastapi_app.get("/api/health")
 
641
  async def health_check():
642
  return {
643
  "status": "healthy",
 
648
  "timestamp": datetime.now().isoformat()
649
  }
650
 
651
+ @fastapi_app.post("/api/enroll")
652
+ async def api_enroll(audio: UploadFile = File(...), user_id: str = Form(...), full_name: str = Form(...)):
 
 
 
 
 
653
  try:
654
  audio_bytes = await audio.read()
 
 
655
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
656
  tmp.write(audio_bytes)
657
  tmp_path = tmp.name
 
658
  result = enroll_sample(tmp_path, user_id, full_name, 1)
659
  os.unlink(tmp_path)
 
660
  db = load_db()
661
  uid = user_id.strip().upper()
662
  samples_collected = db.get(uid, {}).get("samples_collected", 0)
663
  is_complete = db.get(uid, {}).get("status") == "enrolled"
 
664
  return JSONResponse(content={
665
  "success": "error" not in result.lower() and "failed" not in result.lower(),
666
  "message": result,
 
669
  "samples_required": NUM_CLEAN_SAMPLES,
670
  "enrollment_complete": is_complete
671
  })
 
672
  except Exception as e:
673
+ return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
 
 
 
 
674
 
675
+ @fastapi_app.post("/api/verify")
676
+ async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
 
 
 
677
  try:
678
  audio_bytes = await audio.read()
 
679
  with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
680
  tmp.write(audio_bytes)
681
  tmp_path = tmp.name
 
 
682
  uid = user_id.strip().upper()
683
 
 
684
  allowed, sec_msg = check_security(uid)
685
  if not allowed:
686
  os.unlink(tmp_path)
687
+ return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": sec_msg, "locked": True})
688
+
 
 
 
 
 
 
 
689
  db = load_db()
690
  if uid not in db:
691
  os.unlink(tmp_path)
692
+ return JSONResponse(content={"success": False, "message": f"User '{uid}' not found. Please enroll first."})
 
 
 
693
 
694
  if db[uid].get("status") != "enrolled":
695
  os.unlink(tmp_path)
696
  samples = db[uid].get("samples_collected", 0)
697
+ return JSONResponse(content={"success": False, "message": f"Enrollment incomplete. {NUM_CLEAN_SAMPLES - samples} more sample(s) needed."})
 
 
 
698
 
699
  wav = load_audio(tmp_path)
700
  os.unlink(tmp_path)
701
 
 
702
  is_live, live_msg = check_liveness(wav)
703
  if not is_live:
704
  record_attempt(uid, False)
705
+ return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": live_msg, "liveness_passed": False, "antispoofing_passed": None})
706
+
 
 
 
 
 
 
 
 
707
  is_real, spoof_msg = check_antispoofing(wav)
708
  if not is_real:
709
  record_attempt(uid, False)
710
+ return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": spoof_msg, "liveness_passed": True, "antispoofing_passed": False})
711
+
 
 
 
 
 
 
 
 
712
  test_emb = extract_embedding(wav)
713
  stored_emb = np.array(db[uid]["voiceprint"])
714
+ similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
 
 
715
 
716
  granted = similarity >= THRESHOLD
717
  record_attempt(uid, granted)
 
742
  response["message"] = f"Account locked for {LOCKOUT_MINUTES} minutes."
743
 
744
  return JSONResponse(content=response)
 
745
  except Exception as e:
746
+ return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})
 
 
 
747
 
748
+ @fastapi_app.get("/api/users")
 
749
  async def api_list_users():
750
  db = load_db()
751
  users = []
 
760
  })
761
  return JSONResponse(content={"success": True, "users": users, "total": len(users)})
762
 
763
+ @fastapi_app.delete("/api/users/{user_id}")
 
764
  async def api_delete_user(user_id: str):
765
  result = delete_user(user_id)
766
  success = "error" not in result.lower()
767
  return JSONResponse(content={"success": success, "message": result})
768
 
769
+ @fastapi_app.post("/api/reset-lockout")
 
770
  async def api_reset_lockout(user_id: str = Form(...)):
771
  result = reset_lockout(user_id)
772
  return JSONResponse(content={"success": True, "message": result})
773
 
774
 
775
 
 
 
 
 
 
 
776
  # LAUNCH
777
 
778
  if __name__ == "__main__":
779
+ demo.launch(server_name="0.0.0.0", server_port=7860)