File size: 43,823 Bytes
04c4cd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
import torch
import asyncio
import websockets
import json
import threading
import numpy as np
import logging
import time
import tempfile
import os
import re
from concurrent.futures import ThreadPoolExecutor
import subprocess
import struct

# NeMo imports
import nemo.collections.asr as nemo_asr
import soundfile as sf

# Whisper imports
# from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, WhisperTokenizer, pipeline
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor


# Arabic number conversion imports for Whisper
try:
    from pyarabic.number import text2number
    arabic_numbers_available = True
    print("✓ pyarabic library available for Whisper number conversion")
except ImportError:
    arabic_numbers_available = False
    print("✗ pyarabic not available - install with: pip install pyarabic")
    print("Arabic numbers will not be converted to digits for Whisper")

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ===== NeMo Arabic number mapping =====
arabic_numbers_nemo = {
    # Basic digits
    "سفر": "0", "فيرو": "0", "هيرو": "0","صفر": "0", "زيرو": "0", "٠": "0","زيو": "0","زير": "0","زير": "0","زر": "0","زروا": "0","زرا": "0","زيره ": "0","زرو ": "0",
    "واحد": "1", "واحدة": "1", "١": "1",
    "اتنين": "2", "اثنين": "2", "إثنين": "2", "اثنان": "2", "إثنان": "2", "٢": "2",
    "تلاتة": "3", "ثلاثة": "3", "٣": "3","تلاته": "3","ثلاثه": "3","ثلاثا": "3","تلاتا": "3",
    "اربعة": "4", "أربعة": "4", "٤": "4","اربعه": "4","أربعه": "4","أربع": "4","اربع": "4","اربعا": "4","أربعا": "4",
    "خمسة": "5", "خمسه": "5", "٥": "5", "خمس": "5", "خمسا": "5",
    "ستة": "6", "سته": "6", "٦": "6", "ست": "6", "ستّا": "6", "ستةً": "6",
    "سبعة": "7", "سبعه": "7", "٧": "7", "سبع": "7", "سبعا": "7",
    "ثمانية": "8", "ثمانيه": "8", "٨": "8", "ثمان": "8", "ثمنية": "8", "ثمنيه": "8", "ثمانيا": "8", "ثمن": "8",
    "تسعة": "9", "تسعه": "9", "٩": "9", "تسع": "9", "تسعا": "9",
    
    # Teens
    "عشرة": "10", "١٠": "10",
    "حداشر": "11", "احد عشر": "11","احداشر": "11",
    "اتناشر": "12", "اثنا عشر": "12",
    "تلتاشر": "13", "ثلاثة عشر": "13",
    "اربعتاشر": "14", "أربعة عشر": "14",
    "خمستاشر": "15", "خمسة عشر": "15",
    "ستاشر": "16", "ستة عشر": "16",
    "سبعتاشر": "17", "سبعة عشر": "17",
    "طمنتاشر": "18", "ثمانية عشر": "18",
    "تسعتاشر": "19", "تسعة عشر": "19",
    
    # Tens
    "عشرين": "20", "٢٠": "20",
    "تلاتين": "30", "ثلاثين": "30", "٣٠": "30",
    "اربعين": "40", "أربعين": "40", "٤٠": "40",
    "خمسين": "50", "٥٠": "50",
    "ستين": "60", "٦٠": "60",
    "سبعين": "70", "٧٠": "70",
    "تمانين": "80", "ثمانين": "80", "٨٠": "80","تمانون": "80","ثمانون": "80",
    "تسعين": "90", "٩٠": "90",
    
    # Hundreds
    "مية": "100", "مائة": "100", "مئة": "100", "١٠٠": "100",
    "ميتين": "200", "مائتين": "200",
    "تلاتمية": "300", "ثلاثمائة": "300",
    "اربعمية": "400", "أربعمائة": "400",
    "خمسمية": "500", "خمسمائة": "500",
    "ستمية": "600", "ستمائة": "600",
    "سبعمية": "700", "سبعمائة": "700",
    "تمانمية": "800", "ثمانمائة": "800",
    "تسعمية": "900", "تسعمائة": "900",
    
    # Thousands
    "ألف": "1000", "الف": "1000", "١٠٠٠": "1000",
    "ألفين": "2000", "الفين": "2000",
    "تلات تلاف": "3000", "ثلاثة آلاف": "3000",
    "اربعة آلاف": "4000", "أربعة آلاف": "4000",
    "خمسة آلاف": "5000",
    "ستة آلاف": "6000",
    "سبعة آلاف": "7000",
    "تمانية آلاف": "8000", "ثمانية آلاف": "8000",
    "تسعة آلاف": "9000",
    
    # Large numbers
    "عشرة آلاف": "10000",
    "مية ألف": "100000", "مائة ألف": "100000",
    "مليون": "1000000", "١٠٠٠٠٠٠": "1000000",
    "ملايين": "1000000",
    "مليار": "1000000000", "١٠٠٠٠٠٠٠٠٠": "1000000000"
}

