File size: 15,238 Bytes
8df5700
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
#!/usr/bin/env python3
"""
================================================================================
Priority 3: Arabic Diacritization + Algerian Preprocessing Pipeline
================================================================================

Algerian Arabic (Darija) preprocessing is critical for TTS quality:
1. Text is often undiacritized → phonetic ambiguity
2. Heavy code-switching with French
3. Numerals need normalization
4. Mixed Arabic/Latin script usage

This pipeline provides:
1. Arabic diacritization using Sadeed (SOTA, April 2025)
2. Numeral normalization (Eastern ٠١٢ and Western 012 → Arabic words)
3. Basic French/Arabic code-switching handling
4. Text caching for repeated phrases
5. Sentence-level chunking for streaming

Dependencies:
    pip install transformers torch pyarabic num2words

Usage:
    python 03_arabic_preprocessing.py \
        --input "مرحبا كيف حالك 123" \
        --diacritize \
        --normalize_numerals

    python 03_arabic_preprocessing.py \
        --input_file text.txt \
        --output_file processed.txt \
        --diacritize \
        --normalize_numerals \
        --chunk_for_streaming

================================================================================
"""

import argparse
import hashlib
import json
import os
import re
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple

import numpy as np

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

CACHE_DIR = Path.home() / ".cache" / "habibi_tts_preprocess"
CACHE_DIR.mkdir(parents=True, exist_ok=True)

# Arabic numeral mappings
ARABIC_EASTERN_NUMERALS = "٠١٢٣٤٥٦٧٨٩"
ARABIC_WESTERN_NUMERALS = "0123456789"

# Simple Arabic word numbers (for numeral normalization)
ARABIC_NUMBERS = {
    "0": "صفر", "1": "واحد", "2": "اثنان", "3": "ثلاثة",
    "4": "أربعة", "5": "خمسة", "6": "ستة", "7": "سبعة",
    "8": "ثمانية", "9": "تسعة", "10": "عشرة",
    "11": "أحد عشر", "12": "اثنا عشر", "13": "ثلاثة عشر",
    "14": "أربعة عشر", "15": "خمسة عشر", "16": "ستة عشر",
    "17": "سبعة عشر", "18": "ثمانية عشر", "19": "تسعة عشر",
    "20": "عشرون", "30": "ثلاثون", "40": "أربعون",
    "50": "خمسون", "60": "ستون", "70": "سبعون",
    "80": "ثمانون", "90": "تسعون", "100": "مائة",
    "1000": "ألف", "1000000": "مليون",
}

# French words commonly mixed in Algerian Arabic
FRENCH_COMMON_WORDS = {
    "bonjour": "صباح الخير", "merci": "شكرا", "s'il vous plait": "من فضلك",
    "excusez-moi": "عذرا", "oui": "نعم", "non": "لا",
    "bon": "جيد", "très": "جدا", "beaucoup": "كثيرا",
    "comment": "كيف", "ça va": "كيف الحال", "au revoir": "مع السلامة",
    "bonsoir": "مساء الخير", "bonne nuit": "تصبح على خير",
    "pardon": "عذرا", "d'accord": "حسنا", "ok": "حسنا",
}


# ---------------------------------------------------------------------------
# Caching
# ---------------------------------------------------------------------------

class TextCache:
    """Simple file-based cache for preprocessed text."""

    def __init__(self, cache_dir: str = str(CACHE_DIR)):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        self.cache_file = self.cache_dir / "text_cache.json"
        self.cache = {}
        self._load()

    def _load(self):
        if self.cache_file.exists():
            try:
                with open(self.cache_file, "r", encoding="utf-8") as f:
                    self.cache = json.load(f)
            except Exception:
                self.cache = {}

    def _save(self):
        with open(self.cache_file, "w", encoding="utf-8") as f:
            json.dump(self.cache, f, ensure_ascii=False, indent=2)

    def get(self, key: str) -> Optional[str]:
        return self.cache.get(key)

    def set(self, key: str, value: str):
        self.cache[key] = value
        self._save()

    def clear(self):
        self.cache = {}
        self._save()


# ---------------------------------------------------------------------------
# Numeral Normalization
# ---------------------------------------------------------------------------


