File size: 55,799 Bytes
da70c8c
51572be
da70c8c
 
e5e1cf3
 
f53a847
 
 
51572be
fe9f63a
51572be
e5e1cf3
51572be
 
e5e1cf3
 
 
 
 
 
 
615a636
51572be
 
 
 
 
 
 
6d44c5f
b0dee38
 
 
 
 
 
51572be
 
 
 
 
 
 
615a636
 
 
 
 
 
51572be
 
 
 
 
 
 
 
 
 
 
 
 
e5e1cf3
 
51572be
 
 
e5e1cf3
f53a847
e5e1cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f53a847
e5e1cf3
 
 
 
 
 
 
 
51572be
e5e1cf3
51572be
9010a30
cef5a6f
 
 
 
51572be
 
 
 
e5e1cf3
32ed20e
 
822ce81
c329763
b0a5a88
 
 
 
 
 
 
 
9eebcc5
920d18e
 
 
 
 
 
 
 
 
 
 
 
 
 
b0a5a88
 
 
 
 
 
 
 
 
 
 
 
 
 
f53a847
5d79055
32ed20e
 
 
 
 
 
 
822ce81
 
 
 
 
 
ff5e826
 
 
 
 
c329763
 
 
 
 
b0a5a88
 
 
 
 
 
 
 
 
07b30a1
 
 
 
 
 
 
0ce399f
 
 
 
 
 
 
b0a5a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cef5a6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5e1cf3
51572be
e5e1cf3
 
51572be
e5e1cf3
 
 
 
 
 
 
f53a847
 
 
51572be
f53a847
51572be
 
 
2220aba
 
 
 
 
 
 
 
51572be
58aa51a
2220aba
 
 
 
 
51572be
2220aba
 
 
 
 
 
 
 
51572be
 
f53a847
 
5d79055
 
51572be
5d79055
 
51572be
5d79055
51572be
e5e1cf3
51572be
5d79055
51572be
5d79055
51572be
f53a847
51572be
863d06f
58efaa5
e5e1cf3
c329763
f53a847
 
e5e1cf3
f53a847
51572be
 
 
f53a847
51572be
b0a5a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51572be
e699db1
51572be
 
32ed20e
b0a5a88
51572be
b0a5a88
 
 
51572be
e5e1cf3
b0a5a88
 
 
e5e1cf3
 
b0a5a88
 
e5e1cf3
 
 
51572be
e5e1cf3
51572be
 
e699db1
 
51572be
 
 
 
 
 
 
e699db1
58aa51a
e699db1
 
 
 
 
51572be
 
e699db1
51572be
 
 
822ce81
51572be
c329763
51572be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b0a5a88
51572be
 
 
 
 
 
 
 
 
 
 
822ce81
51572be
e699db1
 
 
 
 
51572be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6dba6f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51572be
 
 
e699db1
51572be
 
 
69a2351
51572be
 
 
 
e699db1
51572be
 
 
 
 
 
fafbe00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51572be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fafbe00
51572be
 
fafbe00
 
 
 
 
51572be
 
fafbe00
 
 
 
 
 
51572be
e5e1cf3
 
f53a847
 
51572be
 
 
 
f53a847
 
51572be
 
 
 
 
f53a847
e5e1cf3
f53a847
51572be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adc2859
51572be
 
 
 
e5e1cf3
 
51572be
 
f53a847
51572be
 
f53a847
51572be
 
 
 
 
 
 
 
 
0a1845d
51572be
 
e699db1
51572be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cef5a6f
 
51572be
 
 
 
 
 
 
 
cef5a6f
 
51572be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f53a847
adc2859
c4afdc5
 
 
51572be
 
 
f53a847
51572be
 
 
 
 
 
c329763
8960a88
51572be
 
 
cef5a6f
 
51572be
cef5a6f
 
b0a5a88
cef5a6f
 
b0a5a88
cef5a6f
 
 
b0a5a88
cef5a6f
51572be
 
 
 
cef5a6f
 
 
 
51572be
 
 
 
cdaa0f4
51572be
 
b0a5a88
 
 
 
cef5a6f
 
 
 
51572be
 
 
 
 
b0a5a88
 
51572be
f53a847
 
51572be
 
 
 
f53a847
 
51572be
 
 
615a636
 
 
f448887
 
 
 
 
 
615a636
 
 
51572be
 
 
 
e5e1cf3
 
615a636
51572be
e5e1cf3
 
5bbaad5
b0a5a88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5e1cf3
51572be
9eebcc5
 
4e7bf3c
 
9eebcc5
 
 
 
 
 
 
 
 
4e7bf3c
920d18e
9eebcc5
 
 
 
 
c329763
 
b0a5a88
c329763
 
 
 
 
b0a5a88
 
 
 
 
07b30a1
0ce399f
9eebcc5
b0a5a88
 
 
 
c329763
 
12a2ba3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c75032a
 
 
 
 
 
 
 
 
 
 
 
 
fd7eb4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c75032a
615a636
fd7eb4d
 
 
 
c75032a
fd7eb4d
c75032a
 
 
615a636
 
c75032a
fd7eb4d
 
 
 
 
 
 
 
 
c75032a
 
 
 
 
 
e699db1
c75032a
e6dba6f
 
 
 
 
 
 
 
 
 
 
 
c75032a
 
56e7960
fe9f63a
 
56e7960
 
fe9f63a
 
 
 
 
 
 
 
 
 
 
56e7960
fe9f63a
 
56e7960
fe9f63a
 
fa68581
 
 
 
 
 
 
 
fe9f63a
 
 
 
 
 
 
 
 
56e7960
b062e61
fa68581
51572be
56e7960
fa68581
fe9f63a
c75032a
fa68581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c75032a
 
 
 
615a636
 
 
56e7960
615a636
 
c75032a
51572be
b0a5a88
8960a88
b0a5a88
51572be
12a2ba3
51572be
e5e1cf3
 
fa68581
 
 
 
51572be
 
56e7960
fa68581
51572be
 
 
 
 
e5e1cf3
 
 
ff5e826
 
 
 
 
 
 
51572be
35eb04b
51572be
863d06f
 
 
fd7eb4d
863d06f
c1cb918
863d06f
 
615a636
 
 
 
863d06f
 
 
 
 
 
 
 
 
 
51572be
 
 
 
 
 
 
 
 
35eb04b
51572be
 
 
 
 
 
 
35eb04b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51572be
 
 
 
35eb04b
f448887
51572be
 
 
 
 
 
e5e1cf3
51572be
 
 
 
 
e5e1cf3
 
51572be
d0a140b
51572be
 
 
 
 
 
 
 
d0a140b
 
51572be
d0a140b
 
e5e1cf3
56e7960
51572be
 
56e7960
 
 
 
 
 
51572be
 
 
b062e61
56e7960
 
 
51572be
56e7960
 
 
 
 
fe9f63a
56e7960
fe9f63a
56e7960
 
e5e1cf3
 
56e7960
51572be
56e7960
51572be
 
e5e1cf3
 
c75032a
 
 
 
 
 
 
 
 
 
 
 
 
58aa51a
c75032a
 
 
 
51572be
c75032a
8960a88
 
c75032a
fe9f63a
56e7960
 
 
fa68581
 
fe9f63a
 
56e7960
fe9f63a
 
56e7960
c75032a
 
 
 
 
 
 
b062e61
56e7960
 
b062e61
56e7960
 
b062e61
56e7960
fe9f63a
 
56e7960
fe9f63a
56e7960
 
 
 
fe9f63a
 
c75032a
 
56e7960
 
c75032a
e5e1cf3
ba86463
 
 
 
 
 
 
 
 
 
 
 
 
829c030
f53a847
51572be
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
"""
Gradio Space for BlueTTS — multilingual ONNX TTS (slim 4-model pipeline).
Upstream: https://github.com/maxmelichov/BlueTTS
"""
import os
import re
import json
import time
import base64
import glob
import html
import subprocess
from dataclasses import dataclass
from importlib import import_module
from typing import Any, List, Optional, Tuple, Dict, Union
from unicodedata import normalize as uni_normalize

import numpy as np
from num2words import num2words
import gradio as gr
import onnxruntime as ort

from download_models import BLUE_REPO, download_blue_models, download_default_voices, download_renikud

# ------------------------------------------------------------------
# Paths
# ------------------------------------------------------------------
ONNX_DIR = "onnx_slim"
VOICES_DIR = "voices"
RENIKUD_PATH = "renikud.onnx"
CONFIG_PATH = "tts.json" if os.path.exists("tts.json") else os.path.join(ONNX_DIR, "tts.json")
VOCAB_PATH = next(
    (p for p in (os.path.join(ONNX_DIR, "vocab.json"), "vocab.json",
                 os.path.join(os.path.dirname(os.path.abspath(__file__)), "vocab.json"))
     if os.path.exists(p)),
    os.path.join(ONNX_DIR, "vocab.json"),
)