def replace_arabic_numbers_nemo(text: str) -> str:
    """Convert Arabic number words to digits for NeMo"""
    for word, digit in arabic_numbers_nemo.items():
        text = re.sub(rf"\b{word}\b", digit, text)
    return text

def convert_arabic_numbers_whisper(sentence: str) -> str:
    """

    Replace Arabic number words in a sentence with digits for Whisper,

    preserving all other words and punctuation.

    """
    if not arabic_numbers_available or not sentence.strip():
        return sentence
    
    try:
        # Normalization step
        replacements = {
            "اربعة": "أربعة", "اربع": "أربع", "اثنين": "اثنان",
            "اتنين": "اثنان", "ثلاث": "ثلاثة", "خمس": "خمسة",
            "ست": "ستة", "سبع": "سبعة", "ثمان": "ثمانية",
            "تسع": "تسعة", "عشر": "عشرة",
        }
        for wrong, correct in replacements.items():
            sentence = re.sub(rf"\b{wrong}\b", correct, sentence)

        # Split by whitespace but keep spaces
        words = re.split(r'(\s+)', sentence)
        converted_words = []

        for word in words:
            stripped = word.strip()
            if not stripped:  # skip spaces
                converted_words.append(word)
                continue

            try:
                num = text2number(stripped)
                if isinstance(num, int):
                    if num != 0 or stripped == "صفر":
                        converted_words.append(str(num))
                    else:
                        converted_words.append(word)
                else:
                    converted_words.append(word)
            except Exception:
                converted_words.append(word)

        return ''.join(converted_words)

    except Exception as e:
        logger.warning(f"Error converting Arabic numbers: {e}")
        return sentence

# Global models
asr_model_nemo = None
whisper_model = None
whisper_processor = None
whisper_tokenizer = None
device = None
torch_dtype = None

def initialize_models():
    """Initialize both NeMo and Whisper models"""
    global asr_model_nemo, whisper_model, whisper_processor, whisper_tokenizer, device, torch_dtype
    
    # Initialize device settings
    device = "cuda" if torch.cuda.is_available() else "cpu"
    torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
    
    logger.info(f"Using device: {device}")
    logger.info(f"CUDA available: {torch.cuda.is_available()}")
    
    # Initialize NeMo model
    logger.info("Loading NeMo FastConformer Arabic ASR model...")
    model_path = "stt_ar_fastconformer_hybrid_large_pcd_v1.0.nemo"
    
    if os.path.exists(model_path):
        try:
            asr_model_nemo = nemo_asr.models.EncDecCTCModel.restore_from(model_path)
            asr_model_nemo.eval()
            logger.info("✓ NeMo FastConformer model loaded successfully")
        except Exception as e:
            logger.error(f"Failed to load NeMo model: {e}")
            asr_model_nemo = None
    else:
        logger.warning(f"NeMo model not found at: {model_path}")
        asr_model_nemo = None
    
    # Initialize Whisper model
    # from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor

    logger.info("Loading Whisper large-v3 model...")
    MODEL_NAME = "alaatiger989/FT_Arabic_Whisper_V1_1"

    try:
        # Try with flash attention first
        try:
            import flash_attn
            whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
                MODEL_NAME,
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=True,
                use_safetensors=True,
                attn_implementation="flash_attention_2"
            )
            logger.info("✓ Whisper loaded with flash attention")
        except:
            whisper_model = AutoModelForSpeechSeq2Seq.from_pretrained(
                MODEL_NAME,
                torch_dtype=torch_dtype,
                low_cpu_mem_usage=True,
                use_safetensors=True
            )
            logger.info("✓ Whisper loaded with standard attention")

        whisper_model.to(device)
        whisper_processor = AutoProcessor.from_pretrained(MODEL_NAME)

        # Use processor.tokenizer, don’t reload separately
        whisper_tokenizer = whisper_processor.tokenizer  

        logger.info("✓ Whisper model + tokenizer loaded successfully")

    except Exception as e:
        logger.error(f"Failed to load Whisper model: {e}")
        whisper_model = None

# Initialize models on startup
initialize_models()

# Thread pool for processing
executor = ThreadPoolExecutor(max_workers=4)



