File size: 35,837 Bytes
3e78896
 
 
9e8921c
a8a3da6
3e78896
 
 
 
 
 
9e8921c
3e78896
2dab870
 
 
9e8921c
3e78896
 
 
 
 
 
a8a3da6
3e78896
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8a3da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e78896
 
 
 
 
 
 
 
 
 
 
 
9e8921c
3e78896
 
 
9e8921c
 
3e78896
 
 
9e8921c
3e78896
9e8921c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8a3da6
9e8921c
3e78896
 
9e8921c
 
 
 
 
7bd4461
 
 
 
 
 
 
 
 
 
 
9e8921c
7bd4461
 
 
 
 
 
 
 
 
 
 
a8a3da6
7bd4461
 
 
a8a3da6
7bd4461
 
a8a3da6
7bd4461
a8a3da6
7bd4461
 
 
 
 
a8a3da6
7bd4461
 
 
 
a8a3da6
 
 
7bd4461
 
a8a3da6
7bd4461
 
 
a8a3da6
7bd4461
9e8921c
7bd4461
3e78896
 
a8a3da6
3e78896
 
 
 
9e8921c
3e78896
9e8921c
 
 
 
 
a8a3da6
9e8921c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e78896
 
 
 
9e8921c
3e78896
 
9e8921c
3e78896
9e8921c
3e78896
 
9e8921c
 
 
3e78896
9e8921c
 
 
3e78896
 
a8a3da6
9e8921c
 
 
 
a8a3da6
9e8921c
 
a8a3da6
9e8921c
 
a8a3da6
9e8921c
 
 
a8a3da6
9e8921c
3e78896
 
a8a3da6
9e8921c
 
 
 
 
 
a8a3da6
9e8921c
 
 
 
a8a3da6
7bd4461
9e8921c
 
 
 
 
a8a3da6
9e8921c
 
 
a8a3da6
7bd4461
9e8921c
 
3e78896
9e8921c
 
 
 
 
 
 
 
3e78896
9e8921c
 
 
 
 
 
a8a3da6
9e8921c
 
 
 
 
 
 
 
 
 
 
3e78896
9e8921c
 
 
 
 
a8a3da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e8921c
 
a8a3da6
9e8921c
 
 
 
 
 
a8a3da6
9e8921c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e78896
9e8921c
 
 
 
 
 
3e78896
a8a3da6
9e8921c
 
 
 
a8a3da6
9e8921c
 
 
 
 
 
 
 
a8a3da6
9e8921c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8a3da6
9e8921c
 
 
 
a8a3da6
 
 
9e8921c
 
 
 
a8a3da6
9e8921c
 
 
 
 
 
 
3e78896
 
a8a3da6
3e78896
9e8921c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8a3da6
9e8921c
 
 
 
 
 
 
 
a8a3da6
9e8921c
 
2dab870
 
9e8921c
2dab870
9e8921c
 
 
 
 
 
 
a8a3da6
2dab870
9e8921c
2dab870
9e8921c
a8a3da6
2dab870
7bd4461
9e8921c
 
 
 
 
 
 
 
 
 
 
2dab870
9e8921c
7bd4461
9e8921c
a8a3da6
2dab870
7bd4461
9e8921c
 
 
 
 
 
 
 
 
7bd4461
9e8921c
 
 
a8a3da6
9e8921c
 
a8a3da6
9e8921c
 
 
 
 
a8a3da6
9e8921c
 
 
a8a3da6
9e8921c
 
7bd4461
9e8921c
 
 
a8a3da6
2dab870
9e8921c
 
a8a3da6
 
9e8921c
a8a3da6
9e8921c
 
7bd4461
9e8921c
a8a3da6
2dab870
9e8921c
3e78896
9e8921c
 
2dab870
9e8921c
 
a8a3da6
2dab870
9e8921c
 
 
 
 
a8a3da6
2dab870
9e8921c
 
 
 
2dab870
 
a8a3da6
 
 
 
 
 
 
 
 
 
 
2dab870
 
 
 
a8a3da6
 
 
2dab870
a8a3da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dab870
a8a3da6
 
 
 
 
 
 
2dab870
a8a3da6
 
 
2dab870
 
a8a3da6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2dab870
a8a3da6
2dab870
a8a3da6
 
9e8921c
2dab870
 
a8a3da6
 
 
 
 
 
 
 
2dab870
a8a3da6
2dab870
 
a8a3da6
 
 
 
2dab870
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8a3da6
2dab870
 
a8a3da6
 
9e8921c
2dab870
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
import os
import json
import math
import time
import uuid
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import random
import tempfile
import gradio as gr
from fastapi import FastAPI, UploadFile, File, Form
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from datetime import datetime, timedelta