# ------------------------------------------------------------------
# Fetch models + default voices on first run
# ------------------------------------------------------------------
def _needs_download() -> bool:
    required = ["text_encoder.onnx", "vector_estimator.onnx", "vocoder.onnx",
                "duration_predictor.onnx"]
    repo_marker = os.path.join(ONNX_DIR, ".repo_id")
    if not os.path.exists(repo_marker):
        return True
    with open(repo_marker) as f:
        if f.read().strip() != BLUE_REPO:
            return True
    for fn in required:
        p = os.path.join(ONNX_DIR, fn)
        if not os.path.exists(p) or os.path.getsize(p) < 1000:
            return True
    return False


if _needs_download():
    print("[INFO] Slim ONNX bundle incomplete, downloading…")
    download_blue_models(ONNX_DIR)

download_default_voices(VOICES_DIR)
download_renikud(RENIKUD_PATH)

# ============================================================
# Vocab — phoneme → id map, shared with the old/new checkpoints.
# A vocab.json next to the slim ONNX files wins; otherwise we fall back to
# this built-in IPA map (same as the upstream Piper-style vocab + extras).
# ============================================================
_PIPER_MAP: dict[str, int] = {
    "_": 0, "^": 1, "$": 2, " ": 3, "!": 4, "'": 5, "(": 6, ")": 7, ",": 8, "-": 9, ".": 10,
    ":": 11, ";": 12, "?": 13, "a": 14, "b": 15, "c": 16, "d": 17, "e": 18, "f": 19,
    "h": 20, "i": 21, "j": 22, "k": 23, "l": 24, "m": 25, "n": 26, "o": 27, "p": 28, "q": 29, "r": 30, "s": 31, "t": 32, "u": 33,
    "v": 34, "w": 35, "x": 36, "y": 37, "z": 38, "æ": 39, "ç": 40, "ð": 41, "ø": 42, "ħ": 43, "ŋ": 44, "œ": 45,
    "ǀ": 46, "ǁ": 47, "ǂ": 48, "ǃ": 49, "ɐ": 50, "ɑ": 51, "ɒ": 52, "ɓ": 53, "ɔ": 54, "ɕ": 55,
    "ɖ": 56, "ɗ": 57, "ɘ": 58, "ə": 59, "ɚ": 60, "ɛ": 61, "ɜ": 62, "ɞ": 63, "ɟ": 64, "ɠ": 65, "ɡ": 66, "ɢ": 67,
    "ɣ": 68, "ɤ": 69, "ɥ": 70, "ɦ": 71, "ɧ": 72, "ɨ": 73, "ɪ": 74, "ɫ": 75, "ɬ": 76, "ɭ": 77, "ɮ": 78, "ɯ": 79,
    "ɰ": 80, "ɱ": 81, "ɲ": 82, "ɳ": 83, "ɴ": 84, "ɵ": 85, "ɶ": 86, "ɸ": 87, "ɹ": 88, "ɺ": 89, "ɻ": 90, "ɽ": 91,
    "ɾ": 92, "ʀ": 93, "ʁ": 94, "ʂ": 95, "ʃ": 96, "ʄ": 97, "ʈ": 98, "ʉ": 99, "ʊ": 100, "ʋ": 101, "ʌ": 102, "ʍ": 103,
    "ʎ": 104, "ʏ": 105, "ʐ": 106, "ʑ": 107, "ʒ": 108, "ʔ": 109, "ʕ": 110, "ʘ": 111, "ʙ": 112, "ʛ": 113, "ʜ": 114, "ʝ": 115,
    "ʟ": 116, "ʡ": 117, "ʢ": 118, "ʲ": 119, "ˈ": 120, "ˌ": 121, "ː": 122, "ˑ": 123, "˞": 124,
    "β": 125, "θ": 126, "χ": 127, "ᵻ": 128, "ⱱ": 129, "0": 130, "1": 131, "2": 132, "3": 133, "4": 134,
    "5": 135, "6": 136, "7": 137, "8": 138, "9": 139, "\u0327": 140, "\u0303": 141, "\u032A": 142, "\u032F": 143, "\u0329": 144,
    "ʰ": 145, "ˤ": 146, "ε": 147, "↓": 148, "#": 149, '"': 150, "↑": 151, "\u033A": 152, "\u033B": 153, "g": 154, "ʦ": 155, "X": 156,
}
_EXTENDED_MAP: dict[str, int] = {
    "A": 157, "B": 158, "C": 159, "D": 160, "E": 161, "F": 162, "G": 163, "H": 164, "I": 165, "J": 166, "K": 167, "L": 168, "M": 169, "N": 170,
    "O": 171, "P": 172, "Q": 173, "R": 174, "S": 175, "T": 176, "U": 177, "V": 178, "W": 179, "Y": 180, "Z": 181,
    "ʤ": 182, "ɝ": 183, "ʧ": 184, "ʼ": 185, "ʴ": 186, "ʱ": 187, "ʷ": 188, "ˠ": 189, "→": 190, "↗": 191, "↘": 192,
    "¡": 193, "¿": 194, "…": 195, "«": 196, "»": 197, "*": 198, "~": 199, "/": 200, "\\": 201, "&": 202,
    "\u0361": 203, "\u035C": 204, "\u0325": 205, "\u032C": 206, "\u0339": 207, "\u031C": 208, "\u031D": 209, "\u031E": 210, "\u031F": 211, "\u0320": 212, "\u0330": 213, "\u0334": 214, "\u031A": 215, "\u0318": 216, "\u0319": 217, "\u0348": 218, "\u0306": 219, "\u0308": 220, "\u031B": 221, "\u0324": 222, "\u033C": 223,
    "\u02C0": 224, "\u02C1": 225, "\u02BE": 226, "\u02BF": 227, "\u02BB": 228, "\u02C9": 229, "\u02CA": 230, "\u02CB": 231, "\u02C6": 232,
    "\u02E5": 233, "\u02E6": 234, "\u02E7": 235, "\u02E8": 236, "\u02E9": 237, "\u0300": 238, "\u0301": 239, "\u0302": 240, "\u0304": 241, "\u030C": 242, "\u0307": 243,
}
DEFAULT_CHAR_TO_ID: dict[str, int] = {**_PIPER_MAP, **_EXTENDED_MAP}

AVAILABLE_LANGS = ["en", "es", "de", "it", "he"]
BLUE_SYNTH_MAX_CHUNK_LEN = 200
# When pace blending is enabled, durations are nudged toward this many seconds
# per text token so speed feels more consistent on long or mixed-language text.
DURATION_PACE_DPT_REF = 0.0625
DEFAULT_MIXED_PACE_BLEND = 0.25
LANG_CODE_ALIASES: dict[str, str] = {"ge": "de", "en-us": "en"}
_ESPEAK_MAP = {
    "en": "en-us", "en-us": "en-us", "de": "de", "ge": "de",
    "it": "it", "es": "es",
}
_INLINE_LANG_PAIR = re.compile(r"<(en|en-us|he|es|de|ge|it)>(.*?)(?:</\1>|<\1>)", re.DOTALL | re.IGNORECASE)
_LANG_LIST_BLOCK_RE = re.compile(r"<lang_list\b[^>]*>.*?</lang_list>", re.DOTALL | re.IGNORECASE)
_LANG_TAG_RE = re.compile(r"</?[^>]+>")
_HEBREW_NIKUD_RE = re.compile(r"[\u0591-\u05BD\u05BF\u05C1-\u05C2\u05C4-\u05C5\u05C7]")
_HEBREW_CHAR_RE = re.compile(r"[\u0590-\u05ff]")
_EMAIL_RE = re.compile(r"[A-Za-z0-9._%+\-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}")
_LATIN_ALNUM_RE = re.compile(r"\d+[A-Za-z]+|[A-Za-z]+(?:[.'’\-][A-Za-z0-9]+)*")
_MIXED_EN_SEGMENT_RE = re.compile(
    r"[A-Za-z0-9._%+\-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}"
    r"|\d+[A-Za-z]+"
    r"|[A-Za-z]+(?:[.'’\-][A-Za-z0-9]+)*"
)
_DATE_RE = re.compile(r"(?<!\d)([0-3]?\d)[/.]([01]?\d)[/.](\d{2}|\d{4})(?!\d)")
_HEBREW_MONTH_ORDINALS = {
    1: "לראשון",
    2: "לשני",
    3: "לשלישי",
    4: "לרביעי",
    5: "לחמישי",
    6: "לשישי",
    7: "לשביעי",
    8: "לשמיני",
    9: "לתשיעי",
    10: "לעשירי",
    11: "לאחד עשר",
    12: "לשנים עשר",
}
_PERCENT_WORDS = {
    "he": "אחוז",
    "en": "percent",
    "es": "por ciento",
    "de": "Prozent",
    "it": "per cento",
}
_RATIO_WORDS = {
    "he": "ל",
    "en": "to",
    "es": "a",
    "de": "zu",
    "it": "a",
}