def normalize_numerals(text: str) -> str:
    """
    Convert numerals (both Eastern Arabic ٠١٢ and Western 012) to Arabic words.
    Handles numbers up to millions.
    """
    # First convert Eastern Arabic numerals to Western
    trans = str.maketrans(ARABIC_EASTERN_NUMERALS, ARABIC_WESTERN_NUMERALS)
    text = text.translate(trans)

    def number_to_arabic_words(num_str: str) -> str:
        """Convert a number string to Arabic words."""
        try:
            num = int(num_str)
        except ValueError:
            return num_str

        if num == 0:
            return ARABIC_NUMBERS["0"]

        if num in ARABIC_NUMBERS:
            return ARABIC_NUMBERS[str(num)]

        # Build number from components
        parts = []
        remaining = num

        millions = remaining // 1_000_000
        if millions > 0:
            parts.append(f"{number_to_arabic_words(str(millions))} مليون")
            remaining %= 1_000_000

        thousands = remaining // 1_000
        if thousands > 0:
            parts.append(f"{number_to_arabic_words(str(thousands))} ألف")
            remaining %= 1_000

        hundreds = remaining // 100
        if hundreds > 0:
            if hundreds == 1:
                parts.append("مائة")
            elif hundreds == 2:
                parts.append("مائتان")
            else:
                parts.append(f"{ARABIC_NUMBERS[str(hundreds)]} مائة")
            remaining %= 100

        if remaining > 0:
            if remaining in ARABIC_NUMBERS:
                parts.append(ARABIC_NUMBERS[str(remaining)])
            else:
                tens = (remaining // 10) * 10
                ones = remaining % 10
                if tens > 0:
                    parts.append(ARABIC_NUMBERS.get(str(tens), ""))
                if ones > 0:
                    parts.append(ARABIC_NUMBERS.get(str(ones), ""))

        return " و ".join(parts)

    # Replace numbers in text
    def replace_match(match):
        num_str = match.group(0)
        return number_to_arabic_words(num_str)

    # Match sequences of digits
    text = re.sub(r'\d+', replace_match, text)
    return text


# ---------------------------------------------------------------------------
# Diacritization (using Sadeed or fallback)
# ---------------------------------------------------------------------------


class ArabicDiacritizer:
    """
    Arabic text diacritization using Sadeed model (Misraj/Sadeed).
    Falls back to rule-based if model not available.
    """

    def __init__(self, model_name: str = "Misraj/Sadeed", device: str = "cpu"):
        self.model_name = model_name
        self.device = device
        self.pipeline = None
        self._load_model()

    def _load_model(self):
        """Load the diacritization model."""
        try:
            from transformers import pipeline
            print(f"[DIACRITIZE] Loading {self.model_name}...")
            self.pipeline = pipeline(
                "text2text-generation",
                model=self.model_name,
                device=0 if self.device == "cuda" else -1,
                torch_dtype="auto",
            )
            print("[DIACRITIZE] Model loaded successfully.")
        except Exception as e:
            print(f"[DIACRITIZE] Warning: Could not load model ({e}). Using fallback.")
            self.pipeline = None

    def diacritize(self, text: str) -> str:
        """Add diacritics (tashkeel) to Arabic text."""
        if not self.pipeline:
            return self._fallback_diacritize(text)

        try:
            result = self.pipeline(text, max_length=512, do_sample=False)
            return result[0]["generated_text"]
        except Exception as e:
            print(f"[DIACRITIZE] Error: {e}. Using fallback.")
            return self._fallback_diacritize(text)

    def _fallback_diacritize(self, text: str) -> str:
        """
        Simple rule-based fallback for diacritization.
        This is very basic and should be replaced with a proper model.
        """
        # Common short vowel patterns for Algerian Arabic
        # This is a placeholder - real diacritization requires a trained model
        return text


# ---------------------------------------------------------------------------
# Code-switching Handling
# ---------------------------------------------------------------------------


def handle_code_switching(text: str, translate_french: bool = False) -> str:
    """
    Handle French/Arabic code-switching in Algerian text.
    If translate_french=True, attempts to translate common French words.
    Otherwise, marks language boundaries.
    """
    if not translate_french:
        return text

    # Simple replacement of common French words
    text_lower = text.lower()
    for french, arabic in FRENCH_COMMON_WORDS.items():
        # Case-insensitive replacement
        pattern = re.compile(re.escape(french), re.IGNORECASE)
        text = pattern.sub(arabic, text)

    return text


# ---------------------------------------------------------------------------
# Sentence Chunking for Streaming
# ---------------------------------------------------------------------------


def chunk_for_streaming(text: str, max_chars: int = 135) -> List[str]:
    """
    Split text into sentence-level chunks for streaming TTS.
    Each chunk should be short enough for fast generation.
    """
    # Split on Arabic and Latin punctuation
    sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[؛:،。!?])", text)

    chunks = []
    current_chunk = ""

    for sentence in sentences:
        if not sentence.strip():
            continue
        # Check byte length (F5-TTS uses UTF-8 byte length for chunking)
        if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
            current_chunk += sentence + " " if sentence and sentence[-1].isascii() else sentence
        else:
            if current_chunk:
                chunks.append(current_chunk.strip())
            current_chunk = sentence + " " if sentence and sentence[-1].isascii() else sentence

    if current_chunk:
        chunks.append(current_chunk.strip())

    return chunks