if not hasattr(torchaudio, 'list_audio_backends'):
    torchaudio.list_audio_backends = lambda: ["soundfile"]

from transformers import AutoModel

# Config
CKPT_PATH = 'aam_best.pt'
DB_PATH = 'voiceprint_db.json'
MODEL_NAME = 'microsoft/unispeech-sat-base-sv'
SAMPLE_RATE = 16000
MAX_SEC = 4
MAX_LEN = SAMPLE_RATE * MAX_SEC
THRESHOLD = 0.3500
DEVICE = torch.device('cpu')
NUM_CLEAN_SAMPLES = 6
NUM_NOISY_COPIES = 4
MAX_ATTEMPTS = 3
LOCKOUT_MINUTES = 5
COOLDOWN_SECONDS = 3
ANTISPOOFING_THRESHOLD = 0.02

# Challenge word pool (simple, short, easy to pronounce)
CHALLENGE_WORDS = [
    'Red', 'Blue', 'Gold', 'Star', 'Water',
    'Moon', 'Fire', 'Green', 'Black', 'White',
    'Sun', 'Rain', 'Tree', 'Fish', 'Bird',
    'Stone', 'Wind', 'Cloud', 'Light', 'Sound'
]

# Session steps
SESSION_STEPS = {
    'STARTED': 'started',
    'VERIFIED': 'verified',
    'LIVENESS_PENDING': 'liveness_pending',
    'AUTHENTICATED': 'authenticated',
    'TRANSACTION_PENDING': 'transaction_pending',
    'COMPLETE': 'complete',
    'DENIED': 'denied'
}


# AAM-Softmax model
class AAMSoftmax(nn.Module):
    def __init__(self, in_features, num_classes, margin=0.2, scale=30.0):
        super().__init__()
        self.margin = margin
        self.scale = scale
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.cos_m = math.cos(margin)
        self.sin_m = math.sin(margin)
        self.threshold = math.cos(math.pi - margin)
        self.mm = math.sin(math.pi - margin) * margin

    def forward(self, embeddings, labels=None):
        embeddings = F.normalize(embeddings, p=2, dim=1)
        weight = F.normalize(self.weight, p=2, dim=1)
        cosine = F.linear(embeddings, weight)
        if labels is None:
            return cosine
        sine = torch.sqrt(1.0 - torch.clamp(cosine * cosine, 0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where(cosine > self.threshold, phi, cosine - self.mm)
        one_hot = F.one_hot(labels, cosine.size(1)).float()
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        return output * self.scale


class SpeakerClassifier(nn.Module):
    def __init__(self, input_dim=768, hidden_dim=512, num_classes=227):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.aam = AAMSoftmax(hidden_dim, num_classes)

    def forward(self, x, labels=None):
        x = self.relu(self.fc1(x))
        return self.aam(x, labels)

    def extract_embedding(self, x):
        return self.relu(self.fc1(x))


# Load models
print("Loading UniSpeech-SAT base model...")
base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)
base_model.eval()
for param in base_model.parameters():
    param.requires_grad = False

print("Loading AAM-Softmax checkpoint...")
ckpt = torch.load(CKPT_PATH, map_location=DEVICE)
print(f"Checkpoint type: {type(ckpt)}")
if isinstance(ckpt, dict):
    print(f"Checkpoint keys: {list(ckpt.keys())}")

num_classes = 227
if isinstance(ckpt, dict):
    if 'num_classes' in ckpt:
        num_classes = ckpt['num_classes']
    elif 'num_speakers' in ckpt:
        num_classes = ckpt['num_speakers']

classifier = SpeakerClassifier(input_dim=768, hidden_dim=512, num_classes=num_classes).to(DEVICE)

loaded = False
if isinstance(ckpt, dict):
    for key in ['classifier_state', 'classifier_state_dict', 'model_state_dict', 'state_dict', 'model']:
        if key in ckpt:
            try:
                classifier.load_state_dict(ckpt[key])
                print(f"Loaded classifier from key: '{key}'")
                loaded = True
                break
            except Exception as e:
                print(f"Key '{key}' found but failed: {e}")

    if not loaded:
        sample_keys = list(ckpt.keys())[:5]
        if any('.' in k for k in sample_keys):
            try:
                classifier.load_state_dict(ckpt)
                print("Loaded classifier directly from checkpoint dict")
                loaded = True
            except:
                try:
                    classifier.load_state_dict(ckpt, strict=False)
                    print("Loaded classifier with strict=False")
                    loaded = True
                except Exception as e2:
                    print(f"Direct load failed: {e2}")

    if 'base_model_state' in ckpt:
        try:
            base_model.load_state_dict(ckpt['base_model_state'], strict=False)
            print("Loaded fine-tuned base model weights")
        except:
            pass
elif isinstance(ckpt, nn.Module):
    classifier = ckpt.to(DEVICE)
    print("Loaded classifier directly (model object)")
    loaded = True

if not loaded:
    print("WARNING: Could not load classifier weights. Using random init.")

classifier.eval()
print(f"Models ready. num_classes={num_classes}, loaded={loaded}")


# Database
def load_db():
    if os.path.exists(DB_PATH):
        with open(DB_PATH, 'r') as f:
            return json.load(f)
    return {}

def save_db(db):
    with open(DB_PATH, 'w') as f:
        json.dump(db, f, indent=2, default=str)


# Audio processing
def load_audio(audio_input):
    if isinstance(audio_input, tuple):
        sr, audio_np = audio_input
        wav = torch.tensor(audio_np, dtype=torch.float32)
        if wav.dim() == 1:
            wav = wav.unsqueeze(0)
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        wav = wav.squeeze(0)
        if wav.abs().max() > 1.0:
            wav = wav / 32768.0
        if sr != SAMPLE_RATE:
            wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
    elif isinstance(audio_input, str):
        wav, sr = torchaudio.load(audio_input)
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        wav = wav.squeeze(0)
        if sr != SAMPLE_RATE:
            wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
    elif isinstance(audio_input, bytes):
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
            tmp.write(audio_input)
            tmp_path = tmp.name
        wav, sr = torchaudio.load(tmp_path)
        os.unlink(tmp_path)
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        wav = wav.squeeze(0)
        if sr != SAMPLE_RATE:
            wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)
    else:
        raise ValueError(f"Unsupported audio input type: {type(audio_input)}")

    if wav.shape[0] > MAX_LEN:
        wav = wav[:MAX_LEN]
    elif wav.shape[0] < MAX_LEN:
        wav = F.pad(wav, (0, MAX_LEN - wav.shape[0]))
    return wav