def _strip_helper_markup(text: str) -> str:
    """Remove non-spoken helper markup that can leak into synthesis prompts."""
    text = _LANG_LIST_BLOCK_RE.sub(" ", text)
    text = re.sub(r"</?lang_list\b[^>]*>", " ", text, flags=re.IGNORECASE)
    return text


def _strip_synthesis_tags(text: str) -> str:
    """Remove XML-like tags before tokenization so tag names are never spoken."""
    text = _strip_helper_markup(text)
    return _LANG_TAG_RE.sub(" ", text)


def strip_language_tags_for_display(text: str) -> str:
    """Remove internal language tags from phoneme text shown to users."""
    return re.sub(r"\s+", " ", _LANG_TAG_RE.sub("", text)).strip()


def strip_hebrew_nikud(text: str) -> str:
    """Remove Hebrew niqqud/cantillation marks while preserving Hebrew letters."""
    return _HEBREW_NIKUD_RE.sub("", text)


def _canonical_lang(lang: str) -> str:
    return LANG_CODE_ALIASES.get(lang.lower(), lang.lower())


def _has_mixed_hebrew_latin(text: str, lang: str) -> bool:
    lang = _canonical_lang(lang)
    return lang == "he" and bool(_HEBREW_CHAR_RE.search(text) and _LATIN_ALNUM_RE.search(text))


def strip_hebrew_abbreviation_quotes(text: str, lang: str) -> str:
    """Remove Hebrew abbreviation marks inside words, e.g. מנכ"ל -> מנכל."""
    if _canonical_lang(lang) != "he":
        return text
    return re.sub(r"(?<=[\u0590-\u05ff])[\"'״׳](?=[\u0590-\u05ff])", "", text)


def expand_hebrew_lamed_before_latin(text: str, lang: str) -> str:
    """Avoid one-letter Hebrew chunks in mixed text: CPU ל-GPU -> CPU אל GPU."""
    if _canonical_lang(lang) != "he":
        return text
    return re.sub(r"(?<![\u0590-\u05ff])ל\s*[-–—‑]?\s*(?=[A-Za-z0-9])", "אל ", text)


def strip_silent_separator_tokens(text: str) -> str:
    """Drop punctuation tokens that should not be sent as spoken content."""
    text = re.sub(r"(?<=[\u0590-\u05ff])[-–—‑]+(?=[A-Za-z0-9])", " ", text)
    text = re.sub(r"(?<=[A-Za-z0-9])[-–—‑]+(?=[\u0590-\u05ff])", " ", text)
    text = re.sub(r"(?<![A-Za-z])\s*[-–—‑]+\s*(?![A-Za-z])", " ", text)
    text = re.sub(r"(?<!\d)\s*:+\s*(?!\d)", " ", text)
    return re.sub(r"\s+", " ", text).strip()


def email_to_spoken_english(email: str) -> str:
    """Make email addresses pronounceable before English phonemization."""
    local, _, domain = email.partition("@")

    def spell_short_label(label: str) -> str:
        return " ".join(label) if 0 < len(label) <= 2 and label.isalpha() else label

    local = re.sub(r"[._]+", " dot ", local)
    local = re.sub(r"[-]+", " dash ", local)
    local = re.sub(r"[+]+", " plus ", local)
    domain_parts = [spell_short_label(part) for part in domain.split(".") if part]
    spoken = f"{local} at {' dot '.join(domain_parts)}"
    return re.sub(r"\s+", " ", spoken).strip()


def blend_duration_pace(
    dur: np.ndarray,
    text_mask: np.ndarray,
    pace_blend: float,
    pace_dpt_ref: float = DURATION_PACE_DPT_REF,
) -> np.ndarray:
    """Blend predicted seconds-per-token toward a stable reference pace."""
    blend = min(max(float(pace_blend), 0.0), 1.0)
    if blend <= 0.0:
        return np.asarray(dur, dtype=np.float32).reshape(-1)

    d = np.asarray(dur, dtype=np.float64).reshape(-1)
    token_count = np.maximum(
        np.asarray(text_mask, dtype=np.float64).sum(axis=(1, 2)),
        1.0,
    ).reshape(-1)
    dpt = d / token_count
    blended_dpt = (1.0 - blend) * dpt + blend * float(pace_dpt_ref)
    return (blended_dpt * token_count).astype(np.float32)


# ============================================================
# Phonemization (Renikud for Hebrew, espeak-ng for Latin langs)
# ============================================================
class TextProcessor:
    def __init__(self, renikud_path: Optional[str] = None):
        self.renikud = None
        if renikud_path is None and os.path.exists("model.onnx"):
            renikud_path = "model.onnx"
        if renikud_path and os.path.exists(renikud_path):
            try:
                from renikud_onnx import G2P
                self.renikud = G2P(renikud_path)
                print(f"[INFO] Loaded Renikud G2P from {renikud_path}")
            except ImportError as e:
                raise RuntimeError(
                    "Hebrew G2P needs `renikud-onnx`. Install: `uv sync`."
                ) from e
        self._espeak_backends: Dict[str, Any] = {}
        self._espeak_separator = None
        self._espeak_ready = False
        self._init_espeak()

    def _init_espeak(self):
        try:
            import espeakng_loader
            from phonemizer.backend.espeak.wrapper import EspeakWrapper
            from phonemizer.separator import Separator
            EspeakWrapper.set_library(espeakng_loader.get_library_path())
            if hasattr(EspeakWrapper, "set_data_path"):
                EspeakWrapper.set_data_path(espeakng_loader.get_data_path())
            self._espeak_separator = Separator(phone="", word=" ", syllable="")
            self._espeak_ready = True
        except Exception as e:
            print(f"[WARN] espeak-ng setup failed: {e}")

    def _get_backend(self, espeak_lang: str):
        if espeak_lang not in self._espeak_backends:
            from phonemizer.backend import EspeakBackend
            self._espeak_backends[espeak_lang] = EspeakBackend(
                espeak_lang, preserve_punctuation=True,
                with_stress=True, language_switch="remove-flags",
            )
        return self._espeak_backends[espeak_lang]

    def _espeak(self, text: str, lang: str) -> str:
        espeak_lang = _ESPEAK_MAP.get(lang)
        if espeak_lang is None:
            return text
        if self._espeak_ready:
            try:
                raw = self._get_backend(espeak_lang).phonemize(
                    [text], separator=self._espeak_separator
                )[0]
                return re.sub(r"\s+", " ", raw).strip()
            except Exception as e:
                print(f"[WARN] phonemizer failed for {lang}: {e}")
        try:
            r = subprocess.run(
                ["espeak-ng", "-q", "--ipa=1", "-v", espeak_lang, text],
                check=True, capture_output=True, text=True,
            )
            return re.sub(r"\s+", " ", r.stdout.replace("\n", " ")).strip()
        except Exception as e:
            print(f"[WARN] espeak-ng subprocess failed for {lang}: {e}")
        return text

    def _phonemize_segment(self, content: str, lang: str) -> str:
        content = strip_hebrew_nikud(_strip_synthesis_tags(content)).strip()
        if not content:
            return ""
        lang = LANG_CODE_ALIASES.get(lang, lang)
        has_hebrew = any("\u0590" <= c <= "\u05ff" for c in content)
        if has_hebrew or lang == "he":
            if not has_hebrew:
                return content
            if self.renikud is None:
                raise ValueError("Hebrew text requires Renikud weights (renikud.onnx).")
            return strip_silent_separator_tokens(self.renikud.phonemize(content))
        return strip_silent_separator_tokens(self._espeak(content, lang))

    def _phonemize_tagged_segments(self, content: str, lang: str) -> list[tuple[str, str]]:
        content = strip_hebrew_nikud(_strip_synthesis_tags(content)).strip()
        if not content:
            return []
        lang = _canonical_lang(lang)
        if not _has_mixed_hebrew_latin(content, lang):
            seg = self._phonemize_segment(content, lang)
            return [(lang, seg)] if seg else []

        pieces: list[tuple[str, str]] = []

        def add(piece: str, piece_lang: str) -> None:
            if piece_lang == "en" and _EMAIL_RE.fullmatch(piece):
                piece = email_to_spoken_english(piece)
            seg = self._phonemize_segment(piece, piece_lang)
            if seg:
                pieces.append((_canonical_lang(piece_lang), seg))

        last_end = 0
        for m in _MIXED_EN_SEGMENT_RE.finditer(content):
            if m.start() > last_end:
                add(content[last_end:m.start()], lang)
            add(m.group(0), "en")
            last_end = m.end()
        if last_end < len(content):
            add(content[last_end:], lang)
        return pieces

    @staticmethod
    def _wrap_segments(segments: list[tuple[str, str]]) -> str:
        return " ".join(f"<{tag}>{seg}</{tag}>" for tag, seg in segments if seg)

    def phonemize(self, text: str, lang: str = "he") -> str:
        """Phonemize, preserving inline ``<xx>…</xx>`` spans and re-wrapping
        every segment so the text encoder sees ``<lang>…</lang>`` boundaries."""
        text = _strip_helper_markup(text)
        lang = _canonical_lang(lang)
        if not _INLINE_LANG_PAIR.search(text):
            return self._wrap_segments(self._phonemize_tagged_segments(text, lang))
        pieces: list[tuple[str, str]] = []
        last_end = 0
        for m in _INLINE_LANG_PAIR.finditer(text):
            if m.start() > last_end:
                pieces.extend(self._phonemize_tagged_segments(text[last_end:m.start()], lang))
            tag = _canonical_lang(m.group(1))
            pieces.extend(self._phonemize_tagged_segments(m.group(2), tag))
            last_end = m.end()
        if last_end < len(text):
            pieces.extend(self._phonemize_tagged_segments(text[last_end:], lang))
        return re.sub(r"\s+", " ", self._wrap_segments(pieces)).strip()