class JambonzAudioBuffer:
    def __init__(self, sample_rate=8000, chunk_duration=1.0):
        self.sample_rate = sample_rate
        self.chunk_duration = chunk_duration
        self.chunk_samples = int(chunk_duration * sample_rate)
        
        self.buffer = np.array([], dtype=np.float32)
        self.lock = threading.Lock()
        self.total_audio = np.array([], dtype=np.float32)
        
        # Voice Activity Detection - ADJUSTED FOR WHISPER
        self.silence_threshold = 0.01  # Lower threshold for Whisper
        self.min_speech_samples = int(0.3 * sample_rate)  # 300ms minimum speech
        
    def add_audio(self, audio_data):
        with self.lock:
            self.buffer = np.concatenate([self.buffer, audio_data])
            self.total_audio = np.concatenate([self.total_audio, audio_data])
            
            # Log audio addition for debugging
            logger.debug(f"Added {len(audio_data)} audio samples, total: {len(self.total_audio)}")
    
    def has_chunk_ready(self):
        with self.lock:
            ready = len(self.buffer) >= self.chunk_samples
            if ready:
                logger.debug(f"Chunk ready: {len(self.buffer)} >= {self.chunk_samples}")
            return ready
    
    def is_speech(self, audio_chunk):
        """Enhanced VAD based on energy - better for Whisper"""
        if len(audio_chunk) < self.min_speech_samples:
            logger.debug(f"Audio too short for VAD: {len(audio_chunk)} < {self.min_speech_samples}")
            return False
            
        # Calculate RMS energy
        rms_energy = np.sqrt(np.mean(audio_chunk ** 2))
        
        # Also check peak amplitude
        peak_amplitude = np.max(np.abs(audio_chunk))
        
        is_speech = rms_energy > self.silence_threshold or peak_amplitude > (self.silence_threshold * 2)
        
        logger.debug(f"VAD check - RMS: {rms_energy:.4f}, Peak: {peak_amplitude:.4f}, "
                    f"Threshold: {self.silence_threshold}, Speech: {is_speech}")
        
        return is_speech
    
    def get_chunk_for_processing(self):
        """Get audio chunk for processing"""
        with self.lock:
            if len(self.buffer) < self.chunk_samples:
                return None
            
            logger.debug(f"Returning processing signal, buffer size: {len(self.buffer)}")
            return np.array([1])  # Signal that chunk is ready
    
    def get_all_audio(self):
        """Get all accumulated audio"""
        with self.lock:
            audio_copy = self.total_audio.copy()
            logger.debug(f"Returning {len(audio_copy)} total audio samples")
            return audio_copy
    
    def clear(self):
        with self.lock:
            self.buffer = np.array([], dtype=np.float32)
            self.total_audio = np.array([], dtype=np.float32)
            logger.debug("Audio buffer cleared")
    
    def reset_for_new_segment(self):
        """Reset buffers for new transcription segment"""
        with self.lock:
            self.buffer = np.array([], dtype=np.float32)
            self.total_audio = np.array([], dtype=np.float32)
            logger.debug("Audio buffer reset for new segment")

def linear16_to_audio(audio_bytes, sample_rate=8000):
    """Convert LINEAR16 PCM bytes to numpy array"""
    try:
        audio_array = np.frombuffer(audio_bytes, dtype=np.int16)
        audio_array = audio_array.astype(np.float32) / 32768.0
        return audio_array
    except Exception as e:
        logger.error(f"Error converting LINEAR16 to audio: {e}")
        return np.array([], dtype=np.float32)

from scipy.signal import resample_poly

# def resample_audio(audio_data, source_rate, target_rate):
#     """High-quality resampling using polyphase resampler."""
#     if source_rate == target_rate:
#         return audio_data.astype(np.float32)
#     # convert float32 [-1..1] to float32 still, but resample
#     gcd = np.gcd(source_rate, target_rate)
#     up = target_rate // gcd
#     down = source_rate // gcd
#     # resample_poly expects 1D numpy array
#     try:
#         resampled = resample_poly(audio_data, up, down).astype(np.float32)
#         return resampled
#     except Exception as e:
#         logger.warning(f"resample_audio fallback: {e}")
#         # last-resort simple repeat (keep previous behavior) but warn
#         if source_rate == 8000 and target_rate == 16000:
#             return np.repeat(audio_data, 2).astype(np.float32)
#         return audio_data.astype(np.float32)

import numpy as np
from scipy.signal import resample_poly, butter, lfilter
import webrtcvad
import noisereduce as nr

# Initialize WebRTC VAD once (0..3, higher = more aggressive/noisy environments)
_vad = webrtcvad.Vad(2)