# ---------------------------------------------------------------------------
# Main Preprocessing Pipeline
# ---------------------------------------------------------------------------


class AlgerianTTSPipeline:
    """Complete preprocessing pipeline for Algerian Arabic TTS."""

    def __init__(
        self,
        diacritize: bool = True,
        normalize_numerals: bool = True,
        handle_code_switch: bool = True,
        cache_enabled: bool = True,
        device: str = "cpu",
    ):
        self.diacritize = diacritize
        self.normalize_numerals = normalize_numerals
        self.handle_code_switch = handle_code_switch
        self.cache = TextCache() if cache_enabled else None
        self.diacritizer = ArabicDiacritizer(device=device) if diacritize else None

    def preprocess(self, text: str) -> str:
        """Run full preprocessing pipeline on text."""
        # Check cache
        if self.cache:
            cache_key = hashlib.md5(text.encode("utf-8")).hexdigest()
            cached = self.cache.get(cache_key)
            if cached:
                return cached

        result = text

        # Step 1: Normalize numerals
        if self.normalize_numerals:
            result = normalize_numerals(result)

        # Step 2: Handle code-switching
        if self.handle_code_switch:
            result = handle_code_switching(result, translate_french=True)

        # Step 3: Diacritize
        if self.diacritize and self.diacritizer:
            result = self.diacritizer.diacritize(result)

        # Cache result
        if self.cache:
            self.cache.set(cache_key, result)

        return result

    def preprocess_streaming(self, text: str, max_chars: int = 135) -> List[str]:
        """Preprocess and chunk text for streaming TTS."""
        processed = self.preprocess(text)
        return chunk_for_streaming(processed, max_chars=max_chars)


def main():
    parser = argparse.ArgumentParser(description="Algerian Arabic TTS Preprocessing Pipeline")
    parser.add_argument("--input", help="Input text string")
    parser.add_argument("--input_file", help="Input text file")
    parser.add_argument("--output_file", help="Output file for processed text")
    parser.add_argument("--diacritize", action="store_true", help="Add diacritics")
    parser.add_argument("--normalize_numerals", action="store_true", help="Convert numerals to words")
    parser.add_argument("--handle_code_switch", action="store_true", help="Handle French/Arabic mixing")
    parser.add_argument("--chunk_for_streaming", action="store_true", help="Split into streaming chunks")
    parser.add_argument("--max_chars", type=int, default=135, help="Max chars per chunk")
    parser.add_argument("--device", default="cpu", help="Device for diacritization model")
    parser.add_argument("--clear_cache", action="store_true", help="Clear text cache")
    args = parser.parse_args()

    if args.clear_cache:
        cache = TextCache()
        cache.clear()
        print("[CACHE] Cleared.")
        return

    # Get input text
    if args.input:
        text = args.input
    elif args.input_file:
        with open(args.input_file, "r", encoding="utf-8") as f:
            text = f.read()
    else:
        # Demo text
        text = "مرحبا، كيف حالك اليوم؟ أنا بخير شكرا. الساعة 3:30 والطقس جميل."
        print(f"[DEMO] Using demo text: {text}")

    # Initialize pipeline
    pipeline = AlgerianTTSPipeline(
        diacritize=args.diacritize,
        normalize_numerals=args.normalize_numerals,
        handle_code_switch=args.handle_code_switch,
        device=args.device,
    )

    # Process
    t0 = time.time()
    if args.chunk_for_streaming:
        result = pipeline.preprocess_streaming(text, max_chars=args.max_chars)
        print(f"\n[RESULT] Processed into {len(result)} chunks:")
        for i, chunk in enumerate(result):
            print(f"  Chunk {i+1}: {chunk}")
    else:
        result = pipeline.preprocess(text)
        print(f"\n[RESULT] Processed text:")
        print(f"  Input:  {text}")
        print(f"  Output: {result}")

    t1 = time.time()
    print(f"\n[TIME] Processing took {t1-t0:.3f}s")

    # Save output
    if args.output_file:
        with open(args.output_file, "w", encoding="utf-8") as f:
            if isinstance(result, list):
                for chunk in result:
                    f.write(chunk + "\n")
            else:
                f.write(result)
        print(f"[SAVE] Saved to {args.output_file}")


if __name__ == "__main__":
    main()