# ============================================================
# Char-level tokenizer (vocab.json or built-in fallback)
# ============================================================
class UnicodeProcessor:
    def __init__(self, indexer_path: Optional[str] = None):
        self._char_to_id: Optional[Dict[str, int]]
        self._codepoint_indexer: Optional[Dict[int, int]]
        self.pad_id: int = 0
        if indexer_path and os.path.exists(indexer_path):
            with open(indexer_path, "r") as f:
                raw = json.load(f)
            if isinstance(raw, dict) and "char_to_id" in raw:
                self.pad_id = int(raw.get("pad_id", 0))
                self._char_to_id = {k: int(v) for k, v in raw["char_to_id"].items()}
                self._codepoint_indexer = None
            else:
                self.pad_id = 0
                self._char_to_id = None
                self._codepoint_indexer = {int(k): int(v) for k, v in raw.items()}
            vocab_len = len(self._char_to_id) if self._char_to_id is not None else len(self._codepoint_indexer or {})
            print(f"[INFO] Loaded vocab from {indexer_path} ({vocab_len} entries)")
        else:
            self._char_to_id = dict(DEFAULT_CHAR_TO_ID)
            self._codepoint_indexer = None
            print("[INFO] Using built-in default vocab.")

    def _preprocess(self, text: str, lang: str) -> str:
        text = _strip_synthesis_tags(text)
        text = uni_normalize("NFKD", text)
        text = strip_hebrew_nikud(text)
        emoji_pattern = re.compile(
            "[\U0001f600-\U0001f64f\U0001f300-\U0001f5ff\U0001f680-\U0001f6ff"
            "\U0001f700-\U0001f77f\U0001f780-\U0001f7ff\U0001f800-\U0001f8ff"
            "\U0001f900-\U0001f9ff\U0001fa00-\U0001fa6f\U0001fa70-\U0001faff"
            "\u2600-\u26ff\u2700-\u27bf\U0001f1e6-\U0001f1ff]+", flags=re.UNICODE,
        )
        text = emoji_pattern.sub("", text)
        for k, v in {
            "–": "-", "‑": "-", "—": "-", "_": " ",
            "\u201c": '"', "\u201d": '"', "\u2018": "'", "\u2019": "'",
            "´": "'", "`": "'", "[": " ", "]": " ", "|": " ",
            "/": " ", "#": " ", "→": " ", "←": " ",
        }.items():
            text = text.replace(k, v)
        text = re.sub(r"[♥☆♡©\\]", "", text)
        for k, v in {"@": " at ", "e.g.,": "for example, ", "i.e.,": "that is, "}.items():
            text = text.replace(k, v)
        for pat in (r" ,", r" \.", r" !", r" \?", r" ;", r" :", r" '"):
            text = re.sub(pat, pat.replace(" ", "").replace("\\", ""), text)
        while '""' in text:
            text = text.replace('""', '"')
        while "''" in text:
            text = text.replace("''", "'")
        text = strip_silent_separator_tokens(text)
        text = re.sub(r"\s+", " ", text).strip()
        if not re.search(r"[.!?;:,'\"')\]}…。」』】〉》›»]$", text):
            text += "."
        lang = LANG_CODE_ALIASES.get(lang, lang)
        if lang not in AVAILABLE_LANGS:
            raise ValueError(f"Invalid language: {lang}")
        if not _INLINE_LANG_PAIR.search(text):
            text = f"<{lang}>{text}</{lang}>"
        return text

    def _encode(self, text: str) -> np.ndarray:
        text = _strip_synthesis_tags(text)
        pad = self.pad_id
        if self._char_to_id is not None:
            ids = [self._char_to_id.get(ch, pad) for ch in text]
        else:
            assert self._codepoint_indexer is not None
            ids = [self._codepoint_indexer.get(ord(ch), pad) for ch in text]
        return np.array(ids, dtype=np.int64)

    def __call__(self, text_list: List[str], lang_list: List[str]):
        text_list = [self._preprocess(t, lang) for t, lang in zip(text_list, lang_list)]
        encoded = [self._encode(t) for t in text_list]
        lengths = np.array([len(e) for e in encoded], dtype=np.int64)
        text_ids = np.full((len(encoded), int(lengths.max())), self.pad_id, dtype=np.int64)
        for i, ids in enumerate(encoded):
            text_ids[i, :len(ids)] = ids
        mask = _length_to_mask(lengths)
        return text_ids, mask


def _length_to_mask(lengths: np.ndarray, max_len: Optional[int] = None) -> np.ndarray:
    max_len = max_len or int(lengths.max())
    ids = np.arange(0, max_len)
    m = (ids < np.expand_dims(lengths, 1)).astype(np.float32)
    return m.reshape(-1, 1, max_len)


def _latent_mask(wav_lengths: np.ndarray, base_chunk: int, factor: int) -> np.ndarray:
    size = base_chunk * factor
    lat_len = (wav_lengths + size - 1) // size
    return _length_to_mask(lat_len)


# ============================================================
# Voice style container
# ============================================================
@dataclass
class Style:
    ttl: np.ndarray
    dp: np.ndarray


def load_voice_style(paths: List[str]) -> Style:
    with open(paths[0]) as f:
        return style_from_dict(json.load(f))


def style_from_dict(payload: dict[str, Any]) -> Style:
    ttl_dims = payload["style_ttl"]["dims"]
    dp_dims = payload["style_dp"]["dims"]
    ttl_data = np.array(payload["style_ttl"]["data"], dtype=np.float32).flatten()
    dp_data = np.array(payload["style_dp"]["data"], dtype=np.float32).flatten()
    return Style(
        ttl=ttl_data.reshape(ttl_dims),
        dp=dp_data.reshape(dp_dims),
    )


def load_voice_style_batch(paths: List[str]) -> Style:
    with open(paths[0]) as f:
        first = json.load(f)
    ttl_dims = first["style_ttl"]["dims"]
    dp_dims = first["style_dp"]["dims"]
    B = len(paths)
    ttl = np.zeros([B, ttl_dims[1], ttl_dims[2]], dtype=np.float32)
    dp = np.zeros([B, dp_dims[1], dp_dims[2]], dtype=np.float32)

    for i, p in enumerate(paths):
        with open(p) as f:
            d = json.load(f)
        ttl[i] = np.array(d["style_ttl"]["data"], dtype=np.float32).reshape(ttl_dims[1], ttl_dims[2])
        dp[i] = np.array(d["style_dp"]["data"], dtype=np.float32).reshape(dp_dims[1], dp_dims[2])
    return Style(ttl=ttl, dp=dp)