def resample_audio(audio_data, source_rate, target_rate=16000,

                              lowcut=80.0, highcut=7600.0,

                              frame_ms=30, required_ratio=0.55):
    """

    Resample -> Bandpass filter -> Noise reduction -> WebRTC VAD speech detection.



    Returns:

        processed_audio (np.ndarray float32): cleaned/resampled audio

        is_speech (bool): True if VAD detects speech

    """

    # --- Resample ---
    if source_rate != target_rate:
        gcd = np.gcd(source_rate, target_rate)
        up = target_rate // gcd
        down = source_rate // gcd
        try:
            audio_data = resample_poly(audio_data, up, down).astype(np.float32)
        except Exception:
            audio_data = np.repeat(audio_data, int(target_rate/source_rate)).astype(np.float32)
    else:
        audio_data = audio_data.astype(np.float32)

    # --- Bandpass filter (speech range) ---
    try:
        nyq = 0.5 * target_rate
        low = lowcut / nyq
        high = highcut / nyq
        b, a = butter(4, [low, high], btype='band')
        audio_data = lfilter(b, a, audio_data).astype(np.float32)
    except Exception:
        pass

    # --- Noise reduction ---
    try:
        if len(audio_data) >= int(0.25 * target_rate):
            noise_clip = audio_data[:int(0.25 * target_rate)]
            audio_data = nr.reduce_noise(y=audio_data, y_noise=noise_clip, sr=target_rate).astype(np.float32)
    except Exception:
        pass

    # --- WebRTC VAD ---
    def frame_generator(frame_ms, audio, sample_rate):
        n = int(sample_rate * (frame_ms / 1000.0))
        if len(audio) < n:
            return
        offset = 0
        while offset + n <= len(audio):
            frame = audio[offset:offset+n]
            yield (frame * 32767).astype(np.int16).tobytes()
            offset += n

    frames = list(frame_generator(frame_ms, audio_data, target_rate))
    voiced = 0
    for f in frames:
        try:
            if _vad.is_speech(f, target_rate):
                voiced += 1
        except Exception:
            pass
    ratio = voiced / max(1, len(frames))
    is_speech = ratio >= required_ratio

    return audio_data, is_speech

def transcribe_with_nemo(audio_data, source_sample_rate=8000, target_sample_rate=16000):
    """Transcribe audio using NeMo FastConformer"""
    try:
        if len(audio_data) == 0 or asr_model_nemo is None:
            return ""
        
        # Resample to 16kHz (NeMo models typically expect 16kHz)
        resampled_audio, has_speech = resample_audio(audio_data, source_sample_rate, target_sample_rate)
        
        if has_speech:
            print("Speech detected, sending to ASR...")
            # Skip very short audio
            min_samples = int(0.3 * target_sample_rate)
            if len(resampled_audio) < min_samples:
                return ""
            
            start_time = time.time()
            
            # Save audio to temporary file (NeMo expects file path)
            with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
                sf.write(tmp_file.name, resampled_audio, target_sample_rate)
                tmp_path = tmp_file.name
            
            try:
                # Transcribe with NeMo
                result = asr_model_nemo.transcribe([tmp_path])
                
                if result and len(result) > 0:
                    # Handle different NeMo result formats
                    if hasattr(result[0], 'text'):
                        raw_text = result[0].text
                    elif isinstance(result[0], str):
                        raw_text = result[0]
                    else:
                        raw_text = str(result[0])
                    
                    if not isinstance(raw_text, str):
                        raw_text = str(raw_text)
                    
                    if raw_text and raw_text.strip():
                        # Convert Arabic numbers to digits for NeMo
                        cleaned_text = replace_arabic_numbers_nemo(raw_text)
                        end_time = time.time()
                        
                        if cleaned_text.strip():
                            logger.info(f"NeMo transcription: '{cleaned_text}' (processed in {end_time - start_time:.2f}s)")
                        
                        return cleaned_text.strip()
                        
            finally:
                # Clean up temporary file
                if os.path.exists(tmp_path):
                    os.remove(tmp_path)
            
            return ""
        else:
            print("Silence/noise, skipping...")
        
    except Exception as e:
        logger.error(f"Error during NeMo transcription: {e}")
        return ""

