File size: 27,380 Bytes
7bc2dab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac8d9b
 
7bc2dab
2ac8d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc2dab
 
 
 
 
 
 
2ac8d9b
7bc2dab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac8d9b
7bc2dab
 
 
 
 
 
 
 
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc2dab
 
 
 
 
 
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
2ac8d9b
7bc2dab
 
 
 
 
 
 
2ac8d9b
 
7bc2dab
2ac8d9b
7bc2dab
2ac8d9b
 
7bc2dab
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
2ac8d9b
 
7bc2dab
2ac8d9b
7bc2dab
 
 
 
 
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
 
 
7bc2dab
2ac8d9b
 
 
7bc2dab
2ac8d9b
 
 
7bc2dab
2ac8d9b
 
7bc2dab
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
2ac8d9b
 
7bc2dab
 
2ac8d9b
7bc2dab
 
 
 
 
 
2ac8d9b
 
 
7bc2dab
 
 
 
 
 
2ac8d9b
7bc2dab
 
 
 
2ac8d9b
 
 
 
 
7bc2dab
 
 
 
 
 
 
 
 
 
 
 
 
 
2ac8d9b
7bc2dab
 
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
 
 
 
2ac8d9b
 
 
7bc2dab
 
 
 
 
 
 
2ac8d9b
 
 
 
7bc2dab
 
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
2ac8d9b
 
 
 
 
7bc2dab
2ac8d9b
 
7bc2dab
 
 
2ac8d9b
 
7bc2dab
2ac8d9b
 
 
 
 
 
7bc2dab
 
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
 
2ac8d9b
7bc2dab
 
2ac8d9b
 
 
 
7bc2dab
 
2ac8d9b
 
7bc2dab
 
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
2ac8d9b
 
7bc2dab
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
 
2ac8d9b
7bc2dab
 
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
2ac8d9b
7bc2dab
 
 
2ac8d9b
7bc2dab
 
 
2ac8d9b
 
 
 
 
 
 
 
 
7bc2dab
 
 
 
 
 
2ac8d9b
 
 
 
 
 
 
 
7bc2dab
2ac8d9b
 
 
7bc2dab
 
 
2ac8d9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7bc2dab
 
 
 
 
 
 
 
 
 
 
 
 
2ac8d9b
7bc2dab
2ac8d9b
 
 
7bc2dab
 
 
2ac8d9b
7bc2dab
 
 
 
 
 
047fc3c
7bc2dab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
import gradio as gr
import torch
import torchaudio
import numpy as np
import json
import os
from datetime import datetime
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
import warnings
warnings.filterwarnings('ignore')

# Import your existing classes and functions
from torch import nn
import torchvision

class AudioPreprocessor:
    """Enhanced audio preprocessing for voice security."""
    
    def __init__(self, sample_rate=16000, n_mels=128, n_fft=2048, hop_length=512):
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.mel_spectrogram = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_mels=n_mels,
            n_fft=n_fft,
            hop_length=hop_length
        )
        self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB()
    
    def audio_to_melspectrogram(self, audio_path):
        """Convert audio file to mel-spectrogram."""
        try:
            # Load audio file
            waveform, sr = torchaudio.load(audio_path)
            
            # Resample if necessary
            if sr != self.sample_rate:
                resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
                waveform = resampler(waveform)
            
            # Convert to mono if stereo
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)
            
            # Pad or truncate to fixed length (3 seconds)
            target_length = self.sample_rate * 3
            if waveform.shape[1] > target_length:
                waveform = waveform[:, :target_length]
            else:
                padding = target_length - waveform.shape[1]
                waveform = torch.nn.functional.pad(waveform, (0, padding))
            
            # Convert to mel-spectrogram
            mel_spec = self.mel_spectrogram(waveform)
            mel_spec_db = self.amplitude_to_db(mel_spec)
            
            # Normalize
            mel_spec_db = (mel_spec_db - mel_spec_db.mean()) / (mel_spec_db.std() + 1e-8)
            
            # Convert to 3-channel image (RGB) for pretrained models
            mel_spec_rgb = mel_spec_db.repeat(3, 1, 1)
            
            return mel_spec_rgb, waveform.numpy()
            
        except Exception as e:
            print(f"Error processing audio: {e}")
            return None, None