# ============================================================
# TextToSpeech core (slim pipeline)
# ============================================================
def _hard_split(s: str, max_len: int) -> List[str]:
    """Split ``s`` into pieces of at most ``max_len`` chars, preferring spaces."""
    s = s.strip()
    if len(s) <= max_len:
        return [s] if s else []
    out: List[str] = []
    i, n = 0, len(s)
    while i < n:
        j = min(i + max_len, n)
        if j < n:
            cut = s.rfind(" ", i, j)
            if cut > i + max_len // 4:
                j = cut
        piece = s[i:j].strip()
        if piece:
            out.append(piece)
        i = j
        while i < n and s[i] == " ":
            i += 1
    return out


def chunk_text(text: str, max_len: int = 300) -> List[str]:
    pattern = (
        r"(?<!Mr\.)(?<!Mrs\.)(?<!Ms\.)(?<!Dr\.)(?<!Prof\.)(?<!Sr\.)(?<!Jr\.)"
        r"(?<!Ph\.D\.)(?<!etc\.)(?<!e\.g\.)(?<!i\.e\.)(?<!vs\.)(?<!Inc\.)"
        r"(?<!Ltd\.)(?<!Co\.)(?<!Corp\.)(?<!St\.)(?<!Ave\.)(?<!Blvd\.)"
        r"(?<!\b[A-Z]\.)(?<=[.!?])\s+"
    )
    chunks: List[str] = []
    for paragraph in re.split(r"\n\s*\n+", text.strip()):
        paragraph = paragraph.strip()
        if not paragraph:
            continue
        current = ""
        for sentence in re.split(pattern, paragraph):
            if len(current) + len(sentence) + 1 <= max_len:
                current += (" " if current else "") + sentence
            else:
                if current:
                    chunks.append(current.strip())
                    current = ""
                if len(sentence) > max_len:
                    chunks.extend(_hard_split(sentence, max_len))
                else:
                    current = sentence
        if current:
            chunks.append(current.strip())
    base = chunks if chunks else [text.strip()]
    # Defensive: guarantee nothing exceeds max_len (e.g. phonemization can blow up).
    out: List[str] = []
    for c in base:
        out.extend(_hard_split(c, max_len))
    return out


class BlueTTS:
    def __init__(
        self,
        onnx_dir: str = ONNX_DIR,
        config_path: str = CONFIG_PATH,
        vocab_path: str = VOCAB_PATH,
        renikud_path: Optional[str] = RENIKUD_PATH,
        use_gpu: bool = False,
    ):
        self.cfgs = self._load_cfg(config_path)
        self.sample_rate = int(self.cfgs["ae"]["sample_rate"])
        self.base_chunk_size = int(self.cfgs["ae"]["base_chunk_size"])
        self.chunk_compress_factor = int(self.cfgs["ttl"]["chunk_compress_factor"])
        self.ldim = int(self.cfgs["ttl"]["latent_dim"])

        opts = ort.SessionOptions()
        opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
        n_threads = int(os.environ.get("ORT_NUM_THREADS", min(8, os.cpu_count() or 1)))
        opts.intra_op_num_threads = n_threads
        opts.inter_op_num_threads = 1

        providers = ["CPUExecutionProvider"]
        if use_gpu and "CUDAExecutionProvider" in ort.get_available_providers():
            providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]

        def _load(name: str) -> ort.InferenceSession:
            return ort.InferenceSession(os.path.join(onnx_dir, name),
                                        sess_options=opts, providers=providers)

        self.dp_ort = _load("duration_predictor.onnx")
        self.text_enc_ort = _load("text_encoder.onnx")
        self.vector_est_ort = _load("vector_estimator.onnx")
        self.vocoder_ort = _load("vocoder.onnx")
        self._vf_inputs = {i.name for i in self.vector_est_ort.get_inputs()}
        self._vocoder_input_name = self.vocoder_ort.get_inputs()[0].name

        # Optional uncond embeddings for CFG (if shipped with the slim bundle).
        self._u_text = self._u_ref = None
        uncond_path = os.path.join(onnx_dir, "uncond.npz")
        if os.path.exists(uncond_path):
            u = np.load(uncond_path)
            self._u_text = u["u_text"] if "u_text" in u.files else None
            self._u_ref = u["u_ref"] if "u_ref" in u.files else None

        self.text_processor = UnicodeProcessor(vocab_path)
        self.g2p = TextProcessor(renikud_path)

    @staticmethod
    def _load_cfg(path: str) -> dict:
        if not os.path.exists(path):
            raise FileNotFoundError(f"Missing config {path}")
        with open(path) as f:
            return json.load(f)

    def _sample_noisy_latent(self, duration: np.ndarray, seed: int = 42):
        bsz = len(duration)
        chunk_size = self.base_chunk_size * self.chunk_compress_factor
        wav_len_max = duration.max() * self.sample_rate
        wav_lengths = (duration * self.sample_rate).astype(np.int64)
        latent_len = int(np.ceil(wav_len_max / chunk_size))
        latent_dim = self.ldim * self.chunk_compress_factor
        rng = np.random.default_rng(seed)
        xt = rng.standard_normal((bsz, latent_dim, latent_len)).astype(np.float32)
        latent_mask = _latent_mask(wav_lengths, self.base_chunk_size, self.chunk_compress_factor)
        return xt * latent_mask, latent_mask

    def _infer(
        self,
        text_list: List[str],
        lang_list: List[str],
        style: Style,
        total_step: int,
        speed: float,
        cfg_scale: float,
        seed: int,
        pace_blend: float = 0.0,
        pace_dpt_ref: float = DURATION_PACE_DPT_REF,
    ):
        bsz = len(text_list)
        assert style.ttl.shape[0] == bsz, "style batch mismatch"

        text_ids, text_mask = self.text_processor(text_list, lang_list)
        dur, *_ = self.dp_ort.run(None, {
            "text_ids": text_ids, "style_dp": style.dp, "text_mask": text_mask,
        })
        dur = np.asarray(dur, dtype=np.float32).reshape(-1)
        dur = blend_duration_pace(dur, text_mask, pace_blend, pace_dpt_ref)
        dur = dur / max(speed, 1e-6)
        text_emb, *_ = self.text_enc_ort.run(None, {
            "text_ids": text_ids, "style_ttl": style.ttl, "text_mask": text_mask,
        })
        xt, latent_mask = self._sample_noisy_latent(dur, seed=seed)
        total_t = np.array([total_step] * bsz, dtype=np.float32)

        use_cfg = (cfg_scale != 1.0 and self._u_text is not None and self._u_ref is not None)
        u_text_mask = np.ones((bsz, 1, 1), dtype=np.float32) if use_cfg else None

        for step in range(total_step):
            cur_t = np.array([step] * bsz, dtype=np.float32)
            cond = {
                "noisy_latent": xt, "text_emb": text_emb,
                "style_ttl": style.ttl, "text_mask": text_mask,
                "latent_mask": latent_mask,
                "current_step": cur_t, "total_step": total_t,
            }
            if "cfg_scale" in self._vf_inputs:
                cond["cfg_scale"] = np.array([float(cfg_scale)], dtype=np.float32)
                xt, *_ = self.vector_est_ort.run(None, cond)
            elif use_cfg:
                v_cond, *_ = self.vector_est_ort.run(None, cond)
                u_text_b = np.broadcast_to(self._u_text, (bsz, *self._u_text.shape[1:])).astype(np.float32)
                u_ref_b = np.broadcast_to(self._u_ref, (bsz, *self._u_ref.shape[1:])).astype(np.float32)
                v_uncond, *_ = self.vector_est_ort.run(None, {
                    "noisy_latent": xt, "text_emb": u_text_b,
                    "style_ttl": u_ref_b, "text_mask": u_text_mask,
                    "latent_mask": latent_mask,
                    "current_step": cur_t, "total_step": total_t,
                })
                xt = v_uncond + cfg_scale * (v_cond - v_uncond)
            else:
                xt, *_ = self.vector_est_ort.run(None, cond)

        wav, *_ = self.vocoder_ort.run(None, {self._vocoder_input_name: xt})
        frame_len = self.base_chunk_size * self.chunk_compress_factor
        if wav.shape[-1] > 2 * frame_len:
            wav = wav[..., frame_len:-frame_len]
        if wav.ndim == 3 and wav.shape[1] == 1:
            wav = wav[:, 0, :]
        return wav, dur

    def synthesize(
        self,
        text: Union[str, List[str]],
        lang: Union[str, List[str]],
        style: Style,
        total_step: int = 8,
        speed: float = 0.95,
        cfg_scale: float = 4.0,
        silence_duration: float = 0.15,
        seed: int = 42,
        phonemize: bool = True,
        pace_blend: Optional[float] = None,
        pace_dpt_ref: float = DURATION_PACE_DPT_REF,
    ) -> Tuple[np.ndarray, int]:
        if isinstance(text, list):
            has_inline_lang = any(_INLINE_LANG_PAIR.search(t) is not None for t in text)
            has_auto_mixed = any(_has_mixed_hebrew_latin(t, l) for t, l in zip(text, lang)) if isinstance(lang, list) else False
        else:
            has_inline_lang = _INLINE_LANG_PAIR.search(text) is not None
            has_auto_mixed = _has_mixed_hebrew_latin(text, lang) if isinstance(lang, str) else False
        pace_blend_eff = (
            float(pace_blend)
            if pace_blend is not None
            else (DEFAULT_MIXED_PACE_BLEND if has_inline_lang or has_auto_mixed else 0.0)
        )
        if isinstance(text, list):
            assert isinstance(lang, list) and len(text) == len(lang)
            if phonemize:
                text = [self.g2p.phonemize(t, lang=l) for t, l in zip(text, lang)]
            wav, _ = self._infer(
                text, lang, style, total_step, speed, cfg_scale, seed,
                pace_blend=pace_blend_eff, pace_dpt_ref=pace_dpt_ref,
            )
            return wav, self.sample_rate

        assert isinstance(lang, str)
        assert style.ttl.shape[0] == 1, "single-text mode needs a single style"
        max_len = BLUE_SYNTH_MAX_CHUNK_LEN
        chunks = chunk_text(text, max_len=max_len)
        wav_cat: Optional[np.ndarray] = None
        for raw_chunk in chunks:
            chunk = self.g2p.phonemize(raw_chunk, lang=lang) if phonemize else raw_chunk
            if not chunk:
                continue
            w, _ = self._infer(
                [chunk], [lang], style, total_step, speed, cfg_scale, seed,
                pace_blend=pace_blend_eff, pace_dpt_ref=pace_dpt_ref,
            )
            if wav_cat is None:
                wav_cat = w
            else:
                silence = np.zeros((1, int(silence_duration * self.sample_rate)), dtype=np.float32)
                wav_cat = np.concatenate([wav_cat, silence, w], axis=1)
        if wav_cat is None:
            wav_cat = np.zeros((1, 0), dtype=np.float32)
        return wav_cat.squeeze(0) if wav_cat.ndim == 2 else wav_cat.squeeze(), self.sample_rate