def extract_embedding(wav_tensor):
    with torch.no_grad():
        wav = wav_tensor.unsqueeze(0).to(DEVICE)
        outputs = base_model(wav)
        base_emb = outputs.last_hidden_state.mean(dim=1)
        embedding = classifier.extract_embedding(base_emb)
        embedding = F.normalize(embedding, p=2, dim=1)
    return embedding.squeeze(0).cpu().numpy()

def add_noise(wav_tensor, noise_level=0.005):
    noise = torch.randn_like(wav_tensor) * noise_level
    return wav_tensor + noise


# Liveness detection
def check_liveness(wav_tensor):
    wav_np = wav_tensor.numpy()
    rms = np.sqrt(np.mean(wav_np ** 2))
    if rms < 0.001:
        return False, "Audio too quiet"
    std = np.std(wav_np)
    if std < 0.001:
        return False, "Audio lacks variation"
    zero_crossings = np.sum(np.abs(np.diff(np.sign(wav_np)))) / (2 * len(wav_np))
    if zero_crossings < 0.01:
        return False, "Abnormal audio pattern"
    non_silent = np.abs(wav_np) > 0.01
    speech_ratio = np.sum(non_silent) / len(wav_np)
    if speech_ratio < 0.1:
        return False, "Insufficient speech content"
    return True, "Liveness check passed"


# Antispoofing
def check_antispoofing(wav_tensor):
    wav_np = wav_tensor.numpy()
    fft = np.fft.rfft(wav_np)
    magnitude = np.abs(fft)
    magnitude = magnitude[magnitude > 0]
    if len(magnitude) == 0:
        return False, "No frequency content"
    geometric_mean = np.exp(np.mean(np.log(magnitude + 1e-10)))
    arithmetic_mean = np.mean(magnitude)
    spectral_flatness = geometric_mean / (arithmetic_mean + 1e-10)
    if spectral_flatness > (1.0 - ANTISPOOFING_THRESHOLD):
        return False, "Possible synthetic audio"
    frame_size = 1600
    if len(wav_np) >= frame_size * 3:
        frames = [wav_np[i:i + frame_size] for i in range(0, len(wav_np) - frame_size, frame_size)]
        frame_energies = [np.sqrt(np.mean(f ** 2)) for f in frames]
        energy_std = np.std(frame_energies)
        if energy_std < 0.001:
            return False, "Unnaturally uniform energy"
    return True, "Antispoofing check passed"


# Security: lockout and cooldown
attempt_tracker = {}

def check_security(user_id):
    now = datetime.now()
    if user_id not in attempt_tracker:
        return True, "OK"
    tracker = attempt_tracker[user_id]
    if "locked_until" in tracker and tracker["locked_until"]:
        locked_until = datetime.fromisoformat(tracker["locked_until"])
        if now < locked_until:
            remaining = (locked_until - now).seconds
            return False, f"Account locked. Try again in {remaining} seconds."
        else:
            tracker["count"] = 0
            tracker["locked_until"] = None
    if "last_attempt" in tracker and tracker["last_attempt"]:
        last = datetime.fromisoformat(tracker["last_attempt"])
        elapsed = (now - last).total_seconds()
        if elapsed < COOLDOWN_SECONDS:
            return False, f"Please wait {COOLDOWN_SECONDS - int(elapsed)} seconds."
    return True, "OK"