def transcribe_with_whisper(audio_data, source_sample_rate=8000, target_sample_rate=16000):
    """Transcribe audio chunk using Whisper model directly"""
    try:
        if len(audio_data) == 0 or whisper_model is None:
            return ""
        
        # Resample from 8kHz to 16kHz for Whisper
        resampled_audio, has_speech = resample_audio(audio_data, source_sample_rate, target_sample_rate)
        if has_speech:
            print("Speech detected, sending to ASR...")
            # Ensure minimum length for Whisper
            min_samples = int(0.1 * target_sample_rate)  # 100ms minimum
            if len(resampled_audio) < min_samples:
                return ""
            
            start_time = time.time()
            
            # Prepare input features with proper dtype
            input_features = whisper_processor(
                resampled_audio, 
                sampling_rate=target_sample_rate, 
                return_tensors="pt"
            ).input_features
            
            # Ensure correct dtype and device
            input_features = input_features.to(device=device, dtype=torch_dtype)
            
            # Create attention mask to avoid warnings
            attention_mask = torch.ones(
                input_features.shape[:-1], 
                dtype=torch.long, 
                device=device
            )
            
            # Generate transcription using model directly
            with torch.no_grad():
                predicted_ids = whisper_model.generate(
                    input_features,
                    attention_mask=attention_mask,
                    max_new_tokens=128,
                    do_sample=False,
                    # temperature=0.0,
                    num_beams=1,
                    language="english",
                    task="translate",
                    pad_token_id=whisper_tokenizer.pad_token_id,
                    eos_token_id=whisper_tokenizer.eos_token_id
                )
            
            # Decode the transcription
            transcription = whisper_tokenizer.batch_decode(
                predicted_ids, 
                skip_special_tokens=True
            )[0].strip()
            
            end_time = time.time()
            
            logger.info(f"Whisper transcription completed in {end_time - start_time:.2f}s: '{transcription}'")
            return transcription
        else:
            print("Silence/noise, skipping...")
    except Exception as e:
        logger.error(f"Error during Whisper transcription: {e}")
        return ""