# ============================================================
# App setup
# ============================================================
TTS = BlueTTS(ONNX_DIR, CONFIG_PATH, VOCAB_PATH, RENIKUD_PATH)


def discover_voices() -> Dict[str, str]:
    out: Dict[str, str] = {}
    for p in sorted(glob.glob(os.path.join(VOICES_DIR, "*.json"))):
        try:
            with open(p) as f:
                payload = json.load(f)
            ttl = payload.get("style_ttl")
            if ttl:
                arr = np.array(ttl["data"], dtype=np.float32)
                if float(arr.std()) > 0.3:
                    print(f"[INFO] Skipping incompatible voice JSON {p} (style_ttl std={arr.std():.3f})")
                    continue
        except Exception as e:
            print(f"[WARN] Skipping unreadable voice JSON {p}: {e}")
            continue
        label = os.path.splitext(os.path.basename(p))[0]
        pretty = label.replace("_", " ").replace("spk ", "Speaker ").title()
        out[pretty] = p
    return out


VOICES: Dict[str, str] = discover_voices()
VOICE_STYLES: Dict[str, Style] = {name: load_voice_style([path]) for name, path in VOICES.items()}


def expand_numbers(text: str, lang: str = "en") -> str:
    lang = _canonical_lang(lang)

    def repl(m: re.Match[str]) -> str:
        raw = m.group(0)
        try:
            value: Union[int, float]
            if "." in raw or "," in raw:
                value = float(raw.replace(",", "."))
            else:
                value = int(raw)
            return num2words(value, lang=lang)
        except Exception:
            return raw

    return re.sub(r"(?<![\w])\d+(?:[.,]\d+)?(?![\w])", repl, text)


def expand_percent_symbols(text: str, lang: str = "en") -> str:
    word = _PERCENT_WORDS.get(_canonical_lang(lang), _PERCENT_WORDS["en"])
    text = re.sub(r"(\d+(?:[.,]\d+)?)\s*%", rf"\1 {word}", text)
    return re.sub(r"%", f" {word} ", text)


def expand_ratios(text: str, lang: str = "en") -> str:
    word = _RATIO_WORDS.get(_canonical_lang(lang), _RATIO_WORDS["en"])
    return re.sub(r"(?<!\d)(\d+)\s*:\s*(\d+)(?!\d)", rf"\1 {word} \2", text)


def expand_dates(text: str, lang: str = "en") -> str:
    """Normalize numeric day/month/year dates before generic number expansion."""
    lang = _canonical_lang(lang)

    def repl(m: re.Match[str]) -> str:
        day = int(m.group(1))
        month = int(m.group(2))
        raw_year = m.group(3)
        if not (1 <= day <= 31 and 1 <= month <= 12):
            return m.group(0)
        year = int(raw_year)
        if len(raw_year) == 2:
            year += 2000 if year < 70 else 1900
        if lang == "he":
            return f"{day} {_HEBREW_MONTH_ORDINALS[month]} {year}"
        return f"{day} {month} {year}"

    return _DATE_RE.sub(repl, text)


def normalize_common_text(text: str) -> str:
    text = strip_hebrew_nikud(text)
    text = re.sub(
        r"\banymore\b",
        lambda m: "Any more" if m.group(0)[0].isupper() else "any more",
        text,
        flags=re.IGNORECASE,
    )
    return text


def prepare_text_for_synthesis(text: str, lang: str) -> str:
    text = normalize_common_text(text)
    text = strip_hebrew_abbreviation_quotes(text, lang)
    text = expand_hebrew_lamed_before_latin(text, lang)
    text = expand_dates(text, lang=lang)
    text = expand_percent_symbols(text, lang=lang)
    text = expand_ratios(text, lang=lang)
    text = expand_numbers(text, lang=lang)
    return strip_silent_separator_tokens(text)


def normalize_generated_audio(wav: np.ndarray, target_rms: float = 0.08, peak_limit: float = 0.95) -> np.ndarray:
    """Gently lift quiet generations while leaving normal/loud audio unclipped."""
    wav = np.asarray(wav, dtype=np.float32)
    if wav.size == 0 or not np.isfinite(wav).all():
        return wav

    peak = float(np.max(np.abs(wav)))
    if peak < 1e-6:
        return wav

    active = np.abs(wav) > max(peak * 0.02, 1e-4)
    samples = wav[active] if np.any(active) else wav
    rms = float(np.sqrt(np.mean(np.square(samples))))
    if rms < 1e-6:
        return wav

    # Cap boost so a very quiet/bad generation does not become harsh or noisy.
    gain = min(target_rms / rms, peak_limit / peak, 4.0)
    if gain <= 1.0:
        return wav
    return (wav * gain).astype(np.float32)


# Cache of styles derived from uploaded reference WAVs, keyed by file hash.
_REF_WAV_CACHE: Dict[str, Style] = {}


def _hash_file(path: str) -> str:
    import hashlib
    h = hashlib.sha1()
    with open(path, "rb") as f:
        for chunk in iter(lambda: f.read(1 << 16), b""):
            h.update(chunk)
    return h.hexdigest()


def _env_truthy(name: str) -> bool:
    return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"}


def _pt_marker_ok(marker_path: str, repo_id: str, stamp: str) -> bool:
    if not os.path.exists(marker_path):
        return False
    try:
        lines = open(marker_path, encoding="utf-8").read().splitlines()
    except OSError:
        return False
    if len(lines) < 2:
        return False
    return lines[0].strip() == repo_id and lines[1].strip() == stamp


def _ensure_pt_weights() -> dict[str, str]:
    """Make sure v2 PyTorch/safetensors checkpoints are on disk."""
    repo_id = os.environ.get("BLUE_PT_REPO", "notmax123/blue-v2")
    stamp = os.environ.get("BLUE_PT_BUNDLE_STAMP", "1")
    marker = os.path.join("pt_weights", ".repo_id")
    force = _env_truthy("BLUE_PT_FORCE_DOWNLOAD") or not _pt_marker_ok(marker, repo_id, stamp)
    needed: dict[str, Optional[str]] = {k: _find_pt_weight(v) for k, v in PT_WEIGHT_ALIASES.items()}
    if force or any(v is None for v in needed.values()):
        from huggingface_hub import hf_hub_download
        import shutil
        os.makedirs("pt_weights", exist_ok=True)
        for fn in ("blue_codec.safetensors", "duration_predictor_final.safetensors",
                   "vf_estimetor.safetensors", "stats_multilingual.safetensors"):
            dest = os.path.join("pt_weights", fn)
            print(f"[INFO] Fetching {repo_id}/{fn} …")
            cached = hf_hub_download(
                repo_id=repo_id, filename=fn, repo_type="model",
                token=os.environ.get("HF_TOKEN") or None,
                force_download=force,
            )
            shutil.copy2(cached, dest)
        with open(marker, "w", encoding="utf-8") as f:
            f.write(repo_id + "\n" + stamp + "\n")
        needed = {k: _find_pt_weight(v) for k, v in PT_WEIGHT_ALIASES.items()}
    assert all(v is not None for v in needed.values()), f"still missing: {needed}"
    return {k: v for k, v in needed.items() if v is not None}  # type: ignore[misc]


