Habibi-TTS-ALG-Prod / scripts /03_arabic_preprocessing.py
medyas's picture
Upload scripts/03_arabic_preprocessing.py with huggingface_hub
8df5700 verified
#!/usr/bin/env python3
"""
================================================================================
Priority 3: Arabic Diacritization + Algerian Preprocessing Pipeline
================================================================================
Algerian Arabic (Darija) preprocessing is critical for TTS quality:
1. Text is often undiacritized → phonetic ambiguity
2. Heavy code-switching with French
3. Numerals need normalization
4. Mixed Arabic/Latin script usage
This pipeline provides:
1. Arabic diacritization using Sadeed (SOTA, April 2025)
2. Numeral normalization (Eastern ٠١٢ and Western 012 → Arabic words)
3. Basic French/Arabic code-switching handling
4. Text caching for repeated phrases
5. Sentence-level chunking for streaming
Dependencies:
pip install transformers torch pyarabic num2words
Usage:
python 03_arabic_preprocessing.py \
--input "مرحبا كيف حالك 123" \
--diacritize \
--normalize_numerals
python 03_arabic_preprocessing.py \
--input_file text.txt \
--output_file processed.txt \
--diacritize \
--normalize_numerals \
--chunk_for_streaming
================================================================================
"""
import argparse
import hashlib
import json
import os
import re
import sys
import time
from pathlib import Path
from typing import List, Optional, Tuple
import numpy as np
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
CACHE_DIR = Path.home() / ".cache" / "habibi_tts_preprocess"
CACHE_DIR.mkdir(parents=True, exist_ok=True)
# Arabic numeral mappings
ARABIC_EASTERN_NUMERALS = "٠١٢٣٤٥٦٧٨٩"
ARABIC_WESTERN_NUMERALS = "0123456789"
# Simple Arabic word numbers (for numeral normalization)
ARABIC_NUMBERS = {
"0": "صفر", "1": "واحد", "2": "اثنان", "3": "ثلاثة",
"4": "أربعة", "5": "خمسة", "6": "ستة", "7": "سبعة",
"8": "ثمانية", "9": "تسعة", "10": "عشرة",
"11": "أحد عشر", "12": "اثنا عشر", "13": "ثلاثة عشر",
"14": "أربعة عشر", "15": "خمسة عشر", "16": "ستة عشر",
"17": "سبعة عشر", "18": "ثمانية عشر", "19": "تسعة عشر",
"20": "عشرون", "30": "ثلاثون", "40": "أربعون",
"50": "خمسون", "60": "ستون", "70": "سبعون",
"80": "ثمانون", "90": "تسعون", "100": "مائة",
"1000": "ألف", "1000000": "مليون",
}
# French words commonly mixed in Algerian Arabic
FRENCH_COMMON_WORDS = {
"bonjour": "صباح الخير", "merci": "شكرا", "s'il vous plait": "من فضلك",
"excusez-moi": "عذرا", "oui": "نعم", "non": "لا",
"bon": "جيد", "très": "جدا", "beaucoup": "كثيرا",
"comment": "كيف", "ça va": "كيف الحال", "au revoir": "مع السلامة",
"bonsoir": "مساء الخير", "bonne nuit": "تصبح على خير",
"pardon": "عذرا", "d'accord": "حسنا", "ok": "حسنا",
}
# ---------------------------------------------------------------------------
# Caching
# ---------------------------------------------------------------------------
class TextCache:
"""Simple file-based cache for preprocessed text."""
def __init__(self, cache_dir: str = str(CACHE_DIR)):
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.cache_file = self.cache_dir / "text_cache.json"
self.cache = {}
self._load()
def _load(self):
if self.cache_file.exists():
try:
with open(self.cache_file, "r", encoding="utf-8") as f:
self.cache = json.load(f)
except Exception:
self.cache = {}
def _save(self):
with open(self.cache_file, "w", encoding="utf-8") as f:
json.dump(self.cache, f, ensure_ascii=False, indent=2)
def get(self, key: str) -> Optional[str]:
return self.cache.get(key)
def set(self, key: str, value: str):
self.cache[key] = value
self._save()
def clear(self):
self.cache = {}
self._save()
# ---------------------------------------------------------------------------
# Numeral Normalization
# ---------------------------------------------------------------------------
def normalize_numerals(text: str) -> str:
"""
Convert numerals (both Eastern Arabic ٠١٢ and Western 012) to Arabic words.
Handles numbers up to millions.
"""
# First convert Eastern Arabic numerals to Western
trans = str.maketrans(ARABIC_EASTERN_NUMERALS, ARABIC_WESTERN_NUMERALS)
text = text.translate(trans)
def number_to_arabic_words(num_str: str) -> str:
"""Convert a number string to Arabic words."""
try:
num = int(num_str)
except ValueError:
return num_str
if num == 0:
return ARABIC_NUMBERS["0"]
if num in ARABIC_NUMBERS:
return ARABIC_NUMBERS[str(num)]
# Build number from components
parts = []
remaining = num
millions = remaining // 1_000_000
if millions > 0:
parts.append(f"{number_to_arabic_words(str(millions))} مليون")
remaining %= 1_000_000
thousands = remaining // 1_000
if thousands > 0:
parts.append(f"{number_to_arabic_words(str(thousands))} ألف")
remaining %= 1_000
hundreds = remaining // 100
if hundreds > 0:
if hundreds == 1:
parts.append("مائة")
elif hundreds == 2:
parts.append("مائتان")
else:
parts.append(f"{ARABIC_NUMBERS[str(hundreds)]} مائة")
remaining %= 100
if remaining > 0:
if remaining in ARABIC_NUMBERS:
parts.append(ARABIC_NUMBERS[str(remaining)])
else:
tens = (remaining // 10) * 10
ones = remaining % 10
if tens > 0:
parts.append(ARABIC_NUMBERS.get(str(tens), ""))
if ones > 0:
parts.append(ARABIC_NUMBERS.get(str(ones), ""))
return " و ".join(parts)
# Replace numbers in text
def replace_match(match):
num_str = match.group(0)
return number_to_arabic_words(num_str)
# Match sequences of digits
text = re.sub(r'\d+', replace_match, text)
return text
# ---------------------------------------------------------------------------
# Diacritization (using Sadeed or fallback)
# ---------------------------------------------------------------------------
class ArabicDiacritizer:
"""
Arabic text diacritization using Sadeed model (Misraj/Sadeed).
Falls back to rule-based if model not available.
"""
def __init__(self, model_name: str = "Misraj/Sadeed", device: str = "cpu"):
self.model_name = model_name
self.device = device
self.pipeline = None
self._load_model()
def _load_model(self):
"""Load the diacritization model."""
try:
from transformers import pipeline
print(f"[DIACRITIZE] Loading {self.model_name}...")
self.pipeline = pipeline(
"text2text-generation",
model=self.model_name,
device=0 if self.device == "cuda" else -1,
torch_dtype="auto",
)
print("[DIACRITIZE] Model loaded successfully.")
except Exception as e:
print(f"[DIACRITIZE] Warning: Could not load model ({e}). Using fallback.")
self.pipeline = None
def diacritize(self, text: str) -> str:
"""Add diacritics (tashkeel) to Arabic text."""
if not self.pipeline:
return self._fallback_diacritize(text)
try:
result = self.pipeline(text, max_length=512, do_sample=False)
return result[0]["generated_text"]
except Exception as e:
print(f"[DIACRITIZE] Error: {e}. Using fallback.")
return self._fallback_diacritize(text)
def _fallback_diacritize(self, text: str) -> str:
"""
Simple rule-based fallback for diacritization.
This is very basic and should be replaced with a proper model.
"""
# Common short vowel patterns for Algerian Arabic
# This is a placeholder - real diacritization requires a trained model
return text
# ---------------------------------------------------------------------------
# Code-switching Handling
# ---------------------------------------------------------------------------
def handle_code_switching(text: str, translate_french: bool = False) -> str:
"""
Handle French/Arabic code-switching in Algerian text.
If translate_french=True, attempts to translate common French words.
Otherwise, marks language boundaries.
"""
if not translate_french:
return text
# Simple replacement of common French words
text_lower = text.lower()
for french, arabic in FRENCH_COMMON_WORDS.items():
# Case-insensitive replacement
pattern = re.compile(re.escape(french), re.IGNORECASE)
text = pattern.sub(arabic, text)
return text
# ---------------------------------------------------------------------------
# Sentence Chunking for Streaming
# ---------------------------------------------------------------------------
def chunk_for_streaming(text: str, max_chars: int = 135) -> List[str]:
"""
Split text into sentence-level chunks for streaming TTS.
Each chunk should be short enough for fast generation.
"""
# Split on Arabic and Latin punctuation
sentences = re.split(r"(?<=[;:,.!?])\s+|(?<=[؛:،。!?])", text)
chunks = []
current_chunk = ""
for sentence in sentences:
if not sentence.strip():
continue
# Check byte length (F5-TTS uses UTF-8 byte length for chunking)
if len(current_chunk.encode("utf-8")) + len(sentence.encode("utf-8")) <= max_chars:
current_chunk += sentence + " " if sentence and sentence[-1].isascii() else sentence
else:
if current_chunk:
chunks.append(current_chunk.strip())
current_chunk = sentence + " " if sentence and sentence[-1].isascii() else sentence
if current_chunk:
chunks.append(current_chunk.strip())
return chunks
# ---------------------------------------------------------------------------
# Main Preprocessing Pipeline
# ---------------------------------------------------------------------------
class AlgerianTTSPipeline:
"""Complete preprocessing pipeline for Algerian Arabic TTS."""
def __init__(
self,
diacritize: bool = True,
normalize_numerals: bool = True,
handle_code_switch: bool = True,
cache_enabled: bool = True,
device: str = "cpu",
):
self.diacritize = diacritize
self.normalize_numerals = normalize_numerals
self.handle_code_switch = handle_code_switch
self.cache = TextCache() if cache_enabled else None
self.diacritizer = ArabicDiacritizer(device=device) if diacritize else None
def preprocess(self, text: str) -> str:
"""Run full preprocessing pipeline on text."""
# Check cache
if self.cache:
cache_key = hashlib.md5(text.encode("utf-8")).hexdigest()
cached = self.cache.get(cache_key)
if cached:
return cached
result = text
# Step 1: Normalize numerals
if self.normalize_numerals:
result = normalize_numerals(result)
# Step 2: Handle code-switching
if self.handle_code_switch:
result = handle_code_switching(result, translate_french=True)
# Step 3: Diacritize
if self.diacritize and self.diacritizer:
result = self.diacritizer.diacritize(result)
# Cache result
if self.cache:
self.cache.set(cache_key, result)
return result
def preprocess_streaming(self, text: str, max_chars: int = 135) -> List[str]:
"""Preprocess and chunk text for streaming TTS."""
processed = self.preprocess(text)
return chunk_for_streaming(processed, max_chars=max_chars)
def main():
parser = argparse.ArgumentParser(description="Algerian Arabic TTS Preprocessing Pipeline")
parser.add_argument("--input", help="Input text string")
parser.add_argument("--input_file", help="Input text file")
parser.add_argument("--output_file", help="Output file for processed text")
parser.add_argument("--diacritize", action="store_true", help="Add diacritics")
parser.add_argument("--normalize_numerals", action="store_true", help="Convert numerals to words")
parser.add_argument("--handle_code_switch", action="store_true", help="Handle French/Arabic mixing")
parser.add_argument("--chunk_for_streaming", action="store_true", help="Split into streaming chunks")
parser.add_argument("--max_chars", type=int, default=135, help="Max chars per chunk")
parser.add_argument("--device", default="cpu", help="Device for diacritization model")
parser.add_argument("--clear_cache", action="store_true", help="Clear text cache")
args = parser.parse_args()
if args.clear_cache:
cache = TextCache()
cache.clear()
print("[CACHE] Cleared.")
return
# Get input text
if args.input:
text = args.input
elif args.input_file:
with open(args.input_file, "r", encoding="utf-8") as f:
text = f.read()
else:
# Demo text
text = "مرحبا، كيف حالك اليوم؟ أنا بخير شكرا. الساعة 3:30 والطقس جميل."
print(f"[DEMO] Using demo text: {text}")
# Initialize pipeline
pipeline = AlgerianTTSPipeline(
diacritize=args.diacritize,
normalize_numerals=args.normalize_numerals,
handle_code_switch=args.handle_code_switch,
device=args.device,
)
# Process
t0 = time.time()
if args.chunk_for_streaming:
result = pipeline.preprocess_streaming(text, max_chars=args.max_chars)
print(f"\n[RESULT] Processed into {len(result)} chunks:")
for i, chunk in enumerate(result):
print(f" Chunk {i+1}: {chunk}")
else:
result = pipeline.preprocess(text)
print(f"\n[RESULT] Processed text:")
print(f" Input: {text}")
print(f" Output: {result}")
t1 = time.time()
print(f"\n[TIME] Processing took {t1-t0:.3f}s")
# Save output
if args.output_file:
with open(args.output_file, "w", encoding="utf-8") as f:
if isinstance(result, list):
for chunk in result:
f.write(chunk + "\n")
else:
f.write(result)
print(f"[SAVE] Saved to {args.output_file}")
if __name__ == "__main__":
main()