def record_attempt(user_id, success):
    now = datetime.now()
    if user_id not in attempt_tracker:
        attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
    tracker = attempt_tracker[user_id]
    tracker["last_attempt"] = now.isoformat()
    if success:
        tracker["count"] = 0
        tracker["locked_until"] = None
    else:
        tracker["count"] += 1
        if tracker["count"] >= MAX_ATTEMPTS:
            tracker["locked_until"] = (now + timedelta(minutes=LOCKOUT_MINUTES)).isoformat()


# Generate random challenge (2 words from pool)
def generate_challenge():
    words = random.sample(CHALLENGE_WORDS, 2)
    return ' '.join(words)


# Session storage (in-memory)
sessions = {}

def create_session(user_id):
    session_id = str(uuid.uuid4())
    sessions[session_id] = {
        "session_id": session_id,
        "user_id": user_id.strip().upper(),
        "step": SESSION_STEPS['STARTED'],
        "challenge_phrase": None,
        "full_name": None,
        "similarity": None,
        "created_at": datetime.now().isoformat(),
        "expires_at": (datetime.now() + timedelta(minutes=5)).isoformat()
    }
    return sessions[session_id]

def get_session(session_id):
    if session_id not in sessions:
        return None
    session = sessions[session_id]
    if datetime.now() > datetime.fromisoformat(session["expires_at"]):
        del sessions[session_id]
        return None
    return session


# Enroll
def enroll_sample(audio_input, user_id, full_name, sample_number, total_samples=NUM_CLEAN_SAMPLES):
    if not user_id or not user_id.strip():
        return "Error: User ID is required."
    if not full_name or not full_name.strip():
        return "Error: Full Name is required."
    if audio_input is None:
        return "Error: No audio recorded."

    user_id = user_id.strip().upper()
    full_name = full_name.strip()

    try:
        wav = load_audio(audio_input)
        is_live, live_msg = check_liveness(wav)
        if not is_live:
            return f"Enrollment failed: {live_msg}"
        is_real, spoof_msg = check_antispoofing(wav)
        if not is_real:
            return f"Enrollment failed: {spoof_msg}"

        clean_emb = extract_embedding(wav)
        noisy_embeddings = []
        for i in range(NUM_NOISY_COPIES):
            noise_level = 0.003 + (i * 0.002)
            noisy_wav = add_noise(wav, noise_level)
            noisy_emb = extract_embedding(noisy_wav)
            noisy_embeddings.append(noisy_emb)

        db = load_db()
        if user_id not in db:
            db[user_id] = {
                "full_name": full_name,
                "enrolled_at": datetime.now().isoformat(),
                "sample_embeddings": [],
                "voiceprint": None,
                "status": "enrolling",
                "samples_collected": 0
            }

        sample_data = {
            "clean": clean_emb.tolist(),
            "noisy": [e.tolist() for e in noisy_embeddings]
        }
        db[user_id]["sample_embeddings"].append(sample_data)
        db[user_id]["samples_collected"] = len(db[user_id]["sample_embeddings"])
        db[user_id]["full_name"] = full_name
        samples_collected = db[user_id]["samples_collected"]

        if samples_collected >= total_samples:
            all_embeddings = []
            for sample in db[user_id]["sample_embeddings"]:
                all_embeddings.append(np.array(sample["clean"]))
                for noisy in sample["noisy"]:
                    all_embeddings.append(np.array(noisy))
            avg_embedding = np.mean(all_embeddings, axis=0)
            avg_embedding = avg_embedding / (np.linalg.norm(avg_embedding) + 1e-10)
            db[user_id]["voiceprint"] = avg_embedding.tolist()
            db[user_id]["status"] = "enrolled"
            db[user_id]["completed_at"] = datetime.now().isoformat()
            db[user_id]["sample_embeddings"] = []
            save_db(db)
            return f"Enrollment COMPLETE for {full_name} ({user_id}). Voiceprint created from {total_samples} samples ({total_samples * (1 + NUM_NOISY_COPIES)} embeddings averaged)."
        else:
            save_db(db)
            remaining = total_samples - samples_collected
            return f"Sample {samples_collected}/{total_samples} recorded for {full_name}. {remaining} more sample(s) needed."
    except Exception as e:
        return f"Enrollment error: {str(e)}"