def style_from_wav(ref_wav: str) -> Style:
    """Derive a voice Style from a reference WAV using export_new_voice.py."""
    ckpts = _ensure_pt_weights()
    from export_new_voice import export_voice_style

    payload = export_voice_style(
        ref_wav,
        config=CONFIG_PATH,
        ae_ckpt=ckpts["ae_ckpt"],
        ttl_ckpt=ckpts["ttl_ckpt"],
        dp_ckpt=ckpts["dp_ckpt"],
        stats=ckpts["stats"],
        device="cpu",
    )
    return style_from_dict(payload)


def _reference_audio_status(ref_wav: Optional[str]):
    if not ref_wav:
        return (
            '<div class="ref-status muted">No reference uploaded — '
            'using the saved voice above. Upload or record a clip to clone a custom voice.</div>'
        )
    try:
        import soundfile as sf
        info = sf.info(ref_wav)
        dur = float(info.frames) / float(info.samplerate or 1)
        channels = int(info.channels or 1)
        if dur < 2.0:
            level = "warn"
            msg = "Too short for cloning; use at least 3 seconds."
        elif dur > 20.0:
            level = "warn"
            msg = "Long clips work, but only the early frames are used. Trim to the cleanest 3-12 seconds."
        elif channels > 2:
            level = "warn"
            msg = "Many channels detected; mono or stereo speech works best."
        else:
            level = "ok"
            try:
                cached = _hash_file(ref_wav) in _REF_WAV_CACHE
            except Exception:
                cached = False
            if cached:
                msg = "Cloned voice cached — next generation will be fast."
            else:
                msg = "Ready. First generation exports the voice (~20-40s); subsequent ones are instant."
        return (
            f'<div class="ref-status {level}">'
            f'Reference: {dur:.1f}s, {info.samplerate} Hz, {channels} channel(s). {html.escape(msg)}'
            '</div>'
        )
    except Exception as e:
        return f'<div class="ref-status warn">Could not inspect uploaded audio: {html.escape(str(e))}</div>'


def synthesize_text(text: str, voice: str, lang: str, steps: int, speed: float,
                    ref_wav: Optional[str] = None,
                    progress: "gr.Progress | None" = gr.Progress()):
    t0 = time.time()
    using_ref = bool(ref_wav)
    export_time = 0.0
    if using_ref:
        try:
            cache_key = _hash_file(ref_wav)
            if cache_key in _REF_WAV_CACHE:
                if progress is not None:
                    progress(0.9, desc="Using cached cloned voice")
                style = _REF_WAV_CACHE[cache_key]
            else:
                if progress is not None:
                    progress(
                        0.05,
                        desc="Exporting cloned voice (first time ~20-40s, cached after)",
                    )
                t_exp = time.time()
                style = style_from_wav(ref_wav)
                export_time = time.time() - t_exp
                _REF_WAV_CACHE[cache_key] = style
            if progress is not None:
                progress(0.6, desc="Synthesizing speech")
        except Exception as e:
            err = f'<div class="stats-bar"><span class="stat-pill">❌ voice clone failed: {e}</span></div>'
            return None, err
    else:
        if not VOICE_STYLES:
            err = (
                '<div class="stats-bar"><span class="stat-pill">'
                'No saved voices installed. Upload a reference clip to clone a voice.</span></div>'
            )
            return None, err
        style = VOICE_STYLES[voice]
    wav, sr = TTS.synthesize(
        prepare_text_for_synthesis(text, lang=lang), lang=lang, style=style,
        total_step=int(steps), speed=float(speed), cfg_scale=4.0,
        pace_blend=None,
    )
    wav = normalize_generated_audio(np.asarray(wav).squeeze())
    proc_time = time.time() - t0
    audio_dur = len(wav) / sr if len(wav) > 0 else 0.0
    rtf = proc_time / audio_dur if audio_dur > 0 else 0
    export_pill = (
        f'<span class="stat-pill">🧬 clone export {export_time:.1f}s</span>'
        if using_ref and export_time > 0 else ''
    )
    stats = (
        f'<div class="stats-bar">'
        f'<span class="stat-pill">Voice: {"cloned from upload" if using_ref else html.escape(voice)}</span>'
        f'{export_pill}'
        f'<span class="stat-pill">⏱ {proc_time:.2f}s</span>'
        f'<span class="stat-pill">🔊 {audio_dur:.1f}s audio</span>'
        f'<span class="stat-pill">⚡ {rtf:.2f}x RTF</span>'
        f'</div>'
    )
    return (sr, wav), stats


def phonemes_for_display(text: str, lang: str) -> str:
    """Return user-facing phonemes without internal <lang> routing tags."""
    prepared = prepare_text_for_synthesis(text, lang=lang)
    tagged = TTS.g2p.phonemize(prepared, lang=lang)
    return strip_language_tags_for_display(tagged)


# ============================================================
# Voice-clone tab
# ============================================================
# Accept checkpoints from a handful of common locations (with the filename
# variants we've seen in the wild) so the clone tab works out of the box.
PT_WEIGHTS_SEARCH = [
    "pt_weights",
    "pt_models",
    os.path.join("fonts", "pt_models"),
]
PT_WEIGHT_ALIASES: dict[str, list[str]] = {
    "ae_ckpt":  ["blue_codec.safetensors"],
    "ttl_ckpt": ["vf_estimetor.safetensors"],
    "dp_ckpt":  ["duration_predictor_final.safetensors"],
    "stats":    ["stats_multilingual.safetensors"],
}


def _find_pt_weight(aliases: list[str]) -> Optional[str]:
    for d in PT_WEIGHTS_SEARCH:
        for name in aliases:
            p = os.path.join(d, name)
            if os.path.exists(p):
                return p
    return None


def _refresh_voices() -> None:
    global VOICES, VOICE_STYLES
    VOICES = discover_voices()
    VOICE_STYLES = {name: load_voice_style([path]) for name, path in VOICES.items()}


def clone_voice(ref_wav: Optional[str], voice_name: str):
    """Export a new voice JSON from a reference WAV."""
    if not ref_wav:
        return "Please upload a reference WAV first.", gr.update()
    if not voice_name.strip():
        voice_name = f"custom_{int(time.time())}"
    safe = re.sub(r"[^\w\-]+", "_", voice_name.strip())
    out_path = os.path.join(VOICES_DIR, f"{safe}.json")

    needed = _ensure_pt_weights()
    from export_new_voice import export_voice_style

    payload = export_voice_style(
        ref_wav,
        config=CONFIG_PATH,
        ae_ckpt=needed["ae_ckpt"],
        ttl_ckpt=needed["ttl_ckpt"],
        dp_ckpt=needed["dp_ckpt"],
        stats=needed["stats"],
        device="cpu",
    )
    os.makedirs(os.path.dirname(out_path) or ".", exist_ok=True)
    with open(out_path, "w") as f:
        json.dump(payload, f)

    _refresh_voices()
    pretty = safe.replace("_", " ").title()
    return (
        f"Saved {out_path}. New voice '{pretty}' is now selectable in the Synthesize tab.",
        gr.update(choices=list(VOICES.keys())),
    )


# ============================================================
# Gradio UI (styling retained from previous version)
# ============================================================
EXAMPLES = [
    ["The power to change begins the moment you believe it's possible!", "en"],
    ["הכוח לשנות מתחיל ברגע שבו אתה מאמין שזה אפשרי!", "he"],
    ["¡El poder de cambiar comienza en el momento en que crees que es posible!", "es"],
    ["Il potere di cambiare inizia nel momento in cui credi che sia possibile!", "it"],
    ["Die Kraft zur Veränderung beginnt in dem Moment, in dem du glaubst, dass es möglich ist!", "de"],
]