class UnifiedSTTHandler:
    def __init__(self, websocket):
        self.websocket = websocket
        self.audio_buffer = None
        self.config = {}
        self.running = False
        self.transcription_task = None
        self.use_nemo = False  # Flag to determine which model to use
        
        # Auto-final detection variables
        self.interim_count = 0
        self.last_interim_time = None
        self.silence_timeout = 2.9
        self.min_interim_count = 1
        self.auto_final_task = None
        self.accumulated_transcript = ""
        self.final_sent = False
        self.segment_number = 0
        self.last_partial = ""
        
        # Processing tracking
        self.processing_count = 0

    # Add this debugging method to your UnifiedSTTHandler class

    async def add_audio_data(self, audio_bytes):
        """Add audio data to buffer with enhanced debugging"""
        if self.audio_buffer and self.running:
            audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"])
            self.audio_buffer.add_audio(audio_data)
            
            model_name = "NeMo" if self.use_nemo else "Whisper"
            
            # Debug logging every few audio packets
            if len(audio_data) > 0:
                total_samples = len(self.audio_buffer.get_all_audio())
                total_seconds = total_samples / self.config["sample_rate"]
                
                # Log every second of audio
                if int(total_seconds) != getattr(self, '_last_logged_second', -1):
                    logger.info(f"{model_name} - Accumulated {total_seconds:.1f}s of audio ({total_samples} samples)")
                    self._last_logged_second = int(total_seconds)
                    
                    # Check if we should have chunks ready
                    chunk_ready = self.audio_buffer.has_chunk_ready()
                    logger.info(f"{model_name} - Chunk ready: {chunk_ready}")        

    async def start_processing(self, start_message):
        """Initialize with start message from jambonz"""
        self.config = {
            "language": start_message.get("language", "ar-EG"),
            "format": start_message.get("format", "raw"),
            "encoding": start_message.get("encoding", "LINEAR16"),
            "sample_rate": start_message.get("sampleRateHz", 8000),
            "interim_results": True,  # Always enable for internal processing
            "options": start_message.get("options", {})
        }
        
        # Determine which model to use based on language parameter
        language = self.config["language"]
        if language == "ar-EG":
            logger.info("Selected NeMo FastConformer")
            self.use_nemo = True
            model_name = "NeMo FastConformer"
        elif language == "ar-EG-whis":
            logger.info("Selected Whisper large-v3")
            self.use_nemo = False
            model_name = "Whisper large-v3"
        else:
            # Default to NeMo for any other Arabic variant
            self.use_nemo = True
            model_name = "NeMo FastConformer (default)"
        
        logger.info(f"STT session started with {model_name} for language: {language}")
        logger.info(f"Config: {self.config}")
        
        # Check if selected model is available
        if self.use_nemo and asr_model_nemo is None:
            await self.send_error("NeMo model not available")
            return
        elif not self.use_nemo and whisper_model is None:
            await self.send_error("Whisper model not available")
            return
        
        # Initialize audio buffer with model-specific settings
        if self.use_nemo:
            chunk_duration = 1.0  # NeMo processes every 1 second
        else:
            chunk_duration = 2.0  # Whisper processes every 2 seconds for better accuracy
        
        self.audio_buffer = JambonzAudioBuffer(
            sample_rate=self.config["sample_rate"],
            chunk_duration=chunk_duration
        )
        
        # Adjust VAD threshold for Whisper
        if not self.use_nemo:
            self.audio_buffer.silence_threshold = 0.005  # Lower threshold for Whisper
        
        # Reset session variables
        self.running = True
        self.interim_count = 0
        self.last_interim_time = None
        self.accumulated_transcript = ""
        self.final_sent = False
        self.segment_number = 0
        self.processing_count = 0
        self.last_partial = ""
        
        # Start background transcription task
        self.transcription_task = asyncio.create_task(self._process_audio_chunks())
        
        # Start auto-final detection task
        self.auto_final_task = asyncio.create_task(self._monitor_for_auto_final())
        
        logger.info(f"Background tasks started for {model_name}")



    async def stop_processing(self):
        """Stop current processing session"""
        logger.info("Stopping STT session...")
        self.running = False
        
        # Cancel background tasks
        for task in [self.transcription_task, self.auto_final_task]:
            if task:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass
        
        # Send final transcription if not already sent
        if not self.final_sent and self.accumulated_transcript.strip():
            await self.send_transcription(self.accumulated_transcript, is_final=True)
        
        # Process any remaining audio for comprehensive final transcription
        if self.audio_buffer:
            all_audio = self.audio_buffer.get_all_audio()
            if len(all_audio) > 0 and not self.final_sent:
                loop = asyncio.get_event_loop()
                
                if self.use_nemo:
                    final_transcription = await loop.run_in_executor(
                        executor, transcribe_with_nemo, all_audio, self.config["sample_rate"]
                    )
                else:
                    final_transcription = await loop.run_in_executor(
                        executor, transcribe_with_whisper, all_audio, self.config["sample_rate"]
                    )
                
                if final_transcription.strip():
                    await self.send_transcription(final_transcription, is_final=True)
        
        # Clear audio buffer
        if self.audio_buffer:
            self.audio_buffer.clear()
        
        logger.info("STT session stopped")
    
    async def start_new_segment(self):
        """Start a new transcription segment"""
        self.segment_number += 1
        self.interim_count = 0
        self.last_interim_time = None
        self.accumulated_transcript = ""
        self.final_sent = False
        self.last_partial = ""
        self.processing_count = 0
        
        if self.audio_buffer:
            self.audio_buffer.reset_for_new_segment()
        
        logger.info(f"Started new transcription segment #{self.segment_number}")
    
    async def add_audio_data(self, audio_bytes):
        """Add audio data to buffer"""
        if self.audio_buffer and self.running:
            audio_data = linear16_to_audio(audio_bytes, self.config["sample_rate"])
            self.audio_buffer.add_audio(audio_data)
    
    async def _process_audio_chunks(self):
        """Process audio chunks for interim results - with debugging"""
        model_name = "NeMo" if self.use_nemo else "Whisper"
        logger.info(f"Starting audio chunk processing for {model_name}")
        
        chunk_count = 0
        
        while self.running:
            try:
                if self.audio_buffer and self.audio_buffer.has_chunk_ready():
                    chunk_count += 1
                    logger.info(f"{model_name} - Processing chunk #{chunk_count}")
                    
                    chunk_signal = self.audio_buffer.get_chunk_for_processing()
                    if chunk_signal is not None:
                        all_audio = self.audio_buffer.get_all_audio()
                        
                        logger.info(f"{model_name} - Got {len(all_audio)} samples for processing")
                        
                        if len(all_audio) > 0:
                            # Get the latest chunk for VAD check
                            latest_chunk_start = max(0, len(all_audio) - self.audio_buffer.chunk_samples)
                            latest_chunk = all_audio[latest_chunk_start:]
                            
                            # Check for speech activity
                            has_speech = self.audio_buffer.is_speech(latest_chunk)
                            logger.info(f"{model_name} - Speech detected: {has_speech}")
                            
                            if has_speech:
                                logger.info(f"{model_name} - Starting transcription...")
                                
                                loop = asyncio.get_event_loop()
                                start_time = time.time()
                                
                                try:
                                    # Choose transcription method based on model selection
                                    if self.use_nemo:
                                        transcription = await loop.run_in_executor(
                                            executor, transcribe_with_nemo, all_audio, self.config["sample_rate"]
                                        )
                                    else:
                                        transcription = await loop.run_in_executor(
                                            executor, transcribe_with_whisper, all_audio, self.config["sample_rate"]
                                        )
                                    
                                    process_time = time.time() - start_time
                                    logger.info(f"{model_name} - Transcription completed in {process_time:.2f}s: '{transcription}'")
                                    
                                    if transcription and transcription.strip():
                                        self.processing_count += 1
                                        self.accumulated_transcript = transcription
                                        
                                        if transcription != self.last_partial or self.interim_count == 0:
                                            self.last_partial = transcription
                                            self.interim_count += 1
                                            self.last_interim_time = time.time()
                                            logger.info(f"{model_name} - Updated interim_count to {self.interim_count}")
                                        else:
                                            self.last_interim_time = time.time()
                                            logger.info(f"{model_name} - Same transcription, updating time only")
                                    else:
                                        logger.info(f"{model_name} - No transcription result")
                                        
                                except Exception as e:
                                    logger.error(f"{model_name} - Transcription error: {e}")
                                    import traceback
                                    traceback.print_exc()
                            else:
                                logger.debug(f"{model_name} - No speech in chunk")
                    else:
                        logger.warning(f"{model_name} - Chunk signal was None")
                else:
                    # Log why chunk is not ready
                    if self.audio_buffer:
                        current_size = len(self.audio_buffer.buffer)
                        required_size = self.audio_buffer.chunk_samples
                        if current_size > 0:
                            logger.debug(f"{model_name} - Buffer: {current_size}/{required_size} samples")
                
                await asyncio.sleep(0.1)
                
            except Exception as e:
                logger.error(f"{model_name} - Error in chunk processing: {e}")
                import traceback
                traceback.print_exc()
                await asyncio.sleep(1)

    async def _monitor_for_auto_final(self):
        """Monitor for auto-final conditions with model-specific timeouts"""
        model_name = "NeMo" if self.use_nemo else "Whisper"
        timeout = 2.0 if self.use_nemo else 3.0  # Longer timeout for Whisper
        
        logger.info(f"Starting auto-final monitoring for {model_name} (timeout: {timeout}s)")
        
        while self.running:
            try:
                current_time = time.time()
                
                if (self.interim_count >= self.min_interim_count and 
                    self.last_interim_time is not None and 
                    (current_time - self.last_interim_time) >= timeout and
                    not self.final_sent and
                    self.accumulated_transcript.strip()):
                    
                    silence_duration = current_time - self.last_interim_time
                    logger.info(f"Auto-final triggered for segment #{self.segment_number} ({model_name}) - "
                            f"Interim count: {self.interim_count}, Silence: {silence_duration:.1f}s")
                    
                    await self.send_transcription(self.accumulated_transcript, is_final=True)
                    await self.start_new_segment()
                
                await asyncio.sleep(0.5)  # Check every 500ms
                
            except Exception as e:
                logger.error(f"Error in auto-final monitoring: {e}")
                await asyncio.sleep(0.5)



    async def send_transcription(self, text, is_final=True, confidence=0.9):
        """Send transcription in jambonz format"""
        try:
            # Apply number conversion only for Whisper
            if not self.use_nemo and is_final:
                original_text = text
                converted_text = convert_arabic_numbers_whisper(text)
                
                if original_text != converted_text:
                    logger.info(f"Whisper - Arabic numbers converted: '{original_text}' -> '{converted_text}'")
                text = converted_text
            
            message = {
                "type": "transcription",
                "is_final": True,  # Always send as final
                "alternatives": [
                    {
                        "transcript": text,
                        "confidence": confidence
                    }
                ],
                "language": self.config.get("language", "ar-EG"),
                "channel": 1
            }
            
            await self.websocket.send(json.dumps(message))
            self.final_sent = True
            
            model_name = "NeMo" if self.use_nemo else "Whisper"
            logger.info(f"Sent FINAL transcription ({model_name}): '{text}'")
            
        except Exception as e:
            logger.error(f"Error sending transcription: {e}")
    
    async def send_error(self, error_message):
        """Send error message in jambonz format"""
        try:
            message = {
                "type": "error",
                "error": error_message
            }
            await self.websocket.send(json.dumps(message))
            logger.error(f"Sent error: {error_message}")
        except Exception as e:
            logger.error(f"Error sending error message: {e}")