# Verify
def verify_speaker(audio_input, user_id):
    if not user_id or not user_id.strip():
        return "Error: User ID is required."
    if audio_input is None:
        return "Error: No audio recorded."

    user_id = user_id.strip().upper()
    allowed, sec_msg = check_security(user_id)
    if not allowed:
        return f"ACCESS DENIED: {sec_msg}"

    db = load_db()
    if user_id not in db:
        return f"Error: User '{user_id}' not found."
    if db[user_id].get("status") != "enrolled":
        samples = db[user_id].get("samples_collected", 0)
        remaining = NUM_CLEAN_SAMPLES - samples
        return f"Error: Enrollment incomplete. {remaining} more sample(s) needed."

    try:
        wav = load_audio(audio_input)
        is_live, live_msg = check_liveness(wav)
        if not is_live:
            record_attempt(user_id, False)
            return f"ACCESS DENIED: {live_msg}"
        is_real, spoof_msg = check_antispoofing(wav)
        if not is_real:
            record_attempt(user_id, False)
            return f"ACCESS DENIED: {spoof_msg}"

        test_emb = extract_embedding(wav)
        stored_emb = np.array(db[user_id]["voiceprint"])
        similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))

        if similarity >= THRESHOLD:
            record_attempt(user_id, True)
            full_name = db[user_id].get("full_name", user_id)
            return (f"ACCESS GRANTED\nWelcome, {full_name}\n"
                    f"Confidence: {similarity:.4f} (threshold: {THRESHOLD})\n"
                    f"Liveness: Passed | Antispoofing: Passed")
        else:
            record_attempt(user_id, False)
            tracker = attempt_tracker.get(user_id, {})
            attempts_left = MAX_ATTEMPTS - tracker.get("count", 0)
            msg = f"ACCESS DENIED\nVoice does not match.\nSimilarity: {similarity:.4f} (threshold: {THRESHOLD})\n"
            if attempts_left > 0:
                msg += f"Attempts remaining: {attempts_left}"
            else:
                msg += f"Account locked for {LOCKOUT_MINUTES} minutes."
            return msg
    except Exception as e:
        return f"Verification error: {str(e)}"


# User management
def list_users():
    db = load_db()
    if not db:
        return "No users enrolled yet."
    lines = ["=== Enrolled Users ===\n"]
    for uid, data in db.items():
        name = data.get("full_name", "Unknown")
        status = data.get("status", "unknown")
        enrolled = data.get("enrolled_at", "N/A")
        samples = data.get("samples_collected", 0)
        lines.append(f"ID: {uid} | Name: {name} | Status: {status} | Samples: {samples} | Enrolled: {enrolled}")
    return "\n".join(lines)

def delete_user(user_id):
    if not user_id or not user_id.strip():
        return "Error: User ID is required."
    user_id = user_id.strip().upper()
    db = load_db()
    if user_id not in db:
        return f"Error: User '{user_id}' not found."
    name = db[user_id].get("full_name", user_id)
    del db[user_id]
    save_db(db)
    if user_id in attempt_tracker:
        del attempt_tracker[user_id]
    return f"User '{name}' ({user_id}) deleted."

def reset_lockout(user_id):
    if not user_id or not user_id.strip():
        return "Error: User ID is required."
    user_id = user_id.strip().upper()
    if user_id in attempt_tracker:
        attempt_tracker[user_id] = {"count": 0, "last_attempt": None, "locked_until": None}
        return f"Lockout reset for {user_id}."
    return f"No lockout record for {user_id}."


# CREATE FASTAPI APP FIRST (before Gradio)
app = FastAPI(title="ATM Voice Authentication API", version="1.0.0")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Health check
@app.get("/api/health")
async def health_check():
    return {"status": "healthy", "model": "UniSpeech-SAT + AAM-Softmax", "threshold": THRESHOLD, "device": str(DEVICE), "timestamp": datetime.now().isoformat()}

# Basic enroll endpoint
@app.post("/api/enroll")
async def api_enroll(audio: UploadFile = File(...), user_id: str = Form(...), full_name: str = Form(...)):
    try:
        audio_bytes = await audio.read()
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
            tmp.write(audio_bytes)
            tmp_path = tmp.name
        result = enroll_sample(tmp_path, user_id, full_name, 1)
        os.unlink(tmp_path)
        db = load_db()
        uid = user_id.strip().upper()
        samples_collected = db.get(uid, {}).get("samples_collected", 0)
        is_complete = db.get(uid, {}).get("status") == "enrolled"
        return JSONResponse(content={"success": "error" not in result.lower() and "failed" not in result.lower(), "message": result, "user_id": uid, "samples_collected": samples_collected if not is_complete else NUM_CLEAN_SAMPLES, "samples_required": NUM_CLEAN_SAMPLES, "enrollment_complete": is_complete})
    except Exception as e:
        return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})