def _load_font_face() -> str:
    p = "fonts/EuclidCircularB.woff2"
    if os.path.exists(p):
        b64 = base64.b64encode(open(p, "rb").read()).decode()
        return (
            f"@font-face {{ font-family: 'EuclidCircularB'; "
            f"src: url(data:font/woff2;base64,{b64}) format('woff2'); "
            f"font-weight: 100 900; font-style: normal; }}"
        )
    return ""


css = _load_font_face() + """
@import url('https://fonts.googleapis.com/css2?family=JetBrains+Mono:wght@400;500&display=swap');
* { box-sizing: border-box; }
body, .gradio-container { background:#06101f !important; font-family:'EuclidCircularB',sans-serif !important; color:#e6efff !important; }
.gradio-container { max-width:900px !important; margin:0 auto !important; padding:2rem 1.5rem !important; }
.app-header { text-align:center; margin-bottom:2rem; padding:2rem 0 1rem; }
.app-header h1 { font-size:2.8rem; font-weight:600; letter-spacing:-0.03em; background:linear-gradient(135deg,#38bdf8 0%,#3b82f6 50%,#1d4ed8 100%); -webkit-background-clip:text; -webkit-text-fill-color:transparent; background-clip:text; margin:0 0 0.5rem; }
.app-header p { color:#7ea3d4; font-size:1rem; margin:0 0 1rem; }
.app-header .github-link { display:inline-flex; align-items:center; gap:0.4rem; margin-top:0.75rem; padding:0.45rem 1rem; font-size:0.9rem; font-weight:500; text-decoration:none !important; color:#93c5fd !important; border:1px solid #1e40af; border-radius:999px; background:rgba(59,130,246,0.12); }
.card { background:#0b1a30; border:1px solid #163056; border-radius:16px; padding:1.5rem; margin-bottom:1rem; }
.big-input textarea { background:#081327 !important; border:1px solid #1e3a66 !important; border-radius:10px !important; color:#e6efff !important; font-size:1.1rem !important; line-height:1.6 !important; padding:1rem !important; unicode-bidi:plaintext !important; }
.big-input textarea:focus { border-color:#3b82f6 !important; outline:none !important; box-shadow:0 0 0 3px rgba(59,130,246,0.18) !important; }
.controls-row { margin-top:1rem; display:flex !important; flex-direction:column !important; gap:0.75rem !important; }
.ctrl-row1, .ctrl-row2, .ctrl-row3 { display:flex !important; flex-direction:row !important; gap:0.75rem !important; width:100% !important; }
.ctrl-lang { flex:2 !important; min-width:0 !important; } .ctrl-voice { flex:3 !important; min-width:0 !important; }
.ctrl-steps, .ctrl-speed { flex:1 !important; min-width:0 !important; }
.gen-btn { background:linear-gradient(135deg,#2563eb,#1d4ed8) !important; border:none !important; border-radius:10px !important; color:#fff !important; font-size:1rem !important; font-weight:600 !important; padding:0.75rem 2rem !important; width:100% !important; margin-top:1rem !important; box-shadow:0 6px 18px rgba(37,99,235,0.35) !important; }
.gen-btn:hover { opacity:0.9 !important; filter:brightness(1.05); }
.gradio-audio { background:#0b1a30 !important; border:1px solid #163056 !important; border-radius:12px !important; }
.stats-bar { display:flex; gap:0.75rem; flex-wrap:wrap; margin-top:0.75rem; padding:0.75rem 0; }
.stat-pill { background:#0e2545; border:1px solid #1e40af; border-radius:20px; padding:0.3rem 0.9rem; font-family:'JetBrains Mono',monospace; font-size:0.8rem; color:#93c5fd; }
.gradio-dropdown select, .gradio-dropdown input { background:#081327 !important; border:1px solid #1e3a66 !important; color:#e6efff !important; border-radius:8px !important; }
.ref-panel { margin-top:1rem; padding:1rem; border:1px dashed #1e40af; border-radius:12px; background:#091a34; }
.ref-panel label { color:#bfdbfe !important; }
.ref-panel h3 { color:#dbeafe; margin:0 0 0.25rem; font-size:1rem; font-weight:600; }
.ref-status { margin-top:0.6rem; padding:0.75rem 0.9rem; border-radius:10px; font-size:0.9rem; line-height:1.4; }
.ref-status.ok { color:#bae6fd; background:rgba(14,165,233,0.12); border:1px solid rgba(14,165,233,0.35); }
.ref-status.warn { color:#fde68a; background:rgba(245,158,11,0.10); border:1px solid rgba(245,158,11,0.25); }
.ref-status.muted { color:#93a6c4; background:rgba(59,130,246,0.08); border:1px solid rgba(59,130,246,0.20); }
.ref-help { color:#7ea3d4; font-size:0.86rem; line-height:1.45; margin-top:0.5rem; }
"""

with gr.Blocks(title="BlueTTS V2 — Multilingual TTS") as demo:
    gr.HTML(
        '<div class="app-header"><h1>BlueTTS V2</h1>'
        '<p>Slim multilingual text-to-speech · English · Hebrew · Spanish · German · Italian</p>'
        '<a class="github-link" href="https://github.com/maxmelichov/BlueTTS" target="_blank">GitHub · maxmelichov/BlueTTS</a></div>'
    )

    with gr.Column(elem_classes="card"):
        text_input = gr.Textbox(
            label="Text", placeholder="Type or paste text here…",
            lines=4, elem_classes="big-input",
            value="Great ideas become real when a small team keeps building every single day.",
        )
        with gr.Column(elem_classes="controls-row"):
            with gr.Row(elem_classes="ctrl-row1"):
                lang_input = gr.Dropdown(
                    choices=[("English 🇺🇸", "en"), ("Hebrew 🇮🇱", "he"),
                             ("Spanish 🇪🇸", "es"), ("German 🇩🇪", "de"),
                             ("Italian 🇮🇹", "it")],
                    value="en", label="Language", elem_classes="ctrl-lang",
                )
                voice_input = gr.Dropdown(
                    choices=list(VOICES.keys()),
                    value=next(iter(VOICES.keys()), None),
                    label="Voice", elem_classes="ctrl-voice",
                )
            with gr.Row(elem_classes="ctrl-row2"):
                steps_input = gr.Slider(5, 16, 8, step=1, label="Quality (steps)", elem_classes="ctrl-steps")
                speed_input = gr.Slider(0.8, 1.2, 0.95, step=0.05, label="Speed", elem_classes="ctrl-speed")

        with gr.Column(elem_classes="ref-panel"):
            gr.HTML(
                '<h3 style="color:#dbeafe;margin:0 0 0.25rem;font-size:1rem;font-weight:600;">Clone a voice (optional)</h3>'
                '<div class="ref-help">Upload or record 3-12 seconds of clean speech to clone it. '
                'Leave empty to use the saved voice selected above. Generation starts automatically when you upload. '
                '<b>Heads up:</b> the first sentence with a new clone takes ~20-40s to export the voice — after that, regeneration is instant.</div>'
            )
            ref_wav_input = gr.Audio(
                label="Reference audio",
                sources=["upload", "microphone"], type="filepath",
            )
            ref_status = gr.HTML(_reference_audio_status(None))

        btn = gr.Button("⚡ Generate Speech", elem_classes="gen-btn")
    audio_out = gr.Audio(label="Output", type="numpy", autoplay=True)
    stats_out = gr.HTML()

    gr.Examples(examples=EXAMPLES, inputs=[text_input, lang_input], label="Examples")

    synth_inputs = [text_input, voice_input, lang_input, steps_input, speed_input, ref_wav_input]
    synth_outputs = [audio_out, stats_out]

    def _auto_synth(text, voice, lang, steps, speed, ref_wav):
        if not ref_wav:
            return gr.update(), gr.update()
        return synthesize_text(text, voice, lang, steps, speed, ref_wav)

    ref_wav_input.change(
        _reference_audio_status,
        inputs=[ref_wav_input],
        outputs=[ref_status],
    ).then(
        _auto_synth,
        inputs=synth_inputs,
        outputs=synth_outputs,
    )

    btn.click(
        synthesize_text,
        inputs=synth_inputs,
        outputs=synth_outputs,
    )

    gr.HTML("""
    <script>
    (function applyDirAuto() {
        const ta = document.querySelector('.big-input textarea');
        if (ta) { ta.setAttribute('dir', 'auto'); return; }
        const obs = new MutationObserver(() => {
            const ta = document.querySelector('.big-input textarea');
            if (ta) { ta.setAttribute('dir', 'auto'); obs.disconnect(); }
        });
        obs.observe(document.body, { childList: true, subtree: true });
    })();
    </script>
    """)

if __name__ == "__main__":
    demo.launch(theme=gr.themes.Base(), css=css)