async def handle_jambonz_websocket(websocket):
    """Handle jambonz WebSocket connections"""
    
    client_id = f"jambonz_{id(websocket)}"
    logger.info(f"New unified STT connection: {client_id}")
    
    handler = UnifiedSTTHandler(websocket)
    
    try:
        async for message in websocket:
            try:
                if isinstance(message, str):
                    data = json.loads(message)
                    message_type = data.get("type")
                    
                    if message_type == "start":
                        logger.info(f"Received start message: {data}")
                        await handler.start_processing(data)
                        
                    elif message_type == "stop":
                        logger.info("Received stop message - closing WebSocket")
                        await handler.stop_processing()
                        await websocket.close(code=1000, reason="Session stopped by client")
                        break
                        
                    else:
                        logger.warning(f"Unknown message type: {message_type}")
                        await handler.send_error(f"Unknown message type: {message_type}")
                
                else:
                    # Handle binary audio data
                    if not handler.running or handler.audio_buffer is None:
                        logger.warning("Received audio data outside of active session")
                        await handler.send_error("Received audio before start message or after stop")
                        continue
                    
                    await handler.add_audio_data(message)
            
            except json.JSONDecodeError as e:
                logger.error(f"JSON decode error: {e}")
                await handler.send_error(f"Invalid JSON: {str(e)}")
            except Exception as e:
                logger.error(f"Error processing message: {e}")
                await handler.send_error(f"Processing error: {str(e)}")
    
    except websockets.exceptions.ConnectionClosed:
        logger.info(f"Unified STT connection closed: {client_id}")
    except Exception as e:
        logger.error(f"Unified STT WebSocket error: {e}")
        try:
            await handler.send_error(str(e))
        except:
            pass
    finally:
        if handler.running:
            await handler.stop_processing()
        logger.info(f"Unified STT connection ended: {client_id}")