# Basic verify endpoint
@app.post("/api/verify")
async def api_verify(audio: UploadFile = File(...), user_id: str = Form(...)):
    try:
        audio_bytes = await audio.read()
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
            tmp.write(audio_bytes)
            tmp_path = tmp.name
        uid = user_id.strip().upper()
        allowed, sec_msg = check_security(uid)
        if not allowed:
            os.unlink(tmp_path)
            return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": sec_msg, "locked": True})
        db = load_db()
        if uid not in db:
            os.unlink(tmp_path)
            return JSONResponse(content={"success": False, "message": f"User '{uid}' not found."})
        if db[uid].get("status") != "enrolled":
            os.unlink(tmp_path)
            return JSONResponse(content={"success": False, "message": "Enrollment incomplete."})
        wav = load_audio(tmp_path)
        os.unlink(tmp_path)
        is_live, live_msg = check_liveness(wav)
        if not is_live:
            record_attempt(uid, False)
            return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": live_msg, "liveness_passed": False})
        is_real, spoof_msg = check_antispoofing(wav)
        if not is_real:
            record_attempt(uid, False)
            return JSONResponse(content={"success": True, "access_granted": False, "user_id": uid, "message": spoof_msg, "antispoofing_passed": False})
        test_emb = extract_embedding(wav)
        stored_emb = np.array(db[uid]["voiceprint"])
        similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
        granted = similarity >= THRESHOLD
        record_attempt(uid, granted)
        tracker = attempt_tracker.get(uid, {})
        attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0))
        response = {"success": True, "access_granted": granted, "user_id": uid, "full_name": db[uid].get("full_name", uid), "similarity": round(similarity, 4), "threshold": THRESHOLD, "liveness_passed": True, "antispoofing_passed": True, "attempts_remaining": attempts_remaining if not granted else MAX_ATTEMPTS, "locked": attempts_remaining == 0 and not granted}
        if granted:
            response["message"] = "Access granted. Voice verified."
        elif attempts_remaining > 0:
            response["message"] = f"Voice does not match. {attempts_remaining} attempt(s) remaining."
        else:
            response["message"] = f"Account locked for {LOCKOUT_MINUTES} minutes."
        return JSONResponse(content=response)
    except Exception as e:
        return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})

# List users
@app.get("/api/users")
async def api_list_users():
    db = load_db()
    users = []
    for uid, data in db.items():
        users.append({"user_id": uid, "full_name": data.get("full_name", "Unknown"), "status": data.get("status", "unknown"), "samples_collected": data.get("samples_collected", 0), "enrolled_at": data.get("enrolled_at", None), "completed_at": data.get("completed_at", None)})
    return JSONResponse(content={"success": True, "users": users, "total": len(users)})

# Delete user
@app.delete("/api/users/{user_id}")
async def api_delete_user(user_id: str):
    result = delete_user(user_id)
    success = "error" not in result.lower()
    return JSONResponse(content={"success": success, "message": result})

# Reset lockout
@app.post("/api/reset-lockout")
async def api_reset_lockout(user_id: str = Form(...)):
    result = reset_lockout(user_id)
    return JSONResponse(content={"success": True, "message": result})

# Session: Start
@app.post("/api/session/start")
async def session_start(user_id: str = Form(...)):
    uid = user_id.strip().upper()
    db = load_db()
    if uid not in db:
        return JSONResponse(content={"success": False, "message": f"User '{uid}' not found. Please enroll first."})
    if db[uid].get("status") != "enrolled":
        return JSONResponse(content={"success": False, "message": "Enrollment incomplete."})
    allowed, sec_msg = check_security(uid)
    if not allowed:
        return JSONResponse(content={"success": False, "message": sec_msg, "locked": True})
    session = create_session(uid)
    return JSONResponse(content={"success": True, "session_id": session["session_id"], "user_id": uid, "message": "Session started. Please provide a voice sample to verify your identity.", "next_step": "verify", "instruction": "Record your voice and send it to /api/session/verify"})

# Session: Verify identity
@app.post("/api/session/verify")
async def session_verify(audio: UploadFile = File(...), session_id: str = Form(...)):
    session = get_session(session_id)
    if not session:
        return JSONResponse(content={"success": False, "message": "Session expired or not found."})
    if session["step"] != SESSION_STEPS['STARTED']:
        return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"})
    uid = session["user_id"]
    allowed, sec_msg = check_security(uid)
    if not allowed:
        session["step"] = SESSION_STEPS['DENIED']
        return JSONResponse(content={"success": False, "message": sec_msg, "locked": True})
    try:
        audio_bytes = await audio.read()
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
            tmp.write(audio_bytes)
            tmp_path = tmp.name
        wav = load_audio(tmp_path)
        os.unlink(tmp_path)
        is_live, live_msg = check_liveness(wav)
        if not is_live:
            record_attempt(uid, False)
            return JSONResponse(content={"success": True, "verified": False, "message": live_msg})
        is_real, spoof_msg = check_antispoofing(wav)
        if not is_real:
            record_attempt(uid, False)
            return JSONResponse(content={"success": True, "verified": False, "message": spoof_msg})
        test_emb = extract_embedding(wav)
        db = load_db()
        stored_emb = np.array(db[uid]["voiceprint"])
        similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
        if similarity >= THRESHOLD:
            record_attempt(uid, True)
            full_name = db[uid].get("full_name", uid)
            challenge = generate_challenge()
            session["step"] = SESSION_STEPS['LIVENESS_PENDING']
            session["full_name"] = full_name
            session["similarity"] = round(similarity, 4)
            session["challenge_phrase"] = challenge
            return JSONResponse(content={"success": True, "verified": True, "greeting": f"Welcome, {full_name}", "full_name": full_name, "similarity": round(similarity, 4), "next_step": "liveness", "challenge_phrase": challenge, "instruction": f"Say these words: {challenge}", "message": f"Voice verified. Welcome, {full_name}. For security, please say these words: {challenge}"})
        else:
            record_attempt(uid, False)
            tracker = attempt_tracker.get(uid, {})
            attempts_remaining = max(0, MAX_ATTEMPTS - tracker.get("count", 0))
            locked = attempts_remaining == 0
            if locked:
                session["step"] = SESSION_STEPS['DENIED']
            return JSONResponse(content={"success": True, "verified": False, "similarity": round(similarity, 4), "attempts_remaining": attempts_remaining, "locked": locked, "message": f"Voice does not match. {attempts_remaining} attempt(s) remaining." if not locked else f"Account locked for {LOCKOUT_MINUTES} minutes."})
    except Exception as e:
        return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})

# Session: Liveness check
@app.post("/api/session/liveness")
async def session_liveness(audio: UploadFile = File(...), session_id: str = Form(...)):
    session = get_session(session_id)
    if not session:
        return JSONResponse(content={"success": False, "message": "Session expired or not found."})
    if session["step"] != SESSION_STEPS['LIVENESS_PENDING']:
        return JSONResponse(content={"success": False, "message": f"Invalid step. Current step: {session['step']}"})
    uid = session["user_id"]
    try:
        audio_bytes = await audio.read()
        with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp:
            tmp.write(audio_bytes)
            tmp_path = tmp.name
        wav = load_audio(tmp_path)
        os.unlink(tmp_path)
        is_live, live_msg = check_liveness(wav)
        if not is_live:
            return JSONResponse(content={"success": True, "liveness_passed": False, "message": live_msg})
        is_real, spoof_msg = check_antispoofing(wav)
        if not is_real:
            return JSONResponse(content={"success": True, "liveness_passed": False, "message": spoof_msg})
        test_emb = extract_embedding(wav)
        db = load_db()
        stored_emb = np.array(db[uid]["voiceprint"])
        similarity = float(np.dot(test_emb, stored_emb) / (np.linalg.norm(test_emb) * np.linalg.norm(stored_emb) + 1e-10))
        if similarity >= THRESHOLD:
            session["step"] = SESSION_STEPS['AUTHENTICATED']
            full_name = session["full_name"]
            return JSONResponse(content={"success": True, "liveness_passed": True, "authenticated": True, "full_name": full_name, "similarity": round(similarity, 4), "next_step": "transaction", "instruction": "How much would you like to withdraw?", "message": f"Liveness confirmed. You are fully authenticated, {full_name}. How much would you like to withdraw?"})
        else:
            return JSONResponse(content={"success": True, "liveness_passed": False, "message": "Voice mismatch during liveness check. Please try again.", "challenge_phrase": session["challenge_phrase"], "instruction": f"Please say these words again: {session['challenge_phrase']}"})
    except Exception as e:
        return JSONResponse(status_code=500, content={"success": False, "message": f"Server error: {str(e)}"})

# Session: Transaction
@app.post("/api/session/transaction")
async def session_transaction(session_id: str = Form(...), amount: str = Form(...)):
    session = get_session(session_id)
    if not session:
        return JSONResponse(content={"success": False, "message": "Session expired or not found."})
    if session["step"] != SESSION_STEPS['AUTHENTICATED']:
        return JSONResponse(content={"success": False, "message": f"Not authenticated. Current step: {session['step']}"})
    full_name = session["full_name"]
    session["step"] = SESSION_STEPS['COMPLETE']
    return JSONResponse(content={"success": True, "transaction_approved": True, "full_name": full_name, "amount": amount, "message": f"Transaction approved. {full_name}, you are withdrawing {amount} cedis. Please collect your cash.", "instruction": "Transaction complete. Session ended.", "note": "In production, this step communicates with the banks core system to process the actual withdrawal."})

# Session status
@app.get("/api/session/{session_id}")
async def session_status(session_id: str):
    session = get_session(session_id)
    if not session:
        return JSONResponse(content={"success": False, "message": "Session expired or not found."})
    return JSONResponse(content={"success": True, "session_id": session["session_id"], "user_id": session["user_id"], "step": session["step"], "full_name": session["full_name"], "challenge_phrase": session["challenge_phrase"], "created_at": session["created_at"], "expires_at": session["expires_at"]})


# Gradio interface
with gr.Blocks(title="ATM Voice Authentication System", theme=gr.themes.Soft()) as demo:
    gr.Markdown("""
    # ATM Voice Authentication System
    ### Voice-Based Speaker Verification for Banking Security
    Voice biometric authentication system for secure ATM access
    """)
    with gr.Tabs():
        with gr.Tab("Enroll"):
            gr.Markdown("### Enroll New User\nRecord **6 voice samples** to create your voiceprint. Speak naturally for 3-4 seconds each time.")
            with gr.Row():
                with gr.Column():
                    enroll_audio = gr.Audio(label="Record Voice Sample", sources=["microphone", "upload"], type="numpy")
                    enroll_user_id = gr.Textbox(label="User ID (e.g., ATM_001)", placeholder="ATM_001")
                    enroll_name = gr.Textbox(label="Full Name", placeholder="Jochebed Fafa")
                    enroll_sample_num = gr.Number(label="Sample Number (1-6)", value=1, minimum=1, maximum=6, step=1)
                    enroll_btn = gr.Button("Enroll Sample", variant="primary")
                with gr.Column():
                    enroll_result = gr.Textbox(label="Result", lines=4, interactive=False)
            enroll_btn.click(fn=enroll_sample, inputs=[enroll_audio, enroll_user_id, enroll_name, enroll_sample_num], outputs=enroll_result)
        with gr.Tab("Verify"):
            gr.Markdown("### Verify Identity\nRecord your voice to verify against your enrolled voiceprint.")
            with gr.Row():
                with gr.Column():
                    verify_audio = gr.Audio(label="Record Voice", sources=["microphone", "upload"], type="numpy")
                    verify_user_id = gr.Textbox(label="User ID", placeholder="ATM_001")
                    verify_btn = gr.Button("Verify", variant="primary")
                with gr.Column():
                    verify_result = gr.Textbox(label="Result", lines=6, interactive=False)
            verify_btn.click(fn=verify_speaker, inputs=[verify_audio, verify_user_id], outputs=verify_result)
        with gr.Tab("Users"):
            gr.Markdown("### Manage Enrolled Users")
            list_btn = gr.Button("List All Users")
            users_output = gr.Textbox(label="Enrolled Users", lines=10, interactive=False)
            list_btn.click(fn=list_users, outputs=users_output)
            gr.Markdown("---")
            with gr.Row():
                with gr.Column():
                    del_user_id = gr.Textbox(label="User ID to Delete", placeholder="ATM_001")
                    del_btn = gr.Button("Delete User", variant="stop")
                    del_result = gr.Textbox(label="Result", interactive=False)
                    del_btn.click(fn=delete_user, inputs=del_user_id, outputs=del_result)
                with gr.Column():
                    reset_user_id = gr.Textbox(label="User ID to Reset Lockout", placeholder="ATM_001")
                    reset_btn = gr.Button("Reset Lockout", variant="secondary")
                    reset_result = gr.Textbox(label="Result", interactive=False)
                    reset_btn.click(fn=reset_lockout, inputs=reset_user_id, outputs=reset_result)
        with gr.Tab("API Docs"):
            gr.Markdown("""
            ### REST API Endpoints
            **Base URL:** `https://amfafa-voice-authentication-sys.hf.space`
            ---
            #### Basic Endpoints
            - `POST /api/enroll` - Enroll a voice sample (audio, user_id, full_name)
            - `POST /api/verify` - Verify a voice (audio, user_id)
            - `GET /api/users` - List enrolled users
            - `DELETE /api/users/{user_id}` - Delete a user
            - `GET /api/health` - Health check
            ---
            #### Session-Based Voice Authentication Flow
            - `POST /api/session/start` - Start session (user_id)
            - `POST /api/session/verify` - Verify identity (audio, session_id) - Returns greeting + challenge words
            - `POST /api/session/liveness` - Liveness check (audio, session_id) - Returns authenticated or denied
            - `POST /api/session/transaction` - Confirm transaction (amount, session_id)
            - `GET /api/session/{session_id}` - Check session status
            """)

# Mount Gradio ON the FastAPI app (not the other way around)
app = gr.mount_gradio_app(app, demo, path="/")

# Launch
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)