# Model Classes (same as your original code)
class ResNet18Model(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18Model, self).__init__()
        self.backbone = torchvision.models.resnet18(pretrained=False)
        self.backbone.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(self.backbone.fc.in_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.backbone(x)

class ResNet50Model(nn.Module):
    def __init__(self, num_classes):
        super(ResNet50Model, self).__init__()
        self.backbone = torchvision.models.resnet50(pretrained=False)
        num_ftrs = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.BatchNorm1d(num_ftrs),
            nn.Dropout(0.4),
            nn.Linear(num_ftrs, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        return self.backbone(x)

class EfficientNetB0Model(nn.Module):
    def __init__(self, num_classes):
        super(EfficientNetB0Model, self).__init__()
        self.backbone = torchvision.models.efficientnet_b0(pretrained=False)
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(p=0.3, inplace=True),
            nn.Linear(in_features=1280, out_features=512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        return self.backbone(x)

class MobileNetV2Model(nn.Module):
    def __init__(self, num_classes):
        super(MobileNetV2Model, self).__init__()
        self.backbone = torchvision.models.mobilenet_v2(pretrained=False)
        self.backbone.classifier = nn.Sequential(
            nn.Dropout(0.2),
            nn.Linear(self.backbone.last_channel, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
    
    def forward(self, x):
        return self.backbone(x)

class VoiceSecuritySystem:
    def __init__(self):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.preprocessor = AudioPreprocessor()
        self.models = {}
        self.label_encoder = LabelEncoder()
        
        # Updated model info with actual training results
        self.model_info = {
            "resnet18": {
                "name": "ResNet-18 πŸ† CHAMPION",
                "description": "πŸ₯‡ BEST PERFORMING MODEL - Perfect 100% accuracy with 11.3M parameters (4.9M trainable). Exceptional security with 0.06% FAR and 0% FRR. Ideal for high-security applications requiring zero false rejections.",
                "accuracy": "100.00%",
                "far": "0.0006",
                "frr": "0.0000",
                "parameters": "11.3M total (4.9M trainable)",
                "status": "πŸ† CHAMPION"
            },
            "resnet50": {
                "name": "ResNet-50 πŸ₯ˆ HIGH PERFORMER", 
                "description": "πŸ₯ˆ EXCELLENT ACCURACY - 99.94% accuracy with 24.6M parameters (16.0M trainable). Near-perfect performance with robust feature extraction. Best for applications requiring high accuracy with acceptable computational overhead.",
                "accuracy": "99.94%",
                "far": "0.0006", 
                "frr": "0.0000",
                "parameters": "24.6M total (16.0M trainable)",
                "status": "πŸ₯ˆ RUNNER-UP"
            },
            "efficientnet_b0": {
                "name": "EfficientNet-B0 ⚑ EFFICIENT",
                "description": "⚑ MOBILE OPTIMIZED - 99.76% accuracy with only 4.7M parameters (3.8M trainable). Excellent efficiency-accuracy trade-off. Perfect for mobile deployment with minimal computational requirements.",
                "accuracy": "99.76%",
                "far": "0.0030",
                "frr": "0.0000", 
                "parameters": "4.7M total (3.8M trainable)",
                "status": "⚑ EFFICIENT"
            },
            "mobilenet_v2": {
                "name": "MobileNet-V2 πŸ“± LIGHTWEIGHT",
                "description": "πŸ“± ULTRA-LIGHTWEIGHT - 99.76% accuracy with just 2.9M parameters (1.1M trainable). Smallest model with excellent performance. Ideal for edge devices and real-time applications with limited resources.",
                "accuracy": "99.76%",
                "far": "0.0012",
                "frr": "0.0000",
                "parameters": "2.9M total (1.1M trainable)", 
                "status": "πŸ“± COMPACT"
            }
        }
        self.load_models()
    
    def load_models(self):
        """Load all pre-trained models"""
        # This would load your actual trained models
        # For demo purposes, we'll create placeholder models
        num_classes = 26  # Based on your training output (26 users)
        
        # Initialize label encoder with dummy classes
        dummy_classes = [f"user_{i+1}" for i in range(num_classes)]
        self.label_encoder.fit(dummy_classes)
        
        model_classes = {
            "resnet18": ResNet18Model,
            "resnet50": ResNet50Model,
            "efficientnet_b0": EfficientNetB0Model,
            "mobilenet_v2": MobileNetV2Model
        }
        
        for model_name, model_class in model_classes.items():
            try:
                model = model_class(num_classes).to(self.device)
                # In actual deployment, you would load the trained weights:
                # model.load_state_dict(torch.load(f"models/{model_name}.pth", map_location=self.device))
                model.eval()
                self.models[model_name] = model
                print(f"βœ… Loaded {model_name} successfully")
            except Exception as e:
                print(f"❌ Error loading {model_name}: {e}")
    
    def predict_voice(self, audio_file, model_name, confidence_threshold):
        """Predict voice access using selected model"""
        if audio_file is None:
            return "❌ Error", "No audio file provided", 0.0, self.create_empty_plot(), "Please upload an audio file"
        
        try:
            # Process audio
            features, waveform = self.preprocessor.audio_to_melspectrogram(audio_file)
            if features is None:
                return "❌ Error", "Failed to process audio", 0.0, self.create_empty_plot(), "Audio processing failed"
            
            # Get selected model
            model = self.models.get(model_name)
            if model is None:
                return "❌ Error", "Model not found", 0.0, self.create_empty_plot(), "Selected model is not available"
            
            # Make prediction
            features = features.unsqueeze(0).to(self.device)
            
            with torch.no_grad():
                output = model(features)
                probabilities = torch.softmax(output, dim=1)
                confidence, predicted = torch.max(probabilities, 1)
                
                predicted_class = self.label_encoder.inverse_transform([predicted.item()])[0]
                confidence_score = confidence.item()
                
                # Create visualization
                viz_plot = self.create_prediction_visualization(probabilities.cpu().numpy()[0], 
                                                                predicted_class, confidence_score)
                
                # Determine access decision
                if confidence_score >= confidence_threshold:
                    status = "🟒 ACCESS GRANTED"
                    message = f"Welcome, {predicted_class}!"
                    security_status = f"βœ… AUTHORIZED USER DETECTED"
                else:
                    status = "πŸ”΄ ACCESS DENIED"
                    message = f"Access denied - Low confidence"
                    security_status = f"⚠️ UNAUTHORIZED ACCESS ATTEMPT"
                
                model_stats = self.model_info[model_name]
                detailed_info = f"""
                ## πŸ€– Model Performance
                **Model Used:** {model_stats['name']}  
                **Training Accuracy:** {model_stats['accuracy']}  
                **Model Size:** {model_stats['parameters']}  
                **Status:** {model_stats['status']}
                
                ## πŸ” Prediction Results
                **Predicted User:** {predicted_class}  
                **Confidence Score:** {confidence_score:.3f}  
                **Security Threshold:** {confidence_threshold}  
                **Decision:** {'βœ… GRANT ACCESS' if confidence_score >= confidence_threshold else '❌ DENY ACCESS'}
                
                ## πŸ›‘οΈ Security Metrics
                **False Accept Rate (FAR):** {model_stats['far']}  
                **False Reject Rate (FRR):** {model_stats['frr']}  
                **Security Level:** {'πŸ”’ HIGH' if confidence_score >= 0.8 else 'πŸ”“ MEDIUM' if confidence_score >= 0.5 else '⚠️ LOW'}
                """
                
                return status, message, confidence_score, viz_plot, detailed_info
                
        except Exception as e:
            return "❌ Error", f"Prediction failed: {str(e)}", 0.0, self.create_empty_plot(), "An error occurred during prediction"
    
    def create_prediction_visualization(self, probabilities, predicted_class, confidence):
        """Create visualization of prediction results"""
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
        
        # Enhanced color scheme
        colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#F7DC6F', '#BB8FCE', '#85C1E9', '#F8C471', '#82E0AA', '#F1948A']
        
        # Plot 1: Top 5 predictions with enhanced styling
        top_5_indices = np.argsort(probabilities)[-5:][::-1]
        top_5_probs = probabilities[top_5_indices]
        top_5_labels = [self.label_encoder.inverse_transform([i])[0] for i in top_5_indices]
        
        bars = ax1.barh(range(len(top_5_labels)), top_5_probs, color=colors[:len(top_5_labels)])
        ax1.set_yticks(range(len(top_5_labels)))
        ax1.set_yticklabels(top_5_labels)
        ax1.set_xlabel('Confidence Score', fontweight='bold')
        ax1.set_title('🎯 Top 5 User Predictions', fontweight='bold', fontsize=12)
        ax1.set_xlim(0, 1)
        ax1.grid(axis='x', alpha=0.3)
        
        # Highlight the top prediction with gold color
        bars[0].set_color('#FFD700')
        bars[0].set_edgecolor('#FF8C00')
        bars[0].set_linewidth(3)
        
        # Add value labels with better formatting
        for i, (bar, prob) in enumerate(zip(bars, top_5_probs)):
            ax1.text(prob + 0.02, bar.get_y() + bar.get_height()/2, 
                     f'{prob:.3f}', va='center', fontweight='bold', fontsize=10)
        
        # Plot 2: Enhanced confidence gauge
        theta = np.linspace(0, np.pi, 100)
        r = np.ones_like(theta)
        
        ax2 = plt.subplot(122, projection='polar')
        ax2.set_theta_zero_location('S')
        ax2.set_theta_direction(1)
        ax2.set_ylim(0, 1)
        
        # Enhanced color segments based on confidence levels
        if confidence < 0.3:
            color = '#FF4757'  # Red
            status_text = '⚠️ LOW'
            risk_level = 'HIGH RISK'
        elif confidence < 0.7:
            color = '#FFA726'  # Orange
            status_text = '🟑 MEDIUM'
            risk_level = 'MODERATE RISK'
        else:
            color = '#66BB6A'  # Green
            status_text = 'βœ… HIGH'
            risk_level = 'LOW RISK'
        
        # Draw enhanced gauge
        ax2.fill_between(theta, 0, r, alpha=0.2, color='lightgray')
        confidence_theta = theta[int(confidence * len(theta))]
        ax2.plot([confidence_theta, confidence_theta], [0, 1], color=color, linewidth=10)
        ax2.fill_between(theta[:int(confidence * len(theta))], 0, r[:int(confidence * len(theta))], 
                         alpha=0.8, color=color)
        
        ax2.set_title(f'🎚️ Confidence Level\n{confidence:.3f} - {status_text}\n{risk_level}', 
                      pad=30, fontweight='bold')
        ax2.set_ylim(0, 1)
        ax2.set_yticklabels([])
        ax2.set_xticklabels(['πŸ”΄ Low', '', '🟑 Med', '', '🟒 High'], fontweight='bold')
        
        plt.tight_layout()
        return fig
    
    def create_empty_plot(self):
        """Create empty plot for error cases"""
        fig, ax = plt.subplots(figsize=(10, 6))
        ax.text(0.5, 0.5, 'πŸ“Š No Data Available\nPlease upload an audio file', 
                ha='center', va='center', fontsize=18, color='gray', fontweight='bold')
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
        ax.axis('off')
        return fig
    
    def get_model_comparison(self):
        """Return model comparison information with actual training results"""
        comparison_data = []
        for model_key, info in self.model_info.items():
            comparison_data.append([
                info['name'],
                info['accuracy'],
                info['far'],
                info['frr'],
                info['parameters'],
                info['status']
            ])
        return comparison_data

# Initialize the system
voice_system = VoiceSecuritySystem()

def process_voice(audio_file, model_name, confidence_threshold):
    """Main processing function for Gradio interface"""
    return voice_system.predict_voice(audio_file, model_name, confidence_threshold)

def get_model_info(model_name):
    """Get information about selected model"""
    if model_name in voice_system.model_info:
        info = voice_system.model_info[model_name]
        return f"## {info['name']}\n\n{info['description']}\n\n**πŸ“Š Key Stats:**\n- Accuracy: {info['accuracy']}\n- Parameters: {info['parameters']}\n- FAR: {info['far']} | FRR: {info['frr']}"
    return "Model information not available"

# Enhanced custom CSS
custom_css = """
.gradio-container {
    background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important;
    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif !important;
}
.gr-button-primary {
    background: linear-gradient(45deg, #FF6B6B, #FF8E53) !important;
    border: none !important;
    font-weight: bold !important;
    text-transform: uppercase !important;
    letter-spacing: 1px !important;
}
.gr-button-secondary {
    background: linear-gradient(45deg, #4ECDC4, #44A08D) !important;
    border: none !important;
}
.gr-panel {
    background: rgba(255, 255, 255, 0.95) !important;
    backdrop-filter: blur(15px) !important;
    border-radius: 20px !important;
    border: 2px solid rgba(255, 255, 255, 0.3) !important;
    box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1) !important;
}
.gr-form {
    background: transparent !important;
}
.gr-box {
    border-radius: 15px !important;
    border: 1px solid #E0E0E0 !important;
    box-shadow: 0 4px 16px rgba(0, 0, 0, 0.05) !important;
}
h1, h2, h3 {
    color: #2C3E50 !important;
    text-shadow: 2px 2px 4px rgba(0,0,0,0.1) !important;
}
.champion-badge {
    background: linear-gradient(45deg, #FFD700, #FFA500);
    padding: 5px 10px;
    border-radius: 20px;
    color: #333;
    font-weight: bold;
    display: inline-block;
    margin: 5px;
}
"""

# Create enhanced Gradio interface
with gr.Blocks(css=custom_css, title="πŸ”Š Voice Recognition Security System - Trained Results") as app:
    gr.HTML("""
    <div style="text-align: center; padding: 30px; background: linear-gradient(45deg, #667eea, #764ba2); color: white; border-radius: 20px; margin-bottom: 25px; box-shadow: 0 10px 30px rgba(0,0,0,0.3);">
        <h1 style="margin: 0; font-size: 3em; text-shadow: 3px 3px 6px rgba(0,0,0,0.4);">πŸ”Š Voice Recognition Security System</h1>
        <p style="margin: 15px 0 10px 0; font-size: 1.3em; opacity: 0.95;">Advanced AI-powered voice authentication with 4 deep learning models</p>
        <div style="background: rgba(255,255,255,0.2); padding: 10px; border-radius: 10px; margin-top: 15px;">
            <p style="margin: 0; font-size: 1.1em; font-weight: bold;">πŸ† Training Complete: 26 Users | 1,693 Samples | Best Accuracy: 100%</p>
        </div>
    </div>
    """)
    
    with gr.Row():
        with gr.Column(scale=1):
            gr.HTML("<h2>🎯 Authentication Control Panel</h2>")
            
            # Audio input with enhanced styling
            audio_input = gr.Audio(
                label="🎀 Upload Voice Sample (WAV, MP3, FLAC supported)",
                type="filepath",
                elem_id="audio_input"
            )
            
            # Model selection with performance indicators
            model_selector = gr.Dropdown(
                choices=[
                    ("πŸ† ResNet-18 - CHAMPION (100% Accuracy)", "resnet18"),
                    ("πŸ₯ˆ ResNet-50 - HIGH PERFORMER (99.94% Accuracy)", "resnet50"),
                    ("⚑ EfficientNet-B0 - EFFICIENT (99.76% Accuracy)", "efficientnet_b0"),
                    ("πŸ“± MobileNet-V2 - LIGHTWEIGHT (99.76% Accuracy)", "mobilenet_v2")
                ],
                value="resnet18",
                label="πŸ€– Select AI Model (Ranked by Performance)",
                info="All models trained on 26 users with augmented dataset"
            )
            
            # Enhanced confidence threshold
            confidence_slider = gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.8,
                step=0.05,
                label="🎚️ Security Threshold (Recommended: 0.8 for high security)",
                info="Higher values = More secure but may increase false rejections"
            )
            
            # Enhanced process button
            process_btn = gr.Button(
                "πŸ” AUTHENTICATE VOICE",
                variant="primary",
                size="lg"
            )
            
            # Enhanced model info display
            model_info_display = gr.Markdown(
                get_model_info("resnet18"),
                label="πŸ“‹ Model Performance Details"
            )
        
        with gr.Column(scale=2):
            gr.HTML("<h2>πŸ“Š Authentication Results & Analysis</h2>")
            
            with gr.Row():
                with gr.Column():
                    # Enhanced status display
                    status_output = gr.Textbox(
                        label="🚦 Access Decision",
                        interactive=False,
                        elem_id="status_output"
                    )
                    
                    # Enhanced message display
                    message_output = gr.Textbox(
                        label="πŸ’¬ System Response",
                        interactive=False
                    )
                    
                    # Enhanced confidence display
                    confidence_output = gr.Number(
                        label="πŸ“ˆ Confidence Score (0.000-1.000)",
                        interactive=False,
                        precision=3
                    )
                
                with gr.Column():
                    # Enhanced detailed information
                    detailed_info = gr.Markdown(
                        label="πŸ” Comprehensive Analysis Report"
                    )
            
            # Enhanced visualization plot
            plot_output = gr.Plot(
                label="πŸ“ˆ Prediction Visualization & Confidence Analysis",
                elem_id="plot_output"
            )
    
    # Enhanced model comparison section
    with gr.Row():
        gr.HTML("<h2>βš–οΈ Model Performance Comparison (Training Results)</h2>")
    
    with gr.Row():
        comparison_table = gr.Dataframe(
            headers=["Model", "Accuracy", "FAR (False Accept)", "FRR (False Reject)", "Parameters", "Status"],
            value=voice_system.get_model_comparison(),
            label="πŸ“Š Actual Training Performance Metrics",
            interactive=False
        )
    
    # Enhanced information sections
    with gr.Row():
        with gr.Column():
            gr.HTML("""
            <div style="background: linear-gradient(45deg, #FFF3E0, #FFE0B2); padding: 25px; border-radius: 15px; border-left: 6px solid #FF9800; box-shadow: 0 6px 20px rgba(0,0,0,0.1);">
                <h3>πŸ›‘οΈ Advanced Security Features</h3>
                <ul style="line-height: 1.8;">
                    <li><strong>πŸ† Champion Model:</strong> ResNet-18 achieved perfect 100% accuracy</li>
                    <li><strong>πŸ“Š Multi-Model Architecture:</strong> 4 state-of-the-art models to choose from</li>
                    <li><strong>🎯 Zero False Rejections:</strong> All models achieved 0% FRR</li>
                    <li><strong>⚑ Real-Time Processing:</strong> Optimized for fast authentication</li>
                    <li><strong>πŸ“ˆ Detailed Analytics:</strong> Comprehensive prediction visualization</li>
                    <li><strong>πŸ”’ Adjustable Security:</strong> Customizable confidence thresholds</li>
                </ul>
            </div>
            """)
        
        with gr.Column():
            gr.HTML("""
            <div style="background: linear-gradient(45deg, #E8F5E8, #C8E6C9); padding: 25px; border-radius: 15px; border-left: 6px solid #4CAF50; box-shadow: 0 6px 20px rgba(0,0,0,0.1);">
                <h3>πŸ“– Usage Instructions</h3>
                <ol style="line-height: 1.8;">
                    <li><strong>🎀 Upload Audio:</strong> Record or upload voice sample (3 seconds optimal)</li>
                    <li><strong>πŸ€– Select Model:</strong> Choose from our trained models (ResNet-18 recommended)</li>
                    <li><strong>🎚️ Set Threshold:</strong> Adjust security level (0.8 recommended for high security)</li>
                    <li><strong>πŸ” Authenticate:</strong> Click to process and analyze your voice</li>
                    <li><strong>πŸ“Š Review Results:</strong> Check detailed analysis and confidence metrics</li>
                </ol>
                <div style="background: rgba(76, 175, 80, 0.1); padding: 10px; border-radius: 8px; margin-top: 15px;">
                    <strong>πŸ’‘ Tip:</strong> ResNet-18 offers perfect accuracy with optimal performance!
                </div>
            </div>
            """)
    
    # Training details section
    with gr.Row():
        gr.HTML("""
        <div style="background: linear-gradient(45deg, #E3F2FD, #BBDEFB); padding: 25px; border-radius: 15px; border-left: 6px solid #2196F3; box-shadow: 0 6px 20px rgba(0,0,0,0.1);">
            <h3>πŸŽ“ Training Details & Achievements</h3>
            <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(300px, 1fr)); gap: 20px; margin-top: 15px;">
                <div>
                    <h4>πŸ“Š Dataset Information</h4>
                    <ul>
                        <li><strong>Users:</strong> 26 unique speakers</li>
                        <li><strong>Samples:</strong> 1,693 base samples</li>
                        <li><strong>Augmentation:</strong> 3x factor for training</li>
                        <li><strong>GPU:</strong> Tesla T4 (14.7 GB)</li>
                    </ul>
                </div>
                <div>
                    <h4>πŸ† Best Model Achievements</h4>
                    <ul>
                        <li><strong>ResNet-18:</strong> 100% Perfect Accuracy πŸ₯‡</li>
                        <li><strong>Parameters:</strong> 11.3M (4.9M trainable)</li>
                        <li><strong>Training Time:</strong> 20 epochs (~14 minutes)</li>
                        <li><strong>Security Score:</strong> 0.9997</li>
                    </ul>
                </div>
            </div>
        </div>
        """)
    
    # Event handlers
    model_selector.change(
        fn=get_model_info,
        inputs=[model_selector],
        outputs=[model_info_display]
    )
    
    process_btn.click(
        fn=process_voice,
        inputs=[audio_input, model_selector, confidence_slider],
        outputs=[status_output, message_output, confidence_output, plot_output, detailed_info]
    )
    
    # Enhanced footer
    gr.HTML("""
    <div style="text-align: center; padding: 25px; margin-top: 40px; background: linear-gradient(45deg, #37474F, #455A64); color: white; border-radius: 15px; box-shadow: 0 8px 25px rgba(0,0,0,0.2);">
        <h4>Developed with PyTorch & Gradio</h4>
        <p>&copy; 2025 - Voice Security System. All rights reserved.</p>
    </div>
    """)


# Launch configuration
if __name__ == "__main__":
    app.launch(
        share=True,
        server_name="0.0.0.0",
        server_port=7860,
        show_error=True
    )