async def main():
    """Start the Unified Arabic STT WebSocket server"""
    logger.info("Starting Unified Arabic STT WebSocket server on port 3007...")
    
    # Check model availability
    models_available = []
    if asr_model_nemo is not None:
        models_available.append("NeMo FastConformer (ar-EG)")
    if whisper_model is not None:
        models_available.append("Whisper large-v3 (ar-EG-whis)")
    
    if not models_available:
        logger.error("No models available! Please check model paths and installations.")
        return
    
    # Start WebSocket server
    server = await websockets.serve(
        handle_jambonz_websocket,
        "0.0.0.0", 
        3007,
        ping_interval=20,
        ping_timeout=10,
        close_timeout=10
    )
    
    logger.info("Unified Arabic STT WebSocket server started on ws://0.0.0.0:3007")
    logger.info("Ready to handle jambonz STT requests with both models")
    logger.info("ROUTING:")
    logger.info("- language: 'ar-EG' → NeMo FastConformer (with built-in number conversion)")
    logger.info("- language: 'ar-EG-whis' → Whisper large-v3 (with pyarabic number conversion)")
    logger.info("FEATURES:")
    logger.info("- Continuous transcription with segmentation")
    logger.info("- Voice Activity Detection")
    logger.info("- Auto-final detection (2s silence timeout)")
    logger.info("- Model-specific number conversion")
    logger.info(f"AVAILABLE MODELS: {', '.join(models_available)}")
    
    # Wait for the server to close
    await server.wait_closed()

if __name__ == "__main__":
    print("=" * 80)
    print("Unified Arabic STT Server (NeMo + Whisper)")
    print("=" * 80)
    print("WebSocket Port: 3007")
    print("Protocol: jambonz STT API")
    print("Audio Format: LINEAR16 PCM @ 8kHz → 16kHz")
    print()
    print("LANGUAGE ROUTING:")
    print("- 'ar-EG' → NeMo FastConformer")
    print("  • Built-in Arabic number word to digit conversion")
    print("  • Optimized for Arabic dialects")
    print("- 'ar-EG-whis' → Whisper large-v3")
    print("  • pyarabic library number conversion (final transcripts only)")
    print("  • OpenAI Whisper model")
    print()
    print("FEATURES:")
    print("- Automatic model selection based on language parameter")
    print("- Voice Activity Detection")
    print("- Auto-final detection (2 seconds silence)")
    print("- Model-specific number conversion strategies")
    print("- Continuous transcription with segmentation")
    print()
    
    # Check model availability for startup info
    nemo_status = "✓ Available" if asr_model_nemo is not None else "✗ Not Available"
    whisper_status = "✓ Available" if whisper_model is not None else "✗ Not Available"
    arabic_numbers_status = "✓ Available" if arabic_numbers_available else "✗ Not Available (install pyarabic)"
    
    print("MODEL STATUS:")
    print(f"- NeMo FastConformer: {nemo_status}")
    print(f"- Whisper large-v3: {whisper_status}")
    print(f"- pyarabic (Whisper numbers): {arabic_numbers_status}")
    print("=" * 80)
    
    try:
        asyncio.run(main())
    except KeyboardInterrupt:
        print("\nShutting down unified server...")
    except Exception as e:
        print(f"Server error